Compare commits

...

73 Commits

Author SHA1 Message Date
takatost
5d48406d64 feat: bump version to 0.3.29 (#1462) 2023-11-06 06:55:17 -06:00
takatost
2b2dbabc11 fix: prompt variables validate when using external data tools (#1465) 2023-11-06 06:31:41 -06:00
zxhlyh
13b64bc55a fix: refresh api-based-extension (#1464) 2023-11-06 20:29:41 +08:00
zxhlyh
279f099ba0 fix: chat style (#1463) 2023-11-06 20:11:55 +08:00
zxhlyh
32747641e4 feat: add api-based extension & external data tool & moderation (#1459) 2023-11-06 19:36:32 +08:00
Garfield Dai
db43ed6f41 feat: add api-based extension & external data tool & moderation backend (#1403)
Co-authored-by: takatost <takatost@gmail.com>
2023-11-06 19:36:16 +08:00
YiLi
7699621983 fix: Use correct typehint for return values (#1454)
Co-authored-by: lethe <lethe>
2023-11-06 04:50:51 -06:00
takatost
4dfbcd0b4e feat: support chatglm_turbo model #1443 (#1460) 2023-11-06 04:33:05 -06:00
crazywoola
a9ee18300e fix: service suggested api (#1452) 2023-11-04 19:59:14 +08:00
dependabot[bot]
b4861d2b5c chore(deps): bump word-wrap from 1.2.3 to 1.2.5 in /web (#1440)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-01 11:26:25 +08:00
dependabot[bot]
913f2b84a6 chore(deps-dev): bump postcss from 8.4.24 to 8.4.31 in /web (#1439)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-01 11:24:43 +08:00
dependabot[bot]
cc89933d8f chore(deps): bump crypto-js from 4.1.1 to 4.2.0 in /web (#1437)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-01 11:24:33 +08:00
dependabot[bot]
a14ea6582d chore(deps): bump semver from 5.7.1 to 5.7.2 in /web (#1436)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2023-11-01 11:24:24 +08:00
takatost
076f3289d2 feat: add spark v3.0 llm support (#1434) 2023-10-31 03:13:11 -05:00
crazywoola
518083dfe0 fix: metadata not saved (#1429) 2023-10-30 14:39:15 +08:00
crazywoola
2b366bb321 fix: delete app and setting modal is not wokring in firefox (#1427) 2023-10-29 14:22:05 +08:00
Hickays
292d4c077a fix: Add icons for apps in "Related apps list" (#1425) 2023-10-27 17:55:38 +08:00
zxhlyh
fc4c03640d fix: provider delete api key modal z-index (#1416) 2023-10-26 10:35:03 +08:00
Charlie.Wei
985253197f mermaid front-end rendering initialization exception handling logic o… (#1407) 2023-10-26 10:19:04 +08:00
Hickays
48b4249790 fix: workspace app avatar is abnormal (#1411) 2023-10-26 10:18:38 +08:00
takatost
fb64fcb271 feat: upgrade xinference-client to 0.5.4 (#1402) 2023-10-23 05:49:32 -05:00
takatost
41e452dcc5 fix: hex problem (#1395) 2023-10-22 04:15:54 -05:00
yangbo.zhou
d218c66e25 Added diagram picture file for docker-compose yaml file visualization. (#1374) 2023-10-22 09:55:31 +08:00
Panmuse
e173b1cb2a Update README_CN.md (#1390) 2023-10-21 20:41:26 -05:00
Panmuse
9b598db559 Update README.md (#1389) 2023-10-21 20:41:15 -05:00
takatost
e122d677ad fix: return wrong when init 0 quota in trial provider (#1394) 2023-10-21 14:02:38 -05:00
takatost
4c63cbf5b1 feat: adjust anthropic (#1387) 2023-10-20 02:27:46 -05:00
Charlie.Wei
288705fefd Chrome Dify Chatbot Plug-in (#1378)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
2023-10-19 07:54:43 -05:00
Joel
8c4ae98f3d feat: add advanced prompt doc link (#1363) 2023-10-19 17:52:30 +08:00
Joel
08aa367892 feat: add context missing warning (#1384)
Co-authored-by: StyleZhang <jasonapring2015@outlook.com>
2023-10-19 17:52:14 +08:00
Joel
ff527a0190 fix: not load dataset config (#1381) 2023-10-19 13:55:25 +08:00
zxhlyh
6e05f8ca93 fix: npm run start (#1380) 2023-10-19 11:38:03 +08:00
Joel
6309d070d1 feat: enchance prompt mode copywriting (#1379) 2023-10-19 11:19:34 +08:00
Garfield Dai
fe14130b3c refactor advanced prompt core. (#1350) 2023-10-18 20:02:52 +08:00
wayne.wang
52ebffa857 fix: app config zhipu chatglm_std model, but it still use chatglm_lit… (#1377)
Co-authored-by: wayne.wang <wayne.wang@beibei.com>
2023-10-18 05:07:36 -05:00
zxhlyh
d14f15863d fix: i18n runtime error (#1376) 2023-10-18 16:00:56 +08:00
takatost
7c9b585a47 feat: support weixin ernie-bot-4 and chat mode (#1375) 2023-10-18 02:35:24 -05:00
takatost
c039f4af83 fix: app model config detached in completion thread (#1366) 2023-10-17 08:18:08 -05:00
takatost
07285e5f8b feat: optimize completion model agent (#1364) 2023-10-17 06:54:59 -05:00
Chenglong.li
16d80ebab3 Fix milvus configuration error (#1362)
Signed-off-by: JackLCL <chenglong.li@zilliz.com>
2023-10-17 17:40:40 +08:00
zxhlyh
61e816f24c feat: logo (#1356) 2023-10-16 15:26:25 +08:00
takatost
2feb16d957 feat: bump version to 0.3.28 (#1349) 2023-10-14 11:49:56 -05:00
crazywoola
3043fbe73b remove the suggested api for completion app (#1347) 2023-10-14 10:05:33 -05:00
Hickays
9f99c3f55b fix: modal z-index (#1343) 2023-10-13 05:55:03 -05:00
Joel
a07a6d8c26 feat: switch to generation model set default stop word (#1341) 2023-10-13 16:47:22 +08:00
Garfield Dai
695841a3cf Feat/advanced prompt enhancement (#1340) 2023-10-13 16:47:01 +08:00
takatost
3efaa713da feat: use xinference client instead of xinference (#1339) 2023-10-13 02:46:09 -05:00
takatost
9822f687f7 fix: max tokens of OpenAI gpt-3.5-turbo-instruct to 4097 (#1338) 2023-10-13 02:07:07 -05:00
crazywoola
b9d83c04bc fix: modal z-index (#1337) 2023-10-13 14:58:53 +08:00
Charlie.Wei
298ad6782d Add Message Suggested Api (#1326)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
2023-10-13 14:07:32 +08:00
takatost
f4be2b8bcd fix: raise error in minimax stream generate (#1336) 2023-10-12 23:48:28 -05:00
crazywoola
e83e239faf fix: value.join is not a function in log list (#1332) 2023-10-13 11:34:24 +08:00
taokuizu
62bf7f0fc2 fix: new app with template display (#1322) 2023-10-13 10:18:33 +08:00
takatost
7dea485d57 feat: bump version to 0.3.27 (#1331) 2023-10-12 10:37:48 -05:00
zxhlyh
5b9858a8a3 feat: advanced prompt (#1330)
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: Gillian97 <jinling.sunshine@gmail.com>
2023-10-12 23:14:28 +08:00
Garfield Dai
42a5b3ec17 feat: advanced prompt backend (#1301)
Co-authored-by: takatost <takatost@gmail.com>
2023-10-12 10:13:10 -05:00
takatost
2d1cb076c6 fix: dataset segment not exist return agent response (#1329) 2023-10-12 04:40:20 -05:00
Jyong
289c93d081 Feat/improve document delete logic (#1325)
Co-authored-by: jyong <jyong@dify.ai>
2023-10-12 13:30:44 +08:00
takatost
c0fe706597 feat: adjust to only build the latest image when pushing a tag. (#1324) 2023-10-11 23:38:07 -05:00
takatost
9cba1c8bf4 fix: retriever_resource missing (#1317) 2023-10-11 14:37:11 -05:00
takatost
cbf095465c feat: remove llm client use (#1316) 2023-10-11 14:02:53 -05:00
KVOJJJin
c007dbdc13 Feat: add document of authorization (#1311) 2023-10-11 08:03:36 -05:00
takatost
ff493d017b fix: minimax tests (#1313) 2023-10-11 07:49:26 -05:00
Jyong
7f6ad9653e Fix/grpc gevent compatible (#1314)
Co-authored-by: jyong <jyong@dify.ai>
2023-10-11 20:48:35 +08:00
takatost
2851a9f04e feat: optimize minimax llm call (#1312) 2023-10-11 07:17:41 -05:00
takatost
c536f85b2e fix: compatibility issues with the tongyi model. (#1310) 2023-10-11 05:16:26 -05:00
takatost
b1352ff8b7 feat: using random sampling to check if it violates the review mechan… (#1308) 2023-10-11 04:11:20 -05:00
Jyong
cc63c8499f bump version to 0.3.26 (#1307)
Co-authored-by: jyong <jyong@dify.ai>
2023-10-11 16:11:24 +08:00
Jyong
f191b8b8d1 milvus docker compose env (#1306)
Co-authored-by: jyong <jyong@dify.ai>
2023-10-11 16:05:37 +08:00
Jyong
5003db987d milvus secure check fix (#1305)
Co-authored-by: jyong <jyong@dify.ai>
2023-10-11 13:11:06 +08:00
Jyong
07aab5e868 Feat/add milvus vector db (#1302)
Co-authored-by: jyong <jyong@dify.ai>
2023-10-10 21:56:24 +08:00
takatost
875dfbbf0e fix: openllm completion start with prompt, remove it (#1303) 2023-10-10 04:44:19 -05:00
Charlie.Wei
9e7efa45d4 document segmentApi Add get&update&delete operate (#1285)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
2023-10-10 13:27:06 +08:00
418 changed files with 16123 additions and 2852 deletions

View File

@@ -31,7 +31,7 @@ jobs:
with:
images: langgenius/dify-api
tags: |
type=raw,value=latest,enable={{is_default_branch}}
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=ref,event=branch
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
type=semver,pattern={{major}}.{{minor}}.{{patch}}

View File

@@ -31,7 +31,7 @@ jobs:
with:
images: langgenius/dify-web
tags: |
type=raw,value=latest,enable={{is_default_branch}}
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=ref,event=branch
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
type=semver,pattern={{major}}.{{minor}}.{{patch}}

View File

@@ -1,37 +0,0 @@
import os
import re
from zhon.hanzi import punctuation
def has_chinese_characters(text):
for char in text:
if '\u4e00' <= char <= '\u9fff' or char in punctuation:
return True
return False
def check_file_for_chinese_comments(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
for line_number, line in enumerate(file, start=1):
if has_chinese_characters(line):
print(f"Found Chinese characters in {file_path} on line {line_number}:")
print(line.strip())
return True
return False
def main():
has_chinese = False
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py',
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py',
'prompts.py']
for root, _, files in os.walk("."):
for file in files:
if file.endswith(".py") and file not in excluded_files:
file_path = os.path.join(root, file)
if check_file_for_chinese_comments(file_path):
has_chinese = True
if has_chinese:
raise Exception("Found Chinese characters in Python files. Please remove them.")
if __name__ == "__main__":
main()

View File

@@ -1,31 +0,0 @@
name: Check for Chinese comments
on:
push:
branches:
- 'main'
pull_request:
branches:
- main
jobs:
check-chinese-comments:
runs-on: ubuntu-latest
steps:
- name: Check out repository
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install zhon
- name: Run script to check for Chinese comments
run: |
python .github/workflows/check_no_chinese_comments.py

View File

@@ -37,7 +37,6 @@ https://github.com/langgenius/dify/assets/100913391/f6e658d5-31b3-4c16-a0af-9e19
We provide the following free resources for registered Dify cloud users (sign up at [dify.ai](https://dify.ai)):
* 600,000 free Claude model tokens to build Claude-powered apps
* 200 free OpenAI queries to build OpenAI-based apps

View File

@@ -36,7 +36,6 @@ https://github.com/langgenius/dify/assets/100913391/f6e658d5-31b3-4c16-a0af-9e19
我们为所有注册云端版的用户免费提供以下资源(登录 [dify.ai](https://cloud.dify.ai) 即可使用):
* 60 万 Tokens Claude 模型的消息调用额度,用于创建基于 Claude 模型的 AI 应用
* 200 次 OpenAI 模型的消息调用额度,用于创建基于 OpenAI 模型的 AI 应用
* 300 万 讯飞星火大模型 Token 的调用额度,用于创建基于讯飞星火大模型的 AI 应用
* 100 万 MiniMax Token 的调用额度,用于创建基于 MiniMax 模型的 AI 应用

View File

@@ -50,7 +50,7 @@ S3_REGION=your-region
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Vector database configuration, support: weaviate, qdrant
# Vector database configuration, support: weaviate, qdrant, milvus
VECTOR_STORE=weaviate
# Weaviate configuration
@@ -63,6 +63,13 @@ WEAVIATE_BATCH_SIZE=100
QDRANT_URL=http://localhost:6333
QDRANT_API_KEY=difyai123456
# Milvus configuration
MILVUS_HOST=127.0.0.1
MILVUS_PORT=19530
MILVUS_USER=root
MILVUS_PASSWORD=Milvus
MILVUS_SECURE=false
# Mail configuration, support: resend
MAIL_TYPE=
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>

View File

@@ -10,7 +10,7 @@
"request": "launch",
"module": "flask",
"env": {
"FLASK_APP": "api/app.py",
"FLASK_APP": "app.py",
"FLASK_DEBUG": "1",
"GEVENT_SUPPORT": "True"
},

View File

@@ -6,6 +6,9 @@ from werkzeug.exceptions import Unauthorized
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
from gevent import monkey
monkey.patch_all()
if os.environ.get("VECTOR_STORE") == 'milvus':
import grpc.experimental.gevent
grpc.experimental.gevent.init_gevent()
import logging
import json
@@ -16,7 +19,7 @@ from flask_cors import CORS
from core.model_providers.providers import hosted
from extensions import ext_celery, ext_sentry, ext_redis, ext_login, ext_migrate, \
ext_database, ext_storage, ext_mail, ext_stripe
ext_database, ext_storage, ext_mail, ext_stripe, ext_code_based_extension
from extensions.ext_database import db
from extensions.ext_login import login_manager
@@ -76,6 +79,7 @@ def create_app(test_config=None) -> Flask:
def initialize_extensions(app):
# Since the application instance is now created, pass it to each Flask
# extension instance to bind it to the Flask application instance (app)
ext_code_based_extension.init()
ext_database.init_app(app)
ext_migrate.init(app, db)
ext_redis.init_app(app)

View File

@@ -57,6 +57,7 @@ DEFAULTS = {
'CLEAN_DAY_SETTING': 30,
'UPLOAD_FILE_SIZE_LIMIT': 15,
'UPLOAD_FILE_BATCH_LIMIT': 5,
'OUTPUT_MODERATION_BUFFER_SIZE': 300
}
@@ -92,7 +93,7 @@ class Config:
self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL')
self.CURRENT_VERSION = "0.3.25"
self.CURRENT_VERSION = "0.3.29"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@@ -135,6 +136,14 @@ class Config:
self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
# milvus setting
self.MILVUS_HOST = get_env('MILVUS_HOST')
self.MILVUS_PORT = get_env('MILVUS_PORT')
self.MILVUS_USER = get_env('MILVUS_USER')
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
# cors settings
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
@@ -220,6 +229,9 @@ class Config:
self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
# moderation settings
self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE'))
class CloudEditionConfig(Config):

View File

@@ -31,6 +31,7 @@ model_templates = {
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
@@ -81,6 +82,7 @@ model_templates = {
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
@@ -137,10 +139,11 @@ demo_model_templates = {
},
opening_statement='',
suggested_questions=None,
pre_prompt="Please translate the following text into {{target_language}}:\n",
pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
@@ -169,6 +172,13 @@ demo_model_templates = {
'Italian',
]
}
},{
"paragraph": {
"label": "Query",
"variable": "query",
"required": True,
"default": ""
}
}
])
)
@@ -200,6 +210,7 @@ demo_model_templates = {
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 300,
"temperature": 0.8,
@@ -255,10 +266,11 @@ demo_model_templates = {
},
opening_statement='',
suggested_questions=None,
pre_prompt="请将以下文本翻译为{{target_language}}:\n",
pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
@@ -287,6 +299,13 @@ demo_model_templates = {
"意大利语",
]
}
},{
"paragraph": {
"label": "文本内容",
"variable": "query",
"required": True,
"default": ""
}
}
])
)
@@ -318,6 +337,7 @@ demo_model_templates = {
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 300,
"temperature": 0.8,

View File

@@ -6,10 +6,10 @@ bp = Blueprint('console', __name__, url_prefix='/console/api')
api = ExternalApi(bp)
# Import other controllers
from . import setup, version, apikey, admin
from . import extension, setup, version, apikey, admin
# Import app controllers
from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio
# Import auth controllers
from .auth import login, oauth, data_source_oauth, activate

View File

@@ -0,0 +1,25 @@
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
class AdvancedPromptTemplateList(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('app_mode', type=str, required=True, location='args')
parser.add_argument('model_mode', type=str, required=True, location='args')
parser.add_argument('has_context', type=str, required=False, default='true', location='args')
parser.add_argument('model_name', type=str, required=True, location='args')
args = parser.parse_args()
return AdvancedPromptTemplateService.get_prompt(args)
api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')

View File

@@ -12,35 +12,6 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
class IntroductionGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('prompt_template', type=str, required=True, location='json')
args = parser.parse_args()
account = current_user
try:
answer = LLMGenerator.generate_introduction(
account.current_tenant_id,
args['prompt_template']
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
return {'introduction': answer}
class RuleGenerateApi(Resource):
@setup_required
@login_required
@@ -72,5 +43,4 @@ class RuleGenerateApi(Resource):
return rules
api.add_resource(IntroductionGenerateApi, '/introduction-generate')
api.add_resource(RuleGenerateApi, '/rule-generate')

View File

@@ -295,8 +295,8 @@ class MessageSuggestedQuestionApi(Resource):
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
user=current_user,
message_id=message_id,
user=current_user,
check_enabled=False
)
except MessageNotExistsError:
@@ -329,7 +329,7 @@ class MessageApi(Resource):
message_id = str(message_id)
# get app info
app_model = _get_app(app_id, 'chat')
app_model = _get_app(app_id)
message = db.session.query(Message).filter(
Message.id == message_id,

View File

@@ -27,6 +27,7 @@ class AppParameterApi(InstalledAppResource):
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw
}
@marshal_with(parameters_fields)
@@ -42,7 +43,8 @@ class AppParameterApi(InstalledAppResource):
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
}

View File

@@ -0,0 +1,114 @@
from flask_restful import Resource, reqparse, marshal_with
from flask_login import current_user
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from libs.login import login_required
from models.api_based_extension import APIBasedExtension
from fields.api_based_extension_fields import api_based_extension_fields
from services.code_based_extension_service import CodeBasedExtensionService
from services.api_based_extension_service import APIBasedExtensionService
class CodeBasedExtensionAPI(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('module', type=str, required=True, location='args')
args = parser.parse_args()
return {
'module': args['module'],
'data': CodeBasedExtensionService.get_code_based_extension(args['module'])
}
class APIBasedExtensionAPI(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_fields)
def get(self):
tenant_id = current_user.current_tenant_id
return APIBasedExtensionService.get_all_by_tenant_id(tenant_id)
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_fields)
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('api_endpoint', type=str, required=True, location='json')
parser.add_argument('api_key', type=str, required=True, location='json')
args = parser.parse_args()
extension_data = APIBasedExtension(
tenant_id=current_user.current_tenant_id,
name=args['name'],
api_endpoint=args['api_endpoint'],
api_key=args['api_key']
)
return APIBasedExtensionService.save(extension_data)
class APIBasedExtensionDetailAPI(Resource):
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_fields)
def get(self, id):
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
return APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
@setup_required
@login_required
@account_initialization_required
@marshal_with(api_based_extension_fields)
def post(self, id):
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('api_endpoint', type=str, required=True, location='json')
parser.add_argument('api_key', type=str, required=True, location='json')
args = parser.parse_args()
extension_data_from_db.name = args['name']
extension_data_from_db.api_endpoint = args['api_endpoint']
if args['api_key'] != '[__HIDDEN__]':
extension_data_from_db.api_key = args['api_key']
return APIBasedExtensionService.save(extension_data_from_db)
@setup_required
@login_required
@account_initialization_required
def delete(self, id):
api_based_extension_id = str(id)
tenant_id = current_user.current_tenant_id
extension_data_from_db = APIBasedExtensionService.get_with_tenant_id(tenant_id, api_based_extension_id)
APIBasedExtensionService.delete(extension_data_from_db)
return {'result': 'success'}
api.add_resource(CodeBasedExtensionAPI, '/code-based-extension')
api.add_resource(APIBasedExtensionAPI, '/api-based-extension')
api.add_resource(APIBasedExtensionDetailAPI, '/api-based-extension/<uuid:id>')

View File

@@ -28,6 +28,7 @@ class AppParameterApi(AppApiResource):
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw
}
@marshal_with(parameters_fields)
@@ -42,7 +43,8 @@ class AppParameterApi(AppApiResource):
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
}

View File

@@ -183,4 +183,3 @@ api.add_resource(CompletionApi, '/completion-messages')
api.add_resource(CompletionStopApi, '/completion-messages/<string:task_id>/stop')
api.add_resource(ChatApi, '/chat-messages')
api.add_resource(ChatStopApi, '/chat-messages/<string:task_id>/stop')

View File

@@ -54,6 +54,7 @@ class ConversationDetailApi(AppApiResource):
raise NotFound("Conversation Not Exists.")
return {"result": "success"}, 204
class ConversationRenameApi(AppApiResource):
@marshal_with(simple_conversation_fields)

View File

@@ -10,6 +10,8 @@ from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource
from libs.helper import TimestampField, uuid_value
from services.message_service import MessageService
from extensions.ext_database import db
from models.model import Message, EndUser
class MessageListApi(AppApiResource):
@@ -96,5 +98,38 @@ class MessageFeedbackApi(AppApiResource):
return {'result': 'success'}
class MessageSuggestedApi(AppApiResource):
def get(self, app_model, end_user, message_id):
message_id = str(message_id)
if app_model.mode != 'chat':
raise NotChatAppError()
try:
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
).first()
if end_user is None and message.from_end_user_id is not None:
user = db.session.query(EndUser) \
.filter(
EndUser.tenant_id == app_model.tenant_id,
EndUser.id == message.from_end_user_id,
EndUser.type == 'service_api'
).first()
else:
user = end_user
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
user=user,
message_id=message_id,
check_enabled=False
)
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {'result': 'success', 'data': questions}
api.add_resource(MessageListApi, '/messages')
api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks')
api.add_resource(MessageSuggestedApi, '/messages/<uuid:message_id>/suggested')

View File

@@ -1,7 +1,6 @@
from flask_login import current_user
from flask_restful import reqparse, marshal
from werkzeug.exceptions import NotFound
from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import DatasetApiResource
@@ -9,8 +8,8 @@ from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestE
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from fields.segment_fields import segment_fields
from models.dataset import Dataset
from services.dataset_service import DocumentService, SegmentService
from models.dataset import Dataset, DocumentSegment
from services.dataset_service import DatasetService, DocumentService, SegmentService
class SegmentApi(DatasetApiResource):
@@ -24,6 +23,8 @@ class SegmentApi(DatasetApiResource):
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
@@ -55,5 +56,146 @@ class SegmentApi(DatasetApiResource):
'doc_form': document.doc_form
}, 200
def get(self, tenant_id, dataset_id, document_id):
"""Create single segment."""
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound('Document not found.')
# check embedding model setting
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
parser = reqparse.RequestParser()
parser.add_argument('status', type=str,
action='append', default=[], location='args')
parser.add_argument('keyword', type=str, default=None, location='args')
args = parser.parse_args()
status_list = args['status']
keyword = args['keyword']
query = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
)
if status_list:
query = query.filter(DocumentSegment.status.in_(status_list))
if keyword:
query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
total = query.count()
segments = query.order_by(DocumentSegment.position).all()
return {
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form,
'total': total
}, 200
class DatasetSegmentApi(DatasetApiResource):
def delete(self, tenant_id, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# check segment
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
SegmentService.delete_segment(segment, document, dataset)
return {'result': 'success'}, 200
def post(self, tenant_id, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
if dataset.indexing_technique == 'high_quality':
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
# validate args
parser = reqparse.RequestParser()
parser.add_argument('segments', type=dict, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args['segments'], document)
segment = SegmentService.update_segment(args['segments'], segment, document, dataset)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
api.add_resource(SegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
api.add_resource(DatasetSegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')

View File

@@ -27,6 +27,7 @@ class AppParameterApi(WebApiResource):
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw
}
@marshal_with(parameters_fields)
@@ -41,7 +42,8 @@ class AppParameterApi(WebApiResource):
'speech_to_text': app_model_config.speech_to_text_dict,
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
}

View File

@@ -139,7 +139,7 @@ class ChatStopApi(WebApiResource):
return {'result': 'success'}, 200
def compact_response(response: Union[dict | Generator]) -> Response:
def compact_response(response: Union[dict, Generator]) -> Response:
if isinstance(response, dict):
return Response(response=json.dumps(response), status=200, mimetype='application/json')
else:

View File

@@ -115,7 +115,7 @@ class MessageMoreLikeThisApi(WebApiResource):
streaming = args['response_mode'] == 'streaming'
try:
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming)
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app')
return compact_response(response)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")

View File

@@ -0,0 +1 @@
import core.moderation.base

View File

@@ -2,14 +2,18 @@ import json
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from pydantic import root_validator
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
@@ -24,6 +28,10 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
arbitrary_types_allowed = True
@root_validator
def validate_llm(cls, values: dict) -> dict:
return values
def should_use_agent(self, query: str):
"""
return should use agent
@@ -65,17 +73,57 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
return AgentFinish(return_values={"output": observation}, log=observation)
try:
agent_decision = super().plan(intermediate_steps, callbacks, **kwargs)
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
if isinstance(tool_inputs, dict) and 'query' in tool_inputs and 'chat_history' not in kwargs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
else:
agent_decision.return_values['output'] = ''
return agent_decision
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
def real_plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
)
ai_message = AIMessage(
content=result.content,
additional_kwargs={
'function_call': result.function_call
}
)
agent_decision = _parse_ai_message(ai_message)
return agent_decision
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
@@ -87,7 +135,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
model_instance: BaseLLM,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
@@ -96,11 +144,15 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
),
**kwargs: Any,
) -> BaseSingleActionAgent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
model_instance=model_instance,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)

View File

@@ -5,21 +5,40 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
_format_intermediate_steps
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage, HumanMessage, BaseMessage, \
get_buffer_string
from langchain.tools import BaseTool
from pydantic import root_validator
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
from core.chain.llm_chain import LLMChain
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_model_instance: BaseLLM = None
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator
def validate_llm(cls, values: dict) -> dict:
return values
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
model_instance: BaseLLM,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
@@ -28,12 +47,16 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
),
**kwargs: Any,
) -> BaseSingleActionAgent:
return super().from_llm_and_tools(
llm=llm,
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
model_instance=model_instance,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
extra_prompt_messages=extra_prompt_messages,
system_message=cls.get_system_message(),
**kwargs,
)
@@ -44,23 +67,26 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
:param query:
:return:
"""
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 40
original_max_tokens = self.model_instance.model_kwargs.max_tokens
self.model_instance.model_kwargs.max_tokens = 40
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
try:
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
callbacks=None
)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
function_call = predicted_message.additional_kwargs.get("function_call", {})
function_call = result.function_call
self.llm.max_tokens = original_max_tokens
self.model_instance.model_kwargs.max_tokens = original_max_tokens
return True if function_call else False
@@ -93,10 +119,19 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
)
agent_decision = _parse_ai_message(predicted_message)
ai_message = AIMessage(
content=result.content,
additional_kwargs={
'function_call': result.function_call
}
)
agent_decision = _parse_ai_message(ai_message)
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input
@@ -122,3 +157,142 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
except ValueError:
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages
system_message = None
human_message = None
should_summary_messages = []
for message in messages:
if isinstance(message, SystemMessage):
system_message = message
elif isinstance(message, HumanMessage):
human_message = message
else:
should_summary_messages.append(message)
if len(should_summary_messages) > 2:
ai_message = should_summary_messages[-2]
function_message = should_summary_messages[-1]
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
self.moving_summary_index = len(should_summary_messages)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
new_messages = [system_message, human_message]
if self.moving_summary_index == 0:
should_summary_messages.insert(0, human_message)
self.moving_summary_buffer = self.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
new_messages.append(AIMessage(content=self.moving_summary_buffer))
new_messages.append(ai_message)
new_messages.append(function_message)
return new_messages
def predict_new_summary(
self, messages: List[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix="Human",
ai_prefix="AI",
)
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if model_instance.model_provider.provider_name == 'azure_openai':
model = model_instance.base_model_name
model = model.replace("gpt-35", "gpt-3.5")
else:
model = model_instance.base_model_name
tiktoken_ = _import_tiktoken()
try:
encoding = tiktoken_.encoding_for_model(model)
except KeyError:
model = "cl100k_base"
encoding = tiktoken_.get_encoding(model)
if model.startswith("gpt-3.5-turbo"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
for m in messages:
message = _convert_message_to_dict(m)
num_tokens += tokens_per_message
for key, value in message.items():
if key == "function_call":
for f_key, f_value in value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if kwargs.get('functions'):
for function in kwargs.get('functions'):
num_tokens += len(encoding.encode('name'))
num_tokens += len(encoding.encode(function.get("name")))
num_tokens += len(encoding.encode('description'))
num_tokens += len(encoding.encode(function.get("description")))
parameters = function.get("parameters")
num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters:
num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode(parameters.get("title")))
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(parameters.get("type")))
if 'properties' in parameters:
num_tokens += len(encoding.encode('properties'))
for key, value in parameters.get('properties').items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == 'enum':
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if 'required' in parameters:
num_tokens += len(encoding.encode('required'))
for required_field in parameters['required']:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))
return num_tokens

View File

@@ -1,140 +0,0 @@
from typing import cast, List
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.openai import _convert_message_to_dict
from langchain.memory.summary import SummarizerMixin
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from pydantic import BaseModel
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
from core.model_providers.models.llm.base import BaseLLM
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel = None
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages
system_message = None
human_message = None
should_summary_messages = []
for message in messages:
if isinstance(message, SystemMessage):
system_message = message
elif isinstance(message, HumanMessage):
human_message = message
else:
should_summary_messages.append(message)
if len(should_summary_messages) > 2:
ai_message = should_summary_messages[-2]
function_message = should_summary_messages[-1]
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
self.moving_summary_index = len(should_summary_messages)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
new_messages = [system_message, human_message]
if self.moving_summary_index == 0:
should_summary_messages.insert(0, human_message)
summary_handler = SummarizerMixin(llm=self.summary_llm)
self.moving_summary_buffer = summary_handler.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
new_messages.append(AIMessage(content=self.moving_summary_buffer))
new_messages.append(ai_message)
new_messages.append(function_message)
return new_messages
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
llm = cast(ChatOpenAI, model_instance.client)
model, encoding = llm._get_encoding_model()
if model.startswith("gpt-3.5-turbo"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
for m in messages:
message = _convert_message_to_dict(m)
num_tokens += tokens_per_message
for key, value in message.items():
if key == "function_call":
for f_key, f_value in value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if kwargs.get('functions'):
for function in kwargs.get('functions'):
num_tokens += len(encoding.encode('name'))
num_tokens += len(encoding.encode(function.get("name")))
num_tokens += len(encoding.encode('description'))
num_tokens += len(encoding.encode(function.get("description")))
parameters = function.get("parameters")
num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters:
num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode(parameters.get("title")))
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(parameters.get("type")))
if 'properties' in parameters:
num_tokens += len(encoding.encode('properties'))
for key, value in parameters.get('properties').items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == 'enum':
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if 'required' in parameters:
num_tokens += len(encoding.encode('required'))
for required_field in parameters['required']:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))
return num_tokens

View File

@@ -1,107 +0,0 @@
from typing import List, Tuple, Any, Union, Sequence, Optional
from langchain.agents import BaseMultiActionAgent
from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \
_parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseMultiActionAgent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
extra_prompt_messages=extra_prompt_messages,
system_message=cls.get_system_message(),
**kwargs,
)
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 15
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
try:
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
function_call = predicted_message.additional_kwargs.get("function_call", {})
self.llm.max_tokens = original_max_tokens
return True if function_call else False
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
# summarize messages if rest_tokens < 0
try:
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision
@classmethod
def get_system_message(cls):
# get current time
return SystemMessage(content="You are a helpful AI assistant.\n"
"The current date or current time you know is wrong.\n"
"Respond directly if appropriate.")

View File

@@ -1,10 +1,9 @@
import re
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
from langchain import BasePromptTemplate
from langchain import BasePromptTemplate, PromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
@@ -12,6 +11,8 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.chain.llm_chain import LLMChain
from core.model_providers.models.entity.model_params import ModelMode
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
@@ -49,7 +50,6 @@ Action:
class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
model_instance: BaseLLM
dataset_tools: Sequence[BaseTool]
class Config:
@@ -93,12 +93,16 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
rst = tool.run(tool_input={'query': kwargs['input']})
return AgentFinish(return_values={"output": rst}, log=rst)
if intermediate_steps:
_, observation = intermediate_steps[-1]
return AgentFinish(return_values={"output": observation}, log=observation)
full_inputs = self.get_full_inputs(intermediate_steps, **kwargs)
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
raise new_exception
try:
@@ -108,6 +112,10 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
elif isinstance(tool_inputs, str):
agent_decision.tool_input = kwargs['input']
else:
agent_decision.return_values['output'] = ''
return agent_decision
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
@@ -142,10 +150,65 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def create_completion_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
Args:
tools: List of tools the agent will have access to, used to format the
prompt.
prefix: String to put before the list of tools.
input_variables: List of input variables the final prompt will expect.
Returns:
A PromptTemplate with the template assembled from the pieces here.
"""
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
Question: {input}
Thought: {agent_scratchpad}
"""
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> str:
agent_scratchpad = ""
for action, observation in intermediate_steps:
agent_scratchpad += action.log
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
if not isinstance(agent_scratchpad, str):
raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad:
llm_chain = cast(LLMChain, self.llm_chain)
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{agent_scratchpad}"
)
else:
return agent_scratchpad
else:
return agent_scratchpad
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
model_instance: BaseLLM,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
@@ -157,17 +220,36 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
if model_instance.model_mode == ModelMode.CHAT:
prompt = cls.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
else:
prompt = cls.create_completion_prompt(
tools,
prefix=prefix,
format_instructions=format_instructions,
input_variables=input_variables
)
llm_chain = LLMChain(
model_instance=model_instance,
prompt=prompt,
callback_manager=callback_manager,
output_parser=output_parser,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
dataset_tools=tools,
**kwargs,
)

View File

@@ -1,19 +1,21 @@
import re
from typing import List, Tuple, Any, Union, Sequence, Optional
from typing import List, Tuple, Any, Union, Sequence, Optional, cast
from langchain import BasePromptTemplate
from langchain import BasePromptTemplate, PromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.memory.summary import SummarizerMixin
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException, BaseMessage, \
get_buffer_string
from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.chain.llm_chain import LLMChain
from core.model_providers.models.entity.model_params import ModelMode
from core.model_providers.models.llm.base import BaseLLM
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
@@ -52,8 +54,7 @@ Action:
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel = None
model_instance: BaseLLM
summary_model_instance: BaseLLM = None
class Config:
"""Configuration for this pydantic object."""
@@ -95,14 +96,14 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
if prompts:
messages = prompts[0].to_messages()
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages)
if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
raise new_exception
try:
@@ -118,7 +119,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
"I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2 and self.summary_llm:
if len(intermediate_steps) >= 2 and self.summary_model_instance:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps]
@@ -130,11 +131,10 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
summary_handler = SummarizerMixin(llm=self.summary_llm)
if self.moving_summary_buffer and 'chat_history' in kwargs:
kwargs["chat_history"].pop()
self.moving_summary_buffer = summary_handler.predict_new_summary(
self.moving_summary_buffer = self.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
@@ -144,6 +144,18 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
def predict_new_summary(
self, messages: List[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix="Human",
ai_prefix="AI",
)
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
@classmethod
def create_prompt(
cls,
@@ -173,10 +185,65 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
]
return ChatPromptTemplate(input_variables=input_variables, messages=messages)
@classmethod
def create_completion_prompt(
cls,
tools: Sequence[BaseTool],
prefix: str = PREFIX,
format_instructions: str = FORMAT_INSTRUCTIONS,
input_variables: Optional[List[str]] = None,
) -> PromptTemplate:
"""Create prompt in the style of the zero shot agent.
Args:
tools: List of tools the agent will have access to, used to format the
prompt.
prefix: String to put before the list of tools.
input_variables: List of input variables the final prompt will expect.
Returns:
A PromptTemplate with the template assembled from the pieces here.
"""
suffix = """Begin! Reminder to ALWAYS respond with a valid json blob of a single action. Use tools if necessary. Respond directly if appropriate. Format is Action:```$JSON_BLOB```then Observation:.
Question: {input}
Thought: {agent_scratchpad}
"""
tool_strings = "\n".join([f"{tool.name}: {tool.description}" for tool in tools])
tool_names = ", ".join([tool.name for tool in tools])
format_instructions = format_instructions.format(tool_names=tool_names)
template = "\n\n".join([prefix, tool_strings, format_instructions, suffix])
if input_variables is None:
input_variables = ["input", "agent_scratchpad"]
return PromptTemplate(template=template, input_variables=input_variables)
def _construct_scratchpad(
self, intermediate_steps: List[Tuple[AgentAction, str]]
) -> str:
agent_scratchpad = ""
for action, observation in intermediate_steps:
agent_scratchpad += action.log
agent_scratchpad += f"\n{self.observation_prefix}{observation}\n{self.llm_prefix}"
if not isinstance(agent_scratchpad, str):
raise ValueError("agent_scratchpad should be of type string.")
if agent_scratchpad:
llm_chain = cast(LLMChain, self.llm_chain)
if llm_chain.model_instance.model_mode == ModelMode.CHAT:
return (
f"This was your previous work "
f"(but I haven't seen any of it! I only see what "
f"you return as final answer):\n{agent_scratchpad}"
)
else:
return agent_scratchpad
else:
return agent_scratchpad
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
model_instance: BaseLLM,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
@@ -188,16 +255,35 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
if model_instance.model_mode == ModelMode.CHAT:
prompt = cls.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
else:
prompt = cls.create_completion_prompt(
tools,
prefix=prefix,
format_instructions=format_instructions,
input_variables=input_variables,
)
llm_chain = LLMChain(
model_instance=model_instance,
prompt=prompt,
callback_manager=callback_manager,
output_parser=output_parser,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
**kwargs,
)

View File

@@ -10,7 +10,6 @@ from pydantic import BaseModel, Extra
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
@@ -27,7 +26,6 @@ class PlanningStrategy(str, enum.Enum):
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
MULTI_FUNCTION_CALL = 'multi_function_call'
class AgentConfiguration(BaseModel):
@@ -64,30 +62,18 @@ class AgentExecutor:
if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_llm=self.configuration.summary_model_instance.client
summary_model_instance=self.configuration.summary_model_instance
if self.configuration.summary_model_instance else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client
summary_model_instance=self.configuration.summary_model_instance
if self.configuration.summary_model_instance else None,
verbose=True
)
@@ -95,7 +81,6 @@ class AgentExecutor:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
verbose=True
@@ -104,7 +89,6 @@ class AgentExecutor:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
verbose=True

View File

@@ -1,13 +1,25 @@
import logging
from typing import Any, Dict, List, Union
import threading
import time
from typing import Any, Dict, List, Union, Optional
from flask import Flask, current_app
from langchain.callbacks.base import BaseCallbackHandler
from langchain.schema import LLMResult, BaseMessage
from pydantic import BaseModel
from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
from core.model_providers.models.llm.base import BaseLLM
from core.moderation.base import ModerationOutputsResult, ModerationAction
from core.moderation.factory import ModerationFactory
class ModerationRule(BaseModel):
type: str
config: Dict[str, Any]
class LLMCallbackHandler(BaseCallbackHandler):
@@ -20,6 +32,24 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.start_at = None
self.conversation_message_task = conversation_message_task
self.output_moderation_handler = None
self.init_output_moderation()
def init_output_moderation(self):
app_model_config = self.conversation_message_task.app_model_config
sensitive_word_avoidance_dict = app_model_config.sensitive_word_avoidance_dict
if sensitive_word_avoidance_dict and sensitive_word_avoidance_dict.get("enabled"):
self.output_moderation_handler = OutputModerationHandler(
tenant_id=self.conversation_message_task.tenant_id,
app_id=self.conversation_message_task.app.id,
rule=ModerationRule(
type=sensitive_word_avoidance_dict.get("type"),
config=sensitive_word_avoidance_dict.get("config")
),
on_message_replace_func=self.conversation_message_task.on_message_replace
)
@property
def always_verbose(self) -> bool:
"""Whether to call verbose callbacks even if verbose is False."""
@@ -59,10 +89,19 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.llm_message.prompt_tokens = self.model_instance.get_num_tokens([PromptMessage(content=prompts[0])])
def on_llm_end(self, response: LLMResult, **kwargs: Any) -> None:
if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(response.generations[0][0].text)
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()
self.llm_message.completion = self.output_moderation_handler.moderation_completion(
completion=response.generations[0][0].text,
public_event=True if self.conversation_message_task.streaming else False
)
else:
self.llm_message.completion = response.generations[0][0].text
if not self.conversation_message_task.streaming:
self.conversation_message_task.append_message_text(self.llm_message.completion)
if response.llm_output and 'token_usage' in response.llm_output:
if 'prompt_tokens' in response.llm_output['token_usage']:
self.llm_message.prompt_tokens = response.llm_output['token_usage']['prompt_tokens']
@@ -79,23 +118,161 @@ class LLMCallbackHandler(BaseCallbackHandler):
self.conversation_message_task.save_message(self.llm_message)
def on_llm_new_token(self, token: str, **kwargs: Any) -> None:
try:
self.conversation_message_task.append_message_text(token)
except ConversationTaskStoppedException as ex:
if self.output_moderation_handler and self.output_moderation_handler.should_direct_output():
# stop subscribe new token when output moderation should direct output
ex = ConversationTaskInterruptException()
self.on_llm_error(error=ex)
raise ex
self.llm_message.completion += token
try:
self.conversation_message_task.append_message_text(token)
self.llm_message.completion += token
if self.output_moderation_handler:
self.output_moderation_handler.append_new_token(token)
except ConversationTaskStoppedException as ex:
self.on_llm_error(error=ex)
raise ex
def on_llm_error(
self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any
) -> None:
"""Do nothing."""
if self.output_moderation_handler:
self.output_moderation_handler.stop_thread()
if isinstance(error, ConversationTaskStoppedException):
if self.conversation_message_task.streaming:
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)
self.conversation_message_task.save_message(llm_message=self.llm_message, by_stopped=True)
if isinstance(error, ConversationTaskInterruptException):
self.llm_message.completion = self.output_moderation_handler.get_final_output()
self.llm_message.completion_tokens = self.model_instance.get_num_tokens(
[PromptMessage(content=self.llm_message.completion)]
)
self.conversation_message_task.save_message(llm_message=self.llm_message)
else:
logging.debug("on_llm_error: %s", error)
class OutputModerationHandler(BaseModel):
DEFAULT_BUFFER_SIZE: int = 300
tenant_id: str
app_id: str
rule: ModerationRule
on_message_replace_func: Any
thread: Optional[threading.Thread] = None
thread_running: bool = True
buffer: str = ''
is_final_chunk: bool = False
final_output: Optional[str] = None
class Config:
arbitrary_types_allowed = True
def should_direct_output(self):
return self.final_output is not None
def get_final_output(self):
return self.final_output
def append_new_token(self, token: str):
self.buffer += token
if not self.thread:
self.thread = self.start_thread()
def moderation_completion(self, completion: str, public_event: bool = False) -> str:
self.buffer = completion
self.is_final_chunk = True
result = self.moderation(
tenant_id=self.tenant_id,
app_id=self.app_id,
moderation_buffer=completion
)
if not result or not result.flagged:
return completion
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
else:
final_output = result.text
if public_event:
self.on_message_replace_func(final_output)
return final_output
def start_thread(self) -> threading.Thread:
buffer_size = int(current_app.config.get('MODERATION_BUFFER_SIZE', self.DEFAULT_BUFFER_SIZE))
thread = threading.Thread(target=self.worker, kwargs={
'flask_app': current_app._get_current_object(),
'buffer_size': buffer_size if buffer_size > 0 else self.DEFAULT_BUFFER_SIZE
})
thread.start()
return thread
def stop_thread(self):
if self.thread and self.thread.is_alive():
self.thread_running = False
def worker(self, flask_app: Flask, buffer_size: int):
with flask_app.app_context():
current_length = 0
while self.thread_running:
moderation_buffer = self.buffer
buffer_length = len(moderation_buffer)
if not self.is_final_chunk:
chunk_length = buffer_length - current_length
if 0 <= chunk_length < buffer_size:
time.sleep(1)
continue
current_length = buffer_length
result = self.moderation(
tenant_id=self.tenant_id,
app_id=self.app_id,
moderation_buffer=moderation_buffer
)
if not result or not result.flagged:
continue
if result.action == ModerationAction.DIRECT_OUTPUT:
final_output = result.preset_response
self.final_output = final_output
else:
final_output = result.text + self.buffer[len(moderation_buffer):]
# trigger replace event
if self.thread_running:
self.on_message_replace_func(final_output)
if result.action == ModerationAction.DIRECT_OUTPUT:
break
def moderation(self, tenant_id: str, app_id: str, moderation_buffer: str) -> Optional[ModerationOutputsResult]:
try:
moderation_factory = ModerationFactory(
name=self.rule.type,
app_id=app_id,
tenant_id=tenant_id,
config=self.rule.config
)
result: ModerationOutputsResult = moderation_factory.moderation_for_outputs(moderation_buffer)
return result
except Exception as e:
logging.error("Moderation Output error: %s", e)
return None

View File

@@ -0,0 +1,36 @@
from typing import List, Dict, Any, Optional
from langchain import LLMChain as LCLLMChain
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import LLMResult, Generation
from langchain.schema.language_model import BaseLanguageModel
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
class LLMChain(LCLLMChain):
model_instance: BaseLLM
"""The language model instance to use."""
llm: BaseLanguageModel = FakeLLM(response="")
def generate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
messages = prompts[0].to_messages()
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
stop=stop
)
generations = [
[Generation(text=result.content)]
]
return LLMResult(generations=generations)

View File

@@ -1,92 +0,0 @@
import enum
import logging
from typing import List, Dict, Optional, Any
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from pydantic import BaseModel
from core.model_providers.error import LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.moderation import openai_moderation
class SensitiveWordAvoidanceRule(BaseModel):
class Type(enum.Enum):
MODERATION = "moderation"
KEYWORDS = "keywords"
type: Type
canned_response: str = 'Your content violates our usage policy. Please revise and try again.'
extra_params: dict = {}
class SensitiveWordAvoidanceChain(Chain):
input_key: str = "input" #: :meta private:
output_key: str = "output" #: :meta private:
model_instance: BaseLLM
sensitive_word_avoidance_rule: SensitiveWordAvoidanceRule
@property
def _chain_type(self) -> str:
return "sensitive_word_avoidance_chain"
@property
def input_keys(self) -> List[str]:
"""Expect input key.
:meta private:
"""
return [self.input_key]
@property
def output_keys(self) -> List[str]:
"""Return output key.
:meta private:
"""
return [self.output_key]
def _check_sensitive_word(self, text: str) -> bool:
for word in self.sensitive_word_avoidance_rule.extra_params.get('sensitive_words', []):
if word in text:
return False
return True
def _check_moderation(self, text: str) -> bool:
moderation_model_instance = ModelFactory.get_moderation_model(
tenant_id=self.model_instance.model_provider.provider.tenant_id,
model_provider_name='openai',
model_name=openai_moderation.DEFAULT_MODEL
)
try:
return moderation_model_instance.run(text=text)
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
def _call(
self,
inputs: Dict[str, Any],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> Dict[str, Any]:
text = inputs[self.input_key]
if self.sensitive_word_avoidance_rule.type == SensitiveWordAvoidanceRule.Type.KEYWORDS:
result = self._check_sensitive_word(text)
else:
result = self._check_moderation(text)
if not result:
raise SensitiveWordAvoidanceError(self.sensitive_word_avoidance_rule.canned_response)
return {self.output_key: text}
class SensitiveWordAvoidanceError(Exception):
def __init__(self, message):
super().__init__(message)
self.message = message

View File

@@ -1,14 +1,18 @@
import concurrent
import json
import logging
from typing import Optional, List, Union
from concurrent.futures import ThreadPoolExecutor
from typing import Optional, List, Union, Tuple
from flask import current_app, Flask
from requests.exceptions import ChunkedEncodingError
from core.agent.agent_executor import AgentExecuteResult, PlanningStrategy
from core.callback_handler.main_chain_gather_callback_handler import MainChainGatherCallbackHandler
from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.chain.sensitive_word_avoidance_chain import SensitiveWordAvoidanceError
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.external_data_tool.factory import ExternalDataToolFactory
from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
@@ -16,10 +20,11 @@ from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
from models.dataset import DocumentSegment, Dataset, Document
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompt_transform import PromptTransform
from models.model import App, AppModelConfig, Account, Conversation, EndUser
from core.moderation.base import ModerationException, ModerationAction
from core.moderation.factory import ModerationFactory
class Completion:
@@ -30,7 +35,7 @@ class Completion:
"""
errors: ProviderTokenNotInitError
"""
query = PromptBuilder.process_template(query)
query = PromptTemplateParser.remove_template_variables(query)
memory = None
if conversation:
@@ -78,26 +83,35 @@ class Completion:
)
try:
# parse sensitive_word_avoidance_chain
chain_callback = MainChainGatherCallbackHandler(conversation_message_task)
sensitive_word_avoidance_chain = orchestrator_rule_parser.to_sensitive_word_avoidance_chain(
final_model_instance, [chain_callback])
if sensitive_word_avoidance_chain:
try:
query = sensitive_word_avoidance_chain.run(query)
except SensitiveWordAvoidanceError as ex:
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
agent_execute_result=None,
conversation_message_task=conversation_message_task,
memory=memory,
fake_response=ex.message
)
return
try:
# process sensitive_word_avoidance
inputs, query = cls.moderation_for_inputs(app.id, app.tenant_id, app_model_config, inputs, query)
except ModerationException as e:
cls.run_final_llm(
model_instance=final_model_instance,
mode=app.mode,
app_model_config=app_model_config,
query=query,
inputs=inputs,
agent_execute_result=None,
conversation_message_task=conversation_message_task,
memory=memory,
fake_response=str(e)
)
return
# fill in variable inputs from external data tools if exists
external_data_tools = app_model_config.external_data_tools_list
if external_data_tools:
inputs = cls.fill_in_inputs_from_external_data_tools(
tenant_id=app.tenant_id,
app_id=app.id,
external_data_tools=external_data_tools,
inputs=inputs,
query=query
)
# get agent executor
agent_executor = orchestrator_rule_parser.to_agent_executor(
@@ -137,19 +151,110 @@ class Completion:
memory=memory,
fake_response=fake_response
)
except ConversationTaskStoppedException:
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
return
except ChunkedEncodingError as e:
# Interrupt by LLM (like OpenAI), handle it.
logging.warning(f'ChunkedEncodingError: {e}')
conversation_message_task.end()
return
@classmethod
def moderation_for_inputs(cls, app_id: str, tenant_id: str, app_model_config: AppModelConfig, inputs: dict, query: str):
if not app_model_config.sensitive_word_avoidance_dict['enabled']:
return inputs, query
type = app_model_config.sensitive_word_avoidance_dict['type']
moderation = ModerationFactory(type, app_id, tenant_id, app_model_config.sensitive_word_avoidance_dict['config'])
moderation_result = moderation.moderation_for_inputs(inputs, query)
if not moderation_result.flagged:
return inputs, query
if moderation_result.action == ModerationAction.DIRECT_OUTPUT:
raise ModerationException(moderation_result.preset_response)
elif moderation_result.action == ModerationAction.OVERRIDED:
inputs = moderation_result.inputs
query = moderation_result.query
return inputs, query
@classmethod
def fill_in_inputs_from_external_data_tools(cls, tenant_id: str, app_id: str, external_data_tools: list[dict],
inputs: dict, query: str) -> dict:
"""
Fill in variable inputs from external data tools if exists.
:param tenant_id: workspace id
:param app_id: app id
:param external_data_tools: external data tools configs
:param inputs: the inputs
:param query: the query
:return: the filled inputs
"""
# Group tools by type and config
grouped_tools = {}
for tool in external_data_tools:
if not tool.get("enabled"):
continue
tool_key = (tool.get("type"), json.dumps(tool.get("config"), sort_keys=True))
grouped_tools.setdefault(tool_key, []).append(tool)
results = {}
with ThreadPoolExecutor() as executor:
futures = {}
for tools in grouped_tools.values():
# Only query the first tool in each group
first_tool = tools[0]
future = executor.submit(
cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, first_tool,
inputs, query
)
for tool in tools:
futures[future] = tool
for future in concurrent.futures.as_completed(futures):
tool_key, result = future.result()
if tool_key in grouped_tools:
for tool in grouped_tools[tool_key]:
results[tool['variable']] = result
inputs.update(results)
return inputs
@classmethod
def query_external_data_tool(cls, flask_app: Flask, tenant_id: str, app_id: str, external_data_tool: dict,
inputs: dict, query: str) -> Tuple[Optional[str], Optional[str]]:
with flask_app.app_context():
tool_variable = external_data_tool.get("variable")
tool_type = external_data_tool.get("type")
tool_config = external_data_tool.get("config")
external_data_tool_factory = ExternalDataToolFactory(
name=tool_type,
tenant_id=tenant_id,
app_id=app_id,
variable=tool_variable,
config=tool_config
)
# query external data tool
result = external_data_tool_factory.query(
inputs=inputs,
query=query
)
tool_key = (external_data_tool.get("type"), json.dumps(external_data_tool.get("config"), sort_keys=True))
return tool_key, result
@classmethod
def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
if app.mode != 'completion':
return query
return inputs.get(app_model_config.dataset_query_variable, "")
@classmethod
@@ -159,15 +264,33 @@ class Completion:
conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
fake_response: Optional[str]):
prompt_transform = PromptTransform()
# get llm prompt
prompt_messages, stop_words = model_instance.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
)
if app_model_config.prompt_type == 'simple':
prompt_messages, stop_words = prompt_transform.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory,
model_instance=model_instance
)
else:
prompt_messages = prompt_transform.get_advanced_prompt(
app_mode=mode,
app_model_config=app_model_config,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory,
model_instance=model_instance
)
model_config = app_model_config.model_dict
completion_params = model_config.get("completion_params", {})
stop_words = completion_params.get("stop", [])
cls.recale_llm_max_tokens(
model_instance=model_instance,
@@ -176,7 +299,7 @@ class Completion:
response = model_instance.run(
messages=prompt_messages,
stop=stop_words,
stop=stop_words if stop_words else None,
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
fake_response=fake_response
)
@@ -227,15 +350,30 @@ class Completion:
if max_tokens is None:
max_tokens = 0
prompt_transform = PromptTransform()
prompt_messages = []
# get prompt without memory and context
prompt_messages, _ = model_instance.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=None,
memory=None
)
if app_model_config.prompt_type == 'simple':
prompt_messages, _ = prompt_transform.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=None,
memory=None,
model_instance=model_instance
)
else:
prompt_messages = prompt_transform.get_advanced_prompt(
app_mode=mode,
app_model_config=app_model_config,
inputs=inputs,
query=query,
context=None,
memory=None,
model_instance=model_instance
)
prompt_tokens = model_instance.get_num_tokens(prompt_messages)
rest_tokens = model_limited_tokens - max_tokens - prompt_tokens
@@ -266,52 +404,3 @@ class Completion:
model_kwargs = model_instance.get_model_kwargs()
model_kwargs.max_tokens = max_tokens
model_instance.set_model_kwargs(model_kwargs)
@classmethod
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
app_model_config: AppModelConfig, user: Account, streaming: bool):
final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
tenant_id=app.tenant_id,
model_config=app_model_config.model_dict,
streaming=streaming
)
# get llm prompt
old_prompt_messages, _ = final_model_instance.get_prompt(
mode='completion',
pre_prompt=pre_prompt,
inputs=message.inputs,
query=message.query,
context=None,
memory=None
)
original_completion = message.answer.strip()
prompt = MORE_LIKE_THIS_GENERATE_PROMPT
prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
prompt_messages = [PromptMessage(content=prompt)]
conversation_message_task = ConversationMessageTask(
task_id=task_id,
app=app,
app_model_config=app_model_config,
user=user,
inputs=message.inputs,
query=message.query,
is_override=True if message.override_model_configs else False,
streaming=streaming,
model_instance=final_model_instance
)
cls.recale_llm_max_tokens(
model_instance=final_model_instance,
prompt_messages=prompt_messages
)
final_model_instance.run(
messages=prompt_messages,
callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
)

View File

@@ -10,7 +10,7 @@ from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
from core.model_providers.models.llm.base import BaseLLM
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompt_template import PromptTemplateParser
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -74,10 +74,10 @@ class ConversationMessageTask:
if self.mode == 'chat':
introduction = self.app_model_config.opening_statement
if introduction:
prompt_template = JinjaPromptTemplate.from_template(template=introduction)
prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
prompt_template = PromptTemplateParser(template=introduction)
prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs}
try:
introduction = prompt_template.format(**prompt_inputs)
introduction = prompt_template.format(prompt_inputs)
except KeyError:
pass
@@ -150,12 +150,12 @@ class ConversationMessageTask:
message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER)
message_price_unit = self.model_instance.get_price_unit(MessageType.USER)
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER)
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
total_price = message_total_price + answer_total_price
@@ -163,7 +163,7 @@ class ConversationMessageTask:
self.message.message_tokens = message_tokens
self.message.message_unit_price = message_unit_price
self.message.message_price_unit = message_price_unit
self.message.answer = PromptBuilder.process_template(
self.message.answer = PromptTemplateParser.remove_template_variables(
llm_message.completion.strip()) if llm_message.completion else ''
self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price
@@ -226,15 +226,15 @@ class ConversationMessageTask:
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER)
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER)
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER)
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_total_price = loop_message_total_price + loop_answer_total_price
@@ -290,6 +290,10 @@ class ConversationMessageTask:
db.session.commit()
self.retriever_resource = resource
def on_message_replace(self, text: str):
if text is not None:
self._pub_handler.pub_message_replace(text)
def message_end(self):
self._pub_handler.pub_message_end(self.retriever_resource)
@@ -342,6 +346,24 @@ class PubHandler:
self.pub_end()
raise ConversationTaskStoppedException()
def pub_message_replace(self, text: str):
content = {
'event': 'message_replace',
'data': {
'task_id': self._task_id,
'message_id': str(self._message.id),
'text': text,
'mode': self._conversation.mode,
'conversation_id': str(self._conversation.id)
}
}
redis_client.publish(self._channel, json.dumps(content))
if self._is_stopped():
self.pub_end()
raise ConversationTaskStoppedException()
def pub_chain(self, message_chain: MessageChain):
if self._chain_pub:
content = {
@@ -443,3 +465,7 @@ class PubHandler:
class ConversationTaskStoppedException(Exception):
pass
class ConversationTaskInterruptException(Exception):
pass

View File

View File

@@ -0,0 +1,62 @@
import os
import requests
from models.api_based_extension import APIBasedExtensionPoint
class APIBasedExtensionRequestor:
timeout: (int, int) = (5, 60)
"""timeout for request connect and read"""
def __init__(self, api_endpoint: str, api_key: str) -> None:
self.api_endpoint = api_endpoint
self.api_key = api_key
def request(self, point: APIBasedExtensionPoint, params: dict) -> dict:
"""
Request the api.
:param point: the api point
:param params: the request params
:return: the response json
"""
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer {}".format(self.api_key)
}
url = self.api_endpoint
try:
# proxy support for security
proxies = None
if os.environ.get("API_BASED_EXTENSION_HTTP_PROXY") and os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"):
proxies = {
'http': os.environ.get("API_BASED_EXTENSION_HTTP_PROXY"),
'https': os.environ.get("API_BASED_EXTENSION_HTTPS_PROXY"),
}
response = requests.request(
method='POST',
url=url,
json={
'point': point.value,
'params': params
},
headers=headers,
timeout=self.timeout,
proxies=proxies
)
except requests.exceptions.Timeout:
raise ValueError("request timeout")
except requests.exceptions.ConnectionError:
raise ValueError("request connection error")
if response.status_code != 200:
raise ValueError("request error, status_code: {}, content: {}".format(
response.status_code,
response.text[:100]
))
return response.json()

View File

@@ -0,0 +1,111 @@
import enum
import importlib.util
import json
import logging
import os
from collections import OrderedDict
from typing import Any, Optional
from pydantic import BaseModel
class ExtensionModule(enum.Enum):
MODERATION = 'moderation'
EXTERNAL_DATA_TOOL = 'external_data_tool'
class ModuleExtension(BaseModel):
extension_class: Any
name: str
label: Optional[dict] = None
form_schema: Optional[list] = None
builtin: bool = True
position: Optional[int] = None
class Extensible:
module: ExtensionModule
name: str
tenant_id: str
config: Optional[dict] = None
def __init__(self, tenant_id: str, config: Optional[dict] = None) -> None:
self.tenant_id = tenant_id
self.config = config
@classmethod
def scan_extensions(cls):
extensions = {}
# get the path of the current class
current_path = os.path.abspath(cls.__module__.replace(".", os.path.sep) + '.py')
current_dir_path = os.path.dirname(current_path)
# traverse subdirectories
for subdir_name in os.listdir(current_dir_path):
if subdir_name.startswith('__'):
continue
subdir_path = os.path.join(current_dir_path, subdir_name)
extension_name = subdir_name
if os.path.isdir(subdir_path):
file_names = os.listdir(subdir_path)
# is builtin extension, builtin extension
# in the front-end page and business logic, there are special treatments.
builtin = False
position = None
if '__builtin__' in file_names:
builtin = True
builtin_file_path = os.path.join(subdir_path, '__builtin__')
if os.path.exists(builtin_file_path):
with open(builtin_file_path, 'r') as f:
position = int(f.read().strip())
if (extension_name + '.py') not in file_names:
logging.warning(f"Missing {extension_name}.py file in {subdir_path}, Skip.")
continue
# Dynamic loading {subdir_name}.py file and find the subclass of Extensible
py_path = os.path.join(subdir_path, extension_name + '.py')
spec = importlib.util.spec_from_file_location(extension_name, py_path)
mod = importlib.util.module_from_spec(spec)
spec.loader.exec_module(mod)
extension_class = None
for name, obj in vars(mod).items():
if isinstance(obj, type) and issubclass(obj, cls) and obj != cls:
extension_class = obj
break
if not extension_class:
logging.warning(f"Missing subclass of {cls.__name__} in {py_path}, Skip.")
continue
json_data = {}
if not builtin:
if 'schema.json' not in file_names:
logging.warning(f"Missing schema.json file in {subdir_path}, Skip.")
continue
json_path = os.path.join(subdir_path, 'schema.json')
json_data = {}
if os.path.exists(json_path):
with open(json_path, 'r') as f:
json_data = json.load(f)
extensions[extension_name] = ModuleExtension(
extension_class=extension_class,
name=extension_name,
label=json_data.get('label'),
form_schema=json_data.get('form_schema'),
builtin=builtin,
position=position
)
sorted_items = sorted(extensions.items(), key=lambda x: (x[1].position is None, x[1].position))
sorted_extensions = OrderedDict(sorted_items)
return sorted_extensions

View File

@@ -0,0 +1,47 @@
from core.extension.extensible import ModuleExtension, ExtensionModule
from core.external_data_tool.base import ExternalDataTool
from core.moderation.base import Moderation
class Extension:
__module_extensions: dict[str, dict[str, ModuleExtension]] = {}
module_classes = {
ExtensionModule.MODERATION: Moderation,
ExtensionModule.EXTERNAL_DATA_TOOL: ExternalDataTool
}
def init(self):
for module, module_class in self.module_classes.items():
self.__module_extensions[module.value] = module_class.scan_extensions()
def module_extensions(self, module: str) -> list[ModuleExtension]:
module_extensions = self.__module_extensions.get(module)
if not module_extensions:
raise ValueError(f"Extension Module {module} not found")
return list(module_extensions.values())
def module_extension(self, module: ExtensionModule, extension_name: str) -> ModuleExtension:
module_extensions = self.__module_extensions.get(module.value)
if not module_extensions:
raise ValueError(f"Extension Module {module} not found")
module_extension = module_extensions.get(extension_name)
if not module_extension:
raise ValueError(f"Extension {extension_name} not found")
return module_extension
def extension_class(self, module: ExtensionModule, extension_name: str) -> type:
module_extension = self.module_extension(module, extension_name)
return module_extension.extension_class
def validate_form_schema(self, module: ExtensionModule, extension_name: str, config: dict) -> None:
module_extension = self.module_extension(module, extension_name)
form_schema = module_extension.form_schema
# TODO validate form_schema

View File

View File

@@ -0,0 +1 @@
1

View File

@@ -0,0 +1,92 @@
from typing import Optional
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor
from core.external_data_tool.base import ExternalDataTool
from core.helper import encrypter
from extensions.ext_database import db
from models.api_based_extension import APIBasedExtension, APIBasedExtensionPoint
class ApiExternalDataTool(ExternalDataTool):
"""
The api external data tool.
"""
name: str = "api"
"""the unique name of external data tool"""
@classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
# own validation logic
api_based_extension_id = config.get("api_based_extension_id")
if not api_based_extension_id:
raise ValueError("api_based_extension_id is required")
# get api_based_extension
api_based_extension = db.session.query(APIBasedExtension).filter(
APIBasedExtension.tenant_id == tenant_id,
APIBasedExtension.id == api_based_extension_id
).first()
if not api_based_extension:
raise ValueError("api_based_extension_id is invalid")
def query(self, inputs: dict, query: Optional[str] = None) -> str:
"""
Query the external data tool.
:param inputs: user inputs
:param query: the query of chat app
:return: the tool query result
"""
# get params from config
api_based_extension_id = self.config.get("api_based_extension_id")
# get api_based_extension
api_based_extension = db.session.query(APIBasedExtension).filter(
APIBasedExtension.tenant_id == self.tenant_id,
APIBasedExtension.id == api_based_extension_id
).first()
if not api_based_extension:
raise ValueError("[External data tool] API query failed, variable: {}, "
"error: api_based_extension_id is invalid"
.format(self.config.get('variable')))
# decrypt api_key
api_key = encrypter.decrypt_token(
tenant_id=self.tenant_id,
token=api_based_extension.api_key
)
try:
# request api
requestor = APIBasedExtensionRequestor(
api_endpoint=api_based_extension.api_endpoint,
api_key=api_key
)
except Exception as e:
raise ValueError("[External data tool] API query failed, variable: {}, error: {}".format(
self.config.get('variable'),
e
))
response_json = requestor.request(point=APIBasedExtensionPoint.APP_EXTERNAL_DATA_TOOL_QUERY, params={
'app_id': self.app_id,
'tool_variable': self.variable,
'inputs': inputs,
'query': query
})
if 'result' not in response_json:
raise ValueError("[External data tool] API query failed, variable: {}, error: result not found in response"
.format(self.config.get('variable')))
return response_json['result']

View File

@@ -0,0 +1,45 @@
from abc import abstractmethod, ABC
from typing import Optional
from core.extension.extensible import Extensible, ExtensionModule
class ExternalDataTool(Extensible, ABC):
"""
The base class of external data tool.
"""
module: ExtensionModule = ExtensionModule.EXTERNAL_DATA_TOOL
app_id: str
"""the id of app"""
variable: str
"""the tool variable name of app tool"""
def __init__(self, tenant_id: str, app_id: str, variable: str, config: Optional[dict] = None) -> None:
super().__init__(tenant_id, config)
self.app_id = app_id
self.variable = variable
@classmethod
@abstractmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
raise NotImplementedError
@abstractmethod
def query(self, inputs: dict, query: Optional[str] = None) -> str:
"""
Query the external data tool.
:param inputs: user inputs
:param query: the query of chat app
:return: the tool query result
"""
raise NotImplementedError

View File

@@ -0,0 +1,40 @@
from typing import Optional
from core.extension.extensible import ExtensionModule
from extensions.ext_code_based_extension import code_based_extension
class ExternalDataToolFactory:
def __init__(self, name: str, tenant_id: str, app_id: str, variable: str, config: dict) -> None:
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
self.__extension_instance = extension_class(
tenant_id=tenant_id,
app_id=app_id,
variable=variable,
config=config
)
@classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param name: the name of external data tool
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
code_based_extension.validate_form_schema(ExtensionModule.EXTERNAL_DATA_TOOL, name, config)
extension_class = code_based_extension.extension_class(ExtensionModule.EXTERNAL_DATA_TOOL, name)
extension_class.validate_config(tenant_id, config)
def query(self, inputs: dict, query: Optional[str] = None) -> str:
"""
Query the external data tool.
:param inputs: user inputs
:param query: the query of chat app
:return: the tool query result
"""
return self.__extension_instance.query(inputs, query)

View File

@@ -10,9 +10,8 @@ from core.model_providers.models.entity.model_params import ModelKwargs
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
GENERATOR_QA_PROMPT
from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
class LLMGenerator:
@@ -44,78 +43,19 @@ class LLMGenerator:
return answer.strip()
@classmethod
def generate_conversation_summary(cls, tenant_id: str, messages):
max_tokens = 200
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=max_tokens
)
)
prompt = CONVERSATION_SUMMARY_PROMPT
prompt_with_empty_context = prompt.format(context='')
prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
max_context_token_length = model_instance.model_rules.max_tokens.max
max_context_token_length = max_context_token_length if max_context_token_length else 1500
rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
context = ''
for message in messages:
if not message.answer:
continue
if len(message.query) > 2000:
query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
else:
query = message.query
if len(message.answer) > 2000:
answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
else:
answer = message.answer
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
context += message_qa_text
if not context:
return '[message too long, no summary]'
prompt = prompt.format(context=context)
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
def generate_introduction(cls, tenant_id: str, pre_prompt: str):
prompt = INTRODUCTION_GENERATE_PROMPT
prompt = prompt.format(prompt=pre_prompt)
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
format_instructions = output_parser.get_format_instructions()
prompt = JinjaPromptTemplate(
template="{{histories}}\n{{format_instructions}}\nquestions:\n",
input_variables=["histories"],
partial_variables={"format_instructions": format_instructions}
prompt_template = PromptTemplateParser(
template="{{histories}}\n{{format_instructions}}\nquestions:\n"
)
_input = prompt.format_prompt(histories=histories)
prompt = prompt_template.format({
"histories": histories,
"format_instructions": format_instructions
})
try:
model_instance = ModelFactory.get_text_generation_model(
@@ -128,10 +68,10 @@ class LLMGenerator:
except ProviderTokenNotInitError:
return []
prompts = [PromptMessage(content=_input.to_string())]
prompt_messages = [PromptMessage(content=prompt)]
try:
output = model_instance.run(prompts)
output = model_instance.run(prompt_messages)
questions = output_parser.parse(output.content)
except LLMError:
questions = []
@@ -145,19 +85,21 @@ class LLMGenerator:
def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
output_parser = RuleConfigGeneratorOutputParser()
prompt = OutLinePromptTemplate(
template=output_parser.get_format_instructions(),
input_variables=["audiences", "hoping_to_solve"],
partial_variables={
"variable": '{variable}',
"lanA": '{lanA}',
"lanB": '{lanB}',
"topic": '{topic}'
},
validate_template=False
prompt_template = PromptTemplateParser(
template=output_parser.get_format_instructions()
)
_input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
prompt = prompt_template.format(
inputs={
"audiences": audiences,
"hoping_to_solve": hoping_to_solve,
"variable": "{{variable}}",
"lanA": "{{lanA}}",
"lanB": "{{lanB}}",
"topic": "{{topic}}"
},
remove_template_variables=False
)
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
@@ -167,10 +109,10 @@ class LLMGenerator:
)
)
prompts = [PromptMessage(content=_input.to_string())]
prompt_messages = [PromptMessage(content=prompt)]
try:
output = model_instance.run(prompts)
output = model_instance.run(prompt_messages)
rule_config = output_parser.parse(output.content)
except LLMError as e:
raise e

View File

@@ -1,4 +1,5 @@
import logging
import random
import openai
@@ -16,19 +17,20 @@ def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
length = 2000
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
max_text_chunks = 32
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
if len(text_chunks) == 0:
return True
for text_chunk in chunks:
try:
moderation_result = openai.Moderation.create(input=text_chunk,
api_key=hosted_model_providers.openai.api_key)
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
text_chunk = random.choice(text_chunks)
for result in moderation_result.results:
if result['flagged'] is True:
return False
try:
moderation_result = openai.Moderation.create(input=text_chunk,
api_key=hosted_model_providers.openai.api_key)
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
for result in moderation_result.results:
if result['flagged'] is True:
return False
return True

View File

@@ -0,0 +1,858 @@
"""Wrapper around the Milvus vector database."""
from __future__ import annotations
import logging
from typing import Any, Iterable, List, Optional, Tuple, Union, Sequence
from uuid import uuid4
import numpy as np
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
logger = logging.getLogger(__name__)
DEFAULT_MILVUS_CONNECTION = {
"host": "localhost",
"port": "19530",
"user": "",
"password": "",
"secure": False,
}
class Milvus(VectorStore):
"""Initialize wrapper around the milvus vector database.
In order to use this you need to have `pymilvus` installed and a
running Milvus
See the following documentation for how to run a Milvus instance:
https://milvus.io/docs/install_standalone-docker.md
If looking for a hosted Milvus, take a look at this documentation:
https://zilliz.com/cloud and make use of the Zilliz vectorstore found in
this project,
IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
Args:
embedding_function (Embeddings): Function used to embed the text.
collection_name (str): Which Milvus collection to use. Defaults to
"LangChainCollection".
connection_args (Optional[dict[str, any]]): The connection args used for
this class comes in the form of a dict.
consistency_level (str): The consistency level to use for a collection.
Defaults to "Session".
index_params (Optional[dict]): Which index params to use. Defaults to
HNSW/AUTOINDEX depending on service.
search_params (Optional[dict]): Which search params to use. Defaults to
default of index.
drop_old (Optional[bool]): Whether to drop the current collection. Defaults
to False.
The connection args used for this class comes in the form of a dict,
here are a few of the options:
address (str): The actual address of Milvus
instance. Example address: "localhost:19530"
uri (str): The uri of Milvus instance. Example uri:
"http://randomwebsite:19530",
"tcp:foobarsite:19530",
"https://ok.s3.south.com:19530".
host (str): The host of Milvus instance. Default at "localhost",
PyMilvus will fill in the default host if only port is provided.
port (str/int): The port of Milvus instance. Default at 19530, PyMilvus
will fill in the default port if only host is provided.
user (str): Use which user to connect to Milvus instance. If user and
password are provided, we will add related header in every RPC call.
password (str): Required when user is provided. The password
corresponding to the user.
secure (bool): Default is false. If set to true, tls will be enabled.
client_key_path (str): If use tls two-way authentication, need to
write the client.key path.
client_pem_path (str): If use tls two-way authentication, need to
write the client.pem path.
ca_pem_path (str): If use tls two-way authentication, need to write
the ca.pem path.
server_pem_path (str): If use tls one-way authentication, need to
write the server.pem path.
server_name (str): If use tls, need to write the common name.
Example:
.. code-block:: python
from langchain import Milvus
from langchain.embeddings import OpenAIEmbeddings
embedding = OpenAIEmbeddings()
# Connect to a milvus instance on localhost
milvus_store = Milvus(
embedding_function = Embeddings,
collection_name = "LangChainCollection",
drop_old = True,
)
Raises:
ValueError: If the pymilvus python package is not installed.
"""
def __init__(
self,
embedding_function: Embeddings,
collection_name: str = "LangChainCollection",
connection_args: Optional[dict[str, Any]] = None,
consistency_level: str = "Session",
index_params: Optional[dict] = None,
search_params: Optional[dict] = None,
drop_old: Optional[bool] = False,
):
"""Initialize the Milvus vector store."""
try:
from pymilvus import Collection, utility
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
# Default search params when one is not provided.
self.default_search_params = {
"IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
"IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}},
"IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
"HNSW": {"metric_type": "L2", "params": {"ef": 10}},
"RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}},
"RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}},
"RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}},
"IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}},
"ANNOY": {"metric_type": "L2", "params": {"search_k": 10}},
"AUTOINDEX": {"metric_type": "L2", "params": {}},
}
self.embedding_func = embedding_function
self.collection_name = collection_name
self.index_params = index_params
self.search_params = search_params
self.consistency_level = consistency_level
# In order for a collection to be compatible, pk needs to be auto'id and int
self._primary_field = "id"
# In order for compatibility, the text field will need to be called "text"
self._text_field = "page_content"
# In order for compatibility, the vector field needs to be called "vector"
self._vector_field = "vectors"
# In order for compatibility, the metadata field will need to be called "metadata"
self._metadata_field = "metadata"
self.fields: list[str] = []
# Create the connection to the server
if connection_args is None:
connection_args = DEFAULT_MILVUS_CONNECTION
self.alias = self._create_connection_alias(connection_args)
self.col: Optional[Collection] = None
# Grab the existing collection if it exists
if utility.has_collection(self.collection_name, using=self.alias):
self.col = Collection(
self.collection_name,
using=self.alias,
)
# If need to drop old, drop it
if drop_old and isinstance(self.col, Collection):
self.col.drop()
self.col = None
# Initialize the vector store
self._init()
@property
def embeddings(self) -> Embeddings:
return self.embedding_func
def _create_connection_alias(self, connection_args: dict) -> str:
"""Create the connection to the Milvus server."""
from pymilvus import MilvusException, connections
# Grab the connection arguments that are used for checking existing connection
host: str = connection_args.get("host", None)
port: Union[str, int] = connection_args.get("port", None)
address: str = connection_args.get("address", None)
uri: str = connection_args.get("uri", None)
user = connection_args.get("user", None)
# Order of use is host/port, uri, address
if host is not None and port is not None:
given_address = str(host) + ":" + str(port)
elif uri is not None:
given_address = uri.split("https://")[1]
elif address is not None:
given_address = address
else:
given_address = None
logger.debug("Missing standard address type for reuse atttempt")
# User defaults to empty string when getting connection info
if user is not None:
tmp_user = user
else:
tmp_user = ""
# If a valid address was given, then check if a connection exists
if given_address is not None:
for con in connections.list_connections():
addr = connections.get_connection_addr(con[0])
if (
con[1]
and ("address" in addr)
and (addr["address"] == given_address)
and ("user" in addr)
and (addr["user"] == tmp_user)
):
logger.debug("Using previous connection: %s", con[0])
return con[0]
# Generate a new connection if one doesn't exist
alias = uuid4().hex
try:
connections.connect(alias=alias, **connection_args)
logger.debug("Created new connection using: %s", alias)
return alias
except MilvusException as e:
logger.error("Failed to create new connection using: %s", alias)
raise e
def _init(
self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None
) -> None:
if embeddings is not None:
self._create_collection(embeddings, metadatas)
self._extract_fields()
self._create_index()
self._create_search_params()
self._load()
def _create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None
) -> None:
from pymilvus import (
Collection,
CollectionSchema,
DataType,
FieldSchema,
MilvusException,
)
from pymilvus.orm.types import infer_dtype_bydata
# Determine embedding dim
dim = len(embeddings[0])
fields = []
# Determine metadata schema
# if metadatas:
# # Create FieldSchema for each entry in metadata.
# for key, value in metadatas[0].items():
# # Infer the corresponding datatype of the metadata
# dtype = infer_dtype_bydata(value)
# # Datatype isn't compatible
# if dtype == DataType.UNKNOWN or dtype == DataType.NONE:
# logger.error(
# "Failure to create collection, unrecognized dtype for key: %s",
# key,
# )
# raise ValueError(f"Unrecognized datatype for {key}.")
# # Dataype is a string/varchar equivalent
# elif dtype == DataType.VARCHAR:
# fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535))
# else:
# fields.append(FieldSchema(key, dtype))
if metadatas:
fields.append(FieldSchema(self._metadata_field, DataType.JSON, max_length=65_535))
# Create the text field
fields.append(
FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535)
)
# Create the primary key field
fields.append(
FieldSchema(
self._primary_field, DataType.INT64, is_primary=True, auto_id=True
)
)
# Create the vector field, supports binary or float vectors
fields.append(
FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim)
)
# Create the schema for the collection
schema = CollectionSchema(fields)
# Create the collection
try:
self.col = Collection(
name=self.collection_name,
schema=schema,
consistency_level=self.consistency_level,
using=self.alias,
)
except MilvusException as e:
logger.error(
"Failed to create collection: %s error: %s", self.collection_name, e
)
raise e
def _extract_fields(self) -> None:
"""Grab the existing fields from the Collection"""
from pymilvus import Collection
if isinstance(self.col, Collection):
schema = self.col.schema
for x in schema.fields:
self.fields.append(x.name)
# Since primary field is auto-id, no need to track it
self.fields.remove(self._primary_field)
def _get_index(self) -> Optional[dict[str, Any]]:
"""Return the vector index information if it exists"""
from pymilvus import Collection
if isinstance(self.col, Collection):
for x in self.col.indexes:
if x.field_name == self._vector_field:
return x.to_dict()
return None
def _create_index(self) -> None:
"""Create a index on the collection"""
from pymilvus import Collection, MilvusException
if isinstance(self.col, Collection) and self._get_index() is None:
try:
# If no index params, use a default HNSW based one
if self.index_params is None:
self.index_params = {
"metric_type": "IP",
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
}
try:
self.col.create_index(
self._vector_field,
index_params=self.index_params,
using=self.alias,
)
# If default did not work, most likely on Zilliz Cloud
except MilvusException:
# Use AUTOINDEX based index
self.index_params = {
"metric_type": "L2",
"index_type": "AUTOINDEX",
"params": {},
}
self.col.create_index(
self._vector_field,
index_params=self.index_params,
using=self.alias,
)
logger.debug(
"Successfully created an index on collection: %s",
self.collection_name,
)
except MilvusException as e:
logger.error(
"Failed to create an index on collection: %s", self.collection_name
)
raise e
def _create_search_params(self) -> None:
"""Generate search params based on the current index type"""
from pymilvus import Collection
if isinstance(self.col, Collection) and self.search_params is None:
index = self._get_index()
if index is not None:
index_type: str = index["index_param"]["index_type"]
metric_type: str = index["index_param"]["metric_type"]
self.search_params = self.default_search_params[index_type]
self.search_params["metric_type"] = metric_type
def _load(self) -> None:
"""Load the collection if available."""
from pymilvus import Collection
if isinstance(self.col, Collection) and self._get_index() is not None:
self.col.load()
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
timeout: Optional[int] = None,
batch_size: int = 1000,
**kwargs: Any,
) -> List[str]:
"""Insert text data into Milvus.
Inserting data when the collection has not be made yet will result
in creating a new Collection. The data of the first entity decides
the schema of the new collection, the dim is extracted from the first
embedding and the columns are decided by the first metadata dict.
Metada keys will need to be present for all inserted values. At
the moment there is no None equivalent in Milvus.
Args:
texts (Iterable[str]): The texts to embed, it is assumed
that they all fit in memory.
metadatas (Optional[List[dict]]): Metadata dicts attached to each of
the texts. Defaults to None.
timeout (Optional[int]): Timeout for each batch insert. Defaults
to None.
batch_size (int, optional): Batch size to use for insertion.
Defaults to 1000.
Raises:
MilvusException: Failure to add texts
Returns:
List[str]: The resulting keys for each inserted element.
"""
from pymilvus import Collection, MilvusException
texts = list(texts)
try:
embeddings = self.embedding_func.embed_documents(texts)
except NotImplementedError:
embeddings = [self.embedding_func.embed_query(x) for x in texts]
if len(embeddings) == 0:
logger.debug("Nothing to insert, skipping.")
return []
# If the collection hasn't been initialized yet, perform all steps to do so
if not isinstance(self.col, Collection):
self._init(embeddings, metadatas)
# Dict to hold all insert columns
insert_dict: dict[str, list] = {
self._text_field: texts,
self._vector_field: embeddings,
}
# Collect the metadata into the insert dict.
# if metadatas is not None:
# for d in metadatas:
# for key, value in d.items():
# if key in self.fields:
# insert_dict.setdefault(key, []).append(value)
if metadatas is not None:
for d in metadatas:
insert_dict.setdefault(self._metadata_field, []).append(d)
# Total insert count
vectors: list = insert_dict[self._vector_field]
total_count = len(vectors)
pks: list[str] = []
assert isinstance(self.col, Collection)
for i in range(0, total_count, batch_size):
# Grab end index
end = min(i + batch_size, total_count)
# Convert dict to list of lists batch for insertion
insert_list = [insert_dict[x][i:end] for x in self.fields]
# Insert into the collection.
try:
res: Collection
res = self.col.insert(insert_list, timeout=timeout, **kwargs)
pks.extend(res.primary_keys)
except MilvusException as e:
logger.error(
"Failed to insert batch starting at entity: %s/%s", i, total_count
)
raise e
return pks
def similarity_search(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a similarity search against the query string.
Args:
query (str): The text to search.
k (int, optional): How many results to return. Defaults to 4.
param (dict, optional): The search params for the index type.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
res = self.similarity_search_with_score(
query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
return [doc for doc, _ in res]
def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a similarity search against the query string.
Args:
embedding (List[float]): The embedding vector to search.
k (int, optional): How many results to return. Defaults to 4.
param (dict, optional): The search params for the index type.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
res = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
return [doc for doc, _ in res]
def similarity_search_with_score(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Perform a search on a query string and return results with score.
For more information about the search parameters, take a look at the pymilvus
documentation found here:
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
Args:
query (str): The text being searched.
k (int, optional): The amount of results to return. Defaults to 4.
param (dict): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[float], List[Tuple[Document, any, any]]:
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
# Embed the query text.
embedding = self.embedding_func.embed_query(query)
res = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
return res
def _similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs and relevance scores in the range [0, 1].
0 is dissimilar, 1 is most similar.
Args:
query: input text
k: Number of Documents to return. Defaults to 4.
**kwargs: kwargs to be passed to similarity search. Should include:
score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs
Returns:
List of Tuples of (doc, similarity_score)
"""
return self.similarity_search_with_score(query, k, **kwargs)
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Perform a search on a query string and return results with score.
For more information about the search parameters, take a look at the pymilvus
documentation found here:
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
Args:
embedding (List[float]): The embedding vector being searched.
k (int, optional): The amount of results to return. Defaults to 4.
param (dict): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Tuple[Document, float]]: Result doc and score.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
if param is None:
param = self.search_params
# Determine result metadata fields.
output_fields = self.fields[:]
output_fields.remove(self._vector_field)
# Perform the search.
res = self.col.search(
data=[embedding],
anns_field=self._vector_field,
param=param,
limit=k,
expr=expr,
output_fields=output_fields,
timeout=timeout,
**kwargs,
)
# Organize results.
ret = []
for result in res[0]:
meta = {x: result.entity.get(x) for x in output_fields}
doc = Document(page_content=meta.pop(self._text_field), metadata=meta.get('metadata'))
pair = (doc, result.score)
ret.append(pair)
return ret
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a search and return results that are reordered by MMR.
Args:
query (str): The text being searched.
k (int, optional): How many results to give. Defaults to 4.
fetch_k (int, optional): Total results to select k from.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5
param (dict, optional): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
embedding = self.embedding_func.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding=embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
param=param,
expr=expr,
timeout=timeout,
**kwargs,
)
def max_marginal_relevance_search_by_vector(
self,
embedding: list[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a search and return results that are reordered by MMR.
Args:
embedding (str): The embedding vector being searched.
k (int, optional): How many results to give. Defaults to 4.
fetch_k (int, optional): Total results to select k from.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5
param (dict, optional): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
if param is None:
param = self.search_params
# Determine result metadata fields.
output_fields = self.fields[:]
output_fields.remove(self._vector_field)
# Perform the search.
res = self.col.search(
data=[embedding],
anns_field=self._vector_field,
param=param,
limit=fetch_k,
expr=expr,
output_fields=output_fields,
timeout=timeout,
**kwargs,
)
# Organize results.
ids = []
documents = []
scores = []
for result in res[0]:
meta = {x: result.entity.get(x) for x in output_fields}
doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
documents.append(doc)
scores.append(result.score)
ids.append(result.id)
vectors = self.col.query(
expr=f"{self._primary_field} in {ids}",
output_fields=[self._primary_field, self._vector_field],
timeout=timeout,
)
# Reorganize the results from query to match search order.
vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors}
ordered_result_embeddings = [vectors[x] for x in ids]
# Get the new order of results.
new_ordering = maximal_marginal_relevance(
np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
)
# Reorder the values and return.
ret = []
for x in new_ordering:
# Function can return -1 index
if x == -1:
break
else:
ret.append(documents[x])
return ret
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
collection_name: str = "LangChainCollection",
connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION,
consistency_level: str = "Session",
index_params: Optional[dict] = None,
search_params: Optional[dict] = None,
drop_old: bool = False,
batch_size: int = 100,
ids: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> Milvus:
"""Create a Milvus collection, indexes it with HNSW, and insert data.
Args:
texts (List[str]): Text data.
embedding (Embeddings): Embedding function.
metadatas (Optional[List[dict]]): Metadata for each text if it exists.
Defaults to None.
collection_name (str, optional): Collection name to use. Defaults to
"LangChainCollection".
connection_args (dict[str, Any], optional): Connection args to use. Defaults
to DEFAULT_MILVUS_CONNECTION.
consistency_level (str, optional): Which consistency level to use. Defaults
to "Session".
index_params (Optional[dict], optional): Which index_params to use. Defaults
to None.
search_params (Optional[dict], optional): Which search params to use.
Defaults to None.
drop_old (Optional[bool], optional): Whether to drop the collection with
that name if it exists. Defaults to False.
batch_size:
How many vectors upload per-request.
Default: 100
ids: Optional[Sequence[str]] = None,
Returns:
Milvus: Milvus Vector Store
"""
vector_db = cls(
embedding_function=embedding,
collection_name=collection_name,
connection_args=connection_args,
consistency_level=consistency_level,
index_params=index_params,
search_params=search_params,
drop_old=drop_old,
**kwargs,
)
vector_db.add_texts(texts=texts, metadatas=metadatas, batch_size=batch_size)
return vector_db

View File

@@ -9,30 +9,44 @@ from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.milvus_vector_store import MilvusVectorStore
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
from models.dataset import Dataset
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding
class MilvusConfig(BaseModel):
endpoint: str
host: str
port: int
user: str
password: str
secure: bool = False
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config MILVUS_ENDPOINT is required")
if not values['host']:
raise ValueError("config MILVUS_HOST is required")
if not values['port']:
raise ValueError("config MILVUS_PORT is required")
if not values['user']:
raise ValueError("config MILVUS_USER is required")
if not values['password']:
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
return {
'host': self.host,
'port': self.port,
'user': self.user,
'password': self.password,
'secure': self.secure
}
class MilvusVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
self._client_config = config
def get_type(self) -> str:
return 'milvus'
@@ -49,7 +63,6 @@ class MilvusVectorIndex(BaseVectorIndex):
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
@@ -58,26 +71,29 @@ class MilvusVectorIndex(BaseVectorIndex):
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
index_params = {
'metric_type': 'IP',
'index_type': "HNSW",
'params': {"M": 8, "efConstruction": 64}
}
self._vector_store = MilvusVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
collection_name=self.get_index_name(self.dataset),
connection_args=self._client_config.to_milvus_params(),
index_params=index_params
)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
self._vector_store = MilvusVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=collection_name,
uuids=uuids,
by_text=False
collection_name=collection_name,
ids=uuids,
content_payload_key='page_content'
)
return self
@@ -86,42 +102,53 @@ class MilvusVectorIndex(BaseVectorIndex):
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']
return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self.dataset),
text_key='text',
embedding=self._embeddings,
attributes=attributes,
by_text=False
return MilvusVectorStore(
collection_name=self.get_index_name(self.dataset),
embedding_function=self._embeddings,
connection_args=self._client_config.to_milvus_params()
)
def _get_vector_store_class(self) -> type:
return MilvusVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_document_id(document_id)
if ids:
vector_store.del_texts({
'filter': f'id in {ids}'
})
def delete_by_ids(self, doc_ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_doc_ids(doc_ids)
vector_store.del_texts({
'filter': f' id in {ids}'
})
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": ["document_id"],
"valueText": document_id
})
vector_store.delete()
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return False
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
],
))

View File

@@ -47,6 +47,20 @@ class VectorIndex:
),
embeddings=embeddings
)
elif vector_type == "milvus":
from core.index.vector_index.milvus_vector_index import MilvusVectorIndex, MilvusConfig
return MilvusVectorIndex(
dataset=dataset,
config=MilvusConfig(
host=config.get('MILVUS_HOST'),
port=config.get('MILVUS_PORT'),
user=config.get('MILVUS_USER'),
password=config.get('MILVUS_PASSWORD'),
secure=config.get('MILVUS_SECURE'),
),
embeddings=embeddings
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")

View File

@@ -11,6 +11,7 @@ from flask import current_app, Flask
from flask_login import current_user
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from sqlalchemy.orm.exc import ObjectDeletedError
from core.data_loader.file_extractor import FileExtractor
from core.data_loader.loader.notion import NotionLoader
@@ -79,6 +80,8 @@ class IndexingRunner:
dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except ObjectDeletedError:
logging.warning('Document deleted, document id: {}'.format(dataset_document.id))
except Exception as e:
logging.exception("consume document failed")
dataset_document.indexing_status = 'error'
@@ -276,13 +279,14 @@ class IndexingRunner:
)
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
doc_language)
document_qa_list = self.format_split_text(response)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
@@ -372,13 +376,14 @@ class IndexingRunner:
)
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
doc_language)
document_qa_list = self.format_split_text(response)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
@@ -582,7 +587,6 @@ class IndexingRunner:
all_qa_documents.extend(format_documents)
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule) -> List[Document]:
"""
@@ -734,6 +738,9 @@ class IndexingRunner:
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
if count > 0:
raise DocumentIsPausedException()
document = DatasetDocument.query.filter_by(id=document_id).first()
if not document:
raise DocumentIsDeletedPausedException()
update_params = {
DatasetDocument.indexing_status: after_indexing_status
@@ -781,3 +788,7 @@ class IndexingRunner:
class DocumentIsPausedException(Exception):
pass
class DocumentIsDeletedPausedException(Exception):
pass

View File

@@ -31,7 +31,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
chat_messages: List[PromptMessage] = []
for message in messages:
chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN))
chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
if not chat_messages:

View File

@@ -211,6 +211,9 @@ class ModelProviderFactory:
Provider.quota_type == ProviderQuotaType.TRIAL.value
).first()
if provider.quota_limit == 0:
return None
return provider
no_system_provider = True

View File

@@ -1,8 +1,7 @@
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings
class XinferenceEmbedding(BaseEmbedding):

View File

@@ -1,6 +1,6 @@
import enum
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
from pydantic import BaseModel
@@ -9,26 +9,31 @@ class LLMRunResult(BaseModel):
prompt_tokens: int
completion_tokens: int
source: list = None
function_call: dict = None
class MessageType(enum.Enum):
HUMAN = 'human'
USER = 'user'
ASSISTANT = 'assistant'
SYSTEM = 'system'
class PromptMessage(BaseModel):
type: MessageType = MessageType.HUMAN
type: MessageType = MessageType.USER
content: str = ''
function_call: dict = None
def to_lc_messages(messages: list[PromptMessage]):
lc_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
if message.type == MessageType.USER:
lc_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
lc_messages.append(AIMessage(content=message.content))
additional_kwargs = {}
if message.function_call:
additional_kwargs['function_call'] = message.function_call
lc_messages.append(AIMessage(content=message.content, additional_kwargs=additional_kwargs))
elif message.type == MessageType.SYSTEM:
lc_messages.append(SystemMessage(content=message.content))
@@ -39,11 +44,21 @@ def to_prompt_messages(messages: list[BaseMessage]):
prompt_messages = []
for message in messages:
if isinstance(message, HumanMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
elif isinstance(message, AIMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
message_kwargs = {
'content': message.content,
'type': MessageType.ASSISTANT
}
if 'function_call' in message.additional_kwargs:
message_kwargs['function_call'] = message.additional_kwargs['function_call']
prompt_messages.append(PromptMessage(**message_kwargs))
elif isinstance(message, SystemMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
elif isinstance(message, FunctionMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
return prompt_messages

View File

@@ -81,7 +81,20 @@ class AzureOpenAIModel(BaseLLM):
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
generate_kwargs = {
'stop': stop,
'callbacks': callbacks
}
if isinstance(prompts, str):
generate_kwargs['prompts'] = [prompts]
else:
generate_kwargs['messages'] = [prompts]
if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']
return self._client.generate(**generate_kwargs)
@property
def base_model_name(self) -> str:

View File

@@ -37,12 +37,6 @@ class BaichuanModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.

View File

@@ -1,27 +1,18 @@
import json
import os
import re
import time
from abc import abstractmethod
from typing import List, Optional, Any, Union, Tuple
from typing import List, Optional, Any, Union
import decimal
import logging
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage, BaseMessage, ChatGeneration
from langchain.schema import LLMResult, BaseMessage, ChatGeneration
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.helper import moderation
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_lc_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.third_party.langchain.llms.fake import FakeLLM
import logging
from extensions.ext_database import db
logger = logging.getLogger(__name__)
@@ -157,8 +148,11 @@ class BaseLLM(BaseProviderModel):
except Exception as ex:
raise self.handle_exceptions(ex)
function_call = None
if isinstance(result.generations[0][0], ChatGeneration):
completion_content = result.generations[0][0].message.content
if 'function_call' in result.generations[0][0].message.additional_kwargs:
function_call = result.generations[0][0].message.additional_kwargs.get('function_call')
else:
completion_content = result.generations[0][0].text
@@ -191,7 +185,8 @@ class BaseLLM(BaseProviderModel):
return LLMRunResult(
content=completion_content,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens
completion_tokens=completion_tokens,
function_call=function_call
)
@abstractmethod
@@ -227,7 +222,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return:
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
@@ -245,7 +240,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return: decimal.Decimal('0.0001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
@@ -260,7 +255,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return: decimal.Decimal('0.000001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
price_unit = self.price_config['unit']
else:
price_unit = self.price_config['unit']
@@ -315,121 +310,8 @@ class BaseLLM(BaseProviderModel):
def support_streaming(self):
return False
def get_prompt(self, mode: str,
pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> \
Tuple[List[PromptMessage], Optional[List[str]]]:
prompt_rules = self._read_prompt_rules_from_file(self.prompt_file_name(mode))
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
return [PromptMessage(content=prompt)], stops
def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'common_completion'
else:
return 'common_chat'
def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
context_prompt_content = ''
if context and 'context_prompt' in prompt_rules:
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
context_prompt_content = prompt_template.format(
context=context
)
pre_prompt_content = ''
if pre_prompt:
prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
pre_prompt_content = prompt_template.format(
**prompt_inputs
)
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += pre_prompt_content
query_prompt = prompt_rules['query_prompt'] if 'query_prompt' in prompt_rules else '{{query}}'
if memory and 'histories_prompt' in prompt_rules:
# append chat histories
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=prompt + query_prompt,
inputs={
'query': query
}
)
if self.model_rules.max_tokens.max:
curr_message_tokens = self.get_num_tokens(to_prompt_messages([tmp_human_message]))
max_tokens = self.model_kwargs.max_tokens
rest_tokens = self.model_rules.max_tokens.max - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
else:
rest_tokens = 2000
memory.human_prefix = prompt_rules['human_prefix'] if 'human_prefix' in prompt_rules else 'Human'
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
histories = self._get_history_messages_from_memory(memory, rest_tokens)
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
histories_prompt_content = prompt_template.format(
histories=histories
)
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += (pre_prompt_content + '\n') if pre_prompt_content else ''
elif order == 'histories_prompt':
prompt += histories_prompt_content
prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
query_prompt_content = prompt_template.format(
query=query
)
prompt += query_prompt_content
prompt = re.sub(r'<\|.*?\|>', '', prompt)
stops = prompt_rules.get('stops')
if stops is not None and len(stops) == 0:
stops = None
return prompt, stops
def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
# Get the absolute path of the subdirectory
prompt_path = os.path.join(
os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.realpath(__file__))))),
'prompt/generate_prompts')
json_file_path = os.path.join(prompt_path, f'{prompt_name}.json')
# Open the JSON file and read its content
with open(json_file_path, 'r') as json_file:
return json.load(json_file)
def _get_history_messages_from_memory(self, memory: BaseChatMemory,
max_token_limit: int) -> str:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory_key = memory.memory_variables[0]
external_context = memory.load_memory_variables({})
return external_context[memory_key]
def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
model_mode: Optional[ModelMode] = None) -> Union[str , List[BaseMessage]]:
if not model_mode:
model_mode = self.model_mode
@@ -442,16 +324,7 @@ class BaseLLM(BaseProviderModel):
if len(messages) == 0:
return []
chat_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
chat_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
chat_messages.append(AIMessage(content=message.content))
elif message.type == MessageType.SYSTEM:
chat_messages.append(SystemMessage(content=message.content))
return chat_messages
return to_lc_messages(messages)
def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
"""

View File

@@ -66,15 +66,6 @@ class HuggingfaceHubModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return self._client.get_num_tokens(prompts)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
self.client.model_kwargs = provider_model_kwargs

View File

@@ -1,26 +1,23 @@
import decimal
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.llms import Minimax
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
class MinimaxModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return Minimax(
return MinimaxChatLLM(
model=self.name,
model_kwargs={
'stream': False
},
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
@@ -49,7 +46,7 @@ class MinimaxModel(BaseLLM):
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
return max(self._client.get_num_tokens_from_messages(prompts), 0)
def get_currency(self):
return 'RMB'
@@ -65,3 +62,7 @@ class MinimaxModel(BaseLLM):
return LLMBadRequestError(f"Minimax: {str(ex)}")
else:
return ex
@property
def support_streaming(self):
return True

View File

@@ -33,7 +33,7 @@ MODEL_MAX_TOKENS = {
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-instruct': 8192,
'gpt-3.5-turbo-instruct': 4097,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
}
@@ -106,7 +106,21 @@ class OpenAIModel(BaseLLM):
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
generate_kwargs = {
'stop': stop,
'callbacks': callbacks
}
if isinstance(prompts, str):
generate_kwargs['prompts'] = [prompts]
else:
generate_kwargs['messages'] = [prompts]
if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']
return self._client.generate(**generate_kwargs)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""

View File

@@ -49,15 +49,6 @@ class OpenLLMModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass

View File

@@ -18,7 +18,6 @@ class TongyiModel(BaseLLM):
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
del provider_model_kwargs['max_tokens']
return EnhanceTongyi(
model_name=self.name,
max_retries=1,
@@ -58,7 +57,6 @@ class TongyiModel(BaseLLM):
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
del provider_model_kwargs['max_tokens']
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)

View File

@@ -6,17 +6,16 @@ from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.wenxin import Wenxin
class WenxinModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
# TODO load price_config from configs(db)
return Wenxin(
model=self.name,
streaming=self.streaming,
@@ -38,7 +37,13 @@ class WenxinModel(BaseLLM):
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
generate_kwargs = {'stop': stop, 'callbacks': callbacks, 'messages': [prompts]}
if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']
return self._client.generate(**generate_kwargs)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
@@ -48,7 +53,7 @@ class WenxinModel(BaseLLM):
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
return max(self._client.get_num_tokens_from_messages(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
@@ -58,3 +63,7 @@ class WenxinModel(BaseLLM):
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Wenxin: {str(ex)}")
@property
def support_streaming(self):
return True

View File

@@ -59,15 +59,6 @@ class XinferenceModel(BaseLLM):
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
def prompt_file_name(self, mode: str) -> str:
if 'baichuan' in self.name.lower():
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
else:
return super().prompt_file_name(mode)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
pass

View File

@@ -16,6 +16,7 @@ class ZhipuAIModel(BaseLLM):
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return ZhipuAIChatLLM(
model=self.name,
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,

View File

@@ -9,7 +9,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelMode
from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.anthropic_model import AnthropicModel
from core.model_providers.models.llm.base import ModelType
@@ -34,10 +34,12 @@ class AnthropicProvider(BaseModelProvider):
{
'id': 'claude-instant-1',
'name': 'claude-instant-1',
'mode': ModelMode.CHAT.value,
},
{
'id': 'claude-2',
'name': 'claude-2',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -46,6 +48,9 @@ class AnthropicProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
@@ -167,7 +172,7 @@ class AnthropicProvider(BaseModelProvider):
def should_deduct_quota(self):
if hosted_model_providers.anthropic and \
hosted_model_providers.anthropic.quota_limit and hosted_model_providers.anthropic.quota_limit > 0:
hosted_model_providers.anthropic.quota_limit and hosted_model_providers.anthropic.quota_limit > -1:
return True
return False

View File

@@ -12,7 +12,7 @@ from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \
AZURE_OPENAI_API_VERSION
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule, ModelMode
from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -61,6 +61,10 @@ class AzureOpenAIProvider(BaseModelProvider):
}
credentials = json.loads(provider_model.encrypted_config)
if provider_model.model_type == ModelType.TEXT_GENERATION.value:
model_dict['mode'] = self._get_text_generation_model_mode(credentials['base_model_name'])
if credentials['base_model_name'] in [
'gpt-4',
'gpt-4-32k',
@@ -77,12 +81,19 @@ class AzureOpenAIProvider(BaseModelProvider):
return model_list
def _get_text_generation_model_mode(self, model_name) -> str:
if model_name == 'text-davinci-003':
return ModelMode.COMPLETION.value
else:
return ModelMode.CHAT.value
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
models = [
{
'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -90,6 +101,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -97,6 +109,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4',
'name': 'gpt-4',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -104,6 +117,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4-32k',
'name': 'gpt-4-32k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -111,6 +125,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'text-davinci-003',
'name': 'text-davinci-003',
'mode': ModelMode.COMPLETION.value,
}
]
@@ -314,7 +329,7 @@ class AzureOpenAIProvider(BaseModelProvider):
def should_deduct_quota(self):
if hosted_model_providers.azure_openai \
and hosted_model_providers.azure_openai.quota_limit and hosted_model_providers.azure_openai.quota_limit > 0:
and hosted_model_providers.azure_openai.quota_limit and hosted_model_providers.azure_openai.quota_limit > -1:
return True
return False

View File

@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.baichuan_model import BaichuanModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
@@ -21,6 +21,9 @@ class BaichuanProvider(BaseModelProvider):
Returns the name of a provider.
"""
return 'baichuan'
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
@@ -28,6 +31,7 @@ class BaichuanProvider(BaseModelProvider):
{
'id': 'baichuan2-53b',
'name': 'Baichuan2-53B',
'mode': ModelMode.CHAT.value,
}
]
else:

View File

@@ -61,10 +61,19 @@ class BaseModelProvider(BaseModel, ABC):
ProviderModel.is_valid == True
).order_by(ProviderModel.created_at.asc()).all()
return [{
'id': provider_model.model_name,
'name': provider_model.model_name
} for provider_model in provider_models]
provider_model_list = []
for provider_model in provider_models:
provider_model_dict = {
'id': provider_model.model_name,
'name': provider_model.model_name
}
if model_type == ModelType.TEXT_GENERATION:
provider_model_dict['mode'] = self._get_text_generation_model_mode(provider_model.model_name)
provider_model_list.append(provider_model_dict)
return provider_model_list
@abstractmethod
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
@@ -76,6 +85,16 @@ class BaseModelProvider(BaseModel, ABC):
"""
raise NotImplementedError
@abstractmethod
def _get_text_generation_model_mode(self, model_name) -> str:
"""
get text generation model mode.
:param model_name:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_model_class(self, model_type: ModelType) -> Type:
"""

View File

@@ -6,7 +6,7 @@ from langchain.llms import ChatGLM
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.chatglm_model import ChatGLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from models.provider import ProviderType
@@ -27,15 +27,20 @@ class ChatGLMProvider(BaseModelProvider):
{
'id': 'chatglm2-6b',
'name': 'ChatGLM2-6B',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'chatglm-6b',
'name': 'ChatGLM-6B',
'mode': ModelMode.COMPLETION.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -11,7 +11,7 @@ class HostedOpenAI(BaseModel):
api_organization: str = None
api_key: str
quota_limit: int = 0
"""Quota limit for the openai hosted model. 0 means unlimited."""
"""Quota limit for the openai hosted model. -1 means unlimited."""
paid_enabled: bool = False
paid_stripe_price_id: str = None
paid_increase_quota: int = 1
@@ -21,14 +21,14 @@ class HostedAzureOpenAI(BaseModel):
api_base: str
api_key: str
quota_limit: int = 0
"""Quota limit for the azure openai hosted model. 0 means unlimited."""
"""Quota limit for the azure openai hosted model. -1 means unlimited."""
class HostedAnthropic(BaseModel):
api_base: str = None
api_key: str
quota_limit: int = 0
"""Quota limit for the anthropic hosted model. 0 means unlimited."""
"""Quota limit for the anthropic hosted model. -1 means unlimited."""
paid_enabled: bool = False
paid_stripe_price_id: str = None
paid_increase_quota: int = 1000000

View File

@@ -5,7 +5,7 @@ import requests
from huggingface_hub import HfApi
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -29,6 +29,9 @@ class HuggingfaceHubProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule, ModelMode
from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -27,6 +27,13 @@ class LocalAIProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
credentials = self.get_model_credentials(model_name, ModelType.TEXT_GENERATION)
if credentials['completion_type'] == 'chat_completion':
return ModelMode.CHAT.value
else:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -2,14 +2,15 @@ import json
from json import JSONDecodeError
from typing import Type
from langchain.llms import Minimax
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.minimax_model import MinimaxModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
from models.provider import ProviderType, ProviderQuotaType
@@ -28,10 +29,12 @@ class MinimaxProvider(BaseModelProvider):
{
'id': 'abab5.5-chat',
'name': 'abab5.5-chat',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'abab5-chat',
'name': 'abab5-chat',
'mode': ModelMode.COMPLETION.value,
}
]
elif model_type == ModelType.EMBEDDINGS:
@@ -44,6 +47,9 @@ class MinimaxProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
@@ -98,14 +104,14 @@ class MinimaxProvider(BaseModelProvider):
'minimax_api_key': credentials['minimax_api_key'],
}
llm = Minimax(
llm = MinimaxChatLLM(
model='abab5.5-chat',
max_tokens=10,
temperature=0.01,
**credential_kwargs
)
llm("ping")
llm([HumanMessage(content='ping')])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

View File

@@ -13,8 +13,8 @@ from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.openai_model import OpenAIModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.openai_model import OpenAIModel, COMPLETION_MODELS
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.providers.hosted import hosted_model_providers
@@ -36,6 +36,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -43,10 +44,12 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-3.5-turbo-instruct',
'name': 'GPT-3.5-Turbo-Instruct',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -54,6 +57,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4',
'name': 'gpt-4',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -61,6 +65,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4-32k',
'name': 'gpt-4-32k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -68,6 +73,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'text-davinci-003',
'name': 'text-davinci-003',
'mode': ModelMode.COMPLETION.value,
}
]
@@ -100,6 +106,12 @@ class OpenAIProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
if model_name in COMPLETION_MODELS:
return ModelMode.COMPLETION.value
else:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
@@ -132,7 +144,7 @@ class OpenAIProvider(BaseModelProvider):
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-instruct': 8192,
'gpt-3.5-turbo-instruct': 4097,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
}
@@ -238,7 +250,7 @@ class OpenAIProvider(BaseModelProvider):
def should_deduct_quota(self):
if hosted_model_providers.openai \
and hosted_model_providers.openai.quota_limit and hosted_model_providers.openai.quota_limit > 0:
and hosted_model_providers.openai.quota_limit and hosted_model_providers.openai.quota_limit > -1:
return True
return False

View File

@@ -3,7 +3,7 @@ from typing import Type
from core.helper import encrypter
from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -24,6 +24,9 @@ class OpenLLMProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -6,7 +6,8 @@ import replicate
from replicate.exceptions import ReplicateError
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType
from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType, \
ModelMode
from core.model_providers.models.llm.replicate_model import ReplicateModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -26,6 +27,9 @@ class ReplicateProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value if model_name.endswith('-chat') else ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.spark_model import SparkModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.spark import ChatSpark
@@ -28,17 +28,27 @@ class SparkProvider(BaseModelProvider):
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'spark',
'name': 'Spark V1.5',
'id': 'spark-v3',
'name': 'Spark V3.0',
'mode': ModelMode.CHAT.value,
},
{
'id': 'spark-v2',
'name': 'Spark V2.0',
'mode': ModelMode.CHAT.value,
},
{
'id': 'spark',
'name': 'Spark V1.5',
'mode': ModelMode.CHAT.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
@@ -91,7 +101,7 @@ class SparkProvider(BaseModelProvider):
try:
chat_llm = ChatSpark(
model_name='spark-v2',
model_name='spark-v3',
max_tokens=10,
temperature=0.01,
**credential_kwargs
@@ -105,10 +115,10 @@ class SparkProvider(BaseModelProvider):
chat_llm(messages)
except SparkError as ex:
# try spark v1.5 if v2.1 failed
# try spark v2.1 if v3.1 failed
try:
chat_llm = ChatSpark(
model_name='spark',
model_name='spark-v2',
max_tokens=10,
temperature=0.01,
**credential_kwargs
@@ -122,10 +132,27 @@ class SparkProvider(BaseModelProvider):
chat_llm(messages)
except SparkError as ex:
raise CredentialsValidateFailedError(str(ex))
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex
# try spark v1.5 if v2.1 failed
try:
chat_llm = ChatSpark(
model_name='spark',
max_tokens=10,
temperature=0.01,
**credential_kwargs
)
messages = [
HumanMessage(
content="ping"
)
]
chat_llm(messages)
except SparkError as ex:
raise CredentialsValidateFailedError(str(ex))
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex
except Exception as ex:
logging.exception('Spark config validation failed')
raise ex

View File

@@ -4,7 +4,7 @@ from typing import Type
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.tongyi_model import TongyiModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi
@@ -24,17 +24,22 @@ class TongyiProvider(BaseModelProvider):
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'qwen-v1',
'name': 'qwen-v1',
'id': 'qwen-turbo',
'name': 'qwen-turbo',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'qwen-plus-v1',
'name': 'qwen-plus-v1',
'id': 'qwen-plus',
'name': 'qwen-plus',
'mode': ModelMode.COMPLETION.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
@@ -58,16 +63,16 @@ class TongyiProvider(BaseModelProvider):
:return:
"""
model_max_tokens = {
'qwen-v1': 1500,
'qwen-plus-v1': 6500
'qwen-turbo': 6000,
'qwen-plus': 6000
}
return ModelKwargsRules(
temperature=KwargRule[float](enabled=False),
top_p=KwargRule[float](min=0, max=1, default=0.8, precision=2),
temperature=KwargRule[float](min=0.01, max=1, default=1, precision=2),
top_p=KwargRule[float](min=0.01, max=0.99, default=0.5, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024, precision=0),
max_tokens=KwargRule[int](enabled=False, max=model_max_tokens.get(model_name)),
)
@classmethod
@@ -84,7 +89,7 @@ class TongyiProvider(BaseModelProvider):
}
llm = EnhanceTongyi(
model_name='qwen-v1',
model_name='qwen-turbo',
max_retries=1,
**credential_kwargs
)

View File

@@ -2,9 +2,11 @@ import json
from json import JSONDecodeError
from typing import Type
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.wenxin_model import WenxinModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.wenxin import Wenxin
@@ -23,22 +25,33 @@ class WenxinProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'ernie-bot-4',
'name': 'ERNIE-Bot-4',
'mode': ModelMode.CHAT.value,
},
{
'id': 'ernie-bot',
'name': 'ERNIE-Bot',
'mode': ModelMode.CHAT.value,
},
{
'id': 'ernie-bot-turbo',
'name': 'ERNIE-Bot-turbo',
'mode': ModelMode.CHAT.value,
},
{
'id': 'bloomz-7b',
'name': 'BLOOMZ-7B',
'mode': ModelMode.CHAT.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
@@ -62,11 +75,12 @@ class WenxinProvider(BaseModelProvider):
:return:
"""
model_max_tokens = {
'ernie-bot-4': 4800,
'ernie-bot': 4800,
'ernie-bot-turbo': 11200,
}
if model_name in ['ernie-bot', 'ernie-bot-turbo']:
if model_name in ['ernie-bot-4', 'ernie-bot', 'ernie-bot-turbo']:
return ModelKwargsRules(
temperature=KwargRule[float](min=0.01, max=1, default=0.95, precision=2),
top_p=KwargRule[float](min=0.01, max=1, default=0.8, precision=2),
@@ -105,7 +119,7 @@ class WenxinProvider(BaseModelProvider):
**credential_kwargs
)
llm("ping")
llm([HumanMessage(content='ping')])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

View File

@@ -2,15 +2,15 @@ import json
from typing import Type
import requests
from langchain.embeddings import XinferenceEmbeddings
from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
from models.provider import ProviderType
@@ -26,6 +26,9 @@ class XinferenceProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
@@ -26,21 +26,30 @@ class ZhipuAIProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'chatglm_turbo',
'name': 'chatglm_turbo',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_pro',
'name': 'chatglm_pro',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_std',
'name': 'chatglm_std',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_lite',
'name': 'chatglm_lite',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_lite_32k',
'name': 'chatglm_lite_32k',
'mode': ModelMode.CHAT.value,
}
]
elif model_type == ModelType.EMBEDDINGS:
@@ -53,6 +62,9 @@ class ZhipuAIProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -9,7 +9,7 @@
"trial"
],
"quota_unit": "tokens",
"quota_limit": 600000
"quota_limit": 0
},
"model_flexibility": "fixed",
"price_config": {

View File

@@ -22,6 +22,12 @@
"completion": "0.36",
"unit": "0.0001",
"currency": "RMB"
},
"spark-v3": {
"prompt": "0.36",
"completion": "0.36",
"unit": "0.0001",
"currency": "RMB"
}
}
}

View File

@@ -3,5 +3,19 @@
"custom"
],
"system_config": null,
"model_flexibility": "fixed"
"model_flexibility": "fixed",
"price_config": {
"qwen-turbo": {
"prompt": "0.012",
"completion": "0.012",
"unit": "0.001",
"currency": "RMB"
},
"qwen-plus": {
"prompt": "0.14",
"completion": "0.14",
"unit": "0.001",
"currency": "RMB"
}
}
}

View File

@@ -5,6 +5,12 @@
"system_config": null,
"model_flexibility": "fixed",
"price_config": {
"ernie-bot-4": {
"prompt": "0",
"completion": "0",
"unit": "0.001",
"currency": "RMB"
},
"ernie-bot": {
"prompt": "0.012",
"completion": "0.012",

View File

@@ -11,6 +11,12 @@
},
"model_flexibility": "fixed",
"price_config": {
"chatglm_turbo": {
"prompt": "0.005",
"completion": "0.005",
"unit": "0.001",
"currency": "RMB"
},
"chatglm_pro": {
"prompt": "0.01",
"completion": "0.01",

View File

View File

@@ -0,0 +1 @@
3

View File

View File

@@ -0,0 +1,88 @@
from pydantic import BaseModel
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction
from core.extension.api_based_extension_requestor import APIBasedExtensionRequestor, APIBasedExtensionPoint
from core.helper.encrypter import decrypt_token
from extensions.ext_database import db
from models.api_based_extension import APIBasedExtension
class ModerationInputParams(BaseModel):
app_id: str = ""
inputs: dict = {}
query: str = ""
class ModerationOutputParams(BaseModel):
app_id: str = ""
text: str
class ApiModeration(Moderation):
name: str = "api"
@classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
cls._validate_inputs_and_outputs_config(config, False)
api_based_extension_id = config.get("api_based_extension_id")
if not api_based_extension_id:
raise ValueError("api_based_extension_id is required")
extension = cls._get_api_based_extension(tenant_id, api_based_extension_id)
if not extension:
raise ValueError("API-based Extension not found. Please check it again.")
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config['inputs_config']['enabled']:
params = ModerationInputParams(
app_id=self.app_id,
inputs=inputs,
query=query
)
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_INPUT, params.dict())
return ModerationInputsResult(**result)
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""
if self.config['outputs_config']['enabled']:
params = ModerationOutputParams(
app_id=self.app_id,
text=text
)
result = self._get_config_by_requestor(APIBasedExtensionPoint.APP_MODERATION_OUTPUT, params.dict())
return ModerationOutputsResult(**result)
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
def _get_config_by_requestor(self, extension_point: APIBasedExtensionPoint, params: dict) -> dict:
extension = self._get_api_based_extension(self.tenant_id, self.config.get("api_based_extension_id"))
requestor = APIBasedExtensionRequestor(extension.api_endpoint, decrypt_token(self.tenant_id, extension.api_key))
result = requestor.request(extension_point, params)
return result
@staticmethod
def _get_api_based_extension(tenant_id: str, api_based_extension_id: str) -> APIBasedExtension:
extension = db.session.query(APIBasedExtension).filter(
APIBasedExtension.tenant_id == tenant_id,
APIBasedExtension.id == api_based_extension_id
).first()
return extension

113
api/core/moderation/base.py Normal file
View File

@@ -0,0 +1,113 @@
from abc import ABC, abstractmethod
from typing import Optional
from pydantic import BaseModel
from enum import Enum
from core.extension.extensible import Extensible, ExtensionModule
class ModerationAction(Enum):
DIRECT_OUTPUT = 'direct_output'
OVERRIDED = 'overrided'
class ModerationInputsResult(BaseModel):
flagged: bool = False
action: ModerationAction
preset_response: str = ""
inputs: dict = {}
query: str = ""
class ModerationOutputsResult(BaseModel):
flagged: bool = False
action: ModerationAction
preset_response: str = ""
text: str = ""
class Moderation(Extensible, ABC):
"""
The base class of moderation.
"""
module: ExtensionModule = ExtensionModule.MODERATION
def __init__(self, app_id: str, tenant_id: str, config: Optional[dict] = None) -> None:
super().__init__(tenant_id, config)
self.app_id = app_id
@classmethod
@abstractmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
raise NotImplementedError
@abstractmethod
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
"""
Moderation for inputs.
After the user inputs, this method will be called to perform sensitive content review
on the user inputs and return the processed results.
:param inputs: user inputs
:param query: query string (required in chat app)
:return:
"""
raise NotImplementedError
@abstractmethod
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
"""
Moderation for outputs.
When LLM outputs content, the front end will pass the output content (may be segmented)
to this method for sensitive content review, and the output content will be shielded if the review fails.
:param text: LLM output content
:return:
"""
raise NotImplementedError
@classmethod
def _validate_inputs_and_outputs_config(self, config: dict, is_preset_response_required: bool) -> None:
# inputs_config
inputs_config = config.get("inputs_config")
if not isinstance(inputs_config, dict):
raise ValueError("inputs_config must be a dict")
# outputs_config
outputs_config = config.get("outputs_config")
if not isinstance(outputs_config, dict):
raise ValueError("outputs_config must be a dict")
inputs_config_enabled = inputs_config.get("enabled")
outputs_config_enabled = outputs_config.get("enabled")
if not inputs_config_enabled and not outputs_config_enabled:
raise ValueError("At least one of inputs_config or outputs_config must be enabled")
# preset_response
if not is_preset_response_required:
return
if inputs_config_enabled:
if not inputs_config.get("preset_response"):
raise ValueError("inputs_config.preset_response is required")
if len(inputs_config.get("preset_response")) > 100:
raise ValueError("inputs_config.preset_response must be less than 100 characters")
if outputs_config_enabled:
if not outputs_config.get("preset_response"):
raise ValueError("outputs_config.preset_response is required")
if len(outputs_config.get("preset_response")) > 100:
raise ValueError("outputs_config.preset_response must be less than 100 characters")
class ModerationException(Exception):
pass

View File

@@ -0,0 +1,48 @@
from core.extension.extensible import ExtensionModule
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult
from extensions.ext_code_based_extension import code_based_extension
class ModerationFactory:
__extension_instance: Moderation
def __init__(self, name: str, app_id: str, tenant_id: str, config: dict) -> None:
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
self.__extension_instance = extension_class(app_id, tenant_id, config)
@classmethod
def validate_config(cls, name: str, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param name: the name of extension
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
code_based_extension.validate_form_schema(ExtensionModule.MODERATION, name, config)
extension_class = code_based_extension.extension_class(ExtensionModule.MODERATION, name)
extension_class.validate_config(tenant_id, config)
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
"""
Moderation for inputs.
After the user inputs, this method will be called to perform sensitive content review
on the user inputs and return the processed results.
:param inputs: user inputs
:param query: query string (required in chat app)
:return:
"""
return self.__extension_instance.moderation_for_inputs(inputs, query)
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
"""
Moderation for outputs.
When LLM outputs content, the front end will pass the output content (may be segmented)
to this method for sensitive content review, and the output content will be shielded if the review fails.
:param text: LLM output content
:return:
"""
return self.__extension_instance.moderation_for_outputs(text)

View File

@@ -0,0 +1 @@
2

View File

View File

@@ -0,0 +1,60 @@
from core.moderation.base import Moderation, ModerationInputsResult, ModerationOutputsResult, ModerationAction
class KeywordsModeration(Moderation):
name: str = "keywords"
@classmethod
def validate_config(cls, tenant_id: str, config: dict) -> None:
"""
Validate the incoming form config data.
:param tenant_id: the id of workspace
:param config: the form config data
:return:
"""
cls._validate_inputs_and_outputs_config(config, True)
if not config.get("keywords"):
raise ValueError("keywords is required")
if len(config.get("keywords")) > 1000:
raise ValueError("keywords length must be less than 1000")
def moderation_for_inputs(self, inputs: dict, query: str = "") -> ModerationInputsResult:
flagged = False
preset_response = ""
if self.config['inputs_config']['enabled']:
preset_response = self.config['inputs_config']['preset_response']
if query:
inputs['query__'] = query
keywords_list = self.config['keywords'].split('\n')
flagged = self._is_violated(inputs, keywords_list)
return ModerationInputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
def moderation_for_outputs(self, text: str) -> ModerationOutputsResult:
flagged = False
preset_response = ""
if self.config['outputs_config']['enabled']:
keywords_list = self.config['keywords'].split('\n')
flagged = self._is_violated({'text': text}, keywords_list)
preset_response = self.config['outputs_config']['preset_response']
return ModerationOutputsResult(flagged=flagged, action=ModerationAction.DIRECT_OUTPUT, preset_response=preset_response)
def _is_violated(self, inputs: dict, keywords_list: list) -> bool:
for value in inputs.values():
if self._check_keywords_in_value(keywords_list, value):
return True
return False
def _check_keywords_in_value(self, keywords_list, value):
for keyword in keywords_list:
if keyword.lower() in value.lower():
return True
return False

Some files were not shown because too many files have changed in this diff Show More