Compare commits
20 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8835435558 | ||
|
|
a80d8286c2 | ||
|
|
6b15827246 | ||
|
|
41d0a8b295 | ||
|
|
d0e1ea8f06 | ||
|
|
f3b9647bb4 | ||
|
|
9de67c586f | ||
|
|
92f594f5e7 | ||
|
|
06d5273217 | ||
|
|
94d7babbf1 | ||
|
|
306216dbe5 | ||
|
|
ab2e20ee0a | ||
|
|
146e95d88f | ||
|
|
d7ae86799c | ||
|
|
7b26c9e2ef | ||
|
|
6bcafdbc87 | ||
|
|
059c089f93 | ||
|
|
c1e7193c4b | ||
|
|
2423563d45 | ||
|
|
260672986e |
@@ -18,6 +18,9 @@ SERVICE_API_URL=http://127.0.0.1:5001
|
||||
APP_API_URL=http://127.0.0.1:5001
|
||||
APP_WEB_URL=http://127.0.0.1:3000
|
||||
|
||||
# Files URL
|
||||
FILES_URL=http://127.0.0.1:5001
|
||||
|
||||
# celery configuration
|
||||
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
|
||||
|
||||
@@ -70,6 +73,14 @@ MILVUS_USER=root
|
||||
MILVUS_PASSWORD=Milvus
|
||||
MILVUS_SECURE=false
|
||||
|
||||
# Upload configuration
|
||||
UPLOAD_FILE_SIZE_LIMIT=15
|
||||
UPLOAD_FILE_BATCH_LIMIT=5
|
||||
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
|
||||
|
||||
# Model Configuration
|
||||
MULTIMODAL_SEND_IMAGE_FORMAT=base64
|
||||
|
||||
# Mail configuration, support: resend
|
||||
MAIL_TYPE=
|
||||
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>
|
||||
|
||||
@@ -126,6 +126,7 @@ def register_blueprints(app):
|
||||
from controllers.service_api import bp as service_api_bp
|
||||
from controllers.web import bp as web_bp
|
||||
from controllers.console import bp as console_app_bp
|
||||
from controllers.files import bp as files_bp
|
||||
|
||||
CORS(service_api_bp,
|
||||
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
|
||||
@@ -155,6 +156,12 @@ def register_blueprints(app):
|
||||
|
||||
app.register_blueprint(console_app_bp)
|
||||
|
||||
CORS(files_bp,
|
||||
allow_headers=['Content-Type'],
|
||||
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
|
||||
)
|
||||
app.register_blueprint(files_bp)
|
||||
|
||||
|
||||
# create app
|
||||
app = create_app()
|
||||
|
||||
231
api/config.py
@@ -26,6 +26,7 @@ DEFAULTS = {
|
||||
'SERVICE_API_URL': 'https://api.dify.ai',
|
||||
'APP_WEB_URL': 'https://udify.app',
|
||||
'APP_API_URL': 'https://udify.app',
|
||||
'FILES_URL': '',
|
||||
'STORAGE_TYPE': 'local',
|
||||
'STORAGE_LOCAL_PATH': 'storage',
|
||||
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
|
||||
@@ -57,7 +58,9 @@ DEFAULTS = {
|
||||
'CLEAN_DAY_SETTING': 30,
|
||||
'UPLOAD_FILE_SIZE_LIMIT': 15,
|
||||
'UPLOAD_FILE_BATCH_LIMIT': 5,
|
||||
'OUTPUT_MODERATION_BUFFER_SIZE': 300
|
||||
'UPLOAD_IMAGE_FILE_SIZE_LIMIT': 10,
|
||||
'OUTPUT_MODERATION_BUFFER_SIZE': 300,
|
||||
'MULTIMODAL_SEND_IMAGE_FORMAT': 'base64'
|
||||
}
|
||||
|
||||
|
||||
@@ -84,86 +87,65 @@ class Config:
|
||||
"""Application configuration class."""
|
||||
|
||||
def __init__(self):
|
||||
# app settings
|
||||
self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
|
||||
self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
|
||||
self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
|
||||
self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
|
||||
self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
self.CURRENT_VERSION = "0.3.29"
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.3.30"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
self.TESTING = False
|
||||
self.LOG_LEVEL = get_env('LOG_LEVEL')
|
||||
|
||||
# The backend URL prefix of the console API.
|
||||
# used to concatenate the login authorization callback or notion integration callback.
|
||||
self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
|
||||
|
||||
# The front-end URL prefix of the console web.
|
||||
# used to concatenate some front-end addresses and for CORS configuration use.
|
||||
self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
|
||||
|
||||
# WebApp API backend Url prefix.
|
||||
# used to declare the back-end URL for the front-end API.
|
||||
self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
|
||||
|
||||
# WebApp Url prefix.
|
||||
# used to display WebAPP API Base Url to the front-end.
|
||||
self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
|
||||
|
||||
# Service API Url prefix.
|
||||
# used to display Service API Base Url to the front-end.
|
||||
self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
|
||||
|
||||
# File preview or download Url prefix.
|
||||
# used to display File preview or download Url to the front-end or as Multi-model inputs;
|
||||
# Url is signed and has expiration time.
|
||||
self.FILES_URL = get_env('FILES_URL') if get_env('FILES_URL') else self.CONSOLE_API_URL
|
||||
|
||||
# Fallback Url prefix.
|
||||
# Will be deprecated in the future.
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
|
||||
# Your App secret key will be used for securely signing the session cookie
|
||||
# Make sure you are changing this key for your deployment with a strong key.
|
||||
# You can generate a strong key using `openssl rand -base64 42`.
|
||||
# Alternatively you can set it with `SECRET_KEY` environment variable.
|
||||
self.SECRET_KEY = get_env('SECRET_KEY')
|
||||
|
||||
# redis settings
|
||||
self.REDIS_HOST = get_env('REDIS_HOST')
|
||||
self.REDIS_PORT = get_env('REDIS_PORT')
|
||||
self.REDIS_USERNAME = get_env('REDIS_USERNAME')
|
||||
self.REDIS_PASSWORD = get_env('REDIS_PASSWORD')
|
||||
self.REDIS_DB = get_env('REDIS_DB')
|
||||
self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
|
||||
|
||||
# storage settings
|
||||
self.STORAGE_TYPE = get_env('STORAGE_TYPE')
|
||||
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
|
||||
self.S3_ENDPOINT = get_env('S3_ENDPOINT')
|
||||
self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME')
|
||||
self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY')
|
||||
self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
|
||||
self.S3_REGION = get_env('S3_REGION')
|
||||
|
||||
# vector store settings, only support weaviate, qdrant
|
||||
self.VECTOR_STORE = get_env('VECTOR_STORE')
|
||||
|
||||
# weaviate settings
|
||||
self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
|
||||
self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY')
|
||||
self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED')
|
||||
self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE'))
|
||||
|
||||
# qdrant settings
|
||||
self.QDRANT_URL = get_env('QDRANT_URL')
|
||||
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
|
||||
|
||||
# milvus setting
|
||||
self.MILVUS_HOST = get_env('MILVUS_HOST')
|
||||
self.MILVUS_PORT = get_env('MILVUS_PORT')
|
||||
self.MILVUS_USER = get_env('MILVUS_USER')
|
||||
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
|
||||
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
|
||||
|
||||
|
||||
# cors settings
|
||||
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
||||
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
|
||||
self.WEB_API_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
|
||||
'WEB_API_CORS_ALLOW_ORIGINS', '*')
|
||||
|
||||
# mail settings
|
||||
self.MAIL_TYPE = get_env('MAIL_TYPE')
|
||||
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
|
||||
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
|
||||
|
||||
# sentry settings
|
||||
self.SENTRY_DSN = get_env('SENTRY_DSN')
|
||||
self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE'))
|
||||
self.SENTRY_PROFILES_SAMPLE_RATE = float(get_env('SENTRY_PROFILES_SAMPLE_RATE'))
|
||||
|
||||
# check update url
|
||||
self.CHECK_UPDATE_URL = get_env('CHECK_UPDATE_URL')
|
||||
|
||||
# database settings
|
||||
# ------------------------
|
||||
# Database Configurations.
|
||||
# ------------------------
|
||||
db_credentials = {
|
||||
key: get_env(key) for key in
|
||||
['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT', 'DB_DATABASE']
|
||||
@@ -177,14 +159,102 @@ class Config:
|
||||
|
||||
self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')
|
||||
|
||||
# celery settings
|
||||
# ------------------------
|
||||
# Redis Configurations.
|
||||
# ------------------------
|
||||
self.REDIS_HOST = get_env('REDIS_HOST')
|
||||
self.REDIS_PORT = get_env('REDIS_PORT')
|
||||
self.REDIS_USERNAME = get_env('REDIS_USERNAME')
|
||||
self.REDIS_PASSWORD = get_env('REDIS_PASSWORD')
|
||||
self.REDIS_DB = get_env('REDIS_DB')
|
||||
self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
|
||||
|
||||
# ------------------------
|
||||
# Celery worker Configurations.
|
||||
# ------------------------
|
||||
self.CELERY_BROKER_URL = get_env('CELERY_BROKER_URL')
|
||||
self.CELERY_BACKEND = get_env('CELERY_BACKEND')
|
||||
self.CELERY_RESULT_BACKEND = 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
|
||||
if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
|
||||
self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
|
||||
|
||||
# hosted provider credentials
|
||||
# ------------------------
|
||||
# File Storage Configurations.
|
||||
# ------------------------
|
||||
self.STORAGE_TYPE = get_env('STORAGE_TYPE')
|
||||
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
|
||||
self.S3_ENDPOINT = get_env('S3_ENDPOINT')
|
||||
self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME')
|
||||
self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY')
|
||||
self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
|
||||
self.S3_REGION = get_env('S3_REGION')
|
||||
|
||||
# ------------------------
|
||||
# Vector Store Configurations.
|
||||
# Currently, only support: qdrant, milvus, zilliz, weaviate
|
||||
# ------------------------
|
||||
self.VECTOR_STORE = get_env('VECTOR_STORE')
|
||||
|
||||
# qdrant settings
|
||||
self.QDRANT_URL = get_env('QDRANT_URL')
|
||||
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
|
||||
|
||||
# milvus / zilliz setting
|
||||
self.MILVUS_HOST = get_env('MILVUS_HOST')
|
||||
self.MILVUS_PORT = get_env('MILVUS_PORT')
|
||||
self.MILVUS_USER = get_env('MILVUS_USER')
|
||||
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
|
||||
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
|
||||
|
||||
# weaviate settings
|
||||
self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
|
||||
self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY')
|
||||
self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED')
|
||||
self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE'))
|
||||
|
||||
# ------------------------
|
||||
# Mail Configurations.
|
||||
# ------------------------
|
||||
self.MAIL_TYPE = get_env('MAIL_TYPE')
|
||||
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
|
||||
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
|
||||
|
||||
# ------------------------
|
||||
# Sentry Configurations.
|
||||
# ------------------------
|
||||
self.SENTRY_DSN = get_env('SENTRY_DSN')
|
||||
self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE'))
|
||||
self.SENTRY_PROFILES_SAMPLE_RATE = float(get_env('SENTRY_PROFILES_SAMPLE_RATE'))
|
||||
|
||||
# ------------------------
|
||||
# Business Configurations.
|
||||
# ------------------------
|
||||
|
||||
# multi model send image format, support base64, url, default is base64
|
||||
self.MULTIMODAL_SEND_IMAGE_FORMAT = get_env('MULTIMODAL_SEND_IMAGE_FORMAT')
|
||||
|
||||
# Dataset Configurations.
|
||||
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
|
||||
self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
|
||||
|
||||
# File upload Configurations.
|
||||
self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
|
||||
self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
|
||||
self.UPLOAD_IMAGE_FILE_SIZE_LIMIT = int(get_env('UPLOAD_IMAGE_FILE_SIZE_LIMIT'))
|
||||
|
||||
# Moderation in app Configurations.
|
||||
self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE'))
|
||||
|
||||
# Notion integration setting
|
||||
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
|
||||
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
|
||||
self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
|
||||
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
|
||||
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
|
||||
|
||||
# ------------------------
|
||||
# Platform Configurations.
|
||||
# ------------------------
|
||||
self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED')
|
||||
self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
|
||||
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
|
||||
@@ -212,26 +282,6 @@ class Config:
|
||||
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
|
||||
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
|
||||
|
||||
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
|
||||
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
|
||||
|
||||
# notion import setting
|
||||
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
|
||||
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
|
||||
self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
|
||||
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
|
||||
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
|
||||
|
||||
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
|
||||
self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
|
||||
|
||||
# uploading settings
|
||||
self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
|
||||
self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
|
||||
|
||||
# moderation settings
|
||||
self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE'))
|
||||
|
||||
|
||||
class CloudEditionConfig(Config):
|
||||
|
||||
@@ -246,18 +296,5 @@ class CloudEditionConfig(Config):
|
||||
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
|
||||
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
|
||||
|
||||
|
||||
class TestConfig(Config):
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.TESTING = True
|
||||
|
||||
db_credentials = {
|
||||
key: get_env(key) for key in ['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT']
|
||||
}
|
||||
|
||||
# use a different database for testing: dify_test
|
||||
self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/dify_test"
|
||||
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
|
||||
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
|
||||
|
||||
@@ -40,12 +40,14 @@ class CompletionMessageApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('model_config', type=dict, required=True, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] != 'blocking'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
account = flask_login.current_user
|
||||
|
||||
@@ -113,6 +115,7 @@ class ChatMessageApi(Resource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('model_config', type=dict, required=True, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
@@ -120,6 +123,7 @@ class ChatMessageApi(Resource):
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] != 'blocking'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
account = flask_login.current_user
|
||||
|
||||
|
||||
@@ -108,7 +108,7 @@ class CompletionConversationDetailApi(Resource):
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_id, conversation_id, 'completion')
|
||||
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -230,7 +230,7 @@ class ChatConversationDetailApi(Resource):
|
||||
conversation_id = str(conversation_id)
|
||||
|
||||
return _get_conversation(app_id, conversation_id, 'chat')
|
||||
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@@ -253,8 +253,6 @@ class ChatConversationDetailApi(Resource):
|
||||
return {'result': 'success'}, 204
|
||||
|
||||
|
||||
|
||||
|
||||
api.add_resource(CompletionConversationApi, '/apps/<uuid:app_id>/completion-conversations')
|
||||
api.add_resource(CompletionConversationDetailApi, '/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>')
|
||||
api.add_resource(ChatConversationApi, '/apps/<uuid:app_id>/chat-conversations')
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import datetime
|
||||
import json
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from libs.login import login_required
|
||||
@@ -20,8 +19,6 @@ from models.source import DataSourceBinding
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
|
||||
class DataSourceApi(Resource):
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask_login import current_user
|
||||
|
||||
import services
|
||||
from libs.login import login_required
|
||||
@@ -15,9 +15,6 @@ from fields.file_fields import upload_config_fields, file_fields
|
||||
|
||||
from services.file_service import FileService
|
||||
|
||||
cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
|
||||
|
||||
@@ -30,9 +27,11 @@ class FileApi(Resource):
|
||||
def get(self):
|
||||
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT")
|
||||
batch_count_limit = current_app.config.get("UPLOAD_FILE_BATCH_LIMIT")
|
||||
image_file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT")
|
||||
return {
|
||||
'file_size_limit': file_size_limit,
|
||||
'batch_count_limit': batch_count_limit
|
||||
'batch_count_limit': batch_count_limit,
|
||||
'image_file_size_limit': image_file_size_limit
|
||||
}, 200
|
||||
|
||||
@setup_required
|
||||
@@ -51,7 +50,7 @@ class FileApi(Resource):
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
try:
|
||||
upload_file = FileService.upload_file(file)
|
||||
upload_file = FileService.upload_file(file, current_user)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Generator, Union
|
||||
|
||||
from flask import Response, stream_with_context
|
||||
@@ -17,6 +18,7 @@ from controllers.console.explore.wraps import InstalledAppResource
|
||||
from core.conversation_message_task import PubHandler
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import uuid_value
|
||||
from services.completion_service import CompletionService
|
||||
|
||||
@@ -32,11 +34,16 @@ class CompletionApi(InstalledAppResource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
installed_app.last_used_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
@@ -91,12 +98,17 @@ class ChatApi(InstalledAppResource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
installed_app.last_used_at = datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
|
||||
@@ -38,7 +38,8 @@ class ConversationListApi(InstalledAppResource):
|
||||
user=current_user,
|
||||
last_id=args['last_id'],
|
||||
limit=args['limit'],
|
||||
pinned=pinned
|
||||
pinned=pinned,
|
||||
exclude_debug_conversation=True
|
||||
)
|
||||
except LastConversationNotExistsError:
|
||||
raise NotFound("Last Conversation Not Exists.")
|
||||
@@ -71,11 +72,18 @@ class ConversationRenameApi(InstalledAppResource):
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
conversation_id,
|
||||
current_user,
|
||||
args['name'],
|
||||
args['auto_generate']
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@@ -39,8 +39,9 @@ class InstalledAppsListApi(Resource):
|
||||
}
|
||||
for installed_app in installed_apps
|
||||
]
|
||||
installed_apps.sort(key=lambda app: (-app['is_pinned'], app['last_used_at']
|
||||
if app['last_used_at'] is not None else datetime.min))
|
||||
installed_apps.sort(key=lambda app: (-app['is_pinned'],
|
||||
app['last_used_at'] is None,
|
||||
-app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0))
|
||||
|
||||
return {'installed_apps': installed_apps}
|
||||
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import marshal_with, fields
|
||||
from flask import current_app
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.explore.wraps import InstalledAppResource
|
||||
@@ -19,6 +20,10 @@ class AppParameterApi(InstalledAppResource):
|
||||
'options': fields.List(fields.String)
|
||||
}
|
||||
|
||||
system_parameters_fields = {
|
||||
'image_file_size_limit': fields.String
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
@@ -27,7 +32,9 @@ class AppParameterApi(InstalledAppResource):
|
||||
'retriever_resource': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
'file_upload': fields.Raw,
|
||||
'system_parameters': fields.Nested(system_parameters_fields)
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
@@ -44,7 +51,11 @@ class AppParameterApi(InstalledAppResource):
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
'file_upload': app_model_config.file_upload_dict,
|
||||
'system_parameters': {
|
||||
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from controllers.console.explore.wraps import InstalledAppResource
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
from fields.conversation_fields import message_file_fields
|
||||
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
@@ -19,6 +20,7 @@ message_fields = {
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
@@ -25,6 +25,7 @@ class UniversalChatApi(UniversalChatResource):
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('provider', type=str, required=True, location='json')
|
||||
parser.add_argument('model', type=str, required=True, location='json')
|
||||
@@ -60,6 +61,8 @@ class UniversalChatApi(UniversalChatResource):
|
||||
del args['model']
|
||||
del args['tools']
|
||||
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
app_model=app_model,
|
||||
|
||||
@@ -65,11 +65,18 @@ class UniversalChatConversationRenameApi(UniversalChatResource):
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
conversation_id,
|
||||
current_user,
|
||||
args['name'],
|
||||
args['auto_generate']
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
10
api/controllers/files/__init__.py
Normal file
@@ -0,0 +1,10 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import Blueprint
|
||||
|
||||
from libs.external_api import ExternalApi
|
||||
|
||||
bp = Blueprint('files', __name__)
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import image_preview
|
||||
40
api/controllers/files/image_preview.py
Normal file
@@ -0,0 +1,40 @@
|
||||
from flask import request, Response
|
||||
from flask_restful import Resource
|
||||
|
||||
import services
|
||||
from controllers.files import api
|
||||
from libs.exception import BaseHTTPException
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
class ImagePreviewApi(Resource):
|
||||
def get(self, file_id):
|
||||
file_id = str(file_id)
|
||||
|
||||
timestamp = request.args.get('timestamp')
|
||||
nonce = request.args.get('nonce')
|
||||
sign = request.args.get('sign')
|
||||
|
||||
if not timestamp or not nonce or not sign:
|
||||
return {'content': 'Invalid request.'}, 400
|
||||
|
||||
try:
|
||||
generator, mimetype = FileService.get_image_preview(
|
||||
file_id,
|
||||
timestamp,
|
||||
nonce,
|
||||
sign
|
||||
)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return Response(generator, mimetype=mimetype)
|
||||
|
||||
|
||||
api.add_resource(ImagePreviewApi, '/files/<uuid:file_id>/image-preview')
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_file_type'
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
@@ -7,6 +7,6 @@ bp = Blueprint('service_api', __name__, url_prefix='/v1')
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from .app import completion, app, conversation, message, audio
|
||||
from .app import completion, app, conversation, message, audio, file
|
||||
|
||||
from .dataset import document, segment, dataset
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import fields, marshal_with
|
||||
from flask import current_app
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
@@ -20,6 +21,10 @@ class AppParameterApi(AppApiResource):
|
||||
'options': fields.List(fields.String)
|
||||
}
|
||||
|
||||
system_parameters_fields = {
|
||||
'image_file_size_limit': fields.String
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
@@ -28,7 +33,9 @@ class AppParameterApi(AppApiResource):
|
||||
'retriever_resource': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
'file_upload': fields.Raw,
|
||||
'system_parameters': fields.Nested(system_parameters_fields)
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
@@ -44,7 +51,11 @@ class AppParameterApi(AppApiResource):
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
'file_upload': app_model_config.file_upload_dict,
|
||||
'system_parameters': {
|
||||
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -28,6 +28,7 @@ class CompletionApi(AppApiResource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
@@ -39,13 +40,15 @@ class CompletionApi(AppApiResource):
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
app_model=app_model,
|
||||
user=end_user,
|
||||
args=args,
|
||||
from_source='api',
|
||||
streaming=streaming
|
||||
streaming=streaming,
|
||||
)
|
||||
|
||||
return compact_response(response)
|
||||
@@ -90,10 +93,12 @@ class ChatApi(AppApiResource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
|
||||
parser.add_argument('auto_generate_name', type=bool, required=False, default='True', location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
@@ -65,15 +65,22 @@ class ConversationRenameApi(AppApiResource):
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('user', type=str, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
if end_user is None and args['user'] is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, args['name'])
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
conversation_id,
|
||||
end_user,
|
||||
args['name'],
|
||||
args['auto_generate']
|
||||
)
|
||||
except services.errors.conversation.ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@@ -75,3 +75,26 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = 'no_file_uploaded'
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = 'too_many_files'
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = 'file_too_large'
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_file_type'
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
42
api/controllers/service_api/app/file.py
Normal file
@@ -0,0 +1,42 @@
|
||||
from flask import request
|
||||
from flask_restful import marshal_with
|
||||
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.wraps import AppApiResource
|
||||
from controllers.service_api.app import create_or_update_end_user_for_user_id
|
||||
from controllers.service_api.app.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
|
||||
UnsupportedFileTypeError
|
||||
import services
|
||||
from services.file_service import FileService
|
||||
from fields.file_fields import file_fields
|
||||
|
||||
|
||||
class FileApi(AppApiResource):
|
||||
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model, end_user):
|
||||
|
||||
file = request.files['file']
|
||||
user_args = request.form.get('user')
|
||||
|
||||
if end_user is None and user_args is not None:
|
||||
end_user = create_or_update_end_user_for_user_id(app_model, user_args)
|
||||
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
try:
|
||||
upload_file = FileService.upload_file(file, end_user)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
api.add_resource(FileApi, '/files/upload')
|
||||
@@ -12,7 +12,7 @@ from libs.helper import TimestampField, uuid_value
|
||||
from services.message_service import MessageService
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message, EndUser
|
||||
|
||||
from fields.conversation_fields import message_file_fields
|
||||
|
||||
class MessageListApi(AppApiResource):
|
||||
feedback_fields = {
|
||||
@@ -43,6 +43,7 @@ class MessageListApi(AppApiResource):
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField
|
||||
|
||||
@@ -2,6 +2,7 @@ import json
|
||||
|
||||
from flask import request
|
||||
from flask_restful import reqparse, marshal
|
||||
from flask_login import current_user
|
||||
from sqlalchemy import desc
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
@@ -173,7 +174,7 @@ class DocumentAddByFileApi(DatasetApiResource):
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
upload_file = FileService.upload_file(file)
|
||||
upload_file = FileService.upload_file(file, current_user)
|
||||
data_source = {
|
||||
'type': 'upload_file',
|
||||
'info_list': {
|
||||
@@ -235,7 +236,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
|
||||
upload_file = FileService.upload_file(file)
|
||||
upload_file = FileService.upload_file(file, current_user)
|
||||
data_source = {
|
||||
'type': 'upload_file',
|
||||
'info_list': {
|
||||
|
||||
@@ -7,4 +7,4 @@ bp = Blueprint('web', __name__, url_prefix='/api')
|
||||
api = ExternalApi(bp)
|
||||
|
||||
|
||||
from . import completion, app, conversation, message, site, saved_message, audio, passport
|
||||
from . import completion, app, conversation, message, site, saved_message, audio, passport, file
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_restful import marshal_with, fields
|
||||
from flask import current_app
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
@@ -19,6 +20,10 @@ class AppParameterApi(WebApiResource):
|
||||
'options': fields.List(fields.String)
|
||||
}
|
||||
|
||||
system_parameters_fields = {
|
||||
'image_file_size_limit': fields.String
|
||||
}
|
||||
|
||||
parameters_fields = {
|
||||
'opening_statement': fields.String,
|
||||
'suggested_questions': fields.Raw,
|
||||
@@ -27,7 +32,9 @@ class AppParameterApi(WebApiResource):
|
||||
'retriever_resource': fields.Raw,
|
||||
'more_like_this': fields.Raw,
|
||||
'user_input_form': fields.Raw,
|
||||
'sensitive_word_avoidance': fields.Raw
|
||||
'sensitive_word_avoidance': fields.Raw,
|
||||
'file_upload': fields.Raw,
|
||||
'system_parameters': fields.Nested(system_parameters_fields)
|
||||
}
|
||||
|
||||
@marshal_with(parameters_fields)
|
||||
@@ -43,7 +50,11 @@ class AppParameterApi(WebApiResource):
|
||||
'retriever_resource': app_model_config.retriever_resource_dict,
|
||||
'more_like_this': app_model_config.more_like_this_dict,
|
||||
'user_input_form': app_model_config.user_input_form_list,
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
|
||||
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
|
||||
'file_upload': app_model_config.file_upload_dict,
|
||||
'system_parameters': {
|
||||
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -30,12 +30,14 @@ class CompletionApi(WebApiResource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, location='json', default='')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
@@ -88,6 +90,7 @@ class ChatApi(WebApiResource):
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('inputs', type=dict, required=True, location='json')
|
||||
parser.add_argument('query', type=str, required=True, location='json')
|
||||
parser.add_argument('files', type=list, required=False, location='json')
|
||||
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
|
||||
parser.add_argument('conversation_id', type=uuid_value, location='json')
|
||||
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
|
||||
@@ -95,6 +98,7 @@ class ChatApi(WebApiResource):
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args['response_mode'] == 'streaming'
|
||||
args['auto_generate_name'] = False
|
||||
|
||||
try:
|
||||
response = CompletionService.completion(
|
||||
|
||||
@@ -67,11 +67,18 @@ class ConversationRenameApi(WebApiResource):
|
||||
conversation_id = str(c_id)
|
||||
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument('name', type=str, required=True, location='json')
|
||||
parser.add_argument('name', type=str, required=False, location='json')
|
||||
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
|
||||
args = parser.parse_args()
|
||||
|
||||
try:
|
||||
return ConversationService.rename(app_model, conversation_id, end_user, args['name'])
|
||||
return ConversationService.rename(
|
||||
app_model,
|
||||
conversation_id,
|
||||
end_user,
|
||||
args['name'],
|
||||
args['auto_generate']
|
||||
)
|
||||
except ConversationNotExistsError:
|
||||
raise NotFound("Conversation Not Exists.")
|
||||
|
||||
|
||||
@@ -85,4 +85,28 @@ class UnsupportedAudioTypeError(BaseHTTPException):
|
||||
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
|
||||
error_code = 'provider_not_support_speech_to_text'
|
||||
description = "Provider not support speech to text."
|
||||
code = 400
|
||||
code = 400
|
||||
|
||||
|
||||
class NoFileUploadedError(BaseHTTPException):
|
||||
error_code = 'no_file_uploaded'
|
||||
description = "Please upload your file."
|
||||
code = 400
|
||||
|
||||
|
||||
class TooManyFilesError(BaseHTTPException):
|
||||
error_code = 'too_many_files'
|
||||
description = "Only one file is allowed."
|
||||
code = 400
|
||||
|
||||
|
||||
class FileTooLargeError(BaseHTTPException):
|
||||
error_code = 'file_too_large'
|
||||
description = "File size exceeded. {message}"
|
||||
code = 413
|
||||
|
||||
|
||||
class UnsupportedFileTypeError(BaseHTTPException):
|
||||
error_code = 'unsupported_file_type'
|
||||
description = "File type not allowed."
|
||||
code = 415
|
||||
|
||||
36
api/controllers/web/file.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from flask import request
|
||||
from flask_restful import marshal_with
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from controllers.web.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
|
||||
UnsupportedFileTypeError
|
||||
import services
|
||||
from services.file_service import FileService
|
||||
from fields.file_fields import file_fields
|
||||
|
||||
|
||||
class FileApi(WebApiResource):
|
||||
|
||||
@marshal_with(file_fields)
|
||||
def post(self, app_model, end_user):
|
||||
# get file from request
|
||||
file = request.files['file']
|
||||
|
||||
# check file
|
||||
if 'file' not in request.files:
|
||||
raise NoFileUploadedError()
|
||||
|
||||
if len(request.files) > 1:
|
||||
raise TooManyFilesError()
|
||||
try:
|
||||
upload_file = FileService.upload_file(file, end_user)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return upload_file, 201
|
||||
|
||||
|
||||
api.add_resource(FileApi, '/files/upload')
|
||||
@@ -22,6 +22,7 @@ from services.errors.app import MoreLikeThisDisabledError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
from fields.conversation_fields import message_file_fields
|
||||
|
||||
|
||||
class MessageListApi(WebApiResource):
|
||||
@@ -54,6 +55,7 @@ class MessageListApi(WebApiResource):
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField
|
||||
|
||||
@@ -8,6 +8,8 @@ from controllers.web.wraps import WebApiResource
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from services.errors.message import MessageNotExistsError
|
||||
from services.saved_message_service import SavedMessageService
|
||||
from fields.conversation_fields import message_file_fields
|
||||
|
||||
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
@@ -18,6 +20,7 @@ message_fields = {
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
@@ -11,7 +11,8 @@ from pydantic import BaseModel
|
||||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage, LCHumanMessageWithFiles, \
|
||||
ImagePromptMessageFile
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.moderation.base import ModerationOutputsResult, ModerationAction
|
||||
from core.moderation.factory import ModerationFactory
|
||||
@@ -72,7 +73,12 @@ class LLMCallbackHandler(BaseCallbackHandler):
|
||||
|
||||
real_prompts.append({
|
||||
"role": role,
|
||||
"text": message.content
|
||||
"text": message.content,
|
||||
"files": [{
|
||||
"type": file.type.value,
|
||||
"data": file.data[:10] + '...[TRUNCATED]...' + file.data[-10:],
|
||||
"detail": file.detail.value if isinstance(file, ImagePromptMessageFile) else None,
|
||||
} for file in (message.files if isinstance(message, LCHumanMessageWithFiles) else [])]
|
||||
})
|
||||
|
||||
self.llm_message.prompt = real_prompts
|
||||
|
||||
@@ -13,11 +13,12 @@ from core.callback_handler.llm_callback_handler import LLMCallbackHandler
|
||||
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.external_data_tool.factory import ExternalDataToolFactory
|
||||
from core.file.file_obj import FileObj
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
|
||||
ReadOnlyConversationTokenDBBufferSharedMemory
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.message import PromptMessage, PromptMessageFile
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.orchestrator_rule_parser import OrchestratorRuleParser
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
@@ -30,8 +31,9 @@ from core.moderation.factory import ModerationFactory
|
||||
class Completion:
|
||||
@classmethod
|
||||
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
|
||||
user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool,
|
||||
is_override: bool = False, retriever_from: str = 'dev'):
|
||||
files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
|
||||
streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
|
||||
auto_generate_name: bool = True):
|
||||
"""
|
||||
errors: ProviderTokenNotInitError
|
||||
"""
|
||||
@@ -64,16 +66,21 @@ class Completion:
|
||||
is_override=is_override,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
streaming=streaming,
|
||||
model_instance=final_model_instance
|
||||
model_instance=final_model_instance,
|
||||
auto_generate_name=auto_generate_name
|
||||
)
|
||||
|
||||
prompt_message_files = [file.prompt_message_file for file in files]
|
||||
|
||||
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
|
||||
mode=app.mode,
|
||||
model_instance=final_model_instance,
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs
|
||||
inputs=inputs,
|
||||
files=prompt_message_files
|
||||
)
|
||||
|
||||
# init orchestrator rule parser
|
||||
@@ -95,6 +102,7 @@ class Completion:
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
files=prompt_message_files,
|
||||
agent_execute_result=None,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
@@ -146,6 +154,7 @@ class Completion:
|
||||
app_model_config=app_model_config,
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
files=prompt_message_files,
|
||||
agent_execute_result=agent_execute_result,
|
||||
conversation_message_task=conversation_message_task,
|
||||
memory=memory,
|
||||
@@ -205,21 +214,20 @@ class Completion:
|
||||
results = {}
|
||||
with ThreadPoolExecutor() as executor:
|
||||
futures = {}
|
||||
for tools in grouped_tools.values():
|
||||
# Only query the first tool in each group
|
||||
first_tool = tools[0]
|
||||
for tool in external_data_tools:
|
||||
if not tool.get("enabled"):
|
||||
continue
|
||||
|
||||
future = executor.submit(
|
||||
cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, first_tool,
|
||||
cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, tool,
|
||||
inputs, query
|
||||
)
|
||||
for tool in tools:
|
||||
futures[future] = tool
|
||||
|
||||
futures[future] = tool
|
||||
|
||||
for future in concurrent.futures.as_completed(futures):
|
||||
tool_key, result = future.result()
|
||||
if tool_key in grouped_tools:
|
||||
for tool in grouped_tools[tool_key]:
|
||||
results[tool['variable']] = result
|
||||
tool_variable, result = future.result()
|
||||
results[tool_variable] = result
|
||||
|
||||
inputs.update(results)
|
||||
return inputs
|
||||
@@ -246,9 +254,7 @@ class Completion:
|
||||
query=query
|
||||
)
|
||||
|
||||
tool_key = (external_data_tool.get("type"), json.dumps(external_data_tool.get("config"), sort_keys=True))
|
||||
|
||||
return tool_key, result
|
||||
return tool_variable, result
|
||||
|
||||
@classmethod
|
||||
def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
|
||||
@@ -260,6 +266,7 @@ class Completion:
|
||||
@classmethod
|
||||
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
|
||||
inputs: dict,
|
||||
files: List[PromptMessageFile],
|
||||
agent_execute_result: Optional[AgentExecuteResult],
|
||||
conversation_message_task: ConversationMessageTask,
|
||||
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
|
||||
@@ -269,10 +276,11 @@ class Completion:
|
||||
# get llm prompt
|
||||
if app_model_config.prompt_type == 'simple':
|
||||
prompt_messages, stop_words = prompt_transform.get_prompt(
|
||||
mode=mode,
|
||||
app_mode=mode,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=agent_execute_result.output if agent_execute_result else None,
|
||||
memory=memory,
|
||||
model_instance=model_instance
|
||||
@@ -283,6 +291,7 @@ class Completion:
|
||||
app_model_config=app_model_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=agent_execute_result.output if agent_execute_result else None,
|
||||
memory=memory,
|
||||
model_instance=model_instance
|
||||
@@ -340,7 +349,7 @@ class Completion:
|
||||
|
||||
@classmethod
|
||||
def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
|
||||
query: str, inputs: dict) -> int:
|
||||
query: str, inputs: dict, files: List[PromptMessageFile]) -> int:
|
||||
model_limited_tokens = model_instance.model_rules.max_tokens.max
|
||||
max_tokens = model_instance.get_model_kwargs().max_tokens
|
||||
|
||||
@@ -351,15 +360,15 @@ class Completion:
|
||||
max_tokens = 0
|
||||
|
||||
prompt_transform = PromptTransform()
|
||||
prompt_messages = []
|
||||
|
||||
# get prompt without memory and context
|
||||
if app_model_config.prompt_type == 'simple':
|
||||
prompt_messages, _ = prompt_transform.get_prompt(
|
||||
mode=mode,
|
||||
app_mode=mode,
|
||||
pre_prompt=app_model_config.pre_prompt,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=None,
|
||||
memory=None,
|
||||
model_instance=model_instance
|
||||
@@ -370,6 +379,7 @@ class Completion:
|
||||
app_model_config=app_model_config,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
files=files,
|
||||
context=None,
|
||||
memory=None,
|
||||
model_instance=model_instance
|
||||
|
||||
@@ -6,8 +6,9 @@ from core.callback_handler.entity.agent_loop import AgentLoop
|
||||
from core.callback_handler.entity.dataset_query import DatasetQueryObj
|
||||
from core.callback_handler.entity.llm_message import LLMMessage
|
||||
from core.callback_handler.entity.chain_result import ChainResult
|
||||
from core.file.file_obj import FileObj
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
|
||||
from core.model_providers.models.entity.message import to_prompt_messages, MessageType, PromptMessageFile
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
@@ -16,13 +17,14 @@ from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.dataset import DatasetQuery
|
||||
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
|
||||
MessageChain, DatasetRetrieverResource
|
||||
MessageChain, DatasetRetrieverResource, MessageFile
|
||||
|
||||
|
||||
class ConversationMessageTask:
|
||||
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
|
||||
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
|
||||
conversation: Optional[Conversation] = None, is_override: bool = False):
|
||||
inputs: dict, query: str, files: List[FileObj], streaming: bool,
|
||||
model_instance: BaseLLM, conversation: Optional[Conversation] = None, is_override: bool = False,
|
||||
auto_generate_name: bool = True):
|
||||
self.start_at = time.perf_counter()
|
||||
|
||||
self.task_id = task_id
|
||||
@@ -35,6 +37,7 @@ class ConversationMessageTask:
|
||||
self.user = user
|
||||
self.inputs = inputs
|
||||
self.query = query
|
||||
self.files = files
|
||||
self.streaming = streaming
|
||||
|
||||
self.conversation = conversation
|
||||
@@ -45,6 +48,7 @@ class ConversationMessageTask:
|
||||
self.message = None
|
||||
|
||||
self.retriever_resource = None
|
||||
self.auto_generate_name = auto_generate_name
|
||||
|
||||
self.model_dict = self.app_model_config.model_dict
|
||||
self.provider_name = self.model_dict.get('provider')
|
||||
@@ -100,7 +104,7 @@ class ConversationMessageTask:
|
||||
model_id=self.model_name,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
mode=self.mode,
|
||||
name='',
|
||||
name='New conversation',
|
||||
inputs=self.inputs,
|
||||
introduction=introduction,
|
||||
system_instruction=system_instruction,
|
||||
@@ -142,6 +146,19 @@ class ConversationMessageTask:
|
||||
db.session.add(self.message)
|
||||
db.session.commit()
|
||||
|
||||
for file in self.files:
|
||||
message_file = MessageFile(
|
||||
message_id=self.message.id,
|
||||
type=file.type.value,
|
||||
transfer_method=file.transfer_method.value,
|
||||
url=file.url,
|
||||
upload_file_id=file.upload_file_id,
|
||||
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
|
||||
created_by=self.user.id
|
||||
)
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
|
||||
def append_message_text(self, text: str):
|
||||
if text is not None:
|
||||
self._pub_handler.pub_text(text)
|
||||
@@ -176,7 +193,8 @@ class ConversationMessageTask:
|
||||
message_was_created.send(
|
||||
self.message,
|
||||
conversation=self.conversation,
|
||||
is_first_message=self.is_new_conversation
|
||||
is_first_message=self.is_new_conversation,
|
||||
auto_generate_name=self.auto_generate_name
|
||||
)
|
||||
|
||||
if not by_stopped:
|
||||
|
||||
0
api/core/file/__init__.py
Normal file
79
api/core/file/file_obj.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import enum
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
from core.model_providers.models.entity.message import PromptMessageFile, ImagePromptMessageFile
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
class FileType(enum.Enum):
|
||||
IMAGE = 'image'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileTransferMethod(enum.Enum):
|
||||
REMOTE_URL = 'remote_url'
|
||||
LOCAL_FILE = 'local_file'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in FileTransferMethod:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
class FileObj(BaseModel):
|
||||
id: Optional[str]
|
||||
tenant_id: str
|
||||
type: FileType
|
||||
transfer_method: FileTransferMethod
|
||||
url: Optional[str]
|
||||
upload_file_id: Optional[str]
|
||||
file_config: dict
|
||||
|
||||
@property
|
||||
def data(self) -> Optional[str]:
|
||||
return self._get_data()
|
||||
|
||||
@property
|
||||
def preview_url(self) -> Optional[str]:
|
||||
return self._get_data(force_url=True)
|
||||
|
||||
@property
|
||||
def prompt_message_file(self) -> PromptMessageFile:
|
||||
if self.type == FileType.IMAGE:
|
||||
image_config = self.file_config.get('image')
|
||||
|
||||
return ImagePromptMessageFile(
|
||||
data=self.data,
|
||||
detail=ImagePromptMessageFile.DETAIL.HIGH
|
||||
if image_config.get("detail") == "high" else ImagePromptMessageFile.DETAIL.LOW
|
||||
)
|
||||
|
||||
def _get_data(self, force_url: bool = False) -> Optional[str]:
|
||||
if self.type == FileType.IMAGE:
|
||||
if self.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return self.url
|
||||
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
upload_file = (db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == self.upload_file_id,
|
||||
UploadFile.tenant_id == self.tenant_id
|
||||
).first())
|
||||
|
||||
return UploadFileParser.get_image_data(
|
||||
upload_file=upload_file,
|
||||
force_url=force_url
|
||||
)
|
||||
|
||||
return None
|
||||
180
api/core/file/message_file_parser.py
Normal file
@@ -0,0 +1,180 @@
|
||||
from typing import List, Union, Optional, Dict
|
||||
|
||||
import requests
|
||||
|
||||
from core.file.file_obj import FileObj, FileType, FileTransferMethod
|
||||
from core.file.upload_file_parser import SUPPORT_EXTENSIONS
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import MessageFile, EndUser, AppModelConfig, UploadFile
|
||||
|
||||
|
||||
class MessageFileParser:
|
||||
|
||||
def __init__(self, tenant_id: str, app_id: str) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.app_id = app_id
|
||||
|
||||
def validate_and_transform_files_arg(self, files: List[dict], app_model_config: AppModelConfig,
|
||||
user: Union[Account, EndUser]) -> List[FileObj]:
|
||||
"""
|
||||
validate and transform files arg
|
||||
|
||||
:param files:
|
||||
:param app_model_config:
|
||||
:param user:
|
||||
:return:
|
||||
"""
|
||||
file_upload_config = app_model_config.file_upload_dict
|
||||
|
||||
for file in files:
|
||||
if not isinstance(file, dict):
|
||||
raise ValueError('Invalid file format, must be dict')
|
||||
if not file.get('type'):
|
||||
raise ValueError('Missing file type')
|
||||
FileType.value_of(file.get('type'))
|
||||
if not file.get('transfer_method'):
|
||||
raise ValueError('Missing file transfer method')
|
||||
FileTransferMethod.value_of(file.get('transfer_method'))
|
||||
if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value:
|
||||
if not file.get('url'):
|
||||
raise ValueError('Missing file url')
|
||||
if not file.get('url').startswith('http'):
|
||||
raise ValueError('Invalid file url')
|
||||
if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'):
|
||||
raise ValueError('Missing file upload_file_id')
|
||||
|
||||
# transform files to file objs
|
||||
type_file_objs = self._to_file_objs(files, file_upload_config)
|
||||
|
||||
# validate files
|
||||
new_files = []
|
||||
for file_type, file_objs in type_file_objs.items():
|
||||
if file_type == FileType.IMAGE:
|
||||
# parse and validate files
|
||||
image_config = file_upload_config.get('image')
|
||||
|
||||
# check if image file feature is enabled
|
||||
if not image_config['enabled']:
|
||||
continue
|
||||
|
||||
# Validate number of files
|
||||
if len(files) > image_config['number_limits']:
|
||||
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
|
||||
|
||||
for file_obj in file_objs:
|
||||
# Validate transfer method
|
||||
if file_obj.transfer_method.value not in image_config['transfer_methods']:
|
||||
raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}')
|
||||
|
||||
# Validate file type
|
||||
if file_obj.type != FileType.IMAGE:
|
||||
raise ValueError(f'Invalid file type: {file_obj.type}')
|
||||
|
||||
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
# check remote url valid and is image
|
||||
result, error = self._check_image_remote_url(file_obj.url)
|
||||
if result is False:
|
||||
raise ValueError(error)
|
||||
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
# get upload file from upload_file_id
|
||||
upload_file = (db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == file_obj.upload_file_id,
|
||||
UploadFile.tenant_id == self.tenant_id,
|
||||
UploadFile.created_by == user.id,
|
||||
UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
|
||||
UploadFile.extension.in_(SUPPORT_EXTENSIONS)
|
||||
).first())
|
||||
|
||||
# check upload file is belong to tenant and user
|
||||
if not upload_file:
|
||||
raise ValueError('Invalid upload file')
|
||||
|
||||
new_files.append(file_obj)
|
||||
|
||||
# return all file objs
|
||||
return new_files
|
||||
|
||||
def transform_message_files(self, files: List[MessageFile], app_model_config: Optional[AppModelConfig]) -> List[FileObj]:
|
||||
"""
|
||||
transform message files
|
||||
|
||||
:param files:
|
||||
:param app_model_config:
|
||||
:return:
|
||||
"""
|
||||
# transform files to file objs
|
||||
type_file_objs = self._to_file_objs(files, app_model_config.file_upload_dict)
|
||||
|
||||
# return all file objs
|
||||
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
|
||||
|
||||
def _to_file_objs(self, files: List[Union[Dict, MessageFile]],
|
||||
file_upload_config: dict) -> Dict[FileType, List[FileObj]]:
|
||||
"""
|
||||
transform files to file objs
|
||||
|
||||
:param files:
|
||||
:param file_upload_config:
|
||||
:return:
|
||||
"""
|
||||
type_file_objs: Dict[FileType, List[FileObj]] = {
|
||||
# Currently only support image
|
||||
FileType.IMAGE: []
|
||||
}
|
||||
|
||||
if not files:
|
||||
return type_file_objs
|
||||
|
||||
# group by file type and convert file args or message files to FileObj
|
||||
for file in files:
|
||||
file_obj = self._to_file_obj(file, file_upload_config)
|
||||
if file_obj.type not in type_file_objs:
|
||||
continue
|
||||
|
||||
type_file_objs[file_obj.type].append(file_obj)
|
||||
|
||||
return type_file_objs
|
||||
|
||||
def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) -> FileObj:
|
||||
"""
|
||||
transform file to file obj
|
||||
|
||||
:param file:
|
||||
:return:
|
||||
"""
|
||||
if isinstance(file, dict):
|
||||
transfer_method = FileTransferMethod.value_of(file.get('transfer_method'))
|
||||
return FileObj(
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.get('type')),
|
||||
transfer_method=transfer_method,
|
||||
url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
upload_file_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
file_config=file_upload_config
|
||||
)
|
||||
else:
|
||||
return FileObj(
|
||||
id=file.id,
|
||||
tenant_id=self.tenant_id,
|
||||
type=FileType.value_of(file.type),
|
||||
transfer_method=FileTransferMethod.value_of(file.transfer_method),
|
||||
url=file.url,
|
||||
upload_file_id=file.upload_file_id or None,
|
||||
file_config=file_upload_config
|
||||
)
|
||||
|
||||
def _check_image_remote_url(self, url):
|
||||
try:
|
||||
headers = {
|
||||
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
|
||||
}
|
||||
|
||||
response = requests.head(url, headers=headers, allow_redirects=True)
|
||||
if response.status_code == 200:
|
||||
return True, ""
|
||||
else:
|
||||
return False, "URL does not exist."
|
||||
except requests.RequestException as e:
|
||||
return False, f"Error checking URL: {e}"
|
||||
79
api/core/file/upload_file_parser.py
Normal file
@@ -0,0 +1,79 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
from flask import current_app
|
||||
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
SUPPORT_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif']
|
||||
|
||||
|
||||
class UploadFileParser:
|
||||
@classmethod
|
||||
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
|
||||
if not upload_file:
|
||||
return None
|
||||
|
||||
if upload_file.extension not in SUPPORT_EXTENSIONS:
|
||||
return None
|
||||
|
||||
if current_app.config['MULTIMODAL_SEND_IMAGE_FORMAT'] == 'url' or force_url:
|
||||
return cls.get_signed_temp_image_url(upload_file)
|
||||
else:
|
||||
# get image file base64
|
||||
try:
|
||||
data = storage.load(upload_file.key)
|
||||
except FileNotFoundError:
|
||||
logging.error(f'File not found: {upload_file.key}')
|
||||
return None
|
||||
|
||||
encoded_string = base64.b64encode(data).decode('utf-8')
|
||||
return f'data:{upload_file.mime_type};base64,{encoded_string}'
|
||||
|
||||
@classmethod
|
||||
def get_signed_temp_image_url(cls, upload_file) -> str:
|
||||
"""
|
||||
get signed url from upload file
|
||||
|
||||
:param upload_file: UploadFile object
|
||||
:return:
|
||||
"""
|
||||
base_url = current_app.config.get('FILES_URL')
|
||||
image_preview_url = f'{base_url}/files/{upload_file.id}/image-preview'
|
||||
|
||||
timestamp = str(int(time.time()))
|
||||
nonce = os.urandom(16).hex()
|
||||
data_to_sign = f"image-preview|{upload_file.id}|{timestamp}|{nonce}"
|
||||
secret_key = current_app.config['SECRET_KEY'].encode()
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
|
||||
|
||||
@classmethod
|
||||
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
|
||||
"""
|
||||
verify signature
|
||||
|
||||
:param upload_file_id: file id
|
||||
:param timestamp: timestamp
|
||||
:param nonce: nonce
|
||||
:param sign: signature
|
||||
:return:
|
||||
"""
|
||||
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = current_app.config['SECRET_KEY'].encode()
|
||||
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
|
||||
|
||||
# verify signature
|
||||
if sign != recalculated_encoded_sign:
|
||||
return False
|
||||
|
||||
current_time = int(time.time())
|
||||
return current_time - int(timestamp) <= 300 # expired after 5 minutes
|
||||
@@ -16,7 +16,7 @@ from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
|
||||
|
||||
class LLMGenerator:
|
||||
@classmethod
|
||||
def generate_conversation_name(cls, tenant_id: str, query, answer):
|
||||
def generate_conversation_name(cls, tenant_id: str, query):
|
||||
prompt = CONVERSATION_TITLE_PROMPT
|
||||
|
||||
if len(query) > 2000:
|
||||
@@ -40,8 +40,12 @@ class LLMGenerator:
|
||||
|
||||
result_dict = json.loads(answer)
|
||||
answer = result_dict['Your Output']
|
||||
name = answer.strip()
|
||||
|
||||
return answer.strip()
|
||||
if len(name) > 75:
|
||||
name = name[:75] + '...'
|
||||
|
||||
return name
|
||||
|
||||
@classmethod
|
||||
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):
|
||||
|
||||
@@ -89,22 +89,6 @@ class IndexingRunner:
|
||||
dataset_document.stopped_at = datetime.datetime.utcnow()
|
||||
db.session.commit()
|
||||
|
||||
def format_split_text(self, text):
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
|
||||
matches = re.findall(regex, text, re.MULTILINE)
|
||||
|
||||
result = []
|
||||
for match in matches:
|
||||
q = match[0]
|
||||
a = match[1]
|
||||
if q and a:
|
||||
result.append({
|
||||
"question": q,
|
||||
"answer": re.sub(r"\n\s*", "\n", a.strip())
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
def run_in_splitting_status(self, dataset_document: DatasetDocument):
|
||||
"""Run the indexing process when the index_status is splitting."""
|
||||
try:
|
||||
@@ -647,21 +631,16 @@ class IndexingRunner:
|
||||
return text
|
||||
|
||||
def format_split_text(self, text):
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" # 匹配Q和A的正则表达式
|
||||
matches = re.findall(regex, text, re.MULTILINE) # 获取所有匹配到的结果
|
||||
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
|
||||
matches = re.findall(regex, text, re.MULTILINE)
|
||||
|
||||
result = [] # 存储最终的结果
|
||||
for match in matches:
|
||||
q = match[0]
|
||||
a = match[1]
|
||||
if q and a:
|
||||
# 如果Q和A都存在,就将其添加到结果中
|
||||
result.append({
|
||||
"question": q,
|
||||
"answer": re.sub(r"\n\s*", "\n", a.strip())
|
||||
})
|
||||
|
||||
return result
|
||||
return [
|
||||
{
|
||||
"question": q,
|
||||
"answer": re.sub(r"\n\s*", "\n", a.strip())
|
||||
}
|
||||
for q, a in matches if q and a
|
||||
]
|
||||
|
||||
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
|
||||
"""
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Any, List, Dict
|
||||
from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import get_buffer_string, BaseMessage
|
||||
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from extensions.ext_database import db
|
||||
@@ -21,6 +22,8 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
||||
@property
|
||||
def buffer(self) -> List[BaseMessage]:
|
||||
"""String buffer of memory."""
|
||||
app_model = self.conversation.app
|
||||
|
||||
# fetch limited messages desc, and return reversed
|
||||
messages = db.session.query(Message).filter(
|
||||
Message.conversation_id == self.conversation.id,
|
||||
@@ -28,10 +31,25 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
|
||||
).order_by(Message.created_at.desc()).limit(self.message_limit).all()
|
||||
|
||||
messages = list(reversed(messages))
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=self.conversation.app_id)
|
||||
|
||||
chat_messages: List[PromptMessage] = []
|
||||
for message in messages:
|
||||
chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
|
||||
files = message.message_files
|
||||
if files:
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
files, message.app_model_config
|
||||
)
|
||||
|
||||
prompt_message_files = [file_obj.prompt_message_file for file_obj in file_objs]
|
||||
chat_messages.append(PromptMessage(
|
||||
content=message.query,
|
||||
type=MessageType.USER,
|
||||
files=prompt_message_files
|
||||
))
|
||||
else:
|
||||
chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
|
||||
|
||||
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
|
||||
|
||||
if not chat_messages:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import enum
|
||||
from typing import Any, cast, Union, List, Dict
|
||||
|
||||
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
|
||||
from pydantic import BaseModel
|
||||
@@ -18,17 +19,53 @@ class MessageType(enum.Enum):
|
||||
SYSTEM = 'system'
|
||||
|
||||
|
||||
class PromptMessageFileType(enum.Enum):
|
||||
IMAGE = 'image'
|
||||
|
||||
@staticmethod
|
||||
def value_of(value):
|
||||
for member in PromptMessageFileType:
|
||||
if member.value == value:
|
||||
return member
|
||||
raise ValueError(f"No matching enum found for value '{value}'")
|
||||
|
||||
|
||||
|
||||
class PromptMessageFile(BaseModel):
|
||||
type: PromptMessageFileType
|
||||
data: Any
|
||||
|
||||
|
||||
class ImagePromptMessageFile(PromptMessageFile):
|
||||
class DETAIL(enum.Enum):
|
||||
LOW = 'low'
|
||||
HIGH = 'high'
|
||||
|
||||
type: PromptMessageFileType = PromptMessageFileType.IMAGE
|
||||
detail: DETAIL = DETAIL.LOW
|
||||
|
||||
|
||||
class PromptMessage(BaseModel):
|
||||
type: MessageType = MessageType.USER
|
||||
content: str = ''
|
||||
files: list[PromptMessageFile] = []
|
||||
function_call: dict = None
|
||||
|
||||
|
||||
class LCHumanMessageWithFiles(HumanMessage):
|
||||
# content: Union[str, List[Union[str, Dict]]]
|
||||
content: str
|
||||
files: list[PromptMessageFile]
|
||||
|
||||
|
||||
def to_lc_messages(messages: list[PromptMessage]):
|
||||
lc_messages = []
|
||||
for message in messages:
|
||||
if message.type == MessageType.USER:
|
||||
lc_messages.append(HumanMessage(content=message.content))
|
||||
if not message.files:
|
||||
lc_messages.append(HumanMessage(content=message.content))
|
||||
else:
|
||||
lc_messages.append(LCHumanMessageWithFiles(content=message.content, files=message.files))
|
||||
elif message.type == MessageType.ASSISTANT:
|
||||
additional_kwargs = {}
|
||||
if message.function_call:
|
||||
@@ -44,7 +81,14 @@ def to_prompt_messages(messages: list[BaseMessage]):
|
||||
prompt_messages = []
|
||||
for message in messages:
|
||||
if isinstance(message, HumanMessage):
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
|
||||
if isinstance(message, LCHumanMessageWithFiles):
|
||||
prompt_messages.append(PromptMessage(
|
||||
content=message.content,
|
||||
type=MessageType.USER,
|
||||
files=message.files
|
||||
))
|
||||
else:
|
||||
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
|
||||
elif isinstance(message, AIMessage):
|
||||
message_kwargs = {
|
||||
'content': message.content,
|
||||
|
||||
@@ -8,3 +8,4 @@ class ProviderQuotaUnit(Enum):
|
||||
|
||||
class ModelFeature(Enum):
|
||||
AGENT_THOUGHT = 'agent_thought'
|
||||
VISION = 'vision'
|
||||
|
||||
@@ -19,6 +19,13 @@ from core.model_providers.models.entity.model_params import ModelMode, ModelKwar
|
||||
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
|
||||
|
||||
|
||||
FUNCTION_CALL_MODELS = [
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-35-turbo',
|
||||
'gpt-35-turbo-16k'
|
||||
]
|
||||
|
||||
class AzureOpenAIModel(BaseLLM):
|
||||
def __init__(self, model_provider: BaseModelProvider,
|
||||
name: str,
|
||||
@@ -157,3 +164,7 @@ class AzureOpenAIModel(BaseLLM):
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def support_function_call(self):
|
||||
return self.base_model_name in FUNCTION_CALL_MODELS
|
||||
|
||||
@@ -310,6 +310,10 @@ class BaseLLM(BaseProviderModel):
|
||||
def support_streaming(self):
|
||||
return False
|
||||
|
||||
@property
|
||||
def support_function_call(self):
|
||||
return False
|
||||
|
||||
def _get_prompt_from_messages(self, messages: List[PromptMessage],
|
||||
model_mode: Optional[ModelMode] = None) -> Union[str , List[BaseMessage]]:
|
||||
if not model_mode:
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
import decimal
|
||||
import logging
|
||||
from typing import List, Optional, Any
|
||||
|
||||
import openai
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
from openai import api_requestor
|
||||
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
|
||||
@@ -23,21 +21,36 @@ COMPLETION_MODELS = [
|
||||
]
|
||||
|
||||
CHAT_MODELS = [
|
||||
'gpt-4-1106-preview', # 128,000 tokens
|
||||
'gpt-4-vision-preview', # 128,000 tokens
|
||||
'gpt-4', # 8,192 tokens
|
||||
'gpt-4-32k', # 32,768 tokens
|
||||
'gpt-3.5-turbo-1106', # 16,384 tokens
|
||||
'gpt-3.5-turbo', # 4,096 tokens
|
||||
'gpt-3.5-turbo-16k', # 16,384 tokens
|
||||
]
|
||||
|
||||
MODEL_MAX_TOKENS = {
|
||||
'gpt-4-1106-preview': 128000,
|
||||
'gpt-4-vision-preview': 128000,
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo-1106': 16384,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-instruct': 4097,
|
||||
'gpt-3.5-turbo-16k': 16384,
|
||||
'text-davinci-003': 4097,
|
||||
}
|
||||
|
||||
FUNCTION_CALL_MODELS = [
|
||||
'gpt-4-1106-preview',
|
||||
'gpt-4',
|
||||
'gpt-4-32k',
|
||||
'gpt-3.5-turbo-1106',
|
||||
'gpt-3.5-turbo',
|
||||
'gpt-3.5-turbo-16k'
|
||||
]
|
||||
|
||||
|
||||
class OpenAIModel(BaseLLM):
|
||||
def __init__(self, model_provider: BaseModelProvider,
|
||||
@@ -50,7 +63,6 @@ class OpenAIModel(BaseLLM):
|
||||
else:
|
||||
self.model_mode = ModelMode.CHAT
|
||||
|
||||
# TODO load price config from configs(db)
|
||||
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
@@ -100,7 +112,7 @@ class OpenAIModel(BaseLLM):
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
if self.name == 'gpt-4' \
|
||||
if self.name.startswith('gpt-4') \
|
||||
and self.model_provider.provider.provider_type == ProviderType.SYSTEM.value \
|
||||
and self.model_provider.provider.quota_type == ProviderQuotaType.TRIAL.value:
|
||||
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
|
||||
@@ -175,6 +187,10 @@ class OpenAIModel(BaseLLM):
|
||||
def support_streaming(self):
|
||||
return True
|
||||
|
||||
@property
|
||||
def support_function_call(self):
|
||||
return self.name in FUNCTION_CALL_MODELS
|
||||
|
||||
# def is_model_valid_or_raise(self):
|
||||
# """
|
||||
# check is a valid model.
|
||||
|
||||
@@ -41,9 +41,17 @@ class OpenAIProvider(BaseModelProvider):
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-3.5-turbo-1106',
|
||||
'name': 'gpt-3.5-turbo-1106',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-3.5-turbo-instruct',
|
||||
'name': 'GPT-3.5-Turbo-Instruct',
|
||||
'name': 'gpt-3.5-turbo-instruct',
|
||||
'mode': ModelMode.COMPLETION.value,
|
||||
},
|
||||
{
|
||||
@@ -62,6 +70,22 @@ class OpenAIProvider(BaseModelProvider):
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-4-1106-preview',
|
||||
'name': 'gpt-4-1106-preview',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.AGENT_THOUGHT.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-4-vision-preview',
|
||||
'name': 'gpt-4-vision-preview',
|
||||
'mode': ModelMode.CHAT.value,
|
||||
'features': [
|
||||
ModelFeature.VISION.value
|
||||
]
|
||||
},
|
||||
{
|
||||
'id': 'gpt-4-32k',
|
||||
'name': 'gpt-4-32k',
|
||||
@@ -79,7 +103,7 @@ class OpenAIProvider(BaseModelProvider):
|
||||
|
||||
if self.provider.provider_type == ProviderType.SYSTEM.value \
|
||||
and self.provider.quota_type == ProviderQuotaType.TRIAL.value:
|
||||
models = [item for item in models if item['id'] not in ['gpt-4', 'gpt-4-32k']]
|
||||
models = [item for item in models if not item['id'].startswith('gpt-4')]
|
||||
|
||||
return models
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
@@ -141,8 +165,11 @@ class OpenAIProvider(BaseModelProvider):
|
||||
:return:
|
||||
"""
|
||||
model_max_tokens = {
|
||||
'gpt-4-1106-preview': 128000,
|
||||
'gpt-4-vision-preview': 128000,
|
||||
'gpt-4': 8192,
|
||||
'gpt-4-32k': 32768,
|
||||
'gpt-3.5-turbo-1106': 16384,
|
||||
'gpt-3.5-turbo': 4096,
|
||||
'gpt-3.5-turbo-instruct': 4097,
|
||||
'gpt-3.5-turbo-16k': 16384,
|
||||
|
||||
@@ -24,12 +24,30 @@
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-4-1106-preview": {
|
||||
"prompt": "0.01",
|
||||
"completion": "0.03",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-4-vision-preview": {
|
||||
"prompt": "0.01",
|
||||
"completion": "0.03",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-3.5-turbo": {
|
||||
"prompt": "0.0015",
|
||||
"completion": "0.002",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-3.5-turbo-1106": {
|
||||
"prompt": "0.0010",
|
||||
"completion": "0.002",
|
||||
"unit": "0.001",
|
||||
"currency": "USD"
|
||||
},
|
||||
"gpt-3.5-turbo-instruct": {
|
||||
"prompt": "0.0015",
|
||||
"completion": "0.002",
|
||||
|
||||
@@ -73,8 +73,7 @@ class OrchestratorRuleParser:
|
||||
planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))
|
||||
|
||||
# only OpenAI chat model (include Azure) support function call, use ReACT instead
|
||||
if agent_model_instance.model_mode != ModelMode.CHAT \
|
||||
or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
|
||||
if not agent_model_instance.support_function_call:
|
||||
if planning_strategy == PlanningStrategy.FUNCTION_CALL:
|
||||
planning_strategy = PlanningStrategy.REACT
|
||||
elif planning_strategy == PlanningStrategy.ROUTER:
|
||||
|
||||
@@ -8,7 +8,7 @@ from langchain.memory.chat_memory import BaseChatMemory
|
||||
from langchain.schema import BaseMessage
|
||||
|
||||
from core.model_providers.models.entity.model_params import ModelMode
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages, PromptMessageFile
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.llm.baichuan_model import BaichuanModel
|
||||
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
|
||||
@@ -16,32 +16,57 @@ from core.model_providers.models.llm.openllm_model import OpenLLMModel
|
||||
from core.model_providers.models.llm.xinference_model import XinferenceModel
|
||||
from core.prompt.prompt_builder import PromptBuilder
|
||||
from core.prompt.prompt_template import PromptTemplateParser
|
||||
from models.model import AppModelConfig
|
||||
|
||||
|
||||
class AppMode(enum.Enum):
|
||||
COMPLETION = 'completion'
|
||||
CHAT = 'chat'
|
||||
|
||||
|
||||
class PromptTransform:
|
||||
def get_prompt(self, mode: str,
|
||||
pre_prompt: str, inputs: dict,
|
||||
def get_prompt(self,
|
||||
app_mode: str,
|
||||
pre_prompt: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
files: List[PromptMessageFile],
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM) -> \
|
||||
Tuple[List[PromptMessage], Optional[List[str]]]:
|
||||
prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name(mode, model_instance))
|
||||
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory, model_instance)
|
||||
return [PromptMessage(content=prompt)], stops
|
||||
|
||||
def get_advanced_prompt(self,
|
||||
app_mode: str,
|
||||
app_model_config: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM) -> List[PromptMessage]:
|
||||
|
||||
app_mode_enum = AppMode(app_mode)
|
||||
model_mode_enum = model_instance.model_mode
|
||||
|
||||
prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name(app_mode, model_instance))
|
||||
|
||||
if app_mode_enum == AppMode.CHAT and model_mode_enum == ModelMode.CHAT:
|
||||
stops = None
|
||||
|
||||
prompt_messages = self._get_simple_chat_app_chat_model_prompt_messages(prompt_rules, pre_prompt, inputs,
|
||||
query, context, memory,
|
||||
model_instance, files)
|
||||
else:
|
||||
stops = prompt_rules.get('stops')
|
||||
if stops is not None and len(stops) == 0:
|
||||
stops = None
|
||||
|
||||
prompt_messages = self._get_simple_others_prompt_messages(prompt_rules, pre_prompt, inputs, query, context,
|
||||
memory,
|
||||
model_instance, files)
|
||||
return prompt_messages, stops
|
||||
|
||||
def get_advanced_prompt(self,
|
||||
app_mode: str,
|
||||
app_model_config: AppModelConfig,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
files: List[PromptMessageFile],
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM) -> List[PromptMessage]:
|
||||
|
||||
model_mode = app_model_config.model_dict['mode']
|
||||
|
||||
app_mode_enum = AppMode(app_mode)
|
||||
@@ -51,15 +76,20 @@ class PromptTransform:
|
||||
|
||||
if app_mode_enum == AppMode.CHAT:
|
||||
if model_mode_enum == ModelMode.COMPLETION:
|
||||
prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance)
|
||||
prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query,
|
||||
files, context, memory,
|
||||
model_instance)
|
||||
elif model_mode_enum == ModelMode.CHAT:
|
||||
prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance)
|
||||
prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, files,
|
||||
context, memory, model_instance)
|
||||
elif app_mode_enum == AppMode.COMPLETION:
|
||||
if model_mode_enum == ModelMode.CHAT:
|
||||
prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs, context)
|
||||
prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs,
|
||||
files, context)
|
||||
elif model_mode_enum == ModelMode.COMPLETION:
|
||||
prompt_messages = self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs, context)
|
||||
|
||||
prompt_messages = self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs,
|
||||
files, context)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _get_history_messages_from_memory(self, memory: BaseChatMemory,
|
||||
@@ -71,7 +101,7 @@ class PromptTransform:
|
||||
return external_context[memory_key]
|
||||
|
||||
def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
|
||||
max_token_limit: int) -> List[PromptMessage]:
|
||||
max_token_limit: int) -> List[PromptMessage]:
|
||||
"""Get memory messages."""
|
||||
memory.max_token_limit = max_token_limit
|
||||
memory.return_messages = True
|
||||
@@ -79,7 +109,7 @@ class PromptTransform:
|
||||
external_context = memory.load_memory_variables({})
|
||||
memory.return_messages = False
|
||||
return to_prompt_messages(external_context[memory_key])
|
||||
|
||||
|
||||
def _prompt_file_name(self, mode: str, model_instance: BaseLLM) -> str:
|
||||
# baichuan
|
||||
if isinstance(model_instance, BaichuanModel):
|
||||
@@ -94,13 +124,13 @@ class PromptTransform:
|
||||
return 'common_completion'
|
||||
else:
|
||||
return 'common_chat'
|
||||
|
||||
|
||||
def _prompt_file_name_for_baichuan(self, mode: str) -> str:
|
||||
if mode == 'completion':
|
||||
return 'baichuan_completion'
|
||||
else:
|
||||
return 'baichuan_chat'
|
||||
|
||||
|
||||
def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
|
||||
# Get the absolute path of the subdirectory
|
||||
prompt_path = os.path.join(
|
||||
@@ -111,12 +141,53 @@ class PromptTransform:
|
||||
# Open the JSON file and read its content
|
||||
with open(json_file_path, 'r') as json_file:
|
||||
return json.load(json_file)
|
||||
|
||||
def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM) -> Tuple[str, Optional[list]]:
|
||||
|
||||
def _get_simple_chat_app_chat_model_prompt_messages(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM,
|
||||
files: List[PromptMessageFile]) -> List[PromptMessage]:
|
||||
prompt_messages = []
|
||||
|
||||
context_prompt_content = ''
|
||||
if context and 'context_prompt' in prompt_rules:
|
||||
prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
|
||||
context_prompt_content = prompt_template.format(
|
||||
{'context': context}
|
||||
)
|
||||
|
||||
pre_prompt_content = ''
|
||||
if pre_prompt:
|
||||
prompt_template = PromptTemplateParser(template=pre_prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
pre_prompt_content = prompt_template.format(
|
||||
prompt_inputs
|
||||
)
|
||||
|
||||
prompt = ''
|
||||
for order in prompt_rules['system_prompt_orders']:
|
||||
if order == 'context_prompt':
|
||||
prompt += context_prompt_content
|
||||
elif order == 'pre_prompt':
|
||||
prompt += pre_prompt_content
|
||||
|
||||
prompt = re.sub(r'<\|.*?\|>', '', prompt)
|
||||
|
||||
prompt_messages.append(PromptMessage(type=MessageType.SYSTEM, content=prompt))
|
||||
|
||||
self._append_chat_histories(memory, prompt_messages, model_instance)
|
||||
|
||||
prompt_messages.append(PromptMessage(type=MessageType.USER, content=query, files=files))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _get_simple_others_prompt_messages(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM,
|
||||
files: List[PromptMessageFile]) -> List[PromptMessage]:
|
||||
context_prompt_content = ''
|
||||
if context and 'context_prompt' in prompt_rules:
|
||||
prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
|
||||
@@ -175,16 +246,12 @@ class PromptTransform:
|
||||
|
||||
prompt = re.sub(r'<\|.*?\|>', '', prompt)
|
||||
|
||||
stops = prompt_rules.get('stops')
|
||||
if stops is not None and len(stops) == 0:
|
||||
stops = None
|
||||
return [PromptMessage(content=prompt, files=files)]
|
||||
|
||||
return prompt, stops
|
||||
|
||||
def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
|
||||
if '#context#' in prompt_template.variable_keys:
|
||||
if context:
|
||||
prompt_inputs['#context#'] = context
|
||||
prompt_inputs['#context#'] = context
|
||||
else:
|
||||
prompt_inputs['#context#'] = ''
|
||||
|
||||
@@ -195,17 +262,18 @@ class PromptTransform:
|
||||
else:
|
||||
prompt_inputs['#query#'] = ''
|
||||
|
||||
def _set_histories_variable(self, memory: BaseChatMemory, raw_prompt: str, conversation_histories_role: dict,
|
||||
prompt_template: PromptTemplateParser, prompt_inputs: dict, model_instance: BaseLLM) -> None:
|
||||
def _set_histories_variable(self, memory: BaseChatMemory, raw_prompt: str, conversation_histories_role: dict,
|
||||
prompt_template: PromptTemplateParser, prompt_inputs: dict,
|
||||
model_instance: BaseLLM) -> None:
|
||||
if '#histories#' in prompt_template.variable_keys:
|
||||
if memory:
|
||||
tmp_human_message = PromptBuilder.to_human_message(
|
||||
prompt_content=raw_prompt,
|
||||
inputs={ '#histories#': '', **prompt_inputs }
|
||||
inputs={'#histories#': '', **prompt_inputs}
|
||||
)
|
||||
|
||||
rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance)
|
||||
|
||||
|
||||
memory.human_prefix = conversation_histories_role['user_prefix']
|
||||
memory.ai_prefix = conversation_histories_role['assistant_prefix']
|
||||
histories = self._get_history_messages_from_memory(memory, rest_tokens)
|
||||
@@ -213,7 +281,8 @@ class PromptTransform:
|
||||
else:
|
||||
prompt_inputs['#histories#'] = ''
|
||||
|
||||
def _append_chat_histories(self, memory: BaseChatMemory, prompt_messages: list[PromptMessage], model_instance: BaseLLM) -> None:
|
||||
def _append_chat_histories(self, memory: BaseChatMemory, prompt_messages: list[PromptMessage],
|
||||
model_instance: BaseLLM) -> None:
|
||||
if memory:
|
||||
rest_tokens = self._calculate_rest_token(prompt_messages, model_instance)
|
||||
|
||||
@@ -242,19 +311,19 @@ class PromptTransform:
|
||||
return prompt
|
||||
|
||||
def _get_chat_app_completion_model_prompt_messages(self,
|
||||
app_model_config: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM) -> List[PromptMessage]:
|
||||
|
||||
app_model_config: AppModelConfig,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
files: List[PromptMessageFile],
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM) -> List[PromptMessage]:
|
||||
|
||||
raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
|
||||
conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
|
||||
|
||||
prompt_messages = []
|
||||
prompt = ''
|
||||
|
||||
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
@@ -262,28 +331,29 @@ class PromptTransform:
|
||||
|
||||
self._set_query_variable(query, prompt_template, prompt_inputs)
|
||||
|
||||
self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs, model_instance)
|
||||
self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs,
|
||||
model_instance)
|
||||
|
||||
prompt = self._format_prompt(prompt_template, prompt_inputs)
|
||||
|
||||
prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))
|
||||
prompt_messages.append(PromptMessage(type=MessageType.USER, content=prompt, files=files))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _get_chat_app_chat_model_prompt_messages(self,
|
||||
app_model_config: str,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM) -> List[PromptMessage]:
|
||||
app_model_config: AppModelConfig,
|
||||
inputs: dict,
|
||||
query: str,
|
||||
files: List[PromptMessageFile],
|
||||
context: Optional[str],
|
||||
memory: Optional[BaseChatMemory],
|
||||
model_instance: BaseLLM) -> List[PromptMessage]:
|
||||
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
|
||||
|
||||
prompt_messages = []
|
||||
|
||||
for prompt_item in raw_prompt_list:
|
||||
raw_prompt = prompt_item['text']
|
||||
prompt = ''
|
||||
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
@@ -292,23 +362,23 @@ class PromptTransform:
|
||||
|
||||
prompt = self._format_prompt(prompt_template, prompt_inputs)
|
||||
|
||||
prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
|
||||
|
||||
prompt_messages.append(PromptMessage(type=MessageType(prompt_item['role']), content=prompt))
|
||||
|
||||
self._append_chat_histories(memory, prompt_messages, model_instance)
|
||||
|
||||
prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
|
||||
prompt_messages.append(PromptMessage(type=MessageType.USER, content=query, files=files))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _get_completion_app_completion_model_prompt_messages(self,
|
||||
app_model_config: str,
|
||||
inputs: dict,
|
||||
context: Optional[str]) -> List[PromptMessage]:
|
||||
app_model_config: AppModelConfig,
|
||||
inputs: dict,
|
||||
files: List[PromptMessageFile],
|
||||
context: Optional[str]) -> List[PromptMessage]:
|
||||
raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
|
||||
|
||||
prompt_messages = []
|
||||
prompt = ''
|
||||
|
||||
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
|
||||
@@ -316,21 +386,21 @@ class PromptTransform:
|
||||
|
||||
prompt = self._format_prompt(prompt_template, prompt_inputs)
|
||||
|
||||
prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))
|
||||
prompt_messages.append(PromptMessage(type=MessageType(MessageType.USER), content=prompt, files=files))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _get_completion_app_chat_model_prompt_messages(self,
|
||||
app_model_config: str,
|
||||
inputs: dict,
|
||||
context: Optional[str]) -> List[PromptMessage]:
|
||||
app_model_config: AppModelConfig,
|
||||
inputs: dict,
|
||||
files: List[PromptMessageFile],
|
||||
context: Optional[str]) -> List[PromptMessage]:
|
||||
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
|
||||
|
||||
prompt_messages = []
|
||||
|
||||
for prompt_item in raw_prompt_list:
|
||||
raw_prompt = prompt_item['text']
|
||||
prompt = ''
|
||||
|
||||
prompt_template = PromptTemplateParser(template=raw_prompt)
|
||||
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
|
||||
@@ -339,6 +409,11 @@ class PromptTransform:
|
||||
|
||||
prompt = self._format_prompt(prompt_template, prompt_inputs)
|
||||
|
||||
prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
|
||||
|
||||
return prompt_messages
|
||||
prompt_messages.append(PromptMessage(type=MessageType(prompt_item['role']), content=prompt))
|
||||
|
||||
for prompt_message in prompt_messages[::-1]:
|
||||
if prompt_message.type == MessageType.USER:
|
||||
prompt_message.files = files
|
||||
break
|
||||
|
||||
return prompt_messages
|
||||
|
||||
104
api/core/third_party/langchain/llms/chat_open_ai.py
vendored
@@ -1,10 +1,13 @@
|
||||
import os
|
||||
|
||||
from typing import Dict, Any, Optional, Union, Tuple
|
||||
from typing import Dict, Any, Optional, Union, Tuple, List, cast
|
||||
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage, FunctionMessage
|
||||
from pydantic import root_validator
|
||||
|
||||
from core.model_providers.models.entity.message import LCHumanMessageWithFiles, PromptMessageFileType, ImagePromptMessageFile
|
||||
|
||||
|
||||
class EnhanceChatOpenAI(ChatOpenAI):
|
||||
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
|
||||
@@ -48,3 +51,102 @@ class EnhanceChatOpenAI(ChatOpenAI):
|
||||
"api_key": self.openai_api_key,
|
||||
"organization": self.openai_organization if self.openai_organization else None,
|
||||
}
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||
params = self._client_params
|
||||
if stop is not None:
|
||||
if "stop" in params:
|
||||
raise ValueError("`stop` found in both the input and default params.")
|
||||
params["stop"] = stop
|
||||
message_dicts = [self._convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts, params
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
|
||||
|
||||
Official documentation: https://github.com/openai/openai-cookbook/blob/
|
||||
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
|
||||
model, encoding = self._get_encoding_model()
|
||||
if model.startswith("gpt-3.5-turbo-0301"):
|
||||
# every message follows <im_start>{role/name}\n{content}<im_end>\n
|
||||
tokens_per_message = 4
|
||||
# if there's a name, the role is omitted
|
||||
tokens_per_name = -1
|
||||
elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"get_num_tokens_from_messages() is not presently implemented "
|
||||
f"for model {model}."
|
||||
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
|
||||
"information on how messages are converted to tokens."
|
||||
)
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
num_tokens += tokens_per_message
|
||||
for key, value in message.items():
|
||||
# Cast str(value) in case the message value is not a string
|
||||
# This occurs with function messages
|
||||
# TODO: The current token calculation method for the image type is not implemented,
|
||||
# which need to download the image and then get the resolution for calculation,
|
||||
# and will increase the request delay
|
||||
if isinstance(value, list):
|
||||
text = ''
|
||||
for item in value:
|
||||
if isinstance(item, dict) and item['type'] == 'text':
|
||||
text += item['text']
|
||||
|
||||
value = text
|
||||
num_tokens += len(encoding.encode(str(value)))
|
||||
if key == "name":
|
||||
num_tokens += tokens_per_name
|
||||
# every reply is primed with <im_start>assistant
|
||||
num_tokens += 3
|
||||
return num_tokens
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, LCHumanMessageWithFiles):
|
||||
content = [
|
||||
{
|
||||
"type": "text",
|
||||
"text": message.content
|
||||
}
|
||||
]
|
||||
|
||||
for file in message.files:
|
||||
if file.type == PromptMessageFileType.IMAGE:
|
||||
file = cast(ImagePromptMessageFile, file)
|
||||
content.append({
|
||||
"type": "image_url",
|
||||
"image_url": {
|
||||
"url": file.data,
|
||||
"detail": file.detail.value
|
||||
}
|
||||
})
|
||||
|
||||
message_dict = {"role": "user", "content": content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if "function_call" in message.additional_kwargs:
|
||||
message_dict["function_call"] = message.additional_kwargs["function_call"]
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {
|
||||
"role": "function",
|
||||
"content": message.content,
|
||||
"name": message.name,
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
if "name" in message.additional_kwargs:
|
||||
message_dict["name"] = message.additional_kwargs["name"]
|
||||
return message_dict
|
||||
|
||||
10
api/core/third_party/spark/spark_llm.py
vendored
@@ -17,8 +17,12 @@ import websocket
|
||||
|
||||
class SparkLLMClient:
|
||||
def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
|
||||
|
||||
domain = 'spark-api.xf-yun.com' if not api_domain else api_domain
|
||||
domain = 'spark-api.xf-yun.com'
|
||||
endpoint = 'chat'
|
||||
if api_domain:
|
||||
domain = api_domain
|
||||
if model_name == 'spark-v3':
|
||||
endpoint = 'multimodal'
|
||||
|
||||
model_api_configs = {
|
||||
'spark': {
|
||||
@@ -38,7 +42,7 @@ class SparkLLMClient:
|
||||
api_version = model_api_configs[model_name]['version']
|
||||
|
||||
self.chat_domain = model_api_configs[model_name]['chat_domain']
|
||||
self.api_base = f"wss://{domain}/{api_version}/chat"
|
||||
self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
|
||||
self.app_id = app_id
|
||||
self.ws_url = self.create_url(
|
||||
urlparse(self.api_base).netloc,
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import logging
|
||||
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
@@ -10,8 +8,9 @@ def handle(sender, **kwargs):
|
||||
message = sender
|
||||
conversation = kwargs.get('conversation')
|
||||
is_first_message = kwargs.get('is_first_message')
|
||||
auto_generate_name = kwargs.get('auto_generate_name', True)
|
||||
|
||||
if is_first_message:
|
||||
if auto_generate_name and is_first_message:
|
||||
if conversation.mode == 'chat':
|
||||
app_model = conversation.app
|
||||
if not app_model:
|
||||
@@ -19,14 +18,9 @@ def handle(sender, **kwargs):
|
||||
|
||||
# generate conversation name
|
||||
try:
|
||||
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query, message.answer)
|
||||
|
||||
if len(name) > 75:
|
||||
name = name[:75] + '...'
|
||||
|
||||
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query)
|
||||
conversation.name = name
|
||||
except:
|
||||
conversation.name = 'New conversation'
|
||||
pass
|
||||
|
||||
db.session.add(conversation)
|
||||
db.session.commit()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import os
|
||||
import shutil
|
||||
from contextlib import closing
|
||||
from typing import Union, Generator
|
||||
|
||||
import boto3
|
||||
from botocore.exceptions import ClientError
|
||||
@@ -45,7 +46,13 @@ class Storage:
|
||||
with open(os.path.join(os.getcwd(), filename), "wb") as f:
|
||||
f.write(data)
|
||||
|
||||
def load(self, filename):
|
||||
def load(self, filename: str, stream: bool = False) -> Union[bytes, Generator]:
|
||||
if stream:
|
||||
return self.load_stream(filename)
|
||||
else:
|
||||
return self.load_once(filename)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
if self.storage_type == 's3':
|
||||
try:
|
||||
with closing(self.client) as client:
|
||||
@@ -69,6 +76,34 @@ class Storage:
|
||||
|
||||
return data
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
def generate(filename: str = filename) -> Generator:
|
||||
if self.storage_type == 's3':
|
||||
try:
|
||||
with closing(self.client) as client:
|
||||
response = client.get_object(Bucket=self.bucket_name, Key=filename)
|
||||
for chunk in response['Body'].iter_chunks():
|
||||
yield chunk
|
||||
except ClientError as ex:
|
||||
if ex.response['Error']['Code'] == 'NoSuchKey':
|
||||
raise FileNotFoundError("File not found")
|
||||
else:
|
||||
raise
|
||||
else:
|
||||
if not self.folder or self.folder.endswith('/'):
|
||||
filename = self.folder + filename
|
||||
else:
|
||||
filename = self.folder + '/' + filename
|
||||
|
||||
if not os.path.exists(filename):
|
||||
raise FileNotFoundError("File not found")
|
||||
|
||||
with open(filename, "rb") as f:
|
||||
while chunk := f.read(4096): # Read in chunks of 4KB
|
||||
yield chunk
|
||||
|
||||
return generate()
|
||||
|
||||
def download(self, filename, target_filepath):
|
||||
if self.storage_type == 's3':
|
||||
with closing(self.client) as client:
|
||||
|
||||
@@ -5,7 +5,13 @@ from libs.helper import TimestampField
|
||||
|
||||
class HiddenAPIKey(fields.Raw):
|
||||
def output(self, key, obj):
|
||||
return obj.api_key[:3] + '***' + obj.api_key[-3:]
|
||||
api_key = obj.api_key
|
||||
# If the length of the api_key is less than 8 characters, show the first and last characters
|
||||
if len(api_key) <= 8:
|
||||
return api_key[0] + '******' + api_key[-1]
|
||||
# If the api_key is greater than 8 characters, show the first three and the last three characters
|
||||
else:
|
||||
return api_key[:3] + '******' + api_key[-3:]
|
||||
|
||||
|
||||
api_based_extension_fields = {
|
||||
|
||||
@@ -32,7 +32,8 @@ model_config_fields = {
|
||||
'prompt_type': fields.String,
|
||||
'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'),
|
||||
'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'),
|
||||
'dataset_configs': fields.Raw(attribute='dataset_configs_dict')
|
||||
'dataset_configs': fields.Raw(attribute='dataset_configs_dict'),
|
||||
'file_upload': fields.Raw(attribute='file_upload_dict'),
|
||||
}
|
||||
|
||||
app_detail_fields = {
|
||||
@@ -140,4 +141,4 @@ app_site_fields = {
|
||||
'privacy_policy': fields.String,
|
||||
'customize_token_strategy': fields.String,
|
||||
'prompt_public': fields.Boolean
|
||||
}
|
||||
}
|
||||
|
||||
@@ -28,6 +28,12 @@ annotation_fields = {
|
||||
'created_at': TimestampField
|
||||
}
|
||||
|
||||
message_file_fields = {
|
||||
'id': fields.String,
|
||||
'type': fields.String,
|
||||
'url': fields.String,
|
||||
}
|
||||
|
||||
message_detail_fields = {
|
||||
'id': fields.String,
|
||||
'conversation_id': fields.String,
|
||||
@@ -43,7 +49,8 @@ message_detail_fields = {
|
||||
'from_account_id': fields.String,
|
||||
'feedbacks': fields.List(fields.Nested(feedback_fields)),
|
||||
'annotation': fields.Nested(annotation_fields, allow_null=True),
|
||||
'created_at': TimestampField
|
||||
'created_at': TimestampField,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
}
|
||||
|
||||
feedback_stat_fields = {
|
||||
@@ -111,11 +118,6 @@ conversation_message_detail_fields = {
|
||||
'message': fields.Nested(message_detail_fields, attribute='first_message'),
|
||||
}
|
||||
|
||||
simple_model_config_fields = {
|
||||
'model': fields.Raw(attribute='model_dict'),
|
||||
'pre_prompt': fields.String,
|
||||
}
|
||||
|
||||
conversation_with_summary_fields = {
|
||||
'id': fields.String,
|
||||
'status': fields.String,
|
||||
@@ -180,4 +182,4 @@ conversation_with_model_config_infinite_scroll_pagination_fields = {
|
||||
'limit': fields.Integer,
|
||||
'has_more': fields.Boolean,
|
||||
'data': fields.List(fields.Nested(conversation_with_model_config_fields))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -4,7 +4,8 @@ from libs.helper import TimestampField
|
||||
|
||||
upload_config_fields = {
|
||||
'file_size_limit': fields.Integer,
|
||||
'batch_count_limit': fields.Integer
|
||||
'batch_count_limit': fields.Integer,
|
||||
'image_file_size_limit': fields.Integer,
|
||||
}
|
||||
|
||||
file_fields = {
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from flask_restful import fields
|
||||
|
||||
from libs.helper import TimestampField
|
||||
from fields.conversation_fields import message_file_fields
|
||||
|
||||
feedback_fields = {
|
||||
'rating': fields.String
|
||||
@@ -31,6 +32,7 @@ message_fields = {
|
||||
'inputs': fields.Raw,
|
||||
'query': fields.String,
|
||||
'answer': fields.String,
|
||||
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
|
||||
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
|
||||
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
|
||||
'created_at': TimestampField
|
||||
|
||||
59
api/migrations/versions/8fe468ba0ca5_add_gpt4v_supports.py
Normal file
@@ -0,0 +1,59 @@
|
||||
"""add gpt4v supports
|
||||
|
||||
Revision ID: 8fe468ba0ca5
|
||||
Revises: a9836e3baeee
|
||||
Create Date: 2023-11-09 11:39:00.006432
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '8fe468ba0ca5'
|
||||
down_revision = 'a9836e3baeee'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('message_files',
|
||||
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('message_id', postgresql.UUID(), nullable=False),
|
||||
sa.Column('type', sa.String(length=255), nullable=False),
|
||||
sa.Column('transfer_method', sa.String(length=255), nullable=False),
|
||||
sa.Column('url', sa.Text(), nullable=True),
|
||||
sa.Column('upload_file_id', postgresql.UUID(), nullable=True),
|
||||
sa.Column('created_by_role', sa.String(length=255), nullable=False),
|
||||
sa.Column('created_by', postgresql.UUID(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='message_file_pkey')
|
||||
)
|
||||
with op.batch_alter_table('message_files', schema=None) as batch_op:
|
||||
batch_op.create_index('message_file_created_by_idx', ['created_by'], unique=False)
|
||||
batch_op.create_index('message_file_message_idx', ['message_id'], unique=False)
|
||||
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('file_upload', sa.Text(), nullable=True))
|
||||
|
||||
with op.batch_alter_table('upload_files', schema=None) as batch_op:
|
||||
batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'account'::character varying"), nullable=False))
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table('upload_files', schema=None) as batch_op:
|
||||
batch_op.drop_column('created_by_role')
|
||||
|
||||
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
|
||||
batch_op.drop_column('file_upload')
|
||||
|
||||
with op.batch_alter_table('message_files', schema=None) as batch_op:
|
||||
batch_op.drop_index('message_file_message_idx')
|
||||
batch_op.drop_index('message_file_created_by_idx')
|
||||
|
||||
op.drop_table('message_files')
|
||||
# ### end Alembic commands ###
|
||||
@@ -1,10 +1,10 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import UserMixin
|
||||
from sqlalchemy.dialects.postgresql import UUID
|
||||
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
from libs.helper import generate_string
|
||||
from extensions.ext_database import db
|
||||
from .account import Account, Tenant
|
||||
@@ -98,6 +98,7 @@ class AppModelConfig(db.Model):
|
||||
completion_prompt_config = db.Column(db.Text)
|
||||
dataset_configs = db.Column(db.Text)
|
||||
external_data_tools = db.Column(db.Text)
|
||||
file_upload = db.Column(db.Text)
|
||||
|
||||
@property
|
||||
def app(self):
|
||||
@@ -161,6 +162,10 @@ class AppModelConfig(db.Model):
|
||||
def dataset_configs_dict(self) -> dict:
|
||||
return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}}
|
||||
|
||||
@property
|
||||
def file_upload_dict(self) -> dict:
|
||||
return json.loads(self.file_upload) if self.file_upload else {"image": {"enabled": False, "number_limits": 3, "detail": "high", "transfer_methods": ["remote_url", "local_file"]}}
|
||||
|
||||
def to_dict(self) -> dict:
|
||||
return {
|
||||
"provider": "",
|
||||
@@ -182,7 +187,8 @@ class AppModelConfig(db.Model):
|
||||
"prompt_type": self.prompt_type,
|
||||
"chat_prompt_config": self.chat_prompt_config_dict,
|
||||
"completion_prompt_config": self.completion_prompt_config_dict,
|
||||
"dataset_configs": self.dataset_configs_dict
|
||||
"dataset_configs": self.dataset_configs_dict,
|
||||
"file_upload": self.file_upload_dict
|
||||
}
|
||||
|
||||
def from_model_config_dict(self, model_config: dict):
|
||||
@@ -197,7 +203,8 @@ class AppModelConfig(db.Model):
|
||||
self.more_like_this = json.dumps(model_config['more_like_this'])
|
||||
self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance']) \
|
||||
if model_config.get('sensitive_word_avoidance') else None
|
||||
self.external_data_tools = json.dumps(model_config['external_data_tools'])
|
||||
self.external_data_tools = json.dumps(model_config['external_data_tools']) \
|
||||
if model_config.get('external_data_tools') else None
|
||||
self.model = json.dumps(model_config['model'])
|
||||
self.user_input_form = json.dumps(model_config['user_input_form'])
|
||||
self.dataset_query_variable = model_config.get('dataset_query_variable')
|
||||
@@ -212,6 +219,8 @@ class AppModelConfig(db.Model):
|
||||
if model_config.get('completion_prompt_config') else None
|
||||
self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \
|
||||
if model_config.get('dataset_configs') else None
|
||||
self.file_upload = json.dumps(model_config.get('file_upload')) \
|
||||
if model_config.get('file_upload') else None
|
||||
return self
|
||||
|
||||
def copy(self):
|
||||
@@ -237,7 +246,8 @@ class AppModelConfig(db.Model):
|
||||
prompt_type=self.prompt_type,
|
||||
chat_prompt_config=self.chat_prompt_config,
|
||||
completion_prompt_config=self.completion_prompt_config,
|
||||
dataset_configs=self.dataset_configs
|
||||
dataset_configs=self.dataset_configs,
|
||||
file_upload=self.file_upload
|
||||
)
|
||||
|
||||
return new_app_model_config
|
||||
@@ -511,6 +521,37 @@ class Message(db.Model):
|
||||
return db.session.query(DatasetRetrieverResource).filter(DatasetRetrieverResource.message_id == self.id) \
|
||||
.order_by(DatasetRetrieverResource.position.asc()).all()
|
||||
|
||||
@property
|
||||
def message_files(self):
|
||||
return db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all()
|
||||
|
||||
@property
|
||||
def files(self):
|
||||
message_files = self.message_files
|
||||
|
||||
files = []
|
||||
for message_file in message_files:
|
||||
url = message_file.url
|
||||
if message_file.type == 'image':
|
||||
if message_file.transfer_method == 'local_file':
|
||||
upload_file = (db.session.query(UploadFile)
|
||||
.filter(
|
||||
UploadFile.id == message_file.upload_file_id
|
||||
).first())
|
||||
|
||||
url = UploadFileParser.get_image_data(
|
||||
upload_file=upload_file,
|
||||
force_url=True
|
||||
)
|
||||
|
||||
files.append({
|
||||
'id': message_file.id,
|
||||
'type': message_file.type,
|
||||
'url': url
|
||||
})
|
||||
|
||||
return files
|
||||
|
||||
|
||||
class MessageFeedback(db.Model):
|
||||
__tablename__ = 'message_feedbacks'
|
||||
@@ -539,6 +580,25 @@ class MessageFeedback(db.Model):
|
||||
return account
|
||||
|
||||
|
||||
class MessageFile(db.Model):
|
||||
__tablename__ = 'message_files'
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint('id', name='message_file_pkey'),
|
||||
db.Index('message_file_message_idx', 'message_id'),
|
||||
db.Index('message_file_created_by_idx', 'created_by')
|
||||
)
|
||||
|
||||
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
|
||||
message_id = db.Column(UUID, nullable=False)
|
||||
type = db.Column(db.String(255), nullable=False)
|
||||
transfer_method = db.Column(db.String(255), nullable=False)
|
||||
url = db.Column(db.Text, nullable=True)
|
||||
upload_file_id = db.Column(UUID, nullable=True)
|
||||
created_by_role = db.Column(db.String(255), nullable=False)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
|
||||
|
||||
class MessageAnnotation(db.Model):
|
||||
__tablename__ = 'message_annotations'
|
||||
__table_args__ = (
|
||||
@@ -682,6 +742,7 @@ class UploadFile(db.Model):
|
||||
size = db.Column(db.Integer, nullable=False)
|
||||
extension = db.Column(db.String(255), nullable=False)
|
||||
mime_type = db.Column(db.String(255), nullable=True)
|
||||
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying"))
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
|
||||
used = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
|
||||
@@ -782,4 +843,3 @@ class DatasetRetrieverResource(db.Model):
|
||||
retriever_from = db.Column(db.Text, nullable=False)
|
||||
created_by = db.Column(UUID, nullable=False)
|
||||
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())
|
||||
|
||||
|
||||
@@ -315,6 +315,9 @@ class AppModelConfigService:
|
||||
# moderation validation
|
||||
cls.is_moderation_valid(tenant_id, config)
|
||||
|
||||
# file upload validation
|
||||
cls.is_file_upload_valid(config)
|
||||
|
||||
# Filter out extra parameters
|
||||
filtered_config = {
|
||||
"opening_statement": config["opening_statement"],
|
||||
@@ -338,7 +341,8 @@ class AppModelConfigService:
|
||||
"prompt_type": config["prompt_type"],
|
||||
"chat_prompt_config": config["chat_prompt_config"],
|
||||
"completion_prompt_config": config["completion_prompt_config"],
|
||||
"dataset_configs": config["dataset_configs"]
|
||||
"dataset_configs": config["dataset_configs"],
|
||||
"file_upload": config["file_upload"]
|
||||
}
|
||||
|
||||
return filtered_config
|
||||
@@ -371,6 +375,34 @@ class AppModelConfigService:
|
||||
config=config
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_file_upload_valid(cls, config: dict):
|
||||
if 'file_upload' not in config or not config["file_upload"]:
|
||||
config["file_upload"] = {}
|
||||
|
||||
if not isinstance(config["file_upload"], dict):
|
||||
raise ValueError("file_upload must be of dict type")
|
||||
|
||||
# check image config
|
||||
if 'image' not in config["file_upload"] or not config["file_upload"]["image"]:
|
||||
config["file_upload"]["image"] = {"enabled": False}
|
||||
|
||||
if config['file_upload']['image']['enabled']:
|
||||
number_limits = config['file_upload']['image']['number_limits']
|
||||
if number_limits < 1 or number_limits > 6:
|
||||
raise ValueError("number_limits must be in [1, 6]")
|
||||
|
||||
detail = config['file_upload']['image']['detail']
|
||||
if detail not in ['high', 'low']:
|
||||
raise ValueError("detail must be in ['high', 'low']")
|
||||
|
||||
transfer_methods = config['file_upload']['image']['transfer_methods']
|
||||
if not isinstance(transfer_methods, list):
|
||||
raise ValueError("transfer_methods must be of list type")
|
||||
for method in transfer_methods:
|
||||
if method not in ['remote_url', 'local_file']:
|
||||
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
|
||||
|
||||
@classmethod
|
||||
def is_external_data_tools_valid(cls, tenant_id: str, config: dict):
|
||||
if 'external_data_tools' not in config or not config["external_data_tools"]:
|
||||
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Generator, Union, Any, Optional
|
||||
from typing import Generator, Union, Any, Optional, List
|
||||
|
||||
from flask import current_app, Flask
|
||||
from redis.client import PubSub
|
||||
@@ -12,9 +12,11 @@ from sqlalchemy import and_
|
||||
from core.completion import Completion
|
||||
from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \
|
||||
ConversationTaskInterruptException
|
||||
from core.file.message_file_parser import MessageFileParser
|
||||
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
|
||||
LLMRateLimitError, \
|
||||
LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.model_providers.models.entity.message import PromptMessageFile
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
|
||||
@@ -35,6 +37,9 @@ class CompletionService:
|
||||
# is streaming mode
|
||||
inputs = args['inputs']
|
||||
query = args['query']
|
||||
files = args['files'] if 'files' in args and args['files'] else []
|
||||
auto_generate_name = args['auto_generate_name'] \
|
||||
if 'auto_generate_name' in args else True
|
||||
|
||||
if app_model.mode != 'completion' and not query:
|
||||
raise ValueError('query is required')
|
||||
@@ -132,6 +137,14 @@ class CompletionService:
|
||||
# clean input by app_model_config form rules
|
||||
inputs = cls.get_cleaned_inputs(inputs, app_model_config)
|
||||
|
||||
# parse files
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_objs = message_file_parser.validate_and_transform_files_arg(
|
||||
files,
|
||||
app_model_config,
|
||||
user
|
||||
)
|
||||
|
||||
generate_task_id = str(uuid.uuid4())
|
||||
|
||||
pubsub = redis_client.pubsub()
|
||||
@@ -146,17 +159,20 @@ class CompletionService:
|
||||
'app_model_config': app_model_config.copy(),
|
||||
'query': query,
|
||||
'inputs': inputs,
|
||||
'files': file_objs,
|
||||
'detached_user': user,
|
||||
'detached_conversation': conversation,
|
||||
'streaming': streaming,
|
||||
'is_model_config_override': is_model_config_override,
|
||||
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev'
|
||||
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev',
|
||||
'auto_generate_name': auto_generate_name
|
||||
})
|
||||
|
||||
generate_worker_thread.start()
|
||||
|
||||
# wait for 10 minutes to close the thread
|
||||
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
|
||||
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
|
||||
generate_task_id)
|
||||
|
||||
return cls.compact_response(pubsub, streaming)
|
||||
|
||||
@@ -172,10 +188,12 @@ class CompletionService:
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, app_model_config: AppModelConfig,
|
||||
query: str, inputs: dict, detached_user: Union[Account, EndUser],
|
||||
def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
|
||||
app_model_config: AppModelConfig,
|
||||
query: str, inputs: dict, files: List[PromptMessageFile],
|
||||
detached_user: Union[Account, EndUser],
|
||||
detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
|
||||
retriever_from: str = 'dev'):
|
||||
retriever_from: str = 'dev', auto_generate_name: bool = True):
|
||||
with flask_app.app_context():
|
||||
# fixed the state of the model object when it detached from the original session
|
||||
user = db.session.merge(detached_user)
|
||||
@@ -195,10 +213,12 @@ class CompletionService:
|
||||
query=query,
|
||||
inputs=inputs,
|
||||
user=user,
|
||||
files=files,
|
||||
conversation=conversation,
|
||||
streaming=streaming,
|
||||
is_override=is_model_config_override,
|
||||
retriever_from=retriever_from
|
||||
retriever_from=retriever_from,
|
||||
auto_generate_name=auto_generate_name
|
||||
)
|
||||
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
|
||||
pass
|
||||
@@ -215,7 +235,8 @@ class CompletionService:
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user, generate_task_id) -> threading.Thread:
|
||||
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user,
|
||||
generate_task_id) -> threading.Thread:
|
||||
# wait for 10 minutes to close the thread
|
||||
timeout = 600
|
||||
|
||||
@@ -274,6 +295,12 @@ class CompletionService:
|
||||
model_dict['completion_params'] = completion_params
|
||||
app_model_config.model = json.dumps(model_dict)
|
||||
|
||||
# parse files
|
||||
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
|
||||
file_objs = message_file_parser.transform_message_files(
|
||||
message.files, app_model_config
|
||||
)
|
||||
|
||||
generate_task_id = str(uuid.uuid4())
|
||||
|
||||
pubsub = redis_client.pubsub()
|
||||
@@ -288,11 +315,13 @@ class CompletionService:
|
||||
'app_model_config': app_model_config.copy(),
|
||||
'query': message.query,
|
||||
'inputs': message.inputs,
|
||||
'files': file_objs,
|
||||
'detached_user': user,
|
||||
'detached_conversation': None,
|
||||
'streaming': streaming,
|
||||
'is_model_config_override': True,
|
||||
'retriever_from': retriever_from
|
||||
'retriever_from': retriever_from,
|
||||
'auto_generate_name': False
|
||||
})
|
||||
|
||||
generate_worker_thread.start()
|
||||
@@ -388,7 +417,8 @@ class CompletionService:
|
||||
if event == 'message':
|
||||
yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'message_replace':
|
||||
yield "data: " + json.dumps(cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
|
||||
yield "data: " + json.dumps(
|
||||
cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'chain':
|
||||
yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
|
||||
elif event == 'agent_thought':
|
||||
|
||||
@@ -1,17 +1,20 @@
|
||||
from typing import Union, Optional
|
||||
|
||||
from core.generator.llm_generator import LLMGenerator
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account
|
||||
from models.model import Conversation, App, EndUser
|
||||
from models.model import Conversation, App, EndUser, Message
|
||||
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError
|
||||
|
||||
|
||||
class ConversationService:
|
||||
@classmethod
|
||||
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]],
|
||||
last_id: Optional[str], limit: int,
|
||||
include_ids: Optional[list] = None, exclude_ids: Optional[list] = None) -> InfiniteScrollPagination:
|
||||
include_ids: Optional[list] = None, exclude_ids: Optional[list] = None,
|
||||
exclude_debug_conversation: bool = False) -> InfiniteScrollPagination:
|
||||
if not user:
|
||||
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
|
||||
|
||||
@@ -29,6 +32,9 @@ class ConversationService:
|
||||
if exclude_ids is not None:
|
||||
base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
|
||||
|
||||
if exclude_debug_conversation:
|
||||
base_query = base_query.filter(Conversation.override_model_configs == None)
|
||||
|
||||
if last_id:
|
||||
last_conversation = base_query.filter(
|
||||
Conversation.id == last_id,
|
||||
@@ -63,10 +69,36 @@ class ConversationService:
|
||||
|
||||
@classmethod
|
||||
def rename(cls, app_model: App, conversation_id: str,
|
||||
user: Optional[Union[Account | EndUser]], name: str):
|
||||
user: Optional[Union[Account | EndUser]], name: str, auto_generate: bool):
|
||||
conversation = cls.get_conversation(app_model, conversation_id, user)
|
||||
|
||||
conversation.name = name
|
||||
if auto_generate:
|
||||
return cls.auto_generate_name(app_model, conversation)
|
||||
else:
|
||||
conversation.name = name
|
||||
db.session.commit()
|
||||
|
||||
return conversation
|
||||
|
||||
@classmethod
|
||||
def auto_generate_name(cls, app_model: App, conversation: Conversation):
|
||||
# get conversation first message
|
||||
message = db.session.query(Message) \
|
||||
.filter(
|
||||
Message.app_id == app_model.id,
|
||||
Message.conversation_id == conversation.id
|
||||
).order_by(Message.created_at.asc()).first()
|
||||
|
||||
if not message:
|
||||
raise MessageNotExistsError()
|
||||
|
||||
# generate conversation name
|
||||
try:
|
||||
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query)
|
||||
conversation.name = name
|
||||
except:
|
||||
pass
|
||||
|
||||
db.session.commit()
|
||||
|
||||
return conversation
|
||||
|
||||
@@ -1,46 +1,62 @@
|
||||
import datetime
|
||||
import hashlib
|
||||
import time
|
||||
import uuid
|
||||
from typing import Generator, Tuple, Union
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from werkzeug.datastructures import FileStorage
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
from core.file.upload_file_parser import UploadFileParser
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
from models.account import Account
|
||||
from models.model import UploadFile, EndUser
|
||||
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
|
||||
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
|
||||
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv',
|
||||
'jpg', 'jpeg', 'png', 'webp', 'gif']
|
||||
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif']
|
||||
PREVIEW_WORDS_LIMIT = 3000
|
||||
cache = TTLCache(maxsize=None, ttl=30)
|
||||
|
||||
|
||||
class FileService:
|
||||
|
||||
@staticmethod
|
||||
def upload_file(file: FileStorage) -> UploadFile:
|
||||
def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile:
|
||||
extension = file.filename.split('.')[-1]
|
||||
if extension.lower() not in ALLOWED_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
# read file content
|
||||
file_content = file.read()
|
||||
|
||||
# get file size
|
||||
file_size = len(file_content)
|
||||
|
||||
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
|
||||
if extension.lower() in IMAGE_EXTENSIONS:
|
||||
file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT") * 1024 * 1024
|
||||
else:
|
||||
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
|
||||
|
||||
if file_size > file_size_limit:
|
||||
message = f'File size exceeded. {file_size} > {file_size_limit}'
|
||||
raise FileTooLargeError(message)
|
||||
|
||||
extension = file.filename.split('.')[-1]
|
||||
if extension.lower() not in ALLOWED_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
# user uuid as file name
|
||||
file_uuid = str(uuid.uuid4())
|
||||
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
|
||||
|
||||
if isinstance(user, Account):
|
||||
current_tenant_id = user.current_tenant_id
|
||||
else:
|
||||
# end_user
|
||||
current_tenant_id = user.tenant_id
|
||||
|
||||
file_key = 'upload_files/' + current_tenant_id + '/' + file_uuid + '.' + extension
|
||||
|
||||
# save file to storage
|
||||
storage.save(file_key, file_content)
|
||||
@@ -48,14 +64,15 @@ class FileService:
|
||||
# save file to db
|
||||
config = current_app.config
|
||||
upload_file = UploadFile(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
tenant_id=current_tenant_id,
|
||||
storage_type=config['STORAGE_TYPE'],
|
||||
key=file_key,
|
||||
name=file.filename,
|
||||
size=file_size,
|
||||
extension=extension,
|
||||
mime_type=file.mimetype,
|
||||
created_by=current_user.id,
|
||||
created_by_role=('account' if isinstance(user, Account) else 'end_user'),
|
||||
created_by=user.id,
|
||||
created_at=datetime.datetime.utcnow(),
|
||||
used=False,
|
||||
hash=hashlib.sha3_256(file_content).hexdigest()
|
||||
@@ -99,12 +116,6 @@ class FileService:
|
||||
|
||||
@staticmethod
|
||||
def get_file_preview(file_id: str) -> str:
|
||||
# get file storage key
|
||||
key = file_id + request.path
|
||||
cached_response = cache.get(key)
|
||||
if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
|
||||
return cached_response['response']
|
||||
|
||||
upload_file = db.session.query(UploadFile) \
|
||||
.filter(UploadFile.id == file_id) \
|
||||
.first()
|
||||
@@ -121,3 +132,25 @@ class FileService:
|
||||
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
|
||||
|
||||
return text
|
||||
|
||||
@staticmethod
|
||||
def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str) -> Tuple[Generator, str]:
|
||||
result = UploadFileParser.verify_image_file_signature(file_id, timestamp, nonce, sign)
|
||||
if not result:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
|
||||
upload_file = db.session.query(UploadFile) \
|
||||
.filter(UploadFile.id == file_id) \
|
||||
.first()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found or signature is invalid")
|
||||
|
||||
# extract text from file
|
||||
extension = upload_file.extension
|
||||
if extension.lower() not in IMAGE_EXTENSIONS:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
generator = storage.load(upload_file.key, stream=True)
|
||||
|
||||
return generator, upload_file.mime_type
|
||||
|
||||
@@ -11,7 +11,8 @@ from services.conversation_service import ConversationService
|
||||
class WebConversationService:
|
||||
@classmethod
|
||||
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]],
|
||||
last_id: Optional[str], limit: int, pinned: Optional[bool] = None) -> InfiniteScrollPagination:
|
||||
last_id: Optional[str], limit: int, pinned: Optional[bool] = None,
|
||||
exclude_debug_conversation: bool = False) -> InfiniteScrollPagination:
|
||||
include_ids = None
|
||||
exclude_ids = None
|
||||
if pinned is not None:
|
||||
@@ -32,7 +33,8 @@ class WebConversationService:
|
||||
last_id=last_id,
|
||||
limit=limit,
|
||||
include_ids=include_ids,
|
||||
exclude_ids=exclude_ids
|
||||
exclude_ids=exclude_ids,
|
||||
exclude_debug_conversation=exclude_debug_conversation
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -5,7 +5,7 @@ from unittest.mock import patch
|
||||
from langchain.schema import Generation, ChatGeneration, AIMessage
|
||||
|
||||
from core.model_providers.providers.openai_provider import OpenAIProvider
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType, ImageMessageFile
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs
|
||||
from core.model_providers.models.llm.openai_model import OpenAIModel
|
||||
from models.provider import Provider, ProviderType
|
||||
@@ -57,6 +57,18 @@ def test_chat_get_num_tokens(mock_decrypt):
|
||||
assert rst == 22
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_vision_chat_get_num_tokens(mock_decrypt):
|
||||
openai_model = get_mock_openai_model('gpt-4-vision-preview')
|
||||
messages = [
|
||||
PromptMessage(content='What’s in first image?', files=[
|
||||
ImageMessageFile(
|
||||
data='https://upload.wikimedia.org/wikipedia/commons/0/00/1890s_Carlisle_Boarding_School_Graduates_PA.jpg')
|
||||
])
|
||||
]
|
||||
rst = openai_model.get_num_tokens(messages)
|
||||
assert rst == 77
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
@@ -80,4 +92,20 @@ def test_chat_run(mock_decrypt, mocker):
|
||||
messages,
|
||||
stop=['\nHuman:'],
|
||||
)
|
||||
assert (len(rst.content) > 0)
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_vision_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
||||
openai_model = get_mock_openai_model('gpt-4-vision-preview')
|
||||
messages = [
|
||||
PromptMessage(content='What’s in first image?', files=[
|
||||
ImageMessageFile(data='https://upload.wikimedia.org/wikipedia/commons/0/00/1890s_Carlisle_Boarding_School_Graduates_PA.jpg')
|
||||
])
|
||||
]
|
||||
rst = openai_model.run(
|
||||
messages,
|
||||
)
|
||||
assert len(rst.content) > 0
|
||||
|
||||
@@ -2,7 +2,7 @@ version: '3.1'
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:0.3.29
|
||||
image: langgenius/dify-api:0.3.30
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'api' starts the API server.
|
||||
@@ -19,18 +19,22 @@ services:
|
||||
# different from api or web app domain.
|
||||
# example: http://cloud.dify.ai
|
||||
CONSOLE_API_URL: ''
|
||||
# The URL for Service API endpoints, refers to the base URL of the current API service if api domain is
|
||||
# The URL prefix for Service API endpoints, refers to the base URL of the current API service if api domain is
|
||||
# different from console domain.
|
||||
# example: http://api.dify.ai
|
||||
SERVICE_API_URL: ''
|
||||
# The URL for Web APP api server, refers to the Web App base URL of WEB service if web app domain is different from
|
||||
# The URL prefix for Web APP api server, refers to the Web App base URL of WEB service if web app domain is different from
|
||||
# console or api domain.
|
||||
# example: http://udify.app
|
||||
APP_API_URL: ''
|
||||
# The URL for Web APP frontend, refers to the Web App base URL of WEB service if web app domain is different from
|
||||
# The URL prefix for Web APP frontend, refers to the Web App base URL of WEB service if web app domain is different from
|
||||
# console or api domain.
|
||||
# example: http://udify.app
|
||||
APP_WEB_URL: ''
|
||||
# File preview or download Url prefix.
|
||||
# used to display File preview or download Url to the front-end or as Multi-model inputs;
|
||||
# Url is signed and has expiration time.
|
||||
FILES_URL: ''
|
||||
# When enabled, migrations will be executed prior to application startup and the application will start after the migrations have completed.
|
||||
MIGRATION_ENABLED: 'true'
|
||||
# The configurations of postgres database connection.
|
||||
@@ -124,7 +128,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:0.3.29
|
||||
image: langgenius/dify-api:0.3.30
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'worker' starts the Celery worker for processing the queue.
|
||||
@@ -192,7 +196,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:0.3.29
|
||||
image: langgenius/dify-web:0.3.30
|
||||
restart: always
|
||||
environment:
|
||||
EDITION: SELF_HOSTED
|
||||
|
||||
@@ -17,6 +17,11 @@ server {
|
||||
include proxy.conf;
|
||||
}
|
||||
|
||||
location /files {
|
||||
proxy_pass http://api:5001;
|
||||
include proxy.conf;
|
||||
}
|
||||
|
||||
location / {
|
||||
proxy_pass http://web:3000;
|
||||
include proxy.conf;
|
||||
|
||||
10
third-party/chrome plug-in/README_CN.md
vendored
@@ -20,16 +20,14 @@
|
||||
- options.js 插件配置JS脚本
|
||||
|
||||
### 插件导入完成后,后续配置无差异
|
||||
- 初始化设置Dify 应用配置,分别输入Dify根域名和应用Token,Token可以在Dify应用嵌入中获取,如图:
|
||||
- 创建Dify应用配置,在应用概览中点击嵌入,切换到安装Chrome浏览器扩展视图,点击copy按钮获取ChatBot Url,如图:
|
||||
|
||||

|
||||

|
||||
|
||||
- 点击保存,确认提示配置成功即可
|
||||
|
||||

|
||||

|
||||
|
||||
- 保险起见重启浏览器确保所有分页刷新成功
|
||||
- Chrome打开任意页面均可正常加载DIfy机器人浮动栏,后续如需更换机器人只需要变更Token即可
|
||||
- Chrome打开任意页面均可正常加载DIfy机器人浮动栏,后续如需更换机器人只需要变更ChatBot Url即可
|
||||
|
||||

|
||||

|
||||
6
third-party/chrome plug-in/README_CN.txt
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
## Chrome Dify ChatBot插件
|
||||
|
||||
1、初始化设置Dify 应用配置,分别输入Dify根域名和应用Token,Token可以在Dify应用嵌入中获取;
|
||||
2、点击保存,确认提示配置成功即可;
|
||||
3、保险起见重启浏览器确保所有分页刷新成功;
|
||||
4、Chrome打开任意页面均可正常加载DIfy机器人浮动栏,后续如需更换机器人只需要变更Token即可;
|
||||
14
third-party/chrome plug-in/content.js
vendored
@@ -1,8 +1,7 @@
|
||||
var storage = chrome.storage.sync;
|
||||
chrome.storage.sync.get(['baseUrl', 'token'], function(result) {
|
||||
const storage = chrome.storage.sync;
|
||||
chrome.storage.sync.get(['chatbotUrl'], function(result) {
|
||||
window.difyChatbotConfig = {
|
||||
baseUrl: result.baseUrl,
|
||||
token: result.token
|
||||
chatbotUrl: result.chatbotUrl,
|
||||
};
|
||||
});
|
||||
|
||||
@@ -10,11 +9,10 @@ document.body.onload = embedChatbot;
|
||||
|
||||
async function embedChatbot() {
|
||||
const difyChatbotConfig = window.difyChatbotConfig;
|
||||
if (!difyChatbotConfig || !difyChatbotConfig.token) {
|
||||
console.warn('difyChatbotConfig is empty or token is not provided');
|
||||
if (!difyChatbotConfig) {
|
||||
console.warn('Dify Chatbot Url is empty or is not provided');
|
||||
return;
|
||||
}
|
||||
const baseUrl = difyChatbotConfig.baseUrl
|
||||
const openIcon = `<svg
|
||||
id="openIcon"
|
||||
width="24"
|
||||
@@ -53,7 +51,7 @@ async function embedChatbot() {
|
||||
iframe.allow = "fullscreen;microphone"
|
||||
iframe.title = "dify chatbot bubble window"
|
||||
iframe.id = 'dify-chatbot-bubble-window'
|
||||
iframe.src = `${baseUrl}/chat/${difyChatbotConfig.token}`
|
||||
iframe.src = difyChatbotConfig.chatbotUrl
|
||||
iframe.style.cssText = 'border: none; position: fixed; flex-direction: column; justify-content: space-between; box-shadow: rgba(150, 150, 150, 0.2) 0px 10px 30px 0px, rgba(150, 150, 150, 0.2) 0px 0px 0px 1px; bottom: 6.7rem; right: 1rem; width: 30rem; height: 48rem; border-radius: 0.75rem; display: flex; z-index: 2147483647; overflow: hidden; left: unset; background-color: #F3F4F6;'
|
||||
document.body.appendChild(iframe);
|
||||
}
|
||||
|
||||
BIN
third-party/chrome plug-in/favicon.png
vendored
Normal file
|
After Width: | Height: | Size: 2.8 KiB |
BIN
third-party/chrome plug-in/images/favicon.ico
vendored
Normal file
|
After Width: | Height: | Size: 15 KiB |
BIN
third-party/chrome plug-in/images/img-2.png
vendored
|
Before Width: | Height: | Size: 73 KiB After Width: | Height: | Size: 124 KiB |
BIN
third-party/chrome plug-in/images/img-3.png
vendored
|
Before Width: | Height: | Size: 96 KiB After Width: | Height: | Size: 85 KiB |
BIN
third-party/chrome plug-in/images/img-4.png
vendored
|
Before Width: | Height: | Size: 55 KiB After Width: | Height: | Size: 85 KiB |
BIN
third-party/chrome plug-in/images/img-5.png
vendored
|
Before Width: | Height: | Size: 85 KiB |
4
third-party/chrome plug-in/manifest.json
vendored
@@ -1,7 +1,7 @@
|
||||
{
|
||||
"manifest_version": 3,
|
||||
"name": "Dify Chatbot",
|
||||
"version": "1.3",
|
||||
"version": "1.5",
|
||||
"description": "This is a chrome extension to inject a dify chatbot on any pages",
|
||||
"content_scripts": [
|
||||
{
|
||||
@@ -17,7 +17,7 @@
|
||||
"32": "images/32.png",
|
||||
"48": "images/48.png",
|
||||
"128": "images/128.png"
|
||||
|
||||
|
||||
}
|
||||
},
|
||||
"icons": {
|
||||
|
||||
19
third-party/chrome plug-in/options.css
vendored
Normal file
@@ -0,0 +1,19 @@
|
||||
body {
|
||||
background-color: #f2f2f2;
|
||||
font-family: Arial, sans-serif;
|
||||
}
|
||||
|
||||
h2 {
|
||||
color: #333;
|
||||
}
|
||||
|
||||
label {
|
||||
display: block;
|
||||
margin-top: 10px;
|
||||
margin-bottom: 10px;
|
||||
}
|
||||
|
||||
input[type="text"] {
|
||||
width: 280px;
|
||||
padding: 6px;
|
||||
}
|
||||
21
third-party/chrome plug-in/options.html
vendored
@@ -4,32 +4,21 @@
|
||||
<head>
|
||||
<title>Dify Chatbot Extension</title>
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1.0">
|
||||
<link href="https://cdn.jsdelivr.net/npm/tailwindcss@2.2.19/dist/tailwind.min.css" rel="stylesheet">
|
||||
<link href="./tailwind.css" rel="stylesheet">
|
||||
</head>
|
||||
|
||||
<body class="bg-gray-100 py-8 px-4 w-96">
|
||||
<body class="bg-gray-100 py-4 px-4 w-128">
|
||||
<div class="max-w-md mx-auto bg-white shadow-md rounded-lg p-4">
|
||||
<h2 class="text-2xl font-semibold mb-4">Dify Chatbot Extension</h2>
|
||||
<form>
|
||||
<div class="mb-4 flex items-center">
|
||||
<div class="w-1/4">
|
||||
<label for="base-url" class="block font-semibold text-gray-700">Base URL</label>
|
||||
<label for="chatbot-url" class="block font-semibold text-gray-700">ChatBot URL</label>
|
||||
</div>
|
||||
<div class="w-3/4">
|
||||
<input type="text" id="base-url" name="base-url" value=""
|
||||
<input type="text" id="chatbot-url" name="base-url" value=""
|
||||
class="w-full border border-gray-300 rounded px-3 py-2 focus:outline-none focus:border-blue-400"
|
||||
placeholder="https://udify.app">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="mb-4 flex items-center">
|
||||
<div class="w-1/4">
|
||||
<label for="token" class="block font-semibold text-gray-700">Token</label>
|
||||
</div>
|
||||
<div class="w-3/4">
|
||||
<input type="text" id="token" name="token" value=""
|
||||
class="w-full border border-gray-300 rounded px-3 py-2 focus:outline-none focus:border-blue-400"
|
||||
placeholder="Application Embedded Token">
|
||||
placeholder="https://udify.app/chatbot/7CQBa5yyvYLSkZtx">
|
||||
</div>
|
||||
</div>
|
||||
|
||||
|
||||
29
third-party/chrome plug-in/options.js
vendored
@@ -1,39 +1,28 @@
|
||||
|
||||
document.getElementById('save-button').addEventListener('click', function (e) {
|
||||
e.preventDefault();
|
||||
var baseUrl = document.getElementById('base-url').value;
|
||||
var token = document.getElementById('token').value;
|
||||
var errorTip = document.getElementById('error-tip');
|
||||
const chatbotUrl = document.getElementById('chatbot-url').value;
|
||||
const errorTip = document.getElementById('error-tip');
|
||||
|
||||
if (baseUrl.trim() === "" || token.trim() === "") {
|
||||
if (baseUrl.trim() === "") {
|
||||
errorTip.textContent = "Base URL cannot be empty.";
|
||||
} else {
|
||||
errorTip.textContent = "Token cannot be empty.";
|
||||
}
|
||||
if (chatbotUrl.trim() === "") {
|
||||
errorTip.textContent = "Dify ChatBot URL cannot be empty.";
|
||||
} else {
|
||||
errorTip.textContent = "";
|
||||
|
||||
chrome.storage.sync.set({
|
||||
'baseUrl': baseUrl,
|
||||
'token': token
|
||||
'chatbotUrl': chatbotUrl,
|
||||
}, function () {
|
||||
alert('Save Success!');
|
||||
});
|
||||
}
|
||||
|
||||
});
|
||||
|
||||
// Load parameters from chrome.storage when the page loads
|
||||
chrome.storage.sync.get(['baseUrl', 'token'], function (result) {
|
||||
const baseUrlInput = document.getElementById('base-url');
|
||||
const tokenInput = document.getElementById('token');
|
||||
chrome.storage.sync.get(['chatbotUrl'], function (result) {
|
||||
const chatbotUrlInput = document.getElementById('chatbot-url');
|
||||
|
||||
if (result.baseUrl) {
|
||||
baseUrlInput.value = result.baseUrl;
|
||||
if (result.chatbotUrl) {
|
||||
chatbotUrlInput.value = result.chatbotUrl;
|
||||
}
|
||||
|
||||
if (result.token) {
|
||||
tokenInput.value = result.token;
|
||||
}
|
||||
});
|
||||
176015
third-party/chrome plug-in/tailwind.css
vendored
Normal file
@@ -1,6 +1,7 @@
|
||||
'use client'
|
||||
import type { FC, ReactNode } from 'react'
|
||||
import React, { useEffect, useLayoutEffect, useRef, useState } from 'react'
|
||||
import Textarea from 'rc-textarea'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import cn from 'classnames'
|
||||
import Recorder from 'js-audio-recorder'
|
||||
@@ -10,9 +11,8 @@ import type { DisplayScene, FeedbackFunc, IChatItem, SubmitAnnotationFunc } from
|
||||
import { TryToAskIcon, stopIcon } from './icon-component'
|
||||
import Answer from './answer'
|
||||
import Question from './question'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import TooltipPlus from '@/app/components/base/tooltip-plus'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import AutoHeightTextarea from '@/app/components/base/auto-height-textarea'
|
||||
import Button from '@/app/components/base/button'
|
||||
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
|
||||
import VoiceInput from '@/app/components/base/voice-input'
|
||||
@@ -20,6 +20,10 @@ import { Microphone01 } from '@/app/components/base/icons/src/vender/line/mediaA
|
||||
import { Microphone01 as Microphone01Solid } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices'
|
||||
import { XCircle } from '@/app/components/base/icons/src/vender/solid/general'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import ChatImageUploader from '@/app/components/base/image-uploader/chat-image-uploader'
|
||||
import ImageList from '@/app/components/base/image-uploader/image-list'
|
||||
import { TransferMethod, type VisionFile, type VisionSettings } from '@/types/app'
|
||||
import { useImageFiles } from '@/app/components/base/image-uploader/hooks'
|
||||
|
||||
export type IChatProps = {
|
||||
configElem?: React.ReactNode
|
||||
@@ -37,7 +41,7 @@ export type IChatProps = {
|
||||
onFeedback?: FeedbackFunc
|
||||
onSubmitAnnotation?: SubmitAnnotationFunc
|
||||
checkCanSend?: () => boolean
|
||||
onSend?: (message: string) => void
|
||||
onSend?: (message: string, files: VisionFile[]) => void
|
||||
displayScene?: DisplayScene
|
||||
useCurrentUserAvatar?: boolean
|
||||
isResponsing?: boolean
|
||||
@@ -54,6 +58,7 @@ export type IChatProps = {
|
||||
dataSets?: DataSet[]
|
||||
isShowCitationHitInfo?: boolean
|
||||
isShowPromptLog?: boolean
|
||||
visionConfig?: VisionSettings
|
||||
}
|
||||
|
||||
const Chat: FC<IChatProps> = ({
|
||||
@@ -83,9 +88,19 @@ const Chat: FC<IChatProps> = ({
|
||||
dataSets,
|
||||
isShowCitationHitInfo,
|
||||
isShowPromptLog,
|
||||
visionConfig,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { notify } = useContext(ToastContext)
|
||||
const {
|
||||
files,
|
||||
onUpload,
|
||||
onRemove,
|
||||
onReUpload,
|
||||
onImageLinkLoadError,
|
||||
onImageLinkLoadSuccess,
|
||||
onClear,
|
||||
} = useImageFiles()
|
||||
const isUseInputMethod = useRef(false)
|
||||
|
||||
const [query, setQuery] = React.useState('')
|
||||
@@ -114,9 +129,18 @@ const Chat: FC<IChatProps> = ({
|
||||
const handleSend = () => {
|
||||
if (!valid() || (checkCanSend && !checkCanSend()))
|
||||
return
|
||||
onSend(query)
|
||||
if (!isResponsing)
|
||||
setQuery('')
|
||||
onSend(query, files.filter(file => file.progress !== -1).map(fileItem => ({
|
||||
type: 'image',
|
||||
transfer_method: fileItem.type,
|
||||
url: fileItem.url,
|
||||
upload_file_id: fileItem.fileId,
|
||||
})))
|
||||
if (!files.find(item => item.type === TransferMethod.local_file && !item.fileId)) {
|
||||
if (files.length)
|
||||
onClear()
|
||||
if (!isResponsing)
|
||||
setQuery('')
|
||||
}
|
||||
}
|
||||
|
||||
const handleKeyUp = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
|
||||
@@ -198,6 +222,8 @@ const Chat: FC<IChatProps> = ({
|
||||
item={item}
|
||||
isShowPromptLog={isShowPromptLog}
|
||||
isResponsing={isResponsing}
|
||||
// ['https://placekitten.com/360/360', 'https://placekitten.com/360/640']
|
||||
imgSrcs={(item.message_files && item.message_files?.length > 0) ? item.message_files.map(item => item.url) : []}
|
||||
/>
|
||||
)
|
||||
})}
|
||||
@@ -246,18 +272,42 @@ const Chat: FC<IChatProps> = ({
|
||||
</div>
|
||||
</div>)
|
||||
}
|
||||
<div className="relative">
|
||||
<AutoHeightTextarea
|
||||
<div className='p-[5.5px] max-h-[150px] bg-white border-[1.5px] border-gray-200 rounded-xl overflow-y-auto'>
|
||||
{
|
||||
visionConfig?.enabled && (
|
||||
<>
|
||||
<div className='absolute bottom-2 left-2 flex items-center'>
|
||||
<ChatImageUploader
|
||||
settings={visionConfig}
|
||||
onUpload={onUpload}
|
||||
disabled={files.length >= visionConfig.number_limits}
|
||||
/>
|
||||
<div className='mx-1 w-[1px] h-4 bg-black/5' />
|
||||
</div>
|
||||
<div className='pl-[52px]'>
|
||||
<ImageList
|
||||
list={files}
|
||||
onRemove={onRemove}
|
||||
onReUpload={onReUpload}
|
||||
onImageLinkLoadSuccess={onImageLinkLoadSuccess}
|
||||
onImageLinkLoadError={onImageLinkLoadError}
|
||||
/>
|
||||
</div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
<Textarea
|
||||
className={`
|
||||
block w-full px-2 pr-[118px] py-[7px] leading-5 max-h-none text-sm text-gray-700 outline-none appearance-none resize-none
|
||||
${visionConfig?.enabled && 'pl-12'}
|
||||
`}
|
||||
value={query}
|
||||
onChange={handleContentChange}
|
||||
onKeyUp={handleKeyUp}
|
||||
onKeyDown={handleKeyDown}
|
||||
minHeight={48}
|
||||
autoFocus
|
||||
controlFocus={controlFocus}
|
||||
className={`${cn(s.textArea)} resize-none block w-full pl-3 bg-gray-50 border border-gray-200 rounded-md focus:outline-none sm:text-sm text-gray-700`}
|
||||
autoSize
|
||||
/>
|
||||
<div className="absolute top-0 right-2 flex items-center h-[48px]">
|
||||
<div className="absolute bottom-2 right-2 flex items-center h-8">
|
||||
<div className={`${s.count} mr-4 h-5 leading-5 text-sm bg-gray-50 text-gray-500`}>{query.trim().length}</div>
|
||||
{
|
||||
query
|
||||
@@ -282,9 +332,8 @@ const Chat: FC<IChatProps> = ({
|
||||
{isMobile
|
||||
? sendBtn
|
||||
: (
|
||||
<Tooltip
|
||||
selector='send-tip'
|
||||
htmlContent={
|
||||
<TooltipPlus
|
||||
popupContent={
|
||||
<div>
|
||||
<div>{t('common.operation.send')} Enter</div>
|
||||
<div>{t('common.operation.lineBreak')} Shift Enter</div>
|
||||
@@ -292,7 +341,7 @@ const Chat: FC<IChatProps> = ({
|
||||
}
|
||||
>
|
||||
{sendBtn}
|
||||
</Tooltip>
|
||||
</TooltipPlus>
|
||||
)}
|
||||
</div>
|
||||
{
|
||||
|
||||
@@ -8,14 +8,16 @@ import Log from '../log'
|
||||
import MoreInfo from '../more-info'
|
||||
import AppContext from '@/context/app-context'
|
||||
import { Markdown } from '@/app/components/base/markdown'
|
||||
import ImageGallery from '@/app/components/base/image-gallery'
|
||||
|
||||
type IQuestionProps = Pick<IChatItem, 'id' | 'content' | 'more' | 'useCurrentUserAvatar'> & {
|
||||
isShowPromptLog?: boolean
|
||||
item: IChatItem
|
||||
imgSrcs?: string[]
|
||||
isResponsing?: boolean
|
||||
}
|
||||
|
||||
const Question: FC<IQuestionProps> = ({ id, content, more, useCurrentUserAvatar, isShowPromptLog, item, isResponsing }) => {
|
||||
const Question: FC<IQuestionProps> = ({ id, content, imgSrcs, more, useCurrentUserAvatar, isShowPromptLog, item, isResponsing }) => {
|
||||
const { userProfile } = useContext(AppContext)
|
||||
const userName = userProfile?.name
|
||||
const ref = useRef(null)
|
||||
@@ -23,6 +25,7 @@ const Question: FC<IQuestionProps> = ({ id, content, more, useCurrentUserAvatar,
|
||||
return (
|
||||
<div className={`flex items-start justify-end ${isShowPromptLog && 'first-of-type:pt-[14px]'}`} key={id} ref={ref}>
|
||||
<div className={s.questionWrapWrap}>
|
||||
|
||||
<div className={`${s.question} group relative text-sm text-gray-900`}>
|
||||
{
|
||||
isShowPromptLog && !isResponsing && (
|
||||
@@ -32,6 +35,9 @@ const Question: FC<IQuestionProps> = ({ id, content, more, useCurrentUserAvatar,
|
||||
<div
|
||||
className={'mr-2 py-3 px-4 bg-blue-500 rounded-tl-2xl rounded-b-2xl'}
|
||||
>
|
||||
{imgSrcs && imgSrcs.length > 0 && (
|
||||
<ImageGallery srcs={imgSrcs} />
|
||||
)}
|
||||
<Markdown content={content} />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import type { Annotation, MessageRating } from '@/models/log'
|
||||
|
||||
import type { VisionFile } from '@/types/app'
|
||||
export type MessageMore = {
|
||||
time: string
|
||||
tokens: number
|
||||
@@ -67,6 +67,7 @@ export type IChatItem = {
|
||||
useCurrentUserAvatar?: boolean
|
||||
isOpeningStatement?: boolean
|
||||
log?: { role: string; text: string }[]
|
||||
message_files?: VisionFile[]
|
||||
}
|
||||
|
||||
export type MessageEnd = {
|
||||
|
||||
@@ -33,7 +33,7 @@ export type IConfigModelProps = {
|
||||
mode: string
|
||||
modelId: string
|
||||
provider: ProviderEnum
|
||||
setModel: (model: { id: string; provider: ProviderEnum; mode: ModelModeType }) => void
|
||||
setModel: (model: { id: string; provider: ProviderEnum; mode: ModelModeType; features: string[] }) => void
|
||||
completionParams: CompletionParams
|
||||
onCompletionParamsChange: (newParams: CompletionParams) => void
|
||||
disabled: boolean
|
||||
@@ -121,7 +121,7 @@ const ConfigModel: FC<IConfigModelProps> = ({
|
||||
return adjustedValue
|
||||
}
|
||||
|
||||
const handleSelectModel = ({ id, provider: nextProvider, mode }: { id: string; provider: ProviderEnum; mode: ModelModeType }) => {
|
||||
const handleSelectModel = ({ id, provider: nextProvider, mode, features }: { id: string; provider: ProviderEnum; mode: ModelModeType; features: string[] }) => {
|
||||
return async () => {
|
||||
const prevParamsRule = getAllParams()[provider]?.[modelId]
|
||||
|
||||
@@ -129,6 +129,7 @@ const ConfigModel: FC<IConfigModelProps> = ({
|
||||
id,
|
||||
provider: nextProvider || ProviderEnum.openai,
|
||||
mode,
|
||||
features,
|
||||
})
|
||||
|
||||
await ensureModelParamLoaded(nextProvider, id)
|
||||
@@ -320,6 +321,7 @@ const ConfigModel: FC<IConfigModelProps> = ({
|
||||
id: model.model_name,
|
||||
provider: model.model_provider.provider_name as ProviderEnum,
|
||||
mode: model.model_mode,
|
||||
features: model.features,
|
||||
})()
|
||||
}}
|
||||
/>
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
|
||||
export type IModelNameProps = {
|
||||
modelId: string
|
||||
@@ -16,19 +15,11 @@ export const supportI18nModelName = [
|
||||
]
|
||||
|
||||
const ModelName: FC<IModelNameProps> = ({
|
||||
modelId,
|
||||
modelDisplayName,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
let name = modelId
|
||||
if (supportI18nModelName.includes(modelId))
|
||||
name = t(`common.modelName.${modelId}`)
|
||||
else if (modelDisplayName)
|
||||
name = modelDisplayName
|
||||
|
||||
return (
|
||||
<span title={name}>
|
||||
{name}
|
||||
<span title={modelDisplayName}>
|
||||
{modelDisplayName}
|
||||
</span>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -169,7 +169,7 @@ const ConfigVar: FC<IConfigVarProps> = ({ promptVariables, readonly, onPromptVar
|
||||
}
|
||||
title={
|
||||
<div className='flex items-center'>
|
||||
<div className='ml-1 mr-1'>{t('appDebug.variableTitle')}</div>
|
||||
<div className='mr-1'>{t('appDebug.variableTitle')}</div>
|
||||
{!readonly && (
|
||||
<Tooltip htmlContent={<div className='w-[180px]'>
|
||||
{t('appDebug.variableTip')}
|
||||
|
||||
60
web/app/components/app/configuration/config-vision/index.tsx
Normal file
@@ -0,0 +1,60 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import Panel from '../base/feature-panel'
|
||||
import ParamConfig from './param-config'
|
||||
import { HelpCircle } from '@/app/components/base/icons/src/vender/line/general'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import Switch from '@/app/components/base/switch'
|
||||
import { Eye } from '@/app/components/base/icons/src/vender/solid/general'
|
||||
import ConfigContext from '@/context/debug-configuration'
|
||||
|
||||
const ConfigVision: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const {
|
||||
isShowVisionConfig,
|
||||
visionConfig,
|
||||
setVisionConfig,
|
||||
} = useContext(ConfigContext)
|
||||
|
||||
if (!isShowVisionConfig)
|
||||
return null
|
||||
|
||||
return (<>
|
||||
<Panel
|
||||
className="mt-4"
|
||||
headerIcon={
|
||||
<Eye className='w-4 h-4 text-[#6938EF]'/>
|
||||
}
|
||||
title={
|
||||
<div className='flex items-center'>
|
||||
<div className='mr-1'>{t('appDebug.vision.name')}</div>
|
||||
<Tooltip htmlContent={<div className='w-[180px]' >
|
||||
{t('appDebug.vision.description')}
|
||||
</div>} selector='config-vision-tooltip'>
|
||||
<HelpCircle className='w-[14px] h-[14px] text-gray-400' />
|
||||
</Tooltip>
|
||||
</div>
|
||||
}
|
||||
headerRight={
|
||||
<div className='flex items-center'>
|
||||
<ParamConfig />
|
||||
<div className='ml-4 mr-3 w-[1px] h-3.5 bg-gray-200'></div>
|
||||
<Switch
|
||||
defaultValue={visionConfig.enabled}
|
||||
onChange={value => setVisionConfig({
|
||||
...visionConfig,
|
||||
enabled: value,
|
||||
})}
|
||||
size='md'
|
||||
/>
|
||||
</div>
|
||||
}
|
||||
noBodySpacing
|
||||
/>
|
||||
</>
|
||||
)
|
||||
}
|
||||
export default React.memo(ConfigVision)
|
||||
@@ -0,0 +1,132 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import RadioGroup from './radio-group'
|
||||
import ConfigContext from '@/context/debug-configuration'
|
||||
import { Resolution, TransferMethod } from '@/types/app'
|
||||
import ParamItem from '@/app/components/base/param-item'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { HelpCircle } from '@/app/components/base/icons/src/vender/line/general'
|
||||
|
||||
const MIN = 1
|
||||
const MAX = 6
|
||||
const ParamConfigContent: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const {
|
||||
visionConfig,
|
||||
setVisionConfig,
|
||||
} = useContext(ConfigContext)
|
||||
|
||||
const transferMethod = (() => {
|
||||
if (!visionConfig.transfer_methods || visionConfig.transfer_methods.length === 2)
|
||||
return TransferMethod.all
|
||||
|
||||
return visionConfig.transfer_methods[0]
|
||||
})()
|
||||
|
||||
return (
|
||||
<div>
|
||||
<div>
|
||||
<div className='leading-6 text-base font-semibold text-gray-800'>{t('appDebug.vision.visionSettings.title')}</div>
|
||||
<div className='pt-3 space-y-6'>
|
||||
<div>
|
||||
<div className='mb-2 flex items-center space-x-1'>
|
||||
<div className='leading-[18px] text-[13px] font-semibold text-gray-800'>{t('appDebug.vision.visionSettings.resolution')}</div>
|
||||
<Tooltip htmlContent={<div className='w-[180px]' >
|
||||
{t('appDebug.vision.visionSettings.resolutionTooltip').split('\n').map(item => (
|
||||
<div key={item}>{item}</div>
|
||||
))}
|
||||
</div>} selector='config-resolution-tooltip'>
|
||||
<HelpCircle className='w-[14px] h-[14px] text-gray-400' />
|
||||
</Tooltip>
|
||||
</div>
|
||||
<RadioGroup
|
||||
className='space-x-3'
|
||||
options={[
|
||||
{
|
||||
label: t('appDebug.vision.visionSettings.high'),
|
||||
value: Resolution.high,
|
||||
},
|
||||
{
|
||||
label: t('appDebug.vision.visionSettings.low'),
|
||||
value: Resolution.low,
|
||||
},
|
||||
]}
|
||||
value={visionConfig.detail}
|
||||
onChange={(value: Resolution) => {
|
||||
setVisionConfig({
|
||||
...visionConfig,
|
||||
detail: value,
|
||||
})
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<div className='mb-2 leading-[18px] text-[13px] font-semibold text-gray-800'>{t('appDebug.vision.visionSettings.uploadMethod')}</div>
|
||||
<RadioGroup
|
||||
className='space-x-3'
|
||||
options={[
|
||||
{
|
||||
label: t('appDebug.vision.visionSettings.both'),
|
||||
value: TransferMethod.all,
|
||||
},
|
||||
{
|
||||
label: t('appDebug.vision.visionSettings.localUpload'),
|
||||
value: TransferMethod.local_file,
|
||||
},
|
||||
{
|
||||
label: t('appDebug.vision.visionSettings.url'),
|
||||
value: TransferMethod.remote_url,
|
||||
},
|
||||
]}
|
||||
value={transferMethod}
|
||||
onChange={(value: TransferMethod) => {
|
||||
if (value === TransferMethod.all) {
|
||||
setVisionConfig({
|
||||
...visionConfig,
|
||||
transfer_methods: [TransferMethod.remote_url, TransferMethod.local_file],
|
||||
})
|
||||
return
|
||||
}
|
||||
setVisionConfig({
|
||||
...visionConfig,
|
||||
transfer_methods: [value],
|
||||
})
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
<div>
|
||||
<ParamItem
|
||||
id='upload_limit'
|
||||
className=''
|
||||
name={t('appDebug.vision.visionSettings.uploadLimit')}
|
||||
noTooltip
|
||||
{...{
|
||||
default: 2,
|
||||
step: 1,
|
||||
min: MIN,
|
||||
max: MAX,
|
||||
}}
|
||||
value={visionConfig.number_limits}
|
||||
enable={true}
|
||||
onChange={(_key: string, value: number) => {
|
||||
if (!value)
|
||||
return
|
||||
|
||||
setVisionConfig({
|
||||
...visionConfig,
|
||||
number_limits: value,
|
||||
})
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(ParamConfigContent)
|
||||
@@ -0,0 +1,41 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import { memo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import cn from 'classnames'
|
||||
import ParamConfigContent from './param-config-content'
|
||||
import { Settings01 } from '@/app/components/base/icons/src/vender/line/general'
|
||||
import {
|
||||
PortalToFollowElem,
|
||||
PortalToFollowElemContent,
|
||||
PortalToFollowElemTrigger,
|
||||
} from '@/app/components/base/portal-to-follow-elem'
|
||||
|
||||
const ParamsConfig: FC = () => {
|
||||
const { t } = useTranslation()
|
||||
const [open, setOpen] = useState(false)
|
||||
|
||||
return (
|
||||
<PortalToFollowElem
|
||||
open={open}
|
||||
onOpenChange={setOpen}
|
||||
placement='bottom-end'
|
||||
offset={{
|
||||
mainAxis: 4,
|
||||
}}
|
||||
>
|
||||
<PortalToFollowElemTrigger onClick={() => setOpen(v => !v)}>
|
||||
<div className={cn('flex items-center rounded-md h-7 px-3 space-x-1 text-gray-700 cursor-pointer hover:bg-gray-200', open && 'bg-gray-200')}>
|
||||
<Settings01 className='w-3.5 h-3.5 ' />
|
||||
<div className='ml-1 leading-[18px] text-xs font-medium '>{t('appDebug.vision.settings')}</div>
|
||||
</div>
|
||||
</PortalToFollowElemTrigger>
|
||||
<PortalToFollowElemContent style={{ zIndex: 50 }}>
|
||||
<div className='w-[412px] p-4 bg-white rounded-lg border-[0.5px] border-gray-200 shadow-lg space-y-3'>
|
||||
<ParamConfigContent />
|
||||
</div>
|
||||
</PortalToFollowElemContent>
|
||||
</PortalToFollowElem>
|
||||
)
|
||||
}
|
||||
export default memo(ParamsConfig)
|
||||
@@ -0,0 +1,40 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import React from 'react'
|
||||
import cn from 'classnames'
|
||||
import s from './style.module.css'
|
||||
|
||||
type OPTION = {
|
||||
label: string
|
||||
value: any
|
||||
}
|
||||
|
||||
type Props = {
|
||||
className?: string
|
||||
options: OPTION[]
|
||||
value: any
|
||||
onChange: (value: any) => void
|
||||
}
|
||||
|
||||
const RadioGroup: FC<Props> = ({
|
||||
className = '',
|
||||
options,
|
||||
value,
|
||||
onChange,
|
||||
}) => {
|
||||
return (
|
||||
<div className={cn(className, 'flex')}>
|
||||
{options.map(item => (
|
||||
<div
|
||||
key={item.value}
|
||||
className={cn(s.item, item.value === value && s.checked)}
|
||||
onClick={() => onChange(item.value)}
|
||||
>
|
||||
<div className={s.radio}></div>
|
||||
<div className='text-[13px] font-medium text-gray-900'>{item.label}</div>
|
||||
</div>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
export default React.memo(RadioGroup)
|
||||
@@ -0,0 +1,24 @@
|
||||
.item {
|
||||
@apply grow flex items-center h-8 px-2.5 rounded-lg bg-gray-25 border border-gray-100 cursor-pointer space-x-2;
|
||||
}
|
||||
|
||||
.item:hover {
|
||||
background-color: #ffffff;
|
||||
border-color: #B2CCFF;
|
||||
box-shadow: 0px 12px 16px -4px rgba(16, 24, 40, 0.08), 0px 4px 6px -2px rgba(16, 24, 40, 0.03);
|
||||
}
|
||||
|
||||
.item.checked {
|
||||
background-color: #ffffff;
|
||||
border-color: #528BFF;
|
||||
box-shadow: 0px 1px 2px 0px rgba(16, 24, 40, 0.06), 0px 1px 3px 0px rgba(16, 24, 40, 0.10);
|
||||
}
|
||||
|
||||
.radio {
|
||||
@apply w-4 h-4 border-[2px] border-gray-200 rounded-full;
|
||||
}
|
||||
|
||||
.item.checked .radio {
|
||||
border-width: 5px;
|
||||
border-color: #155eef;
|
||||
}
|
||||
@@ -10,6 +10,7 @@ import ChatGroup from '../features/chat-group'
|
||||
import ExperienceEnchanceGroup from '../features/experience-enchance-group'
|
||||
import Toolbox from '../toolbox'
|
||||
import HistoryPanel from '../config-prompt/conversation-histroy/history-panel'
|
||||
import ConfigVision from '../config-vision'
|
||||
import AddFeatureBtn from './feature/add-feature-btn'
|
||||
import ChooseFeature from './feature/choose-feature'
|
||||
import useFeature from './feature/use-feature'
|
||||
@@ -193,6 +194,8 @@ const Config: FC = () => {
|
||||
|
||||
<Tools />
|
||||
|
||||
<ConfigVision />
|
||||
|
||||
{/* Chat History */}
|
||||
{isAdvancedMode && isChatApp && modelModeType === ModelModeType.completion && (
|
||||
<HistoryPanel
|
||||
|
||||
@@ -81,7 +81,7 @@ const DatasetConfig: FC = () => {
|
||||
>
|
||||
{hasData
|
||||
? (
|
||||
<div className='flex flex-wrap mt-1 px-3 justify-between'>
|
||||
<div className='flex flex-wrap mt-1 px-3 pb-3 justify-between'>
|
||||
{dataSet.map(item => (
|
||||
<CardItem
|
||||
key={item.id}
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import useSWR from 'swr'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import React, { useEffect, useRef, useState } from 'react'
|
||||
import cn from 'classnames'
|
||||
@@ -11,7 +12,7 @@ import HasNotSetAPIKEY from '../base/warning-mask/has-not-set-api'
|
||||
import FormattingChanged from '../base/warning-mask/formatting-changed'
|
||||
import GroupName from '../base/group-name'
|
||||
import CannotQueryDataset from '../base/warning-mask/cannot-query-dataset'
|
||||
import { AppType, ModelModeType } from '@/types/app'
|
||||
import { AppType, ModelModeType, TransferMethod } from '@/types/app'
|
||||
import PromptValuePanel, { replaceStringWithValues } from '@/app/components/app/configuration/prompt-value-panel'
|
||||
import type { IChatItem } from '@/app/components/app/chat/type'
|
||||
import Chat from '@/app/components/app/chat'
|
||||
@@ -19,12 +20,13 @@ import ConfigContext from '@/context/debug-configuration'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { fetchConvesationMessages, fetchSuggestedQuestions, sendChatMessage, sendCompletionMessage, stopChatMessageResponding } from '@/service/debug'
|
||||
import Button from '@/app/components/base/button'
|
||||
import type { ModelConfig as BackendModelConfig } from '@/types/app'
|
||||
import type { ModelConfig as BackendModelConfig, VisionFile } from '@/types/app'
|
||||
import { promptVariablesToUserInputsForm } from '@/utils/model-config'
|
||||
import TextGeneration from '@/app/components/app/text-generate/item'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import type { Inputs } from '@/models/debug'
|
||||
import { fetchFileUploadConfig } from '@/service/common'
|
||||
|
||||
type IDebug = {
|
||||
hasSetAPIKEY: boolean
|
||||
@@ -64,10 +66,12 @@ const Debug: FC<IDebug> = ({
|
||||
hasSetContextVar,
|
||||
datasetConfigs,
|
||||
externalDataToolsConfig,
|
||||
visionConfig,
|
||||
} = useContext(ConfigContext)
|
||||
const { speech2textDefaultModel } = useProviderContext()
|
||||
const [chatList, setChatList, getChatList] = useGetState<IChatItem[]>([])
|
||||
const chatListDomRef = useRef<HTMLDivElement>(null)
|
||||
const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig)
|
||||
useEffect(() => {
|
||||
// scroll to bottom
|
||||
if (chatListDomRef.current)
|
||||
@@ -161,17 +165,28 @@ const Debug: FC<IDebug> = ({
|
||||
logError(t('appDebug.errorMessage.valueOfVarRequired', { key: hasEmptyInput }))
|
||||
return false
|
||||
}
|
||||
|
||||
// eslint-disable-next-line @typescript-eslint/no-use-before-define
|
||||
if (completionFiles.find(item => item.transfer_method === TransferMethod.local_file && !item.upload_file_id)) {
|
||||
notify({ type: 'info', message: t('appDebug.errorMessage.waitForImgUpload') })
|
||||
return false
|
||||
}
|
||||
return !hasEmptyInput
|
||||
}
|
||||
|
||||
const doShowSuggestion = isShowSuggestion && !isResponsing
|
||||
const [suggestQuestions, setSuggestQuestions] = useState<string[]>([])
|
||||
const onSend = async (message: string) => {
|
||||
const onSend = async (message: string, files?: VisionFile[]) => {
|
||||
if (isResponsing) {
|
||||
notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') })
|
||||
return false
|
||||
}
|
||||
|
||||
if (files?.find(item => item.transfer_method === TransferMethod.local_file && !item.upload_file_id)) {
|
||||
notify({ type: 'info', message: t('appDebug.errorMessage.waitForImgUpload') })
|
||||
return false
|
||||
}
|
||||
|
||||
const postDatasets = dataSets.map(({ id }) => ({
|
||||
dataset: {
|
||||
enabled: true,
|
||||
@@ -207,6 +222,9 @@ const Debug: FC<IDebug> = ({
|
||||
completion_params: completionParams as any,
|
||||
},
|
||||
dataset_configs: datasetConfigs,
|
||||
file_upload: {
|
||||
image: visionConfig,
|
||||
},
|
||||
}
|
||||
|
||||
if (isAdvancedMode) {
|
||||
@@ -214,19 +232,32 @@ const Debug: FC<IDebug> = ({
|
||||
postModelConfig.completion_prompt_config = completionPromptConfig
|
||||
}
|
||||
|
||||
const data = {
|
||||
const data: Record<string, any> = {
|
||||
conversation_id: conversationId,
|
||||
inputs,
|
||||
query: message,
|
||||
model_config: postModelConfig,
|
||||
}
|
||||
|
||||
if (visionConfig.enabled && files && files?.length > 0) {
|
||||
data.files = files.map((item) => {
|
||||
if (item.transfer_method === TransferMethod.local_file) {
|
||||
return {
|
||||
...item,
|
||||
url: '',
|
||||
}
|
||||
}
|
||||
return item
|
||||
})
|
||||
}
|
||||
|
||||
// qustion
|
||||
const questionId = `question-${Date.now()}`
|
||||
const questionItem = {
|
||||
id: questionId,
|
||||
content: message,
|
||||
isAnswer: false,
|
||||
message_files: files,
|
||||
}
|
||||
|
||||
const placeholderAnswerId = `answer-placeholder-${Date.now()}`
|
||||
@@ -347,6 +378,7 @@ const Debug: FC<IDebug> = ({
|
||||
const [completionRes, setCompletionRes] = useState('')
|
||||
const [messageId, setMessageId] = useState<string | null>(null)
|
||||
|
||||
const [completionFiles, setCompletionFiles] = useState<VisionFile[]>([])
|
||||
const sendTextCompletion = async () => {
|
||||
if (isResponsing) {
|
||||
notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') })
|
||||
@@ -394,6 +426,9 @@ const Debug: FC<IDebug> = ({
|
||||
completion_params: completionParams as any,
|
||||
},
|
||||
dataset_configs: datasetConfigs,
|
||||
file_upload: {
|
||||
image: visionConfig,
|
||||
},
|
||||
}
|
||||
|
||||
if (isAdvancedMode) {
|
||||
@@ -401,11 +436,23 @@ const Debug: FC<IDebug> = ({
|
||||
postModelConfig.completion_prompt_config = completionPromptConfig
|
||||
}
|
||||
|
||||
const data = {
|
||||
const data: Record<string, any> = {
|
||||
inputs,
|
||||
model_config: postModelConfig,
|
||||
}
|
||||
|
||||
if (visionConfig.enabled && completionFiles && completionFiles?.length > 0) {
|
||||
data.files = completionFiles.map((item) => {
|
||||
if (item.transfer_method === TransferMethod.local_file) {
|
||||
return {
|
||||
...item,
|
||||
url: '',
|
||||
}
|
||||
}
|
||||
return item
|
||||
})
|
||||
}
|
||||
|
||||
setCompletionRes('')
|
||||
setMessageId('')
|
||||
let res: string[] = []
|
||||
@@ -448,6 +495,11 @@ const Debug: FC<IDebug> = ({
|
||||
appType={mode as AppType}
|
||||
onSend={sendTextCompletion}
|
||||
inputs={inputs}
|
||||
visionConfig={{
|
||||
...visionConfig,
|
||||
image_file_size_limit: fileUploadConfigResponse?.image_file_size_limit,
|
||||
}}
|
||||
onVisionFilesChange={setCompletionFiles}
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col grow">
|
||||
@@ -475,6 +527,10 @@ const Debug: FC<IDebug> = ({
|
||||
isShowCitation={citationConfig.enabled}
|
||||
isShowCitationHitInfo
|
||||
isShowPromptLog
|
||||
visionConfig={{
|
||||
...visionConfig,
|
||||
image_file_size_limit: fileUploadConfigResponse?.image_file_size_limit,
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -25,19 +25,19 @@ import type {
|
||||
} from '@/models/debug'
|
||||
import type { ExternalDataTool } from '@/models/common'
|
||||
import type { DataSet } from '@/models/datasets'
|
||||
import type { ModelConfig as BackendModelConfig } from '@/types/app'
|
||||
import type { ModelConfig as BackendModelConfig, VisionSettings } from '@/types/app'
|
||||
import ConfigContext from '@/context/debug-configuration'
|
||||
import ConfigModel from '@/app/components/app/configuration/config-model'
|
||||
import Config from '@/app/components/app/configuration/config'
|
||||
import Debug from '@/app/components/app/configuration/debug'
|
||||
import Confirm from '@/app/components/base/confirm'
|
||||
import { ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
|
||||
import { ModelFeature, ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
|
||||
import { ToastContext } from '@/app/components/base/toast'
|
||||
import { fetchAppDetail, updateAppModelConfig } from '@/service/apps'
|
||||
import { promptVariablesToUserInputsForm, userInputsFormToPromptVariables } from '@/utils/model-config'
|
||||
import { fetchDatasets } from '@/service/datasets'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { AppType, ModelModeType } from '@/types/app'
|
||||
import { AppType, ModelModeType, Resolution, TransferMethod } from '@/types/app'
|
||||
import { FlipBackward } from '@/app/components/base/icons/src/vender/line/arrows'
|
||||
import { PromptMode } from '@/models/debug'
|
||||
import { DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config'
|
||||
@@ -198,6 +198,7 @@ const Configuration: FC = () => {
|
||||
}
|
||||
|
||||
const { textGenerationModelList } = useProviderContext()
|
||||
const currModel = textGenerationModelList.find(item => item.model_name === modelConfig.model_id)
|
||||
const hasSetCustomAPIKEY = !!textGenerationModelList?.find(({ model_provider: provider }) => {
|
||||
if (provider.provider_type === 'system' && provider.quota_type === 'paid')
|
||||
return true
|
||||
@@ -271,7 +272,8 @@ const Configuration: FC = () => {
|
||||
id: modelId,
|
||||
provider,
|
||||
mode: modeMode,
|
||||
}: { id: string; provider: ProviderEnum; mode: ModelModeType }) => {
|
||||
features,
|
||||
}: { id: string; provider: ProviderEnum; mode: ModelModeType; features: string[] }) => {
|
||||
if (isAdvancedMode) {
|
||||
const appMode = mode
|
||||
|
||||
@@ -297,10 +299,31 @@ const Configuration: FC = () => {
|
||||
})
|
||||
|
||||
setModelConfig(newModelConfig)
|
||||
const supportVision = features && features.includes(ModelFeature.vision)
|
||||
// eslint-disable-next-line @typescript-eslint/no-use-before-define
|
||||
setVisionConfig({
|
||||
// eslint-disable-next-line @typescript-eslint/no-use-before-define
|
||||
...visionConfig,
|
||||
enabled: supportVision,
|
||||
}, true)
|
||||
}
|
||||
|
||||
const isShowVisionConfig = !!currModel?.features.includes(ModelFeature.vision)
|
||||
const [visionConfig, doSetVisionConfig] = useState({
|
||||
enabled: false,
|
||||
number_limits: 2,
|
||||
detail: Resolution.low,
|
||||
transfer_methods: [TransferMethod.local_file],
|
||||
})
|
||||
|
||||
const setVisionConfig = (config: VisionSettings, notNoticeFormattingChanged?: boolean) => {
|
||||
doSetVisionConfig(config)
|
||||
if (!notNoticeFormattingChanged)
|
||||
setFormattingChanged(true)
|
||||
}
|
||||
|
||||
useEffect(() => {
|
||||
fetchAppDetail({ url: '/apps', id: appId }).then(async (res) => {
|
||||
fetchAppDetail({ url: '/apps', id: appId }).then(async (res: any) => {
|
||||
setMode(res.mode)
|
||||
const modelConfig = res.model_config
|
||||
const promptMode = modelConfig.prompt_type === PromptMode.advanced ? PromptMode.advanced : PromptMode.simple
|
||||
@@ -362,6 +385,10 @@ const Configuration: FC = () => {
|
||||
},
|
||||
completionParams: model.completion_params,
|
||||
}
|
||||
|
||||
if (modelConfig.file_upload)
|
||||
setVisionConfig(modelConfig.file_upload.image, true)
|
||||
|
||||
syncToPublishedConfig(config)
|
||||
setPublishedConfig(config)
|
||||
setDatasetConfigs(modelConfig.dataset_configs)
|
||||
@@ -459,6 +486,9 @@ const Configuration: FC = () => {
|
||||
completion_params: completionParams as any,
|
||||
},
|
||||
dataset_configs: datasetConfigs,
|
||||
file_upload: {
|
||||
image: visionConfig,
|
||||
},
|
||||
}
|
||||
|
||||
if (isAdvancedMode) {
|
||||
@@ -557,6 +587,9 @@ const Configuration: FC = () => {
|
||||
datasetConfigs,
|
||||
setDatasetConfigs,
|
||||
hasSetContextVar,
|
||||
isShowVisionConfig,
|
||||
visionConfig,
|
||||
setVisionConfig,
|
||||
}}
|
||||
>
|
||||
<>
|
||||
|
||||
@@ -14,17 +14,23 @@ import { DEFAULT_VALUE_MAX_LEN } from '@/config'
|
||||
import Button from '@/app/components/base/button'
|
||||
import { ChevronDown, ChevronRight } from '@/app/components/base/icons/src/vender/line/arrows'
|
||||
import Tooltip from '@/app/components/base/tooltip-plus'
|
||||
import TextGenerationImageUploader from '@/app/components/base/image-uploader/text-generation-image-uploader'
|
||||
import type { VisionFile, VisionSettings } from '@/types/app'
|
||||
|
||||
export type IPromptValuePanelProps = {
|
||||
appType: AppType
|
||||
onSend?: () => void
|
||||
inputs: Inputs
|
||||
visionConfig: VisionSettings
|
||||
onVisionFilesChange: (files: VisionFile[]) => void
|
||||
}
|
||||
|
||||
const PromptValuePanel: FC<IPromptValuePanelProps> = ({
|
||||
appType,
|
||||
onSend,
|
||||
inputs,
|
||||
visionConfig,
|
||||
onVisionFilesChange,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const { modelModeType, modelConfig, setInputs, mode, isAdvancedMode, completionPromptConfig, chatPromptConfig } = useContext(ConfigContext)
|
||||
@@ -152,6 +158,24 @@ const PromptValuePanel: FC<IPromptValuePanelProps> = ({
|
||||
<div className='text-xs text-gray-500'>{t('appDebug.inputs.noVar')}</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
appType === AppType.completion && visionConfig?.enabled && (
|
||||
<div className="mt-3 xl:flex justify-between">
|
||||
<div className="mr-1 py-2 shrink-0 w-[120px] text-sm text-gray-900">Image Upload</div>
|
||||
<div className='grow'>
|
||||
<TextGenerationImageUploader
|
||||
settings={visionConfig}
|
||||
onFilesChange={files => onVisionFilesChange(files.filter(file => file.progress !== -1).map(fileItem => ({
|
||||
type: 'image',
|
||||
transfer_method: fileItem.type,
|
||||
url: fileItem.url,
|
||||
upload_file_id: fileItem.fileId,
|
||||
})))}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||