Compare commits

...

47 Commits

Author SHA1 Message Date
takatost
2feb16d957 feat: bump version to 0.3.28 (#1349) 2023-10-14 11:49:56 -05:00
crazywoola
3043fbe73b remove the suggested api for completion app (#1347) 2023-10-14 10:05:33 -05:00
Hickays
9f99c3f55b fix: modal z-index (#1343) 2023-10-13 05:55:03 -05:00
Joel
a07a6d8c26 feat: switch to generation model set default stop word (#1341) 2023-10-13 16:47:22 +08:00
Garfield Dai
695841a3cf Feat/advanced prompt enhancement (#1340) 2023-10-13 16:47:01 +08:00
takatost
3efaa713da feat: use xinference client instead of xinference (#1339) 2023-10-13 02:46:09 -05:00
takatost
9822f687f7 fix: max tokens of OpenAI gpt-3.5-turbo-instruct to 4097 (#1338) 2023-10-13 02:07:07 -05:00
crazywoola
b9d83c04bc fix: modal z-index (#1337) 2023-10-13 14:58:53 +08:00
Charlie.Wei
298ad6782d Add Message Suggested Api (#1326)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
2023-10-13 14:07:32 +08:00
takatost
f4be2b8bcd fix: raise error in minimax stream generate (#1336) 2023-10-12 23:48:28 -05:00
crazywoola
e83e239faf fix: value.join is not a function in log list (#1332) 2023-10-13 11:34:24 +08:00
taokuizu
62bf7f0fc2 fix: new app with template display (#1322) 2023-10-13 10:18:33 +08:00
takatost
7dea485d57 feat: bump version to 0.3.27 (#1331) 2023-10-12 10:37:48 -05:00
zxhlyh
5b9858a8a3 feat: advanced prompt (#1330)
Co-authored-by: Joel <iamjoel007@gmail.com>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: Gillian97 <jinling.sunshine@gmail.com>
2023-10-12 23:14:28 +08:00
Garfield Dai
42a5b3ec17 feat: advanced prompt backend (#1301)
Co-authored-by: takatost <takatost@gmail.com>
2023-10-12 10:13:10 -05:00
takatost
2d1cb076c6 fix: dataset segment not exist return agent response (#1329) 2023-10-12 04:40:20 -05:00
Jyong
289c93d081 Feat/improve document delete logic (#1325)
Co-authored-by: jyong <jyong@dify.ai>
2023-10-12 13:30:44 +08:00
takatost
c0fe706597 feat: adjust to only build the latest image when pushing a tag. (#1324) 2023-10-11 23:38:07 -05:00
takatost
9cba1c8bf4 fix: retriever_resource missing (#1317) 2023-10-11 14:37:11 -05:00
takatost
cbf095465c feat: remove llm client use (#1316) 2023-10-11 14:02:53 -05:00
KVOJJJin
c007dbdc13 Feat: add document of authorization (#1311) 2023-10-11 08:03:36 -05:00
takatost
ff493d017b fix: minimax tests (#1313) 2023-10-11 07:49:26 -05:00
Jyong
7f6ad9653e Fix/grpc gevent compatible (#1314)
Co-authored-by: jyong <jyong@dify.ai>
2023-10-11 20:48:35 +08:00
takatost
2851a9f04e feat: optimize minimax llm call (#1312) 2023-10-11 07:17:41 -05:00
takatost
c536f85b2e fix: compatibility issues with the tongyi model. (#1310) 2023-10-11 05:16:26 -05:00
takatost
b1352ff8b7 feat: using random sampling to check if it violates the review mechan… (#1308) 2023-10-11 04:11:20 -05:00
Jyong
cc63c8499f bump version to 0.3.26 (#1307)
Co-authored-by: jyong <jyong@dify.ai>
2023-10-11 16:11:24 +08:00
Jyong
f191b8b8d1 milvus docker compose env (#1306)
Co-authored-by: jyong <jyong@dify.ai>
2023-10-11 16:05:37 +08:00
Jyong
5003db987d milvus secure check fix (#1305)
Co-authored-by: jyong <jyong@dify.ai>
2023-10-11 13:11:06 +08:00
Jyong
07aab5e868 Feat/add milvus vector db (#1302)
Co-authored-by: jyong <jyong@dify.ai>
2023-10-10 21:56:24 +08:00
takatost
875dfbbf0e fix: openllm completion start with prompt, remove it (#1303) 2023-10-10 04:44:19 -05:00
Charlie.Wei
9e7efa45d4 document segmentApi Add get&update&delete operate (#1285)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
2023-10-10 13:27:06 +08:00
takatost
8bf892b306 feat: bump version to 0.3.25 (#1300) 2023-10-10 13:03:49 +08:00
takatost
8480b0197b fix: prompt for baichuan text generation models (#1299) 2023-10-10 13:01:18 +08:00
zxhlyh
df07fb5951 feat: provider add baichuan (#1298) 2023-10-09 23:10:43 -05:00
takatost
4ab4bcc074 feat: support openllm embedding (#1293) 2023-10-09 23:09:35 -05:00
takatost
1d4f019de4 feat: add baichuan llm support (#1294)
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
2023-10-09 23:09:26 -05:00
takatost
677aacc8e3 feat: upgrade xinference client to 0.5.2 (#1292) 2023-10-09 08:12:58 -05:00
takatost
fda937175d feat: qdrant support in docker compose (#1286) 2023-10-08 12:04:04 -05:00
takatost
024250803a feat: move login_required wrapper outside (#1281) 2023-10-08 05:21:32 -05:00
Charlie.Wei
b711ce33b7 Application share qrcode (#1277)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
2023-10-08 09:34:49 +08:00
Rhon Joe
52bec63275 chore(web): strong type (#1259) 2023-10-07 04:42:16 -05:00
qiuqiua
657fa80f4d fix devcontainer issue (#1273) 2023-10-07 10:34:25 +08:00
takatost
373e90ee6d fix: detached model in completion thread (#1269) 2023-10-02 22:27:25 +08:00
takatost
41d4c5b424 fix: count down thread in completion db not commit (#1267) 2023-10-02 10:19:26 +08:00
takatost
86a9dea428 fix: db not commit when streaming output (#1266) 2023-10-01 16:41:52 +08:00
takatost
8606d80c66 fix: request timeout when openai completion (#1265) 2023-10-01 16:00:23 +08:00
324 changed files with 11618 additions and 2115 deletions

View File

@@ -1,11 +1,8 @@
FROM mcr.microsoft.com/devcontainers/anaconda:0-3
FROM mcr.microsoft.com/devcontainers/python:3.10
COPY . .
# Copy environment.yml (if found) to a temp location so we update the environment. Also
# copy "noop.txt" so the COPY instruction does not fail if no environment.yml exists.
COPY environment.yml* .devcontainer/noop.txt /tmp/conda-tmp/
RUN if [ -f "/tmp/conda-tmp/environment.yml" ]; then umask 0002 && /opt/conda/bin/conda env update -n base -f /tmp/conda-tmp/environment.yml; fi \
&& rm -rf /tmp/conda-tmp
# [Optional] Uncomment this section to install additional OS packages.
# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
# && apt-get -y install --no-install-recommends <your-package-list-here>
# && apt-get -y install --no-install-recommends <your-package-list-here>

View File

@@ -1,13 +1,12 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/anaconda
{
"name": "Anaconda (Python 3)",
"name": "Python 3.10",
"build": {
"context": "..",
"dockerfile": "Dockerfile"
},
"features": {
"ghcr.io/dhoeric/features/act:1": {},
"ghcr.io/devcontainers/features/node:1": {
"nodeGypDependencies": true,
"version": "lts"

View File

@@ -31,7 +31,7 @@ jobs:
with:
images: langgenius/dify-api
tags: |
type=raw,value=latest,enable={{is_default_branch}}
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=ref,event=branch
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
type=semver,pattern={{major}}.{{minor}}.{{patch}}

View File

@@ -31,7 +31,7 @@ jobs:
with:
images: langgenius/dify-web
tags: |
type=raw,value=latest,enable={{is_default_branch}}
type=raw,value=latest,enable=${{ startsWith(github.ref, 'refs/tags/') }}
type=ref,event=branch
type=sha,enable=true,priority=100,prefix=,suffix=,format=long
type=semver,pattern={{major}}.{{minor}}.{{patch}}

View File

@@ -1,37 +0,0 @@
import os
import re
from zhon.hanzi import punctuation
def has_chinese_characters(text):
for char in text:
if '\u4e00' <= char <= '\u9fff' or char in punctuation:
return True
return False
def check_file_for_chinese_comments(file_path):
with open(file_path, 'r', encoding='utf-8') as file:
for line_number, line in enumerate(file, start=1):
if has_chinese_characters(line):
print(f"Found Chinese characters in {file_path} on line {line_number}:")
print(line.strip())
return True
return False
def main():
has_chinese = False
excluded_files = ["model_template.py", 'stopwords.py', 'commands.py',
'indexing_runner.py', 'web_reader_tool.py', 'spark_provider.py',
'prompts.py']
for root, _, files in os.walk("."):
for file in files:
if file.endswith(".py") and file not in excluded_files:
file_path = os.path.join(root, file)
if check_file_for_chinese_comments(file_path):
has_chinese = True
if has_chinese:
raise Exception("Found Chinese characters in Python files. Please remove them.")
if __name__ == "__main__":
main()

View File

@@ -1,31 +0,0 @@
name: Check for Chinese comments
on:
push:
branches:
- 'main'
pull_request:
branches:
- main
jobs:
check-chinese-comments:
runs-on: ubuntu-latest
steps:
- name: Check out repository
uses: actions/checkout@v2
- name: Set up Python
uses: actions/setup-python@v2
with:
python-version: 3.9
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install zhon
- name: Run script to check for Chinese comments
run: |
python .github/workflows/check_no_chinese_comments.py

1
.gitignore vendored
View File

@@ -144,6 +144,7 @@ docker/volumes/app/storage/*
docker/volumes/db/data/*
docker/volumes/redis/data/*
docker/volumes/weaviate/*
docker/volumes/qdrant/*
sdks/python-client/build
sdks/python-client/dist

View File

@@ -50,7 +50,7 @@ S3_REGION=your-region
WEB_API_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://127.0.0.1:3000,*
# Vector database configuration, support: weaviate, qdrant
# Vector database configuration, support: weaviate, qdrant, milvus
VECTOR_STORE=weaviate
# Weaviate configuration
@@ -59,9 +59,16 @@ WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
WEAVIATE_GRPC_ENABLED=false
WEAVIATE_BATCH_SIZE=100
# Qdrant configuration, use `path:` prefix for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
QDRANT_URL=path:storage/qdrant
QDRANT_API_KEY=your-qdrant-api-key
# Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
QDRANT_URL=http://localhost:6333
QDRANT_API_KEY=difyai123456
# Milvus configuration
MILVUS_HOST=127.0.0.1
MILVUS_PORT=19530
MILVUS_USER=root
MILVUS_PASSWORD=Milvus
MILVUS_SECURE=false
# Mail configuration, support: resend
MAIL_TYPE=

View File

@@ -6,6 +6,9 @@ from werkzeug.exceptions import Unauthorized
if not os.environ.get("DEBUG") or os.environ.get("DEBUG").lower() != 'true':
from gevent import monkey
monkey.patch_all()
if os.environ.get("VECTOR_STORE") == 'milvus':
import grpc.experimental.gevent
grpc.experimental.gevent.init_gevent()
import logging
import json

View File

@@ -92,7 +92,7 @@ class Config:
self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL')
self.CURRENT_VERSION = "0.3.24"
self.CURRENT_VERSION = "0.3.28"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@@ -135,6 +135,14 @@ class Config:
self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
# milvus setting
self.MILVUS_HOST = get_env('MILVUS_HOST')
self.MILVUS_PORT = get_env('MILVUS_PORT')
self.MILVUS_USER = get_env('MILVUS_USER')
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
# cors settings
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)

View File

@@ -31,6 +31,7 @@ model_templates = {
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
@@ -81,6 +82,7 @@ model_templates = {
'model': json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 512,
"temperature": 1,
@@ -137,10 +139,11 @@ demo_model_templates = {
},
opening_statement='',
suggested_questions=None,
pre_prompt="Please translate the following text into {{target_language}}:\n",
pre_prompt="Please translate the following text into {{target_language}}:\n{{query}}\ntranslate:",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
@@ -169,6 +172,13 @@ demo_model_templates = {
'Italian',
]
}
},{
"paragraph": {
"label": "Query",
"variable": "query",
"required": True,
"default": ""
}
}
])
)
@@ -200,6 +210,7 @@ demo_model_templates = {
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 300,
"temperature": 0.8,
@@ -255,10 +266,11 @@ demo_model_templates = {
},
opening_statement='',
suggested_questions=None,
pre_prompt="请将以下文本翻译为{{target_language}}:\n",
pre_prompt="请将以下文本翻译为{{target_language}}:\n{{query}}\n翻译:",
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo-instruct",
"mode": "completion",
"completion_params": {
"max_tokens": 1000,
"temperature": 0,
@@ -287,6 +299,13 @@ demo_model_templates = {
"意大利语",
]
}
},{
"paragraph": {
"label": "文本内容",
"variable": "query",
"required": True,
"default": ""
}
}
])
)
@@ -318,6 +337,7 @@ demo_model_templates = {
model=json.dumps({
"provider": "openai",
"name": "gpt-3.5-turbo",
"mode": "chat",
"completion_params": {
"max_tokens": 300,
"temperature": 0.8,

View File

@@ -9,7 +9,7 @@ api = ExternalApi(bp)
from . import setup, version, apikey, admin
# Import app controllers
from .app import app, site, completion, model_config, statistic, conversation, message, generator, audio
from .app import advanced_prompt_template, app, site, completion, model_config, statistic, conversation, message, generator, audio
# Import auth controllers
from .auth import login, oauth, data_source_oauth, activate

View File

@@ -1,5 +1,5 @@
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
import flask_restful
from flask_restful import Resource, fields, marshal_with
from werkzeug.exceptions import Forbidden

View File

@@ -0,0 +1,25 @@
from flask_restful import Resource, reqparse
from controllers.console import api
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from libs.login import login_required
from services.advanced_prompt_template_service import AdvancedPromptTemplateService
class AdvancedPromptTemplateList(Resource):
@setup_required
@login_required
@account_initialization_required
def get(self):
parser = reqparse.RequestParser()
parser.add_argument('app_mode', type=str, required=True, location='args')
parser.add_argument('model_mode', type=str, required=True, location='args')
parser.add_argument('has_context', type=str, required=False, default='true', location='args')
parser.add_argument('model_name', type=str, required=True, location='args')
args = parser.parse_args()
return AdvancedPromptTemplateService.get_prompt(args)
api.add_resource(AdvancedPromptTemplateList, '/app/prompt-templates')

View File

@@ -3,10 +3,9 @@ import json
import logging
from datetime import datetime
import flask
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal_with, abort, inputs
from werkzeug.exceptions import Forbidden
from constants.model_template import model_templates, demo_model_templates
@@ -17,11 +16,9 @@ from controllers.console.wraps import account_initialization_required
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType
from events.app_event import app_was_created, app_was_deleted
from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
app_detail_fields_with_site
from libs.helper import TimestampField
from extensions.ext_database import db
from models.model import App, AppModelConfig, Site
from services.app_model_config_service import AppModelConfigService

View File

@@ -2,8 +2,8 @@
import logging
from flask import request
from core.login.login import login_required
from werkzeug.exceptions import InternalServerError, NotFound
from libs.login import login_required
from werkzeug.exceptions import InternalServerError
import services
from controllers.console import api

View File

@@ -5,7 +5,7 @@ from typing import Generator, Union
import flask_login
from flask import Response, stream_with_context
from core.login.login import login_required
from libs.login import login_required
from werkzeug.exceptions import InternalServerError, NotFound
import services

View File

@@ -2,8 +2,8 @@ from datetime import datetime
import pytz
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal_with
from flask_restful.inputs import int_range
from sqlalchemy import or_, func
from sqlalchemy.orm import joinedload
@@ -15,7 +15,7 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from fields.conversation_fields import conversation_pagination_fields, conversation_detail_fields, \
conversation_message_detail_fields, conversation_with_summary_pagination_fields
from libs.helper import TimestampField, datetime_string, uuid_value
from libs.helper import datetime_string
from extensions.ext_database import db
from models.model import Message, MessageAnnotation, Conversation

View File

@@ -1,5 +1,5 @@
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, reqparse
from controllers.console import api
@@ -12,35 +12,6 @@ from core.model_providers.error import ProviderTokenNotInitError, QuotaExceededE
LLMAPIUnavailableError, LLMRateLimitError, LLMAuthorizationError, ModelCurrentlyNotSupportError
class IntroductionGenerateApi(Resource):
@setup_required
@login_required
@account_initialization_required
def post(self):
parser = reqparse.RequestParser()
parser.add_argument('prompt_template', type=str, required=True, location='json')
args = parser.parse_args()
account = current_user
try:
answer = LLMGenerator.generate_introduction(
account.current_tenant_id,
args['prompt_template']
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, LLMAuthorizationError) as e:
raise CompletionRequestError(str(e))
return {'introduction': answer}
class RuleGenerateApi(Resource):
@setup_required
@login_required
@@ -72,5 +43,4 @@ class RuleGenerateApi(Resource):
return rules
api.add_resource(IntroductionGenerateApi, '/introduction-generate')
api.add_resource(RuleGenerateApi, '/rule-generate')

View File

@@ -16,9 +16,9 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.login.login import login_required
from libs.login import login_required
from fields.conversation_fields import message_detail_fields
from libs.helper import uuid_value, TimestampField
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db
from models.model import MessageAnnotation, Conversation, Message, MessageFeedback
@@ -295,8 +295,8 @@ class MessageSuggestedQuestionApi(Resource):
try:
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
user=current_user,
message_id=message_id,
user=current_user,
check_enabled=False
)
except MessageNotExistsError:
@@ -329,7 +329,7 @@ class MessageApi(Resource):
message_id = str(message_id)
# get app info
app_model = _get_app(app_id, 'chat')
app_model = _get_app(app_id)
message = db.session.query(Message).filter(
Message.id == message_id,

View File

@@ -1,5 +1,4 @@
# -*- coding:utf-8 -*-
import json
from flask import request
from flask_restful import Resource
@@ -9,7 +8,7 @@ from controllers.console import api
from controllers.console.app import _get_app
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.login.login import login_required
from libs.login import login_required
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from models.model import AppModelConfig

View File

@@ -1,7 +1,7 @@
# -*- coding:utf-8 -*-
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal_with
from werkzeug.exceptions import NotFound, Forbidden
from controllers.console import api

View File

@@ -5,7 +5,7 @@ from datetime import datetime
import pytz
from flask import jsonify
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, reqparse
from controllers.console import api

View File

@@ -1,16 +1,13 @@
import logging
from datetime import datetime
from typing import Optional
import flask_login
import requests
from flask import request, redirect, current_app, session
from flask import request, redirect, current_app
from flask_login import current_user
from flask_restful import Resource
from werkzeug.exceptions import Forbidden
from core.login.login import login_required
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth
from controllers.console import api
from ..setup import setup_required

View File

@@ -2,10 +2,10 @@ import datetime
import json
from cachetools import TTLCache
from flask import request, current_app
from flask import request
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, marshal_with, fields, reqparse, marshal
from libs.login import login_required
from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound
from controllers.console import api
@@ -15,7 +15,6 @@ from core.data_loader.loader.notion import NotionLoader
from core.indexing_runner import IndexingRunner
from extensions.ext_database import db
from fields.data_source_fields import integrate_notion_info_list_fields, integrate_list_fields
from libs.helper import TimestampField
from models.dataset import Document
from models.source import DataSourceBinding
from services.dataset_service import DatasetService, DocumentService

View File

@@ -4,8 +4,8 @@ from flask import request, current_app
from flask_login import current_user
from controllers.console.apikey import api_key_list, api_key_fields
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal, marshal_with
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal, marshal_with
from werkzeug.exceptions import NotFound, Forbidden
import services
from controllers.console import api

View File

@@ -1,11 +1,10 @@
# -*- coding:utf-8 -*-
import random
from datetime import datetime
from typing import List
from flask import request, current_app
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import desc, asc
from werkzeug.exceptions import NotFound, Forbidden
@@ -25,7 +24,6 @@ from core.model_providers.model_factory import ModelFactory
from extensions.ext_redis import redis_client
from fields.document_fields import document_with_segments_fields, document_fields, \
dataset_and_document_fields, document_status_fields
from libs.helper import TimestampField
from extensions.ext_database import db
from models.dataset import DatasetProcessRule, Dataset
from models.dataset import Document, DocumentSegment

View File

@@ -14,13 +14,12 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.login.login import login_required
from libs.login import login_required
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import segment_fields
from models.dataset import DocumentSegment
from libs.helper import TimestampField
from services.dataset_service import DatasetService, DocumentService, SegmentService
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task

View File

@@ -2,8 +2,8 @@ from cachetools import TTLCache
from flask import request, current_app
import services
from core.login.login import login_required
from flask_restful import Resource, marshal_with, fields
from libs.login import login_required
from flask_restful import Resource, marshal_with
from controllers.console import api
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \

View File

@@ -1,7 +1,7 @@
import logging
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden

View File

@@ -2,8 +2,8 @@
from datetime import datetime
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with, inputs
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal_with, inputs
from sqlalchemy import and_
from werkzeug.exceptions import NotFound, Forbidden, BadRequest
@@ -12,7 +12,6 @@ from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
from libs.helper import TimestampField
from models.model import App, InstalledApp, RecommendedApp
from services.account_service import TenantService

View File

@@ -1,6 +1,6 @@
# -*- coding:utf-8 -*-
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, fields, marshal_with
from sqlalchemy import and_

View File

@@ -1,5 +1,5 @@
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource
from functools import wraps

View File

@@ -2,7 +2,7 @@ import json
from functools import wraps
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required

View File

@@ -4,7 +4,7 @@ from datetime import datetime
import pytz
from flask import current_app, request
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError

View File

@@ -1,7 +1,7 @@
# -*- coding:utf-8 -*-
from flask import current_app
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal
import services

View File

@@ -1,5 +1,5 @@
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden

View File

@@ -1,5 +1,5 @@
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, reqparse
from controllers.console import api

View File

@@ -1,6 +1,6 @@
# -*- coding:utf-8 -*-
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden

View File

@@ -1,7 +1,7 @@
import json
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, abort, reqparse
from werkzeug.exceptions import Forbidden

View File

@@ -3,9 +3,8 @@ import logging
from flask import request
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, fields, marshal_with, reqparse, marshal, inputs
from flask_restful.inputs import int_range
from controllers.console import api
from controllers.console.admin import admin_required

View File

@@ -54,6 +54,7 @@ class ConversationDetailApi(AppApiResource):
raise NotFound("Conversation Not Exists.")
return {"result": "success"}, 204
class ConversationRenameApi(AppApiResource):
@marshal_with(simple_conversation_fields)

View File

@@ -10,6 +10,8 @@ from controllers.service_api.app.error import NotChatAppError
from controllers.service_api.wraps import AppApiResource
from libs.helper import TimestampField, uuid_value
from services.message_service import MessageService
from extensions.ext_database import db
from models.model import Account, Message
class MessageListApi(AppApiResource):
@@ -96,5 +98,36 @@ class MessageFeedbackApi(AppApiResource):
return {'result': 'success'}
class MessageSuggestedApi(AppApiResource):
def get(self, app_model, end_user, message_id):
message_id = str(message_id)
if app_model.mode != 'chat':
raise NotChatAppError()
try:
message = db.session.query(Message).filter(
Message.id == message_id,
Message.app_id == app_model.id,
).first()
if end_user is None and message.from_account_id is not None:
user = db.session.get(Account, message.from_account_id)
elif end_user is None and message.from_end_user_id is not None:
user = create_or_update_end_user_for_user_id(app_model, message.from_end_user_id)
else:
user = end_user
questions = MessageService.get_suggested_questions_after_answer(
app_model=app_model,
user=user,
message_id=message_id
)
except services.errors.message.MessageNotExistsError:
raise NotFound("Message Not Exists.")
return {'result': 'success', 'data': questions}
api.add_resource(MessageListApi, '/messages')
api.add_resource(MessageFeedbackApi, '/messages/<uuid:message_id>/feedbacks')
api.add_resource(MessageSuggestedApi, '/messages/<uuid:message_id>/suggested')

View File

@@ -4,12 +4,9 @@ import services.dataset_service
from controllers.service_api import api
from controllers.service_api.dataset.error import DatasetNameDuplicateError
from controllers.service_api.wraps import DatasetApiResource
from core.login.login import current_user
from libs.login import current_user
from core.model_providers.models.entity.model_params import ModelType
from extensions.ext_database import db
from fields.dataset_fields import dataset_detail_fields
from models.account import Account, TenantAccountJoin
from models.dataset import Dataset
from services.dataset_service import DatasetService
from services.provider_service import ProviderService

View File

@@ -1,8 +1,6 @@
import datetime
import json
import uuid
from flask import current_app, request
from flask import request
from flask_restful import reqparse, marshal
from sqlalchemy import desc
from werkzeug.exceptions import NotFound
@@ -13,13 +11,11 @@ from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
NoFileUploadedError, TooManyFilesError
from controllers.service_api.wraps import DatasetApiResource
from core.login.login import current_user
from libs.login import current_user
from core.model_providers.error import ProviderTokenNotInitError
from extensions.ext_database import db
from extensions.ext_storage import storage
from fields.document_fields import document_fields, document_status_fields
from models.dataset import Dataset, Document, DocumentSegment
from models.model import UploadFile
from services.dataset_service import DocumentService
from services.file_service import FileService

View File

@@ -1,7 +1,6 @@
from flask_login import current_user
from flask_restful import reqparse, marshal
from werkzeug.exceptions import NotFound
from controllers.service_api import api
from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.wraps import DatasetApiResource
@@ -9,8 +8,8 @@ from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestE
from core.model_providers.model_factory import ModelFactory
from extensions.ext_database import db
from fields.segment_fields import segment_fields
from models.dataset import Dataset
from services.dataset_service import DocumentService, SegmentService
from models.dataset import Dataset, DocumentSegment
from services.dataset_service import DatasetService, DocumentService, SegmentService
class SegmentApi(DatasetApiResource):
@@ -24,6 +23,8 @@ class SegmentApi(DatasetApiResource):
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
@@ -55,5 +56,146 @@ class SegmentApi(DatasetApiResource):
'doc_form': document.doc_form
}, 200
def get(self, tenant_id, dataset_id, document_id):
"""Create single segment."""
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset.id, document_id)
if not document:
raise NotFound('Document not found.')
# check embedding model setting
if dataset.indexing_technique == 'high_quality':
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
parser = reqparse.RequestParser()
parser.add_argument('status', type=str,
action='append', default=[], location='args')
parser.add_argument('keyword', type=str, default=None, location='args')
args = parser.parse_args()
status_list = args['status']
keyword = args['keyword']
query = DocumentSegment.query.filter(
DocumentSegment.document_id == str(document_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
)
if status_list:
query = query.filter(DocumentSegment.status.in_(status_list))
if keyword:
query = query.where(DocumentSegment.content.ilike(f'%{keyword}%'))
total = query.count()
segments = query.order_by(DocumentSegment.position).all()
return {
'data': marshal(segments, segment_fields),
'doc_form': document.doc_form,
'total': total
}, 200
class DatasetSegmentApi(DatasetApiResource):
def delete(self, tenant_id, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
# check segment
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
SegmentService.delete_segment(segment, document, dataset)
return {'result': 'success'}, 200
def post(self, tenant_id, dataset_id, document_id, segment_id):
# check dataset
dataset_id = str(dataset_id)
tenant_id = str(tenant_id)
dataset = db.session.query(Dataset).filter(
Dataset.tenant_id == tenant_id,
Dataset.id == dataset_id
).first()
if not dataset:
raise NotFound('Dataset not found.')
# check user's model setting
DatasetService.check_dataset_model_setting(dataset)
# check document
document_id = str(document_id)
document = DocumentService.get_document(dataset_id, document_id)
if not document:
raise NotFound('Document not found.')
if dataset.indexing_technique == 'high_quality':
# check embedding model setting
try:
ModelFactory.get_embedding_model(
tenant_id=current_user.current_tenant_id,
model_provider_name=dataset.embedding_model_provider,
model_name=dataset.embedding_model
)
except LLMBadRequestError:
raise ProviderNotInitializeError(
f"No Embedding Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
# check segment
segment_id = str(segment_id)
segment = DocumentSegment.query.filter(
DocumentSegment.id == str(segment_id),
DocumentSegment.tenant_id == current_user.current_tenant_id
).first()
if not segment:
raise NotFound('Segment not found.')
# validate args
parser = reqparse.RequestParser()
parser.add_argument('segments', type=dict, required=False, nullable=True, location='json')
args = parser.parse_args()
SegmentService.segment_create_args_validate(args['segments'], document)
segment = SegmentService.update_segment(args['segments'], segment, document, dataset)
return {
'data': marshal(segment, segment_fields),
'doc_form': document.doc_form
}, 200
api.add_resource(SegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments')
api.add_resource(DatasetSegmentApi, '/datasets/<uuid:dataset_id>/documents/<uuid:document_id>/segments/<uuid:segment_id>')

View File

@@ -7,10 +7,9 @@ from flask_login import user_logged_in
from flask_restful import Resource
from werkzeug.exceptions import NotFound, Unauthorized
from core.login.login import _get_user
from libs.login import _get_user
from extensions.ext_database import db
from models.account import Tenant, TenantAccountJoin, Account
from models.dataset import Dataset
from models.model import ApiToken, App

View File

@@ -115,7 +115,7 @@ class MessageMoreLikeThisApi(WebApiResource):
streaming = args['response_mode'] == 'streaming'
try:
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming)
response = CompletionService.generate_more_like_this(app_model, end_user, message_id, streaming, 'web_app')
return compact_response(response)
except MessageNotExistsError:
raise NotFound("Message Not Exists.")

View File

@@ -2,14 +2,18 @@ import json
from typing import Tuple, List, Any, Union, Sequence, Optional, cast
from langchain.agents import OpenAIFunctionsAgent, BaseSingleActionAgent
from langchain.agents.openai_functions_agent.base import _format_intermediate_steps, _parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema import AgentAction, AgentFinish, SystemMessage, Generation, LLMResult, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from pydantic import root_validator
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
@@ -24,6 +28,10 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
arbitrary_types_allowed = True
@root_validator
def validate_llm(cls, values: dict) -> dict:
return values
def should_use_agent(self, query: str):
"""
return should use agent
@@ -65,17 +73,57 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
return AgentFinish(return_values={"output": observation}, log=observation)
try:
agent_decision = super().plan(intermediate_steps, callbacks, **kwargs)
agent_decision = self.real_plan(intermediate_steps, callbacks, **kwargs)
if isinstance(agent_decision, AgentAction):
tool_inputs = agent_decision.tool_input
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
else:
agent_decision.return_values['output'] = ''
return agent_decision
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
def real_plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
)
ai_message = AIMessage(
content=result.content,
additional_kwargs={
'function_call': result.function_call
}
)
agent_decision = _parse_ai_message(ai_message)
return agent_decision
async def aplan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
@@ -87,7 +135,7 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
model_instance: BaseLLM,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
@@ -96,11 +144,15 @@ class MultiDatasetRouterAgent(OpenAIFunctionsAgent):
),
**kwargs: Any,
) -> BaseSingleActionAgent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
model_instance=model_instance,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
**kwargs,
)

View File

@@ -5,21 +5,40 @@ from langchain.agents.openai_functions_agent.base import _parse_ai_message, \
_format_intermediate_steps
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.chat_models.openai import _convert_message_to_dict, _import_tiktoken
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema import AgentAction, AgentFinish, SystemMessage, AIMessage, HumanMessage, BaseMessage, \
get_buffer_string
from langchain.tools import BaseTool
from pydantic import root_validator
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
from core.chain.llm_chain import LLMChain
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_model_instance: BaseLLM = None
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
@root_validator
def validate_llm(cls, values: dict) -> dict:
return values
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
model_instance: BaseLLM,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
@@ -28,12 +47,16 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
),
**kwargs: Any,
) -> BaseSingleActionAgent:
return super().from_llm_and_tools(
llm=llm,
prompt = cls.create_prompt(
extra_prompt_messages=extra_prompt_messages,
system_message=system_message,
)
return cls(
model_instance=model_instance,
llm=FakeLLM(response=''),
prompt=prompt,
tools=tools,
callback_manager=callback_manager,
extra_prompt_messages=extra_prompt_messages,
system_message=cls.get_system_message(),
**kwargs,
)
@@ -44,23 +67,26 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
:param query:
:return:
"""
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 40
original_max_tokens = self.model_instance.model_kwargs.max_tokens
self.model_instance.model_kwargs.max_tokens = 40
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
try:
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
callbacks=None
)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
function_call = predicted_message.additional_kwargs.get("function_call", {})
function_call = result.function_call
self.llm.max_tokens = original_max_tokens
self.model_instance.model_kwargs.max_tokens = original_max_tokens
return True if function_call else False
@@ -93,10 +119,19 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
functions=self.functions,
)
agent_decision = _parse_ai_message(predicted_message)
ai_message = AIMessage(
content=result.content,
additional_kwargs={
'function_call': result.function_call
}
)
agent_decision = _parse_ai_message(ai_message)
if isinstance(agent_decision, AgentAction) and agent_decision.tool == 'dataset':
tool_inputs = agent_decision.tool_input
@@ -122,3 +157,142 @@ class AutoSummarizingOpenAIFunctionCallAgent(OpenAIFunctionsAgent, OpenAIFunctio
return super().return_stopped_response(early_stopping_method, intermediate_steps, **kwargs)
except ValueError:
return AgentFinish({"output": "I'm sorry, I don't know how to respond to that."}, "")
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages
system_message = None
human_message = None
should_summary_messages = []
for message in messages:
if isinstance(message, SystemMessage):
system_message = message
elif isinstance(message, HumanMessage):
human_message = message
else:
should_summary_messages.append(message)
if len(should_summary_messages) > 2:
ai_message = should_summary_messages[-2]
function_message = should_summary_messages[-1]
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
self.moving_summary_index = len(should_summary_messages)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
new_messages = [system_message, human_message]
if self.moving_summary_index == 0:
should_summary_messages.insert(0, human_message)
self.moving_summary_buffer = self.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
new_messages.append(AIMessage(content=self.moving_summary_buffer))
new_messages.append(ai_message)
new_messages.append(function_message)
return new_messages
def predict_new_summary(
self, messages: List[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix="Human",
ai_prefix="AI",
)
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
if model_instance.model_provider.provider_name == 'azure_openai':
model = model_instance.base_model_name
model = model.replace("gpt-35", "gpt-3.5")
else:
model = model_instance.base_model_name
tiktoken_ = _import_tiktoken()
try:
encoding = tiktoken_.encoding_for_model(model)
except KeyError:
model = "cl100k_base"
encoding = tiktoken_.get_encoding(model)
if model.startswith("gpt-3.5-turbo"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
for m in messages:
message = _convert_message_to_dict(m)
num_tokens += tokens_per_message
for key, value in message.items():
if key == "function_call":
for f_key, f_value in value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if kwargs.get('functions'):
for function in kwargs.get('functions'):
num_tokens += len(encoding.encode('name'))
num_tokens += len(encoding.encode(function.get("name")))
num_tokens += len(encoding.encode('description'))
num_tokens += len(encoding.encode(function.get("description")))
parameters = function.get("parameters")
num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters:
num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode(parameters.get("title")))
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(parameters.get("type")))
if 'properties' in parameters:
num_tokens += len(encoding.encode('properties'))
for key, value in parameters.get('properties').items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == 'enum':
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if 'required' in parameters:
num_tokens += len(encoding.encode('required'))
for required_field in parameters['required']:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))
return num_tokens

View File

@@ -1,140 +0,0 @@
from typing import cast, List
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.openai import _convert_message_to_dict
from langchain.memory.summary import SummarizerMixin
from langchain.schema import SystemMessage, HumanMessage, BaseMessage, AIMessage
from langchain.schema.language_model import BaseLanguageModel
from pydantic import BaseModel
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError, CalcTokenMixin
from core.model_providers.models.llm.base import BaseLLM
class OpenAIFunctionCallSummarizeMixin(BaseModel, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel = None
model_instance: BaseLLM
class Config:
"""Configuration for this pydantic object."""
arbitrary_types_allowed = True
def summarize_messages_if_needed(self, messages: List[BaseMessage], **kwargs) -> List[BaseMessage]:
# calculate rest tokens and summarize previous function observation messages if rest_tokens < 0
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages, **kwargs)
rest_tokens = rest_tokens - 20 # to deal with the inaccuracy of rest_tokens
if rest_tokens >= 0:
return messages
system_message = None
human_message = None
should_summary_messages = []
for message in messages:
if isinstance(message, SystemMessage):
system_message = message
elif isinstance(message, HumanMessage):
human_message = message
else:
should_summary_messages.append(message)
if len(should_summary_messages) > 2:
ai_message = should_summary_messages[-2]
function_message = should_summary_messages[-1]
should_summary_messages = should_summary_messages[self.moving_summary_index:-2]
self.moving_summary_index = len(should_summary_messages)
else:
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
new_messages = [system_message, human_message]
if self.moving_summary_index == 0:
should_summary_messages.insert(0, human_message)
summary_handler = SummarizerMixin(llm=self.summary_llm)
self.moving_summary_buffer = summary_handler.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
new_messages.append(AIMessage(content=self.moving_summary_buffer))
new_messages.append(ai_message)
new_messages.append(function_message)
return new_messages
def get_num_tokens_from_messages(self, model_instance: BaseLLM, messages: List[BaseMessage], **kwargs) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
llm = cast(ChatOpenAI, model_instance.client)
model, encoding = llm._get_encoding_model()
if model.startswith("gpt-3.5-turbo"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
for m in messages:
message = _convert_message_to_dict(m)
num_tokens += tokens_per_message
for key, value in message.items():
if key == "function_call":
for f_key, f_value in value.items():
num_tokens += len(encoding.encode(f_key))
num_tokens += len(encoding.encode(f_value))
else:
num_tokens += len(encoding.encode(value))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
if kwargs.get('functions'):
for function in kwargs.get('functions'):
num_tokens += len(encoding.encode('name'))
num_tokens += len(encoding.encode(function.get("name")))
num_tokens += len(encoding.encode('description'))
num_tokens += len(encoding.encode(function.get("description")))
parameters = function.get("parameters")
num_tokens += len(encoding.encode('parameters'))
if 'title' in parameters:
num_tokens += len(encoding.encode('title'))
num_tokens += len(encoding.encode(parameters.get("title")))
num_tokens += len(encoding.encode('type'))
num_tokens += len(encoding.encode(parameters.get("type")))
if 'properties' in parameters:
num_tokens += len(encoding.encode('properties'))
for key, value in parameters.get('properties').items():
num_tokens += len(encoding.encode(key))
for field_key, field_value in value.items():
num_tokens += len(encoding.encode(field_key))
if field_key == 'enum':
for enum_field in field_value:
num_tokens += 3
num_tokens += len(encoding.encode(enum_field))
else:
num_tokens += len(encoding.encode(field_key))
num_tokens += len(encoding.encode(str(field_value)))
if 'required' in parameters:
num_tokens += len(encoding.encode('required'))
for required_field in parameters['required']:
num_tokens += 3
num_tokens += len(encoding.encode(required_field))
return num_tokens

View File

@@ -1,107 +0,0 @@
from typing import List, Tuple, Any, Union, Sequence, Optional
from langchain.agents import BaseMultiActionAgent
from langchain.agents.openai_functions_multi_agent.base import OpenAIMultiFunctionsAgent, _format_intermediate_steps, \
_parse_ai_message
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts.chat import BaseMessagePromptTemplate
from langchain.schema import AgentAction, AgentFinish, SystemMessage
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import BaseTool
from core.agent.agent.calc_token_mixin import ExceededLLMTokensLimitError
from core.agent.agent.openai_function_call_summarize_mixin import OpenAIFunctionCallSummarizeMixin
class AutoSummarizingOpenMultiAIFunctionCallAgent(OpenAIMultiFunctionsAgent, OpenAIFunctionCallSummarizeMixin):
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
extra_prompt_messages: Optional[List[BaseMessagePromptTemplate]] = None,
system_message: Optional[SystemMessage] = SystemMessage(
content="You are a helpful AI assistant."
),
**kwargs: Any,
) -> BaseMultiActionAgent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
extra_prompt_messages=extra_prompt_messages,
system_message=cls.get_system_message(),
**kwargs,
)
def should_use_agent(self, query: str):
"""
return should use agent
:param query:
:return:
"""
original_max_tokens = self.llm.max_tokens
self.llm.max_tokens = 15
prompt = self.prompt.format_prompt(input=query, agent_scratchpad=[])
messages = prompt.to_messages()
try:
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=None
)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
raise new_exception
function_call = predicted_message.additional_kwargs.get("function_call", {})
self.llm.max_tokens = original_max_tokens
return True if function_call else False
def plan(
self,
intermediate_steps: List[Tuple[AgentAction, str]],
callbacks: Callbacks = None,
**kwargs: Any,
) -> Union[AgentAction, AgentFinish]:
"""Given input, decided what to do.
Args:
intermediate_steps: Steps the LLM has taken to date, along with observations
**kwargs: User inputs.
Returns:
Action specifying what tool to use.
"""
agent_scratchpad = _format_intermediate_steps(intermediate_steps)
selected_inputs = {
k: kwargs[k] for k in self.prompt.input_variables if k != "agent_scratchpad"
}
full_inputs = dict(**selected_inputs, agent_scratchpad=agent_scratchpad)
prompt = self.prompt.format_prompt(**full_inputs)
messages = prompt.to_messages()
# summarize messages if rest_tokens < 0
try:
messages = self.summarize_messages_if_needed(messages, functions=self.functions)
except ExceededLLMTokensLimitError as e:
return AgentFinish(return_values={"output": str(e)}, log=str(e))
predicted_message = self.llm.predict_messages(
messages, functions=self.functions, callbacks=callbacks
)
agent_decision = _parse_ai_message(predicted_message)
return agent_decision
@classmethod
def get_system_message(cls):
# get current time
return SystemMessage(content="You are a helpful AI assistant.\n"
"The current date or current time you know is wrong.\n"
"Respond directly if appropriate.")

View File

@@ -4,7 +4,6 @@ from typing import List, Tuple, Any, Union, Sequence, Optional, cast
from langchain import BasePromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
@@ -12,6 +11,7 @@ from langchain.schema import AgentAction, AgentFinish, OutputParserException
from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.chain.llm_chain import LLMChain
from core.model_providers.models.llm.base import BaseLLM
from core.tool.dataset_retriever_tool import DatasetRetrieverTool
@@ -49,7 +49,6 @@ Action:
class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
model_instance: BaseLLM
dataset_tools: Sequence[BaseTool]
class Config:
@@ -98,7 +97,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
raise new_exception
try:
@@ -108,6 +107,8 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
if isinstance(tool_inputs, dict) and 'query' in tool_inputs:
tool_inputs['query'] = kwargs['input']
agent_decision.tool_input = tool_inputs
else:
agent_decision.return_values['output'] = ''
return agent_decision
except OutputParserException:
return AgentFinish({"output": "I'm sorry, the answer of model is invalid, "
@@ -145,7 +146,7 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
model_instance: BaseLLM,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
@@ -157,17 +158,28 @@ class StructuredMultiDatasetRouterAgent(StructuredChatAgent):
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
output_parser=output_parser,
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
prompt = cls.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
llm_chain = LLMChain(
model_instance=model_instance,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
dataset_tools=tools,
**kwargs,
)

View File

@@ -4,16 +4,17 @@ from typing import List, Tuple, Any, Union, Sequence, Optional
from langchain import BasePromptTemplate
from langchain.agents import StructuredChatAgent, AgentOutputParser, Agent
from langchain.agents.structured_chat.base import HUMAN_MESSAGE_TEMPLATE
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.base import BaseCallbackManager
from langchain.callbacks.manager import Callbacks
from langchain.memory.summary import SummarizerMixin
from langchain.memory.prompt import SUMMARY_PROMPT
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, ChatPromptTemplate
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException
from langchain.schema import AgentAction, AgentFinish, AIMessage, HumanMessage, OutputParserException, BaseMessage, \
get_buffer_string
from langchain.tools import BaseTool
from langchain.agents.structured_chat.prompt import PREFIX, SUFFIX
from core.agent.agent.calc_token_mixin import CalcTokenMixin, ExceededLLMTokensLimitError
from core.chain.llm_chain import LLMChain
from core.model_providers.models.llm.base import BaseLLM
FORMAT_INSTRUCTIONS = """Use a json blob to specify a tool by providing an action key (tool name) and an action_input key (tool input).
@@ -52,8 +53,7 @@ Action:
class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
moving_summary_buffer: str = ""
moving_summary_index: int = 0
summary_llm: BaseLanguageModel = None
model_instance: BaseLLM
summary_model_instance: BaseLLM = None
class Config:
"""Configuration for this pydantic object."""
@@ -95,14 +95,14 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
if prompts:
messages = prompts[0].to_messages()
rest_tokens = self.get_message_rest_tokens(self.model_instance, messages)
rest_tokens = self.get_message_rest_tokens(self.llm_chain.model_instance, messages)
if rest_tokens < 0:
full_inputs = self.summarize_messages(intermediate_steps, **kwargs)
try:
full_output = self.llm_chain.predict(callbacks=callbacks, **full_inputs)
except Exception as e:
new_exception = self.model_instance.handle_exceptions(e)
new_exception = self.llm_chain.model_instance.handle_exceptions(e)
raise new_exception
try:
@@ -118,7 +118,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
"I don't know how to respond to that."}, "")
def summarize_messages(self, intermediate_steps: List[Tuple[AgentAction, str]], **kwargs):
if len(intermediate_steps) >= 2 and self.summary_llm:
if len(intermediate_steps) >= 2 and self.summary_model_instance:
should_summary_intermediate_steps = intermediate_steps[self.moving_summary_index:-1]
should_summary_messages = [AIMessage(content=observation)
for _, observation in should_summary_intermediate_steps]
@@ -130,11 +130,10 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
error_msg = "Exceeded LLM tokens limit, stopped."
raise ExceededLLMTokensLimitError(error_msg)
summary_handler = SummarizerMixin(llm=self.summary_llm)
if self.moving_summary_buffer and 'chat_history' in kwargs:
kwargs["chat_history"].pop()
self.moving_summary_buffer = summary_handler.predict_new_summary(
self.moving_summary_buffer = self.predict_new_summary(
messages=should_summary_messages,
existing_summary=self.moving_summary_buffer
)
@@ -144,6 +143,18 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
return self.get_full_inputs([intermediate_steps[-1]], **kwargs)
def predict_new_summary(
self, messages: List[BaseMessage], existing_summary: str
) -> str:
new_lines = get_buffer_string(
messages,
human_prefix="Human",
ai_prefix="AI",
)
chain = LLMChain(model_instance=self.summary_model_instance, prompt=SUMMARY_PROMPT)
return chain.predict(summary=existing_summary, new_lines=new_lines)
@classmethod
def create_prompt(
cls,
@@ -176,7 +187,7 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
@classmethod
def from_llm_and_tools(
cls,
llm: BaseLanguageModel,
model_instance: BaseLLM,
tools: Sequence[BaseTool],
callback_manager: Optional[BaseCallbackManager] = None,
output_parser: Optional[AgentOutputParser] = None,
@@ -188,16 +199,27 @@ class AutoSummarizingStructuredChatAgent(StructuredChatAgent, CalcTokenMixin):
memory_prompts: Optional[List[BasePromptTemplate]] = None,
**kwargs: Any,
) -> Agent:
return super().from_llm_and_tools(
llm=llm,
tools=tools,
callback_manager=callback_manager,
output_parser=output_parser,
"""Construct an agent from an LLM and tools."""
cls._validate_tools(tools)
prompt = cls.create_prompt(
tools,
prefix=prefix,
suffix=suffix,
human_message_template=human_message_template,
format_instructions=format_instructions,
input_variables=input_variables,
memory_prompts=memory_prompts,
)
llm_chain = LLMChain(
model_instance=model_instance,
prompt=prompt,
callback_manager=callback_manager,
)
tool_names = [tool.name for tool in tools]
_output_parser = output_parser
return cls(
llm_chain=llm_chain,
allowed_tools=tool_names,
output_parser=_output_parser,
**kwargs,
)

View File

@@ -10,7 +10,6 @@ from pydantic import BaseModel, Extra
from core.agent.agent.multi_dataset_router_agent import MultiDatasetRouterAgent
from core.agent.agent.openai_function_call import AutoSummarizingOpenAIFunctionCallAgent
from core.agent.agent.openai_multi_function_call import AutoSummarizingOpenMultiAIFunctionCallAgent
from core.agent.agent.output_parser.structured_chat import StructuredChatOutputParser
from core.agent.agent.structed_multi_dataset_router_agent import StructuredMultiDatasetRouterAgent
from core.agent.agent.structured_chat import AutoSummarizingStructuredChatAgent
@@ -27,7 +26,6 @@ class PlanningStrategy(str, enum.Enum):
REACT_ROUTER = 'react_router'
REACT = 'react'
FUNCTION_CALL = 'function_call'
MULTI_FUNCTION_CALL = 'multi_function_call'
class AgentConfiguration(BaseModel):
@@ -64,30 +62,18 @@ class AgentExecutor:
if self.configuration.strategy == PlanningStrategy.REACT:
agent = AutoSummarizingStructuredChatAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
summary_llm=self.configuration.summary_model_instance.client
summary_model_instance=self.configuration.summary_model_instance
if self.configuration.summary_model_instance else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.FUNCTION_CALL:
agent = AutoSummarizingOpenAIFunctionCallAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client
if self.configuration.summary_model_instance else None,
verbose=True
)
elif self.configuration.strategy == PlanningStrategy.MULTI_FUNCTION_CALL:
agent = AutoSummarizingOpenMultiAIFunctionCallAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None, # used for read chat histories memory
summary_llm=self.configuration.summary_model_instance.client
summary_model_instance=self.configuration.summary_model_instance
if self.configuration.summary_model_instance else None,
verbose=True
)
@@ -95,7 +81,6 @@ class AgentExecutor:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = MultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
extra_prompt_messages=self.configuration.memory.buffer if self.configuration.memory else None,
verbose=True
@@ -104,7 +89,6 @@ class AgentExecutor:
self.configuration.tools = [t for t in self.configuration.tools if isinstance(t, DatasetRetrieverTool)]
agent = StructuredMultiDatasetRouterAgent.from_llm_and_tools(
model_instance=self.configuration.model_instance,
llm=self.configuration.model_instance.client,
tools=self.configuration.tools,
output_parser=StructuredChatOutputParser(),
verbose=True

View File

@@ -0,0 +1,36 @@
from typing import List, Dict, Any, Optional
from langchain import LLMChain as LCLLMChain
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import LLMResult, Generation
from langchain.schema.language_model import BaseLanguageModel
from core.model_providers.models.entity.message import to_prompt_messages
from core.model_providers.models.llm.base import BaseLLM
from core.third_party.langchain.llms.fake import FakeLLM
class LLMChain(LCLLMChain):
model_instance: BaseLLM
"""The language model instance to use."""
llm: BaseLanguageModel = FakeLLM(response="")
def generate(
self,
input_list: List[Dict[str, Any]],
run_manager: Optional[CallbackManagerForChainRun] = None,
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
messages = prompts[0].to_messages()
prompt_messages = to_prompt_messages(messages)
result = self.model_instance.run(
messages=prompt_messages,
stop=stop
)
generations = [
[Generation(text=result.content)]
]
return LLMResult(generations=generations)

View File

@@ -1,4 +1,3 @@
import json
import logging
from typing import Optional, List, Union
@@ -16,10 +15,8 @@ from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompts import MORE_LIKE_THIS_GENERATE_PROMPT
from models.dataset import DocumentSegment, Dataset, Document
from models.model import App, AppModelConfig, Account, Conversation, Message, EndUser
from core.prompt.prompt_template import PromptTemplateParser
from models.model import App, AppModelConfig, Account, Conversation, EndUser
class Completion:
@@ -30,7 +27,7 @@ class Completion:
"""
errors: ProviderTokenNotInitError
"""
query = PromptBuilder.process_template(query)
query = PromptTemplateParser.remove_template_variables(query)
memory = None
if conversation:
@@ -160,14 +157,28 @@ class Completion:
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
fake_response: Optional[str]):
# get llm prompt
prompt_messages, stop_words = model_instance.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
)
if app_model_config.prompt_type == 'simple':
prompt_messages, stop_words = model_instance.get_prompt(
mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
)
else:
prompt_messages = model_instance.get_advanced_prompt(
app_mode=mode,
app_model_config=app_model_config,
inputs=inputs,
query=query,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory
)
model_config = app_model_config.model_dict
completion_params = model_config.get("completion_params", {})
stop_words = completion_params.get("stop", [])
cls.recale_llm_max_tokens(
model_instance=model_instance,
@@ -176,7 +187,7 @@ class Completion:
response = model_instance.run(
messages=prompt_messages,
stop=stop_words,
stop=stop_words if stop_words else None,
callbacks=[LLMCallbackHandler(model_instance, conversation_message_task)],
fake_response=fake_response
)
@@ -266,52 +277,3 @@ class Completion:
model_kwargs = model_instance.get_model_kwargs()
model_kwargs.max_tokens = max_tokens
model_instance.set_model_kwargs(model_kwargs)
@classmethod
def generate_more_like_this(cls, task_id: str, app: App, message: Message, pre_prompt: str,
app_model_config: AppModelConfig, user: Account, streaming: bool):
final_model_instance = ModelFactory.get_text_generation_model_from_model_config(
tenant_id=app.tenant_id,
model_config=app_model_config.model_dict,
streaming=streaming
)
# get llm prompt
old_prompt_messages, _ = final_model_instance.get_prompt(
mode='completion',
pre_prompt=pre_prompt,
inputs=message.inputs,
query=message.query,
context=None,
memory=None
)
original_completion = message.answer.strip()
prompt = MORE_LIKE_THIS_GENERATE_PROMPT
prompt = prompt.format(prompt=old_prompt_messages[0].content, original_completion=original_completion)
prompt_messages = [PromptMessage(content=prompt)]
conversation_message_task = ConversationMessageTask(
task_id=task_id,
app=app,
app_model_config=app_model_config,
user=user,
inputs=message.inputs,
query=message.query,
is_override=True if message.override_model_configs else False,
streaming=streaming,
model_instance=final_model_instance
)
cls.recale_llm_max_tokens(
model_instance=final_model_instance,
prompt_messages=prompt_messages
)
final_model_instance.run(
messages=prompt_messages,
callbacks=[LLMCallbackHandler(final_model_instance, conversation_message_task)]
)

View File

@@ -10,7 +10,7 @@ from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
from core.model_providers.models.llm.base import BaseLLM
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompt_template import PromptTemplateParser
from events.message_event import message_was_created
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -74,10 +74,10 @@ class ConversationMessageTask:
if self.mode == 'chat':
introduction = self.app_model_config.opening_statement
if introduction:
prompt_template = JinjaPromptTemplate.from_template(template=introduction)
prompt_inputs = {k: self.inputs[k] for k in prompt_template.input_variables if k in self.inputs}
prompt_template = PromptTemplateParser(template=introduction)
prompt_inputs = {k: self.inputs[k] for k in prompt_template.variable_keys if k in self.inputs}
try:
introduction = prompt_template.format(**prompt_inputs)
introduction = prompt_template.format(prompt_inputs)
except KeyError:
pass
@@ -94,7 +94,7 @@ class ConversationMessageTask:
if not self.conversation:
self.is_new_conversation = True
self.conversation = Conversation(
app_id=self.app_model_config.app_id,
app_id=self.app.id,
app_model_config_id=self.app_model_config.id,
model_provider=self.provider_name,
model_id=self.model_name,
@@ -115,7 +115,7 @@ class ConversationMessageTask:
db.session.commit()
self.message = Message(
app_id=self.app_model_config.app_id,
app_id=self.app.id,
model_provider=self.provider_name,
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
@@ -150,12 +150,12 @@ class ConversationMessageTask:
message_tokens = llm_message.prompt_tokens
answer_tokens = llm_message.completion_tokens
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.HUMAN)
message_price_unit = self.model_instance.get_price_unit(MessageType.HUMAN)
message_unit_price = self.model_instance.get_tokens_unit_price(MessageType.USER)
message_price_unit = self.model_instance.get_price_unit(MessageType.USER)
answer_unit_price = self.model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
answer_price_unit = self.model_instance.get_price_unit(MessageType.ASSISTANT)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.HUMAN)
message_total_price = self.model_instance.calc_tokens_price(message_tokens, MessageType.USER)
answer_total_price = self.model_instance.calc_tokens_price(answer_tokens, MessageType.ASSISTANT)
total_price = message_total_price + answer_total_price
@@ -163,7 +163,7 @@ class ConversationMessageTask:
self.message.message_tokens = message_tokens
self.message.message_unit_price = message_unit_price
self.message.message_price_unit = message_price_unit
self.message.answer = PromptBuilder.process_template(
self.message.answer = PromptTemplateParser.remove_template_variables(
llm_message.completion.strip()) if llm_message.completion else ''
self.message.answer_tokens = answer_tokens
self.message.answer_unit_price = answer_unit_price
@@ -226,15 +226,15 @@ class ConversationMessageTask:
def on_agent_end(self, message_agent_thought: MessageAgentThought, agent_model_instance: BaseLLM,
agent_loop: AgentLoop):
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.HUMAN)
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.HUMAN)
agent_message_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.USER)
agent_message_price_unit = agent_model_instance.get_price_unit(MessageType.USER)
agent_answer_unit_price = agent_model_instance.get_tokens_unit_price(MessageType.ASSISTANT)
agent_answer_price_unit = agent_model_instance.get_price_unit(MessageType.ASSISTANT)
loop_message_tokens = agent_loop.prompt_tokens
loop_answer_tokens = agent_loop.completion_tokens
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.HUMAN)
loop_message_total_price = agent_model_instance.calc_tokens_price(loop_message_tokens, MessageType.USER)
loop_answer_total_price = agent_model_instance.calc_tokens_price(loop_answer_tokens, MessageType.ASSISTANT)
loop_total_price = loop_message_total_price + loop_answer_total_price

View File

@@ -10,9 +10,8 @@ from core.model_providers.models.entity.model_params import ModelKwargs
from core.prompt.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.prompt.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.prompt.prompt_template import JinjaPromptTemplate, OutLinePromptTemplate
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, CONVERSATION_SUMMARY_PROMPT, INTRODUCTION_GENERATE_PROMPT, \
GENERATOR_QA_PROMPT
from core.prompt.prompt_template import PromptTemplateParser
from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
class LLMGenerator:
@@ -44,78 +43,19 @@ class LLMGenerator:
return answer.strip()
@classmethod
def generate_conversation_summary(cls, tenant_id: str, messages):
max_tokens = 200
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
model_kwargs=ModelKwargs(
max_tokens=max_tokens
)
)
prompt = CONVERSATION_SUMMARY_PROMPT
prompt_with_empty_context = prompt.format(context='')
prompt_tokens = model_instance.get_num_tokens([PromptMessage(content=prompt_with_empty_context)])
max_context_token_length = model_instance.model_rules.max_tokens.max
max_context_token_length = max_context_token_length if max_context_token_length else 1500
rest_tokens = max_context_token_length - prompt_tokens - max_tokens - 1
context = ''
for message in messages:
if not message.answer:
continue
if len(message.query) > 2000:
query = message.query[:300] + "...[TRUNCATED]..." + message.query[-300:]
else:
query = message.query
if len(message.answer) > 2000:
answer = message.answer[:300] + "...[TRUNCATED]..." + message.answer[-300:]
else:
answer = message.answer
message_qa_text = "\n\nHuman:" + query + "\n\nAssistant:" + answer
if rest_tokens - model_instance.get_num_tokens([PromptMessage(content=context + message_qa_text)]) > 0:
context += message_qa_text
if not context:
return '[message too long, no summary]'
prompt = prompt.format(context=context)
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
def generate_introduction(cls, tenant_id: str, pre_prompt: str):
prompt = INTRODUCTION_GENERATE_PROMPT
prompt = prompt.format(prompt=pre_prompt)
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id
)
prompts = [PromptMessage(content=prompt)]
response = model_instance.run(prompts)
answer = response.content
return answer.strip()
@classmethod
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
output_parser = SuggestedQuestionsAfterAnswerOutputParser()
format_instructions = output_parser.get_format_instructions()
prompt = JinjaPromptTemplate(
template="{{histories}}\n{{format_instructions}}\nquestions:\n",
input_variables=["histories"],
partial_variables={"format_instructions": format_instructions}
prompt_template = PromptTemplateParser(
template="{{histories}}\n{{format_instructions}}\nquestions:\n"
)
_input = prompt.format_prompt(histories=histories)
prompt = prompt_template.format({
"histories": histories,
"format_instructions": format_instructions
})
try:
model_instance = ModelFactory.get_text_generation_model(
@@ -128,10 +68,10 @@ class LLMGenerator:
except ProviderTokenNotInitError:
return []
prompts = [PromptMessage(content=_input.to_string())]
prompt_messages = [PromptMessage(content=prompt)]
try:
output = model_instance.run(prompts)
output = model_instance.run(prompt_messages)
questions = output_parser.parse(output.content)
except LLMError:
questions = []
@@ -145,19 +85,21 @@ class LLMGenerator:
def generate_rule_config(cls, tenant_id: str, audiences: str, hoping_to_solve: str) -> dict:
output_parser = RuleConfigGeneratorOutputParser()
prompt = OutLinePromptTemplate(
template=output_parser.get_format_instructions(),
input_variables=["audiences", "hoping_to_solve"],
partial_variables={
"variable": '{variable}',
"lanA": '{lanA}',
"lanB": '{lanB}',
"topic": '{topic}'
},
validate_template=False
prompt_template = PromptTemplateParser(
template=output_parser.get_format_instructions()
)
_input = prompt.format_prompt(audiences=audiences, hoping_to_solve=hoping_to_solve)
prompt = prompt_template.format(
inputs={
"audiences": audiences,
"hoping_to_solve": hoping_to_solve,
"variable": "{{variable}}",
"lanA": "{{lanA}}",
"lanB": "{{lanB}}",
"topic": "{{topic}}"
},
remove_template_variables=False
)
model_instance = ModelFactory.get_text_generation_model(
tenant_id=tenant_id,
@@ -167,10 +109,10 @@ class LLMGenerator:
)
)
prompts = [PromptMessage(content=_input.to_string())]
prompt_messages = [PromptMessage(content=prompt)]
try:
output = model_instance.run(prompts)
output = model_instance.run(prompt_messages)
rule_config = output_parser.parse(output.content)
except LLMError as e:
raise e

View File

@@ -1,4 +1,5 @@
import logging
import random
import openai
@@ -16,19 +17,20 @@ def check_moderation(model_provider: BaseModelProvider, text: str) -> bool:
length = 2000
text_chunks = [text[i:i + length] for i in range(0, len(text), length)]
max_text_chunks = 32
chunks = [text_chunks[i:i + max_text_chunks] for i in range(0, len(text_chunks), max_text_chunks)]
if len(text_chunks) == 0:
return True
for text_chunk in chunks:
try:
moderation_result = openai.Moderation.create(input=text_chunk,
api_key=hosted_model_providers.openai.api_key)
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
text_chunk = random.choice(text_chunks)
for result in moderation_result.results:
if result['flagged'] is True:
return False
try:
moderation_result = openai.Moderation.create(input=text_chunk,
api_key=hosted_model_providers.openai.api_key)
except Exception as ex:
logging.exception(ex)
raise LLMBadRequestError('Rate limit exceeded, please try again later.')
for result in moderation_result.results:
if result['flagged'] is True:
return False
return True

View File

@@ -0,0 +1,858 @@
"""Wrapper around the Milvus vector database."""
from __future__ import annotations
import logging
from typing import Any, Iterable, List, Optional, Tuple, Union, Sequence
from uuid import uuid4
import numpy as np
from langchain.docstore.document import Document
from langchain.embeddings.base import Embeddings
from langchain.vectorstores.base import VectorStore
from langchain.vectorstores.utils import maximal_marginal_relevance
logger = logging.getLogger(__name__)
DEFAULT_MILVUS_CONNECTION = {
"host": "localhost",
"port": "19530",
"user": "",
"password": "",
"secure": False,
}
class Milvus(VectorStore):
"""Initialize wrapper around the milvus vector database.
In order to use this you need to have `pymilvus` installed and a
running Milvus
See the following documentation for how to run a Milvus instance:
https://milvus.io/docs/install_standalone-docker.md
If looking for a hosted Milvus, take a look at this documentation:
https://zilliz.com/cloud and make use of the Zilliz vectorstore found in
this project,
IF USING L2/IP metric IT IS HIGHLY SUGGESTED TO NORMALIZE YOUR DATA.
Args:
embedding_function (Embeddings): Function used to embed the text.
collection_name (str): Which Milvus collection to use. Defaults to
"LangChainCollection".
connection_args (Optional[dict[str, any]]): The connection args used for
this class comes in the form of a dict.
consistency_level (str): The consistency level to use for a collection.
Defaults to "Session".
index_params (Optional[dict]): Which index params to use. Defaults to
HNSW/AUTOINDEX depending on service.
search_params (Optional[dict]): Which search params to use. Defaults to
default of index.
drop_old (Optional[bool]): Whether to drop the current collection. Defaults
to False.
The connection args used for this class comes in the form of a dict,
here are a few of the options:
address (str): The actual address of Milvus
instance. Example address: "localhost:19530"
uri (str): The uri of Milvus instance. Example uri:
"http://randomwebsite:19530",
"tcp:foobarsite:19530",
"https://ok.s3.south.com:19530".
host (str): The host of Milvus instance. Default at "localhost",
PyMilvus will fill in the default host if only port is provided.
port (str/int): The port of Milvus instance. Default at 19530, PyMilvus
will fill in the default port if only host is provided.
user (str): Use which user to connect to Milvus instance. If user and
password are provided, we will add related header in every RPC call.
password (str): Required when user is provided. The password
corresponding to the user.
secure (bool): Default is false. If set to true, tls will be enabled.
client_key_path (str): If use tls two-way authentication, need to
write the client.key path.
client_pem_path (str): If use tls two-way authentication, need to
write the client.pem path.
ca_pem_path (str): If use tls two-way authentication, need to write
the ca.pem path.
server_pem_path (str): If use tls one-way authentication, need to
write the server.pem path.
server_name (str): If use tls, need to write the common name.
Example:
.. code-block:: python
from langchain import Milvus
from langchain.embeddings import OpenAIEmbeddings
embedding = OpenAIEmbeddings()
# Connect to a milvus instance on localhost
milvus_store = Milvus(
embedding_function = Embeddings,
collection_name = "LangChainCollection",
drop_old = True,
)
Raises:
ValueError: If the pymilvus python package is not installed.
"""
def __init__(
self,
embedding_function: Embeddings,
collection_name: str = "LangChainCollection",
connection_args: Optional[dict[str, Any]] = None,
consistency_level: str = "Session",
index_params: Optional[dict] = None,
search_params: Optional[dict] = None,
drop_old: Optional[bool] = False,
):
"""Initialize the Milvus vector store."""
try:
from pymilvus import Collection, utility
except ImportError:
raise ValueError(
"Could not import pymilvus python package. "
"Please install it with `pip install pymilvus`."
)
# Default search params when one is not provided.
self.default_search_params = {
"IVF_FLAT": {"metric_type": "L2", "params": {"nprobe": 10}},
"IVF_SQ8": {"metric_type": "L2", "params": {"nprobe": 10}},
"IVF_PQ": {"metric_type": "L2", "params": {"nprobe": 10}},
"HNSW": {"metric_type": "L2", "params": {"ef": 10}},
"RHNSW_FLAT": {"metric_type": "L2", "params": {"ef": 10}},
"RHNSW_SQ": {"metric_type": "L2", "params": {"ef": 10}},
"RHNSW_PQ": {"metric_type": "L2", "params": {"ef": 10}},
"IVF_HNSW": {"metric_type": "L2", "params": {"nprobe": 10, "ef": 10}},
"ANNOY": {"metric_type": "L2", "params": {"search_k": 10}},
"AUTOINDEX": {"metric_type": "L2", "params": {}},
}
self.embedding_func = embedding_function
self.collection_name = collection_name
self.index_params = index_params
self.search_params = search_params
self.consistency_level = consistency_level
# In order for a collection to be compatible, pk needs to be auto'id and int
self._primary_field = "id"
# In order for compatibility, the text field will need to be called "text"
self._text_field = "page_content"
# In order for compatibility, the vector field needs to be called "vector"
self._vector_field = "vectors"
# In order for compatibility, the metadata field will need to be called "metadata"
self._metadata_field = "metadata"
self.fields: list[str] = []
# Create the connection to the server
if connection_args is None:
connection_args = DEFAULT_MILVUS_CONNECTION
self.alias = self._create_connection_alias(connection_args)
self.col: Optional[Collection] = None
# Grab the existing collection if it exists
if utility.has_collection(self.collection_name, using=self.alias):
self.col = Collection(
self.collection_name,
using=self.alias,
)
# If need to drop old, drop it
if drop_old and isinstance(self.col, Collection):
self.col.drop()
self.col = None
# Initialize the vector store
self._init()
@property
def embeddings(self) -> Embeddings:
return self.embedding_func
def _create_connection_alias(self, connection_args: dict) -> str:
"""Create the connection to the Milvus server."""
from pymilvus import MilvusException, connections
# Grab the connection arguments that are used for checking existing connection
host: str = connection_args.get("host", None)
port: Union[str, int] = connection_args.get("port", None)
address: str = connection_args.get("address", None)
uri: str = connection_args.get("uri", None)
user = connection_args.get("user", None)
# Order of use is host/port, uri, address
if host is not None and port is not None:
given_address = str(host) + ":" + str(port)
elif uri is not None:
given_address = uri.split("https://")[1]
elif address is not None:
given_address = address
else:
given_address = None
logger.debug("Missing standard address type for reuse atttempt")
# User defaults to empty string when getting connection info
if user is not None:
tmp_user = user
else:
tmp_user = ""
# If a valid address was given, then check if a connection exists
if given_address is not None:
for con in connections.list_connections():
addr = connections.get_connection_addr(con[0])
if (
con[1]
and ("address" in addr)
and (addr["address"] == given_address)
and ("user" in addr)
and (addr["user"] == tmp_user)
):
logger.debug("Using previous connection: %s", con[0])
return con[0]
# Generate a new connection if one doesn't exist
alias = uuid4().hex
try:
connections.connect(alias=alias, **connection_args)
logger.debug("Created new connection using: %s", alias)
return alias
except MilvusException as e:
logger.error("Failed to create new connection using: %s", alias)
raise e
def _init(
self, embeddings: Optional[list] = None, metadatas: Optional[list[dict]] = None
) -> None:
if embeddings is not None:
self._create_collection(embeddings, metadatas)
self._extract_fields()
self._create_index()
self._create_search_params()
self._load()
def _create_collection(
self, embeddings: list, metadatas: Optional[list[dict]] = None
) -> None:
from pymilvus import (
Collection,
CollectionSchema,
DataType,
FieldSchema,
MilvusException,
)
from pymilvus.orm.types import infer_dtype_bydata
# Determine embedding dim
dim = len(embeddings[0])
fields = []
# Determine metadata schema
# if metadatas:
# # Create FieldSchema for each entry in metadata.
# for key, value in metadatas[0].items():
# # Infer the corresponding datatype of the metadata
# dtype = infer_dtype_bydata(value)
# # Datatype isn't compatible
# if dtype == DataType.UNKNOWN or dtype == DataType.NONE:
# logger.error(
# "Failure to create collection, unrecognized dtype for key: %s",
# key,
# )
# raise ValueError(f"Unrecognized datatype for {key}.")
# # Dataype is a string/varchar equivalent
# elif dtype == DataType.VARCHAR:
# fields.append(FieldSchema(key, DataType.VARCHAR, max_length=65_535))
# else:
# fields.append(FieldSchema(key, dtype))
if metadatas:
fields.append(FieldSchema(self._metadata_field, DataType.JSON, max_length=65_535))
# Create the text field
fields.append(
FieldSchema(self._text_field, DataType.VARCHAR, max_length=65_535)
)
# Create the primary key field
fields.append(
FieldSchema(
self._primary_field, DataType.INT64, is_primary=True, auto_id=True
)
)
# Create the vector field, supports binary or float vectors
fields.append(
FieldSchema(self._vector_field, infer_dtype_bydata(embeddings[0]), dim=dim)
)
# Create the schema for the collection
schema = CollectionSchema(fields)
# Create the collection
try:
self.col = Collection(
name=self.collection_name,
schema=schema,
consistency_level=self.consistency_level,
using=self.alias,
)
except MilvusException as e:
logger.error(
"Failed to create collection: %s error: %s", self.collection_name, e
)
raise e
def _extract_fields(self) -> None:
"""Grab the existing fields from the Collection"""
from pymilvus import Collection
if isinstance(self.col, Collection):
schema = self.col.schema
for x in schema.fields:
self.fields.append(x.name)
# Since primary field is auto-id, no need to track it
self.fields.remove(self._primary_field)
def _get_index(self) -> Optional[dict[str, Any]]:
"""Return the vector index information if it exists"""
from pymilvus import Collection
if isinstance(self.col, Collection):
for x in self.col.indexes:
if x.field_name == self._vector_field:
return x.to_dict()
return None
def _create_index(self) -> None:
"""Create a index on the collection"""
from pymilvus import Collection, MilvusException
if isinstance(self.col, Collection) and self._get_index() is None:
try:
# If no index params, use a default HNSW based one
if self.index_params is None:
self.index_params = {
"metric_type": "IP",
"index_type": "HNSW",
"params": {"M": 8, "efConstruction": 64},
}
try:
self.col.create_index(
self._vector_field,
index_params=self.index_params,
using=self.alias,
)
# If default did not work, most likely on Zilliz Cloud
except MilvusException:
# Use AUTOINDEX based index
self.index_params = {
"metric_type": "L2",
"index_type": "AUTOINDEX",
"params": {},
}
self.col.create_index(
self._vector_field,
index_params=self.index_params,
using=self.alias,
)
logger.debug(
"Successfully created an index on collection: %s",
self.collection_name,
)
except MilvusException as e:
logger.error(
"Failed to create an index on collection: %s", self.collection_name
)
raise e
def _create_search_params(self) -> None:
"""Generate search params based on the current index type"""
from pymilvus import Collection
if isinstance(self.col, Collection) and self.search_params is None:
index = self._get_index()
if index is not None:
index_type: str = index["index_param"]["index_type"]
metric_type: str = index["index_param"]["metric_type"]
self.search_params = self.default_search_params[index_type]
self.search_params["metric_type"] = metric_type
def _load(self) -> None:
"""Load the collection if available."""
from pymilvus import Collection
if isinstance(self.col, Collection) and self._get_index() is not None:
self.col.load()
def add_texts(
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
timeout: Optional[int] = None,
batch_size: int = 1000,
**kwargs: Any,
) -> List[str]:
"""Insert text data into Milvus.
Inserting data when the collection has not be made yet will result
in creating a new Collection. The data of the first entity decides
the schema of the new collection, the dim is extracted from the first
embedding and the columns are decided by the first metadata dict.
Metada keys will need to be present for all inserted values. At
the moment there is no None equivalent in Milvus.
Args:
texts (Iterable[str]): The texts to embed, it is assumed
that they all fit in memory.
metadatas (Optional[List[dict]]): Metadata dicts attached to each of
the texts. Defaults to None.
timeout (Optional[int]): Timeout for each batch insert. Defaults
to None.
batch_size (int, optional): Batch size to use for insertion.
Defaults to 1000.
Raises:
MilvusException: Failure to add texts
Returns:
List[str]: The resulting keys for each inserted element.
"""
from pymilvus import Collection, MilvusException
texts = list(texts)
try:
embeddings = self.embedding_func.embed_documents(texts)
except NotImplementedError:
embeddings = [self.embedding_func.embed_query(x) for x in texts]
if len(embeddings) == 0:
logger.debug("Nothing to insert, skipping.")
return []
# If the collection hasn't been initialized yet, perform all steps to do so
if not isinstance(self.col, Collection):
self._init(embeddings, metadatas)
# Dict to hold all insert columns
insert_dict: dict[str, list] = {
self._text_field: texts,
self._vector_field: embeddings,
}
# Collect the metadata into the insert dict.
# if metadatas is not None:
# for d in metadatas:
# for key, value in d.items():
# if key in self.fields:
# insert_dict.setdefault(key, []).append(value)
if metadatas is not None:
for d in metadatas:
insert_dict.setdefault(self._metadata_field, []).append(d)
# Total insert count
vectors: list = insert_dict[self._vector_field]
total_count = len(vectors)
pks: list[str] = []
assert isinstance(self.col, Collection)
for i in range(0, total_count, batch_size):
# Grab end index
end = min(i + batch_size, total_count)
# Convert dict to list of lists batch for insertion
insert_list = [insert_dict[x][i:end] for x in self.fields]
# Insert into the collection.
try:
res: Collection
res = self.col.insert(insert_list, timeout=timeout, **kwargs)
pks.extend(res.primary_keys)
except MilvusException as e:
logger.error(
"Failed to insert batch starting at entity: %s/%s", i, total_count
)
raise e
return pks
def similarity_search(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a similarity search against the query string.
Args:
query (str): The text to search.
k (int, optional): How many results to return. Defaults to 4.
param (dict, optional): The search params for the index type.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
res = self.similarity_search_with_score(
query=query, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
return [doc for doc, _ in res]
def similarity_search_by_vector(
self,
embedding: List[float],
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a similarity search against the query string.
Args:
embedding (List[float]): The embedding vector to search.
k (int, optional): How many results to return. Defaults to 4.
param (dict, optional): The search params for the index type.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
res = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
return [doc for doc, _ in res]
def similarity_search_with_score(
self,
query: str,
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Perform a search on a query string and return results with score.
For more information about the search parameters, take a look at the pymilvus
documentation found here:
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
Args:
query (str): The text being searched.
k (int, optional): The amount of results to return. Defaults to 4.
param (dict): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[float], List[Tuple[Document, any, any]]:
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
# Embed the query text.
embedding = self.embedding_func.embed_query(query)
res = self.similarity_search_with_score_by_vector(
embedding=embedding, k=k, param=param, expr=expr, timeout=timeout, **kwargs
)
return res
def _similarity_search_with_relevance_scores(
self,
query: str,
k: int = 4,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Return docs and relevance scores in the range [0, 1].
0 is dissimilar, 1 is most similar.
Args:
query: input text
k: Number of Documents to return. Defaults to 4.
**kwargs: kwargs to be passed to similarity search. Should include:
score_threshold: Optional, a floating point value between 0 to 1 to
filter the resulting set of retrieved docs
Returns:
List of Tuples of (doc, similarity_score)
"""
return self.similarity_search_with_score(query, k, **kwargs)
def similarity_search_with_score_by_vector(
self,
embedding: List[float],
k: int = 4,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Tuple[Document, float]]:
"""Perform a search on a query string and return results with score.
For more information about the search parameters, take a look at the pymilvus
documentation found here:
https://milvus.io/api-reference/pymilvus/v2.2.6/Collection/search().md
Args:
embedding (List[float]): The embedding vector being searched.
k (int, optional): The amount of results to return. Defaults to 4.
param (dict): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Tuple[Document, float]]: Result doc and score.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
if param is None:
param = self.search_params
# Determine result metadata fields.
output_fields = self.fields[:]
output_fields.remove(self._vector_field)
# Perform the search.
res = self.col.search(
data=[embedding],
anns_field=self._vector_field,
param=param,
limit=k,
expr=expr,
output_fields=output_fields,
timeout=timeout,
**kwargs,
)
# Organize results.
ret = []
for result in res[0]:
meta = {x: result.entity.get(x) for x in output_fields}
doc = Document(page_content=meta.pop(self._text_field), metadata=meta.get('metadata'))
pair = (doc, result.score)
ret.append(pair)
return ret
def max_marginal_relevance_search(
self,
query: str,
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a search and return results that are reordered by MMR.
Args:
query (str): The text being searched.
k (int, optional): How many results to give. Defaults to 4.
fetch_k (int, optional): Total results to select k from.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5
param (dict, optional): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
embedding = self.embedding_func.embed_query(query)
return self.max_marginal_relevance_search_by_vector(
embedding=embedding,
k=k,
fetch_k=fetch_k,
lambda_mult=lambda_mult,
param=param,
expr=expr,
timeout=timeout,
**kwargs,
)
def max_marginal_relevance_search_by_vector(
self,
embedding: list[float],
k: int = 4,
fetch_k: int = 20,
lambda_mult: float = 0.5,
param: Optional[dict] = None,
expr: Optional[str] = None,
timeout: Optional[int] = None,
**kwargs: Any,
) -> List[Document]:
"""Perform a search and return results that are reordered by MMR.
Args:
embedding (str): The embedding vector being searched.
k (int, optional): How many results to give. Defaults to 4.
fetch_k (int, optional): Total results to select k from.
Defaults to 20.
lambda_mult: Number between 0 and 1 that determines the degree
of diversity among the results with 0 corresponding
to maximum diversity and 1 to minimum diversity.
Defaults to 0.5
param (dict, optional): The search params for the specified index.
Defaults to None.
expr (str, optional): Filtering expression. Defaults to None.
timeout (int, optional): How long to wait before timeout error.
Defaults to None.
kwargs: Collection.search() keyword arguments.
Returns:
List[Document]: Document results for search.
"""
if self.col is None:
logger.debug("No existing collection to search.")
return []
if param is None:
param = self.search_params
# Determine result metadata fields.
output_fields = self.fields[:]
output_fields.remove(self._vector_field)
# Perform the search.
res = self.col.search(
data=[embedding],
anns_field=self._vector_field,
param=param,
limit=fetch_k,
expr=expr,
output_fields=output_fields,
timeout=timeout,
**kwargs,
)
# Organize results.
ids = []
documents = []
scores = []
for result in res[0]:
meta = {x: result.entity.get(x) for x in output_fields}
doc = Document(page_content=meta.pop(self._text_field), metadata=meta)
documents.append(doc)
scores.append(result.score)
ids.append(result.id)
vectors = self.col.query(
expr=f"{self._primary_field} in {ids}",
output_fields=[self._primary_field, self._vector_field],
timeout=timeout,
)
# Reorganize the results from query to match search order.
vectors = {x[self._primary_field]: x[self._vector_field] for x in vectors}
ordered_result_embeddings = [vectors[x] for x in ids]
# Get the new order of results.
new_ordering = maximal_marginal_relevance(
np.array(embedding), ordered_result_embeddings, k=k, lambda_mult=lambda_mult
)
# Reorder the values and return.
ret = []
for x in new_ordering:
# Function can return -1 index
if x == -1:
break
else:
ret.append(documents[x])
return ret
@classmethod
def from_texts(
cls,
texts: List[str],
embedding: Embeddings,
metadatas: Optional[List[dict]] = None,
collection_name: str = "LangChainCollection",
connection_args: dict[str, Any] = DEFAULT_MILVUS_CONNECTION,
consistency_level: str = "Session",
index_params: Optional[dict] = None,
search_params: Optional[dict] = None,
drop_old: bool = False,
batch_size: int = 100,
ids: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> Milvus:
"""Create a Milvus collection, indexes it with HNSW, and insert data.
Args:
texts (List[str]): Text data.
embedding (Embeddings): Embedding function.
metadatas (Optional[List[dict]]): Metadata for each text if it exists.
Defaults to None.
collection_name (str, optional): Collection name to use. Defaults to
"LangChainCollection".
connection_args (dict[str, Any], optional): Connection args to use. Defaults
to DEFAULT_MILVUS_CONNECTION.
consistency_level (str, optional): Which consistency level to use. Defaults
to "Session".
index_params (Optional[dict], optional): Which index_params to use. Defaults
to None.
search_params (Optional[dict], optional): Which search params to use.
Defaults to None.
drop_old (Optional[bool], optional): Whether to drop the collection with
that name if it exists. Defaults to False.
batch_size:
How many vectors upload per-request.
Default: 100
ids: Optional[Sequence[str]] = None,
Returns:
Milvus: Milvus Vector Store
"""
vector_db = cls(
embedding_function=embedding,
collection_name=collection_name,
connection_args=connection_args,
consistency_level=consistency_level,
index_params=index_params,
search_params=search_params,
drop_old=drop_old,
**kwargs,
)
vector_db.add_texts(texts=texts, metadatas=metadatas, batch_size=batch_size)
return vector_db

View File

@@ -9,30 +9,44 @@ from core.index.base import BaseIndex
from core.index.vector_index.base import BaseVectorIndex
from core.vector_store.milvus_vector_store import MilvusVectorStore
from core.vector_store.weaviate_vector_store import WeaviateVectorStore
from models.dataset import Dataset
from extensions.ext_database import db
from models.dataset import Dataset, DatasetCollectionBinding
class MilvusConfig(BaseModel):
endpoint: str
host: str
port: int
user: str
password: str
secure: bool = False
batch_size: int = 100
@root_validator()
def validate_config(cls, values: dict) -> dict:
if not values['endpoint']:
raise ValueError("config MILVUS_ENDPOINT is required")
if not values['host']:
raise ValueError("config MILVUS_HOST is required")
if not values['port']:
raise ValueError("config MILVUS_PORT is required")
if not values['user']:
raise ValueError("config MILVUS_USER is required")
if not values['password']:
raise ValueError("config MILVUS_PASSWORD is required")
return values
def to_milvus_params(self):
return {
'host': self.host,
'port': self.port,
'user': self.user,
'password': self.password,
'secure': self.secure
}
class MilvusVectorIndex(BaseVectorIndex):
def __init__(self, dataset: Dataset, config: MilvusConfig, embeddings: Embeddings):
super().__init__(dataset, embeddings)
self._client = self._init_client(config)
self._client_config = config
def get_type(self) -> str:
return 'milvus'
@@ -49,7 +63,6 @@ class MilvusVectorIndex(BaseVectorIndex):
dataset_id = dataset.id
return "Vector_index_" + dataset_id.replace("-", "_") + '_Node'
def to_index_struct(self) -> dict:
return {
"type": self.get_type(),
@@ -58,26 +71,29 @@ class MilvusVectorIndex(BaseVectorIndex):
def create(self, texts: list[Document], **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
index_params = {
'metric_type': 'IP',
'index_type': "HNSW",
'params': {"M": 8, "efConstruction": 64}
}
self._vector_store = MilvusVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=self.get_index_name(self.dataset),
uuids=uuids,
by_text=False
collection_name=self.get_index_name(self.dataset),
connection_args=self._client_config.to_milvus_params(),
index_params=index_params
)
return self
def create_with_collection_name(self, texts: list[Document], collection_name: str, **kwargs) -> BaseIndex:
uuids = self._get_uuids(texts)
self._vector_store = WeaviateVectorStore.from_documents(
self._vector_store = MilvusVectorStore.from_documents(
texts,
self._embeddings,
client=self._client,
index_name=collection_name,
uuids=uuids,
by_text=False
collection_name=collection_name,
ids=uuids,
content_payload_key='page_content'
)
return self
@@ -86,42 +102,53 @@ class MilvusVectorIndex(BaseVectorIndex):
"""Only for created index."""
if self._vector_store:
return self._vector_store
attributes = ['doc_id', 'dataset_id', 'document_id']
if self._is_origin():
attributes = ['doc_id']
return WeaviateVectorStore(
client=self._client,
index_name=self.get_index_name(self.dataset),
text_key='text',
embedding=self._embeddings,
attributes=attributes,
by_text=False
return MilvusVectorStore(
collection_name=self.get_index_name(self.dataset),
embedding_function=self._embeddings,
connection_args=self._client_config.to_milvus_params()
)
def _get_vector_store_class(self) -> type:
return MilvusVectorStore
def delete_by_document_id(self, document_id: str):
if self._is_origin():
self.recreate_dataset(self.dataset)
return
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_document_id(document_id)
if ids:
vector_store.del_texts({
'filter': f'id in {ids}'
})
def delete_by_ids(self, doc_ids: list[str]) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
ids = vector_store.get_ids_by_doc_ids(doc_ids)
vector_store.del_texts({
'filter': f' id in {ids}'
})
def delete_by_group_id(self, group_id: str) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
vector_store.del_texts({
"operator": "Equal",
"path": ["document_id"],
"valueText": document_id
})
vector_store.delete()
def _is_origin(self):
if self.dataset.index_struct_dict:
class_prefix: str = self.dataset.index_struct_dict['vector_store']['class_prefix']
if not class_prefix.endswith('_Node'):
# original class_prefix
return True
def delete(self) -> None:
vector_store = self._get_vector_store()
vector_store = cast(self._get_vector_store_class(), vector_store)
return False
from qdrant_client.http import models
vector_store.del_texts(models.Filter(
must=[
models.FieldCondition(
key="group_id",
match=models.MatchValue(value=self.dataset.id),
),
],
))

View File

@@ -47,6 +47,20 @@ class VectorIndex:
),
embeddings=embeddings
)
elif vector_type == "milvus":
from core.index.vector_index.milvus_vector_index import MilvusVectorIndex, MilvusConfig
return MilvusVectorIndex(
dataset=dataset,
config=MilvusConfig(
host=config.get('MILVUS_HOST'),
port=config.get('MILVUS_PORT'),
user=config.get('MILVUS_USER'),
password=config.get('MILVUS_PASSWORD'),
secure=config.get('MILVUS_SECURE'),
),
embeddings=embeddings
)
else:
raise ValueError(f"Vector store {config.get('VECTOR_STORE')} is not supported.")

View File

@@ -11,6 +11,7 @@ from flask import current_app, Flask
from flask_login import current_user
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from sqlalchemy.orm.exc import ObjectDeletedError
from core.data_loader.file_extractor import FileExtractor
from core.data_loader.loader.notion import NotionLoader
@@ -79,6 +80,8 @@ class IndexingRunner:
dataset_document.error = str(e.description)
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
except ObjectDeletedError:
logging.warning('Document deleted, document id: {}'.format(dataset_document.id))
except Exception as e:
logging.exception("consume document failed")
dataset_document.indexing_status = 'error'
@@ -276,13 +279,14 @@ class IndexingRunner:
)
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
doc_language)
document_qa_list = self.format_split_text(response)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
@@ -372,13 +376,14 @@ class IndexingRunner:
)
if len(preview_texts) > 0:
# qa model document
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0], doc_language)
response = LLMGenerator.generate_qa_document(current_user.current_tenant_id, preview_texts[0],
doc_language)
document_qa_list = self.format_split_text(response)
return {
"total_segments": total_segments * 20,
"tokens": total_segments * 2000,
"total_price": '{:f}'.format(
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.HUMAN)),
text_generation_model.calc_tokens_price(total_segments * 2000, MessageType.USER)),
"currency": embedding_model.get_currency(),
"qa_preview": document_qa_list,
"preview": preview_texts
@@ -582,7 +587,6 @@ class IndexingRunner:
all_qa_documents.extend(format_documents)
def _split_to_documents_for_estimate(self, text_docs: List[Document], splitter: TextSplitter,
processing_rule: DatasetProcessRule) -> List[Document]:
"""
@@ -734,6 +738,9 @@ class IndexingRunner:
count = DatasetDocument.query.filter_by(id=document_id, is_paused=True).count()
if count > 0:
raise DocumentIsPausedException()
document = DatasetDocument.query.filter_by(id=document_id).first()
if not document:
raise DocumentIsDeletedPausedException()
update_params = {
DatasetDocument.indexing_status: after_indexing_status
@@ -781,3 +788,7 @@ class IndexingRunner:
class DocumentIsPausedException(Exception):
pass
class DocumentIsDeletedPausedException(Exception):
pass

View File

@@ -31,7 +31,7 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
chat_messages: List[PromptMessage] = []
for message in messages:
chat_messages.append(PromptMessage(content=message.query, type=MessageType.HUMAN))
chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
if not chat_messages:

View File

@@ -51,6 +51,9 @@ class ModelProviderFactory:
elif provider_name == 'chatglm':
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
return ChatGLMProvider
elif provider_name == 'baichuan':
from core.model_providers.providers.baichuan_provider import BaichuanProvider
return BaichuanProvider
elif provider_name == 'azure_openai':
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
return AzureOpenAIProvider

View File

@@ -0,0 +1,22 @@
from core.third_party.langchain.embeddings.openllm_embedding import OpenLLMEmbeddings
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
class OpenLLMEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = OpenLLMEmbeddings(
server_url=credentials['server_url']
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"OpenLLM embedding: {str(ex)}")

View File

@@ -1,9 +1,7 @@
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
from replicate.exceptions import ModelError, ReplicateError
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings
class XinferenceEmbedding(BaseEmbedding):
@@ -21,7 +19,4 @@ class XinferenceEmbedding(BaseEmbedding):
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"Xinference embedding: {str(ex)}")
else:
return ex
return LLMBadRequestError(f"Xinference embedding: {str(ex)}")

View File

@@ -1,6 +1,6 @@
import enum
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
from pydantic import BaseModel
@@ -9,26 +9,31 @@ class LLMRunResult(BaseModel):
prompt_tokens: int
completion_tokens: int
source: list = None
function_call: dict = None
class MessageType(enum.Enum):
HUMAN = 'human'
USER = 'user'
ASSISTANT = 'assistant'
SYSTEM = 'system'
class PromptMessage(BaseModel):
type: MessageType = MessageType.HUMAN
type: MessageType = MessageType.USER
content: str = ''
function_call: dict = None
def to_lc_messages(messages: list[PromptMessage]):
lc_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
if message.type == MessageType.USER:
lc_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
lc_messages.append(AIMessage(content=message.content))
additional_kwargs = {}
if message.function_call:
additional_kwargs['function_call'] = message.function_call
lc_messages.append(AIMessage(content=message.content, additional_kwargs=additional_kwargs))
elif message.type == MessageType.SYSTEM:
lc_messages.append(SystemMessage(content=message.content))
@@ -39,11 +44,21 @@ def to_prompt_messages(messages: list[BaseMessage]):
prompt_messages = []
for message in messages:
if isinstance(message, HumanMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.HUMAN))
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
elif isinstance(message, AIMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.ASSISTANT))
message_kwargs = {
'content': message.content,
'type': MessageType.ASSISTANT
}
if 'function_call' in message.additional_kwargs:
message_kwargs['function_call'] = message.additional_kwargs['function_call']
prompt_messages.append(PromptMessage(**message_kwargs))
elif isinstance(message, SystemMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.SYSTEM))
elif isinstance(message, FunctionMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
return prompt_messages

View File

@@ -81,7 +81,20 @@ class AzureOpenAIModel(BaseLLM):
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
generate_kwargs = {
'stop': stop,
'callbacks': callbacks
}
if isinstance(prompts, str):
generate_kwargs['prompts'] = [prompts]
else:
generate_kwargs['messages'] = [prompts]
if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']
return self._client.generate(**generate_kwargs)
@property
def base_model_name(self) -> str:

View File

@@ -0,0 +1,67 @@
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
class BaichuanModel(BaseLLM):
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return BaichuanChatLLM(
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens_from_messages(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Baichuan: {str(ex)}")
@property
def support_streaming(self):
return True

View File

@@ -13,11 +13,12 @@ from langchain.schema import LLMResult, SystemMessage, AIMessage, HumanMessage,
from core.callback_handler.std_out_callback_handler import DifyStreamingStdOutCallbackHandler, DifyStdOutCallbackHandler
from core.helper import moderation
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages
from core.model_providers.models.entity.message import PromptMessage, MessageType, LLMRunResult, to_prompt_messages, \
to_lc_messages
from core.model_providers.models.entity.model_params import ModelType, ModelKwargs, ModelMode, ModelKwargsRules
from core.model_providers.providers.base import BaseModelProvider
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompt_template import PromptTemplateParser
from core.third_party.langchain.llms.fake import FakeLLM
import logging
@@ -132,8 +133,6 @@ class BaseLLM(BaseProviderModel):
if self.deduct_quota:
self.model_provider.check_quota_over_limit()
db.session.commit()
if not callbacks:
callbacks = self.callbacks
else:
@@ -159,8 +158,11 @@ class BaseLLM(BaseProviderModel):
except Exception as ex:
raise self.handle_exceptions(ex)
function_call = None
if isinstance(result.generations[0][0], ChatGeneration):
completion_content = result.generations[0][0].message.content
if 'function_call' in result.generations[0][0].message.additional_kwargs:
function_call = result.generations[0][0].message.additional_kwargs.get('function_call')
else:
completion_content = result.generations[0][0].text
@@ -193,7 +195,8 @@ class BaseLLM(BaseProviderModel):
return LLMRunResult(
content=completion_content,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens
completion_tokens=completion_tokens,
function_call=function_call
)
@abstractmethod
@@ -229,7 +232,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return:
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
@@ -247,7 +250,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return: decimal.Decimal('0.0001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
unit_price = self.price_config['prompt']
else:
unit_price = self.price_config['completion']
@@ -262,7 +265,7 @@ class BaseLLM(BaseProviderModel):
:param message_type:
:return: decimal.Decimal('0.000001')
"""
if message_type == MessageType.HUMAN or message_type == MessageType.SYSTEM:
if message_type == MessageType.USER or message_type == MessageType.SYSTEM:
price_unit = self.price_config['unit']
else:
price_unit = self.price_config['unit']
@@ -327,6 +330,85 @@ class BaseLLM(BaseProviderModel):
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory)
return [PromptMessage(content=prompt)], stops
def get_advanced_prompt(self, app_mode: str,
app_model_config: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory]) -> List[PromptMessage]:
model_mode = app_model_config.model_dict['mode']
conversation_histories_role = {}
raw_prompt_list = []
prompt_messages = []
if app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
raw_prompt_list = [{
'role': MessageType.USER.value,
'text': prompt_text
}]
conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
elif app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
elif app_mode == 'completion' and model_mode == ModelMode.CHAT.value:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
elif app_mode == 'completion' and model_mode == ModelMode.COMPLETION.value:
prompt_text = app_model_config.completion_prompt_config_dict['prompt']['text']
raw_prompt_list = [{
'role': MessageType.USER.value,
'text': prompt_text
}]
else:
raise Exception("app_mode or model_mode not support")
for prompt_item in raw_prompt_list:
prompt = prompt_item['text']
# set prompt template variables
prompt_template = PromptTemplateParser(template=prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
if '#context#' in prompt:
if context:
prompt_inputs['#context#'] = context
else:
prompt_inputs['#context#'] = ''
if '#query#' in prompt:
if query:
prompt_inputs['#query#'] = query
else:
prompt_inputs['#query#'] = ''
if '#histories#' in prompt:
if memory and app_mode == 'chat' and model_mode == ModelMode.COMPLETION.value:
memory.human_prefix = conversation_histories_role['user_prefix']
memory.ai_prefix = conversation_histories_role['assistant_prefix']
histories = self._get_history_messages_from_memory(memory, 2000)
prompt_inputs['#histories#'] = histories
else:
prompt_inputs['#histories#'] = ''
prompt = prompt_template.format(
prompt_inputs
)
prompt = re.sub(r'<\|.*?\|>', '', prompt)
prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
if memory and app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
memory.human_prefix = MessageType.USER.value
memory.ai_prefix = MessageType.ASSISTANT.value
histories = self._get_history_messages_list_from_memory(memory, 2000)
prompt_messages.extend(histories)
if app_mode == 'chat' and model_mode == ModelMode.CHAT.value:
prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
return prompt_messages
def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'common_completion'
@@ -339,17 +421,17 @@ class BaseLLM(BaseProviderModel):
memory: Optional[BaseChatMemory]) -> Tuple[str, Optional[list]]:
context_prompt_content = ''
if context and 'context_prompt' in prompt_rules:
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['context_prompt'])
prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
context_prompt_content = prompt_template.format(
context=context
{'context': context}
)
pre_prompt_content = ''
if pre_prompt:
prompt_template = JinjaPromptTemplate.from_template(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.input_variables if k in inputs}
prompt_template = PromptTemplateParser(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
pre_prompt_content = prompt_template.format(
**prompt_inputs
prompt_inputs
)
prompt = ''
@@ -382,10 +464,8 @@ class BaseLLM(BaseProviderModel):
memory.ai_prefix = prompt_rules['assistant_prefix'] if 'assistant_prefix' in prompt_rules else 'Assistant'
histories = self._get_history_messages_from_memory(memory, rest_tokens)
prompt_template = JinjaPromptTemplate.from_template(template=prompt_rules['histories_prompt'])
histories_prompt_content = prompt_template.format(
histories=histories
)
prompt_template = PromptTemplateParser(template=prompt_rules['histories_prompt'])
histories_prompt_content = prompt_template.format({'histories': histories})
prompt = ''
for order in prompt_rules['system_prompt_orders']:
@@ -396,10 +476,8 @@ class BaseLLM(BaseProviderModel):
elif order == 'histories_prompt':
prompt += histories_prompt_content
prompt_template = JinjaPromptTemplate.from_template(template=query_prompt)
query_prompt_content = prompt_template.format(
query=query
)
prompt_template = PromptTemplateParser(template=query_prompt)
query_prompt_content = prompt_template.format({'query': query})
prompt += query_prompt_content
@@ -430,6 +508,16 @@ class BaseLLM(BaseProviderModel):
external_context = memory.load_memory_variables({})
return external_context[memory_key]
def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
max_token_limit: int) -> List[PromptMessage]:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory.return_messages = True
memory_key = memory.memory_variables[0]
external_context = memory.load_memory_variables({})
memory.return_messages = False
return to_prompt_messages(external_context[memory_key])
def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str | List[BaseMessage]]:
if not model_mode:
@@ -444,16 +532,7 @@ class BaseLLM(BaseProviderModel):
if len(messages) == 0:
return []
chat_messages = []
for message in messages:
if message.type == MessageType.HUMAN:
chat_messages.append(HumanMessage(content=message.content))
elif message.type == MessageType.ASSISTANT:
chat_messages.append(AIMessage(content=message.content))
elif message.type == MessageType.SYSTEM:
chat_messages.append(SystemMessage(content=message.content))
return chat_messages
return to_lc_messages(messages)
def _to_model_kwargs_input(self, model_rules: ModelKwargsRules, model_kwargs: ModelKwargs) -> dict:
"""

View File

@@ -1,26 +1,23 @@
import decimal
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.llms import Minimax
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
class MinimaxModel(BaseLLM):
model_mode: ModelMode = ModelMode.COMPLETION
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return Minimax(
return MinimaxChatLLM(
model=self.name,
model_kwargs={
'stream': False
},
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
@@ -49,7 +46,7 @@ class MinimaxModel(BaseLLM):
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens(prompts), 0)
return max(self._client.get_num_tokens_from_messages(prompts), 0)
def get_currency(self):
return 'RMB'
@@ -65,3 +62,7 @@ class MinimaxModel(BaseLLM):
return LLMBadRequestError(f"Minimax: {str(ex)}")
else:
return ex
@property
def support_streaming(self):
return True

View File

@@ -5,6 +5,7 @@ from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from openai import api_requestor
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
@@ -32,7 +33,7 @@ MODEL_MAX_TOKENS = {
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-instruct': 8192,
'gpt-3.5-turbo-instruct': 4097,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
}
@@ -105,7 +106,21 @@ class OpenAIModel(BaseLLM):
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
generate_kwargs = {
'stop': stop,
'callbacks': callbacks
}
if isinstance(prompts, str):
generate_kwargs['prompts'] = [prompts]
else:
generate_kwargs['messages'] = [prompts]
if 'functions' in kwargs:
generate_kwargs['functions'] = kwargs['functions']
return self._client.generate(**generate_kwargs)
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""

View File

@@ -18,7 +18,6 @@ class TongyiModel(BaseLLM):
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
del provider_model_kwargs['max_tokens']
return EnhanceTongyi(
model_name=self.name,
max_retries=1,
@@ -58,7 +57,6 @@ class TongyiModel(BaseLLM):
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
del provider_model_kwargs['max_tokens']
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)

View File

@@ -9,7 +9,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelMode
from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.anthropic_model import AnthropicModel
from core.model_providers.models.llm.base import ModelType
@@ -34,10 +34,12 @@ class AnthropicProvider(BaseModelProvider):
{
'id': 'claude-instant-1',
'name': 'claude-instant-1',
'mode': ModelMode.CHAT.value,
},
{
'id': 'claude-2',
'name': 'claude-2',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -46,6 +48,9 @@ class AnthropicProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -12,7 +12,7 @@ from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.azure_openai_embedding import AzureOpenAIEmbedding, \
AZURE_OPENAI_API_VERSION
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule
from core.model_providers.models.entity.model_params import ModelType, ModelKwargsRules, KwargRule, ModelMode
from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.llm.azure_openai_model import AzureOpenAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -61,6 +61,10 @@ class AzureOpenAIProvider(BaseModelProvider):
}
credentials = json.loads(provider_model.encrypted_config)
if provider_model.model_type == ModelType.TEXT_GENERATION.value:
model_dict['mode'] = self._get_text_generation_model_mode(credentials['base_model_name'])
if credentials['base_model_name'] in [
'gpt-4',
'gpt-4-32k',
@@ -77,12 +81,19 @@ class AzureOpenAIProvider(BaseModelProvider):
return model_list
def _get_text_generation_model_mode(self, model_name) -> str:
if model_name == 'text-davinci-003':
return ModelMode.COMPLETION.value
else:
return ModelMode.CHAT.value
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
models = [
{
'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -90,6 +101,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -97,6 +109,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4',
'name': 'gpt-4',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -104,6 +117,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4-32k',
'name': 'gpt-4-32k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -111,6 +125,7 @@ class AzureOpenAIProvider(BaseModelProvider):
{
'id': 'text-davinci-003',
'name': 'text-davinci-003',
'mode': ModelMode.COMPLETION.value,
}
]

View File

@@ -0,0 +1,171 @@
import json
from json import JSONDecodeError
from typing import Type
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.baichuan_model import BaichuanModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
from models.provider import ProviderType
class BaichuanProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'baichuan'
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'baichuan2-53b',
'name': 'Baichuan2-53B',
'mode': ModelMode.CHAT.value,
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = BaichuanModel
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=0.3, precision=2),
top_p=KwargRule[float](min=0, max=0.99, default=0.85, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](enabled=False),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'api_key' not in credentials:
raise CredentialsValidateFailedError('Baichuan api_key must be provided.')
if 'secret_key' not in credentials:
raise CredentialsValidateFailedError('Baichuan secret_key must be provided.')
try:
credential_kwargs = {
'api_key': credentials['api_key'],
'secret_key': credentials['secret_key'],
}
llm = BaichuanChatLLM(
temperature=0,
**credential_kwargs
)
llm([HumanMessage(content='ping')])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
credentials['secret_key'] = encrypter.encrypt_token(tenant_id, credentials['secret_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'api_key': None,
'secret_key': None,
}
if credentials['api_key']:
credentials['api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_key']
)
if obfuscated:
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
if credentials['secret_key']:
credentials['secret_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['secret_key']
)
if obfuscated:
credentials['secret_key'] = encrypter.obfuscated_token(credentials['secret_key'])
return credentials
else:
return {}
def should_deduct_quota(self):
return True
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@@ -61,10 +61,19 @@ class BaseModelProvider(BaseModel, ABC):
ProviderModel.is_valid == True
).order_by(ProviderModel.created_at.asc()).all()
return [{
'id': provider_model.model_name,
'name': provider_model.model_name
} for provider_model in provider_models]
provider_model_list = []
for provider_model in provider_models:
provider_model_dict = {
'id': provider_model.model_name,
'name': provider_model.model_name
}
if model_type == ModelType.TEXT_GENERATION:
provider_model_dict['mode'] = self._get_text_generation_model_mode(provider_model.model_name)
provider_model_list.append(provider_model_dict)
return provider_model_list
@abstractmethod
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
@@ -76,6 +85,16 @@ class BaseModelProvider(BaseModel, ABC):
"""
raise NotImplementedError
@abstractmethod
def _get_text_generation_model_mode(self, model_name) -> str:
"""
get text generation model mode.
:param model_name:
:return:
"""
raise NotImplementedError
@abstractmethod
def get_model_class(self, model_type: ModelType) -> Type:
"""

View File

@@ -6,7 +6,7 @@ from langchain.llms import ChatGLM
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.chatglm_model import ChatGLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from models.provider import ProviderType
@@ -27,15 +27,20 @@ class ChatGLMProvider(BaseModelProvider):
{
'id': 'chatglm2-6b',
'name': 'ChatGLM2-6B',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'chatglm-6b',
'name': 'ChatGLM-6B',
'mode': ModelMode.COMPLETION.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -5,7 +5,7 @@ import requests
from huggingface_hub import HfApi
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -29,6 +29,9 @@ class HuggingfaceHubProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -6,7 +6,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.embedding.localai_embedding import LocalAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule
from core.model_providers.models.entity.model_params import ModelKwargsRules, ModelType, KwargRule, ModelMode
from core.model_providers.models.llm.localai_model import LocalAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -27,6 +27,13 @@ class LocalAIProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
credentials = self.get_model_credentials(model_name, ModelType.TEXT_GENERATION)
if credentials['completion_type'] == 'chat_completion':
return ModelMode.CHAT.value
else:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -2,14 +2,15 @@ import json
from json import JSONDecodeError
from typing import Type
from langchain.llms import Minimax
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.minimax_embedding import MinimaxEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.minimax_model import MinimaxModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.minimax_llm import MinimaxChatLLM
from models.provider import ProviderType, ProviderQuotaType
@@ -28,10 +29,12 @@ class MinimaxProvider(BaseModelProvider):
{
'id': 'abab5.5-chat',
'name': 'abab5.5-chat',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'abab5-chat',
'name': 'abab5-chat',
'mode': ModelMode.COMPLETION.value,
}
]
elif model_type == ModelType.EMBEDDINGS:
@@ -44,6 +47,9 @@ class MinimaxProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
@@ -98,14 +104,14 @@ class MinimaxProvider(BaseModelProvider):
'minimax_api_key': credentials['minimax_api_key'],
}
llm = Minimax(
llm = MinimaxChatLLM(
model='abab5.5-chat',
max_tokens=10,
temperature=0.01,
**credential_kwargs
)
llm("ping")
llm([HumanMessage(content='ping')])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

View File

@@ -13,8 +13,8 @@ from core.model_providers.models.entity.provider import ModelFeature
from core.model_providers.models.speech2text.openai_whisper import OpenAIWhisper
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.openai_embedding import OpenAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.openai_model import OpenAIModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.openai_model import OpenAIModel, COMPLETION_MODELS
from core.model_providers.models.moderation.openai_moderation import OpenAIModeration
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.providers.hosted import hosted_model_providers
@@ -36,6 +36,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-3.5-turbo',
'name': 'gpt-3.5-turbo',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -43,10 +44,12 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-3.5-turbo-instruct',
'name': 'GPT-3.5-Turbo-Instruct',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'gpt-3.5-turbo-16k',
'name': 'gpt-3.5-turbo-16k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -54,6 +57,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4',
'name': 'gpt-4',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -61,6 +65,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'gpt-4-32k',
'name': 'gpt-4-32k',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
@@ -68,6 +73,7 @@ class OpenAIProvider(BaseModelProvider):
{
'id': 'text-davinci-003',
'name': 'text-davinci-003',
'mode': ModelMode.COMPLETION.value,
}
]
@@ -100,6 +106,12 @@ class OpenAIProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
if model_name in COMPLETION_MODELS:
return ModelMode.COMPLETION.value
else:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
@@ -132,7 +144,7 @@ class OpenAIProvider(BaseModelProvider):
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-instruct': 8192,
'gpt-3.5-turbo-instruct': 4097,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
}

View File

@@ -2,11 +2,13 @@ import json
from typing import Type
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.embeddings.openllm_embedding import OpenLLMEmbeddings
from core.third_party.langchain.llms.openllm import OpenLLM
from models.provider import ProviderType
@@ -22,6 +24,9 @@ class OpenLLMProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
@@ -31,6 +36,8 @@ class OpenLLMProvider(BaseModelProvider):
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = OpenLLMModel
elif model_type== ModelType.EMBEDDINGS:
model_class = OpenLLMEmbedding
else:
raise NotImplementedError
@@ -69,14 +76,21 @@ class OpenLLMProvider(BaseModelProvider):
'server_url': credentials['server_url']
}
llm = OpenLLM(
llm_kwargs={
'max_new_tokens': 10
},
**credential_kwargs
)
if model_type == ModelType.TEXT_GENERATION:
llm = OpenLLM(
llm_kwargs={
'max_new_tokens': 10
},
**credential_kwargs
)
llm("ping")
llm("ping")
elif model_type == ModelType.EMBEDDINGS:
embedding = OpenLLMEmbeddings(
**credential_kwargs
)
embedding.embed_query("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

View File

@@ -6,7 +6,8 @@ import replicate
from replicate.exceptions import ReplicateError
from core.helper import encrypter
from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType
from core.model_providers.models.entity.model_params import KwargRule, KwargRuleType, ModelKwargsRules, ModelType, \
ModelMode
from core.model_providers.models.llm.replicate_model import ReplicateModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
@@ -26,6 +27,9 @@ class ReplicateProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value if model_name.endswith('-chat') else ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.spark_model import SparkModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.spark import ChatSpark
@@ -30,15 +30,20 @@ class SparkProvider(BaseModelProvider):
{
'id': 'spark',
'name': 'Spark V1.5',
'mode': ModelMode.CHAT.value,
},
{
'id': 'spark-v2',
'name': 'Spark V2.0',
'mode': ModelMode.CHAT.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -4,7 +4,7 @@ from typing import Type
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.tongyi_model import TongyiModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.tongyi_llm import EnhanceTongyi
@@ -24,17 +24,22 @@ class TongyiProvider(BaseModelProvider):
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'qwen-v1',
'name': 'qwen-v1',
'id': 'qwen-turbo',
'name': 'qwen-turbo',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'qwen-plus-v1',
'name': 'qwen-plus-v1',
'id': 'qwen-plus',
'name': 'qwen-plus',
'mode': ModelMode.COMPLETION.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
@@ -58,16 +63,16 @@ class TongyiProvider(BaseModelProvider):
:return:
"""
model_max_tokens = {
'qwen-v1': 1500,
'qwen-plus-v1': 6500
'qwen-turbo': 6000,
'qwen-plus': 6000
}
return ModelKwargsRules(
temperature=KwargRule[float](enabled=False),
top_p=KwargRule[float](min=0, max=1, default=0.8, precision=2),
temperature=KwargRule[float](min=0.01, max=1, default=1, precision=2),
top_p=KwargRule[float](min=0.01, max=0.99, default=0.5, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](min=10, max=model_max_tokens.get(model_name), default=1024, precision=0),
max_tokens=KwargRule[int](enabled=False, max=model_max_tokens.get(model_name)),
)
@classmethod
@@ -84,7 +89,7 @@ class TongyiProvider(BaseModelProvider):
}
llm = EnhanceTongyi(
model_name='qwen-v1',
model_name='qwen-turbo',
max_retries=1,
**credential_kwargs
)

View File

@@ -4,7 +4,7 @@ from typing import Type
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.wenxin_model import WenxinModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.wenxin import Wenxin
@@ -26,19 +26,25 @@ class WenxinProvider(BaseModelProvider):
{
'id': 'ernie-bot',
'name': 'ERNIE-Bot',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'ernie-bot-turbo',
'name': 'ERNIE-Bot-turbo',
'mode': ModelMode.COMPLETION.value,
},
{
'id': 'bloomz-7b',
'name': 'BLOOMZ-7B',
'mode': ModelMode.COMPLETION.value,
}
]
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -2,15 +2,15 @@ import json
from typing import Type
import requests
from langchain.embeddings import XinferenceEmbeddings
from core.helper import encrypter
from core.model_providers.models.embedding.xinference_embedding import XinferenceEmbedding
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType, ModelMode
from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbeddings
from core.third_party.langchain.llms.xinference_llm import XinferenceLLM
from models.provider import ProviderType
@@ -26,6 +26,9 @@ class XinferenceProvider(BaseModelProvider):
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.COMPLETION.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -7,7 +7,7 @@ from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.embedding.zhipuai_embedding import ZhipuAIEmbedding
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType, ModelMode
from core.model_providers.models.llm.zhipuai_model import ZhipuAIModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.zhipuai_llm import ZhipuAIChatLLM
@@ -29,18 +29,22 @@ class ZhipuAIProvider(BaseModelProvider):
{
'id': 'chatglm_pro',
'name': 'chatglm_pro',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_std',
'name': 'chatglm_std',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_lite',
'name': 'chatglm_lite',
'mode': ModelMode.CHAT.value,
},
{
'id': 'chatglm_lite_32k',
'name': 'chatglm_lite_32k',
'mode': ModelMode.CHAT.value,
}
]
elif model_type == ModelType.EMBEDDINGS:
@@ -53,6 +57,9 @@ class ZhipuAIProvider(BaseModelProvider):
else:
return []
def _get_text_generation_model_mode(self, model_name) -> str:
return ModelMode.CHAT.value
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.

View File

@@ -7,10 +7,11 @@
"spark",
"wenxin",
"zhipuai",
"baichuan",
"chatglm",
"replicate",
"huggingface_hub",
"xinference",
"openllm",
"localai"
]
]

View File

@@ -0,0 +1,15 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "fixed",
"price_config": {
"baichuan2-53b": {
"prompt": "0.01",
"completion": "0.01",
"unit": "0.001",
"currency": "RMB"
}
}
}

View File

@@ -3,5 +3,19 @@
"custom"
],
"system_config": null,
"model_flexibility": "fixed"
"model_flexibility": "fixed",
"price_config": {
"qwen-turbo": {
"prompt": "0.012",
"completion": "0.012",
"unit": "0.001",
"currency": "RMB"
},
"qwen-plus": {
"prompt": "0.14",
"completion": "0.14",
"unit": "0.001",
"currency": "RMB"
}
}
}

View File

@@ -1,7 +1,5 @@
import math
from typing import Optional
from flask import current_app
from langchain import WikipediaAPIWrapper
from langchain.callbacks.manager import Callbacks
from langchain.memory.chat_memory import BaseChatMemory
@@ -27,7 +25,6 @@ from core.tool.web_reader_tool import WebReaderTool
from extensions.ext_database import db
from models.dataset import Dataset, DatasetProcessRule
from models.model import AppModelConfig
from models.provider import ProviderType
class OrchestratorRuleParser:
@@ -39,18 +36,20 @@ class OrchestratorRuleParser:
def to_agent_executor(self, conversation_message_task: ConversationMessageTask, memory: Optional[BaseChatMemory],
rest_tokens: int, chain_callback: MainChainGatherCallbackHandler,
return_resource: bool = False, retriever_from: str = 'dev') -> Optional[AgentExecutor]:
retriever_from: str = 'dev') -> Optional[AgentExecutor]:
if not self.app_model_config.agent_mode_dict:
return None
agent_mode_config = self.app_model_config.agent_mode_dict
model_dict = self.app_model_config.model_dict
return_resource = self.app_model_config.retriever_resource_dict.get('enabled', False)
chain = None
if agent_mode_config and agent_mode_config.get('enabled'):
tool_configs = agent_mode_config.get('tools', [])
agent_provider_name = model_dict.get('provider', 'openai')
agent_model_name = model_dict.get('name', 'gpt-4')
dataset_configs = self.app_model_config.dataset_configs_dict
agent_model_instance = ModelFactory.get_text_generation_model(
tenant_id=self.tenant_id,
@@ -77,7 +76,7 @@ class OrchestratorRuleParser:
# only OpenAI chat model (include Azure) support function call, use ReACT instead
if agent_model_instance.model_mode != ModelMode.CHAT \
or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
if planning_strategy in [PlanningStrategy.FUNCTION_CALL, PlanningStrategy.MULTI_FUNCTION_CALL]:
if planning_strategy == PlanningStrategy.FUNCTION_CALL:
planning_strategy = PlanningStrategy.REACT
elif planning_strategy == PlanningStrategy.ROUTER:
planning_strategy = PlanningStrategy.REACT_ROUTER
@@ -97,13 +96,14 @@ class OrchestratorRuleParser:
summary_model_instance = None
tools = self.to_tools(
agent_model_instance=agent_model_instance,
tool_configs=tool_configs,
callbacks=[agent_callback, DifyStdOutCallbackHandler()],
agent_model_instance=agent_model_instance,
conversation_message_task=conversation_message_task,
rest_tokens=rest_tokens,
callbacks=[agent_callback, DifyStdOutCallbackHandler()],
return_resource=return_resource,
retriever_from=retriever_from
retriever_from=retriever_from,
dataset_configs=dataset_configs
)
if len(tools) == 0:
@@ -171,20 +171,12 @@ class OrchestratorRuleParser:
return None
def to_tools(self, agent_model_instance: BaseLLM, tool_configs: list,
conversation_message_task: ConversationMessageTask,
rest_tokens: int, callbacks: Callbacks = None, return_resource: bool = False,
retriever_from: str = 'dev') -> list[BaseTool]:
def to_tools(self, tool_configs: list, callbacks: Callbacks = None, **kwargs) -> list[BaseTool]:
"""
Convert app agent tool configs to tools
:param agent_model_instance:
:param rest_tokens:
:param tool_configs: app agent tool configs
:param conversation_message_task:
:param callbacks:
:param return_resource:
:param retriever_from:
:return:
"""
tools = []
@@ -196,29 +188,35 @@ class OrchestratorRuleParser:
tool = None
if tool_type == "dataset":
tool = self.to_dataset_retriever_tool(tool_val, conversation_message_task, rest_tokens, return_resource, retriever_from)
tool = self.to_dataset_retriever_tool(tool_config=tool_val, **kwargs)
elif tool_type == "web_reader":
tool = self.to_web_reader_tool(agent_model_instance)
tool = self.to_web_reader_tool(tool_config=tool_val, **kwargs)
elif tool_type == "google_search":
tool = self.to_google_search_tool()
tool = self.to_google_search_tool(tool_config=tool_val, **kwargs)
elif tool_type == "wikipedia":
tool = self.to_wikipedia_tool()
tool = self.to_wikipedia_tool(tool_config=tool_val, **kwargs)
elif tool_type == "current_datetime":
tool = self.to_current_datetime_tool()
tool = self.to_current_datetime_tool(tool_config=tool_val, **kwargs)
if tool:
tool.callbacks.extend(callbacks)
if tool.callbacks is not None:
tool.callbacks.extend(callbacks)
else:
tool.callbacks = callbacks
tools.append(tool)
return tools
def to_dataset_retriever_tool(self, tool_config: dict, conversation_message_task: ConversationMessageTask,
rest_tokens: int, return_resource: bool = False, retriever_from: str = 'dev') \
dataset_configs: dict, rest_tokens: int,
return_resource: bool = False, retriever_from: str = 'dev',
**kwargs) \
-> Optional[BaseTool]:
"""
A dataset tool is a tool that can be used to retrieve information from a dataset
:param rest_tokens:
:param tool_config:
:param dataset_configs:
:param conversation_message_task:
:param return_resource:
:param retriever_from:
@@ -236,10 +234,20 @@ class OrchestratorRuleParser:
if dataset and dataset.available_document_count == 0 and dataset.available_document_count == 0:
return None
k = self._dynamic_calc_retrieve_k(dataset, rest_tokens)
top_k = dataset_configs.get("top_k", 2)
# dynamically adjust top_k when the remaining token number is not enough to support top_k
top_k = self._dynamic_calc_retrieve_k(dataset=dataset, top_k=top_k, rest_tokens=rest_tokens)
score_threshold = None
score_threshold_config = dataset_configs.get("score_threshold")
if score_threshold_config and score_threshold_config.get("enable"):
score_threshold = score_threshold_config.get("value")
tool = DatasetRetrieverTool.from_dataset(
dataset=dataset,
k=k,
top_k=top_k,
score_threshold=score_threshold,
callbacks=[DatasetToolCallbackHandler(conversation_message_task)],
conversation_message_task=conversation_message_task,
return_resource=return_resource,
@@ -248,7 +256,7 @@ class OrchestratorRuleParser:
return tool
def to_web_reader_tool(self, agent_model_instance: BaseLLM) -> Optional[BaseTool]:
def to_web_reader_tool(self, tool_config: dict, agent_model_instance: BaseLLM, **kwargs) -> Optional[BaseTool]:
"""
A tool for reading web pages
@@ -269,15 +277,14 @@ class OrchestratorRuleParser:
summary_model_instance = None
tool = WebReaderTool(
llm=summary_model_instance.client if summary_model_instance else None,
model_instance=summary_model_instance if summary_model_instance else None,
max_chunk_length=4000,
continue_reading=True,
callbacks=[DifyStdOutCallbackHandler()]
continue_reading=True
)
return tool
def to_google_search_tool(self) -> Optional[BaseTool]:
def to_google_search_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
tool_provider = SerpAPIToolProvider(tenant_id=self.tenant_id)
func_kwargs = tool_provider.credentials_to_func_kwargs()
if not func_kwargs:
@@ -290,47 +297,39 @@ class OrchestratorRuleParser:
"is not up to date. "
"Input should be a search query.",
func=OptimizedSerpAPIWrapper(**func_kwargs).run,
args_schema=OptimizedSerpAPIInput,
callbacks=[DifyStdOutCallbackHandler()]
args_schema=OptimizedSerpAPIInput
)
return tool
def to_current_datetime_tool(self) -> Optional[BaseTool]:
tool = DatetimeTool(
callbacks=[DifyStdOutCallbackHandler()]
)
def to_current_datetime_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
tool = DatetimeTool()
return tool
def to_wikipedia_tool(self) -> Optional[BaseTool]:
def to_wikipedia_tool(self, tool_config: dict, **kwargs) -> Optional[BaseTool]:
class WikipediaInput(BaseModel):
query: str = Field(..., description="search query.")
return WikipediaQueryRun(
name="wikipedia",
api_wrapper=WikipediaAPIWrapper(doc_content_chars_max=4000),
args_schema=WikipediaInput,
callbacks=[DifyStdOutCallbackHandler()]
args_schema=WikipediaInput
)
@classmethod
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, rest_tokens: int) -> int:
DEFAULT_K = 2
CONTEXT_TOKENS_PERCENT = 0.3
MAX_K = 10
def _dynamic_calc_retrieve_k(cls, dataset: Dataset, top_k: int, rest_tokens: int) -> int:
if rest_tokens == -1:
return DEFAULT_K
return top_k
processing_rule = dataset.latest_process_rule
if not processing_rule:
return DEFAULT_K
return top_k
if processing_rule.mode == "custom":
rules = processing_rule.rules_dict
if not rules:
return DEFAULT_K
return top_k
segmentation = rules["segmentation"]
segment_max_tokens = segmentation["max_tokens"]
@@ -338,14 +337,7 @@ class OrchestratorRuleParser:
segment_max_tokens = DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens']
# when rest_tokens is less than default context tokens
if rest_tokens < segment_max_tokens * DEFAULT_K:
if rest_tokens < segment_max_tokens * top_k:
return rest_tokens // segment_max_tokens
context_limit_tokens = math.floor(rest_tokens * CONTEXT_TOKENS_PERCENT)
# when context_limit_tokens is less than default context tokens, use default_k
if context_limit_tokens <= segment_max_tokens * DEFAULT_K:
return DEFAULT_K
# Expand the k value when there's still some room left in the 30% rest tokens space, but less than the MAX_K
return min(context_limit_tokens // segment_max_tokens, MAX_K)
return min(top_k, 10)

View File

@@ -0,0 +1,83 @@
CONTEXT = "Use the following context as your learned knowledge, inside <context></context> XML tags.\n\n<context>\n{{#context#}}\n</context>\n\nWhen answer to user:\n- If you don't know, just say that you don't know.\n- If you don't know when you are not sure, ask for clarification.\nAvoid mentioning that you obtained the information from the context.\nAnd answer according to the language of the user's question.\n"
BAICHUAN_CONTEXT = "用户在与一个客观的助手对话。助手会尊重找到的材料,给出全面专业的解释,但不会过度演绎。同时回答中不会暴露引用的材料:\n\n```\n{{#context#}}\n```\n\n"
CHAT_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {
"prompt": {
"text": "{{#pre_prompt#}}\nHere is the chat histories between human and assistant, inside <histories></histories> XML tags.\n\n<histories>\n{{#histories#}}\n</histories>\n\n\nHuman: {{#query#}}\n\nAssistant: "
},
"conversation_histories_role": {
"user_prefix": "Human",
"assistant_prefix": "Assistant"
}
},
"stop": ["Human:"]
}
CHAT_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "system",
"text": "{{#pre_prompt#}}"
}]
}
}
COMPLETION_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "user",
"text": "{{#pre_prompt#}}"
}]
}
}
COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {
"prompt": {
"text": "{{#pre_prompt#}}"
}
},
"stop": ["Human:"]
}
BAICHUAN_CHAT_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {
"prompt": {
"text": "{{#pre_prompt#}}\n\n用户和助手的历史对话内容如下:\n```\n{{#histories#}}\n```\n\n\n\n用户:{{#query#}}"
},
"conversation_histories_role": {
"user_prefix": "用户",
"assistant_prefix": "助手"
}
},
"stop": ["用户:"]
}
BAICHUAN_CHAT_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "system",
"text": "{{#pre_prompt#}}"
}]
}
}
BAICHUAN_COMPLETION_APP_CHAT_PROMPT_CONFIG = {
"chat_prompt_config": {
"prompt": [{
"role": "user",
"text": "{{#pre_prompt#}}"
}]
}
}
BAICHUAN_COMPLETION_APP_COMPLETION_PROMPT_CONFIG = {
"completion_prompt_config": {
"prompt": {
"text": "{{#pre_prompt#}}"
}
},
"stop": ["用户:"]
}

View File

@@ -1,38 +1,24 @@
import re
from langchain.schema import BaseMessage, SystemMessage, AIMessage, HumanMessage
from langchain.prompts import SystemMessagePromptTemplate, HumanMessagePromptTemplate, AIMessagePromptTemplate
from langchain.schema import BaseMessage
from core.prompt.prompt_template import JinjaPromptTemplate
from core.prompt.prompt_template import PromptTemplateParser
class PromptBuilder:
@classmethod
def parse_prompt(cls, prompt: str, inputs: dict) -> str:
prompt_template = PromptTemplateParser(prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
prompt = prompt_template.format(prompt_inputs)
return prompt
@classmethod
def to_system_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
system_prompt_template = SystemMessagePromptTemplate(prompt=prompt_template)
prompt_inputs = {k: inputs[k] for k in system_prompt_template.input_variables if k in inputs}
system_message = system_prompt_template.format(**prompt_inputs)
return system_message
return SystemMessage(content=cls.parse_prompt(prompt_content, inputs))
@classmethod
def to_ai_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
ai_prompt_template = AIMessagePromptTemplate(prompt=prompt_template)
prompt_inputs = {k: inputs[k] for k in ai_prompt_template.input_variables if k in inputs}
ai_message = ai_prompt_template.format(**prompt_inputs)
return ai_message
return AIMessage(content=cls.parse_prompt(prompt_content, inputs))
@classmethod
def to_human_message(cls, prompt_content: str, inputs: dict) -> BaseMessage:
prompt_template = JinjaPromptTemplate.from_template(prompt_content)
human_prompt_template = HumanMessagePromptTemplate(prompt=prompt_template)
human_message = human_prompt_template.format(**inputs)
return human_message
@classmethod
def process_template(cls, template: str):
processed_template = re.sub(r'\{{2}(.+)\}{2}', r'{\1}', template)
# processed_template = re.sub(r'\{([a-zA-Z_]\w+?)\}', r'\1', template)
# processed_template = re.sub(r'\{\{([a-zA-Z_]\w+?)\}\}', r'{\1}', processed_template)
return processed_template
return HumanMessage(content=cls.parse_prompt(prompt_content, inputs))

View File

@@ -1,79 +1,39 @@
import re
from typing import Any
from jinja2 import Environment, meta
from langchain import PromptTemplate
from langchain.formatting import StrictFormatter
REGEX = re.compile(r"\{\{([a-zA-Z_][a-zA-Z0-9_]{1,29}|#histories#|#query#|#context#)\}\}")
class JinjaPromptTemplate(PromptTemplate):
template_format: str = "jinja2"
"""The format of the prompt template. Options are: 'f-string', 'jinja2'."""
class PromptTemplateParser:
"""
Rules:
1. Template variables must be enclosed in `{{}}`.
2. The template variable Key can only be: letters + numbers + underscore, with a maximum length of 16 characters,
and can only start with letters and underscores.
3. The template variable Key cannot contain new lines or spaces, and must comply with rule 2.
4. In addition to the above, 3 types of special template variable Keys are accepted:
`{{#histories#}}` `{{#query#}}` `{{#context#}}`. No other `{{##}}` template variables are allowed.
"""
def __init__(self, template: str):
self.template = template
self.variable_keys = self.extract()
def extract(self) -> list:
# Regular expression to match the template rules
return re.findall(REGEX, self.template)
def format(self, inputs: dict, remove_template_variables: bool = True) -> str:
def replacer(match):
key = match.group(1)
value = inputs.get(key, match.group(0)) # return original matched string if key not found
if remove_template_variables:
return PromptTemplateParser.remove_template_variables(value)
return value
return re.sub(REGEX, replacer, self.template)
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
"""Load a prompt template from a template."""
env = Environment()
template = template.replace("{{}}", "{}")
ast = env.parse(template)
input_variables = meta.find_undeclared_variables(ast)
if "partial_variables" in kwargs:
partial_variables = kwargs["partial_variables"]
input_variables = {
var for var in input_variables if var not in partial_variables
}
return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs
)
class OutLinePromptTemplate(PromptTemplate):
@classmethod
def from_template(cls, template: str, **kwargs: Any) -> PromptTemplate:
"""Load a prompt template from a template."""
input_variables = {
v for _, v, _, _ in OneLineFormatter().parse(template) if v is not None
}
return cls(
input_variables=list(sorted(input_variables)), template=template, **kwargs
)
def format(self, **kwargs: Any) -> str:
"""Format the prompt with the inputs.
Args:
kwargs: Any arguments to be passed to the prompt template.
Returns:
A formatted string.
Example:
.. code-block:: python
prompt.format(variable1="foo")
"""
kwargs = self._merge_partial_and_user_variables(**kwargs)
return OneLineFormatter().format(self.template, **kwargs)
class OneLineFormatter(StrictFormatter):
def parse(self, format_string):
last_end = 0
results = []
for match in re.finditer(r"{([a-zA-Z_]\w*)}", format_string):
field_name = match.group(1)
start, end = match.span()
literal_text = format_string[last_end:start]
last_end = end
results.append((literal_text, field_name, '', None))
remaining_literal_text = format_string[last_end:]
if remaining_literal_text:
results.append((remaining_literal_text, None, None, None))
return results
def remove_template_variables(cls, text: str):
return re.sub(REGEX, r'{\1}', text)

View File

@@ -61,36 +61,6 @@ User Input: yo, 你今天咋样?
User Input:
"""
CONVERSATION_SUMMARY_PROMPT = (
"Please generate a short summary of the following conversation.\n"
"If the following conversation communicating in English, you should only return an English summary.\n"
"If the following conversation communicating in Chinese, you should only return a Chinese summary.\n"
"[Conversation Start]\n"
"{context}\n"
"[Conversation End]\n\n"
"summary:"
)
INTRODUCTION_GENERATE_PROMPT = (
"I am designing a product for users to interact with an AI through dialogue. "
"The Prompt given to the AI before the conversation is:\n\n"
"```\n{prompt}\n```\n\n"
"Please generate a brief introduction of no more than 50 words that greets the user, based on this Prompt. "
"Do not reveal the developer's motivation or deep logic behind the Prompt, "
"but focus on building a relationship with the user:\n"
)
MORE_LIKE_THIS_GENERATE_PROMPT = (
"-----\n"
"{original_completion}\n"
"-----\n\n"
"Please use the above content as a sample for generating the result, "
"and include key information points related to the original sample in the result. "
"Try to rephrase this information in different ways and predict according to the rules below.\n\n"
"-----\n"
"{prompt}\n"
)
SUGGESTED_QUESTIONS_AFTER_ANSWER_INSTRUCTION_PROMPT = (
"Please help me predict the three most likely questions that human would ask, "
"and keeping each question under 20 characters.\n"
@@ -157,10 +127,10 @@ and fill in variables, with a welcome sentence, and keep TLDR.
```
<< MY INTENDED AUDIENCES >>
{audiences}
{{audiences}}
<< HOPING TO SOLVE >>
{hoping_to_solve}
{{hoping_to_solve}}
<< OUTPUT >>
"""

View File

@@ -0,0 +1,67 @@
"""Wrapper around OpenLLM embedding models."""
from typing import Any, List, Optional
import requests
from pydantic import BaseModel, Extra
from langchain.embeddings.base import Embeddings
class OpenLLMEmbeddings(BaseModel, Embeddings):
"""Wrapper around OpenLLM embedding models.
"""
client: Any #: :meta private:
server_url: Optional[str] = None
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to OpenLLM's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = []
for text in texts:
result = self.invoke_embedding(text=text)
embeddings.append(result)
return [list(map(float, e)) for e in embeddings]
def invoke_embedding(self, text):
params = [
text
]
headers = {"Content-Type": "application/json"}
response = requests.post(
f'{self.server_url}/v1/embeddings',
headers=headers,
json=params
)
if not response.ok:
raise ValueError(f"OpenLLM HTTP {response.status_code} error: {response.text}")
json_response = response.json()
return json_response[0]["embeddings"][0]
def embed_query(self, text: str) -> List[float]:
"""Call out to OpenLLM's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]

Some files were not shown because too many files have changed in this diff Show More