Compare commits

...

20 Commits

Author SHA1 Message Date
Garfield Dai
8835435558 fix: change model mode. (#1520) 2023-11-13 23:13:01 +08:00
takatost
a80d8286c2 feat: bump version to 0.3.30 (#1519) 2023-11-13 22:50:42 +08:00
zxhlyh
6b15827246 feat: [frontend] support vision (#1518)
Co-authored-by: Joel <iamjoel007@gmail.com>
2023-11-13 22:32:39 +08:00
takatost
41d0a8b295 feat: [backend] vision support (#1510)
Co-authored-by: Garfield Dai <dai.hai@foxmail.com>
2023-11-13 22:05:46 +08:00
crazywoola
d0e1ea8f06 1506 remove duplicated code (#1511) 2023-11-13 19:05:32 +08:00
zxhlyh
f3b9647bb4 feat: add spark 3.0 tip (#1516) 2023-11-13 18:01:37 +08:00
takatost
9de67c586f feat: update free plan rules of spark (#1515) 2023-11-13 17:00:36 +08:00
Charlie.Wei
92f594f5e7 Change Embedded chrome plugin Url (#1498)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
2023-11-10 16:44:26 +08:00
Benjamin
06d5273217 Fixed missing i18n app-debug.zh.ts items. (#1503) 2023-11-10 16:43:10 +08:00
crazywoola
94d7babbf1 feat: update the docs in forking applications (#1491) 2023-11-08 19:44:15 +08:00
Charlie.Wei
306216dbe5 application embedded add chrome && ChatBot Chrome plugin update v1.5 (#1480)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
2023-11-08 17:59:53 +08:00
zxhlyh
ab2e20ee0a fix: rename api based extension (#1485) 2023-11-08 13:03:50 +08:00
zxhlyh
146e95d88f fix: api extension selector (#1486) 2023-11-08 13:03:42 +08:00
takatost
d7ae86799c feat: support basic feature of OpenAI new models (#1476) 2023-11-07 04:05:59 -06:00
zxhlyh
7b26c9e2ef fix: code-based extension (#1477) 2023-11-07 17:56:07 +08:00
zxhlyh
6bcafdbc87 fix: openai model name (#1474) 2023-11-07 17:41:43 +08:00
takatost
059c089f93 fix: external data tool batch retrieve bug (#1472) 2023-11-07 01:28:22 -06:00
Garfield Dai
c1e7193c4b feat: hidden api key enhancement. (#1468) 2023-11-06 23:07:30 +08:00
takatost
2423563d45 fix: external data tool parse error (#1469) 2023-11-06 08:40:01 -06:00
takatost
260672986e fix: universal chat external_data_tools NPE (#1467) 2023-11-06 08:08:53 -06:00
174 changed files with 181286 additions and 839 deletions

View File

@@ -18,6 +18,9 @@ SERVICE_API_URL=http://127.0.0.1:5001
APP_API_URL=http://127.0.0.1:5001
APP_WEB_URL=http://127.0.0.1:3000
# Files URL
FILES_URL=http://127.0.0.1:5001
# celery configuration
CELERY_BROKER_URL=redis://:difyai123456@localhost:6379/1
@@ -70,6 +73,14 @@ MILVUS_USER=root
MILVUS_PASSWORD=Milvus
MILVUS_SECURE=false
# Upload configuration
UPLOAD_FILE_SIZE_LIMIT=15
UPLOAD_FILE_BATCH_LIMIT=5
UPLOAD_IMAGE_FILE_SIZE_LIMIT=10
# Model Configuration
MULTIMODAL_SEND_IMAGE_FORMAT=base64
# Mail configuration, support: resend
MAIL_TYPE=
MAIL_DEFAULT_SEND_FROM=no-reply <no-reply@dify.ai>

View File

@@ -126,6 +126,7 @@ def register_blueprints(app):
from controllers.service_api import bp as service_api_bp
from controllers.web import bp as web_bp
from controllers.console import bp as console_app_bp
from controllers.files import bp as files_bp
CORS(service_api_bp,
allow_headers=['Content-Type', 'Authorization', 'X-App-Code'],
@@ -155,6 +156,12 @@ def register_blueprints(app):
app.register_blueprint(console_app_bp)
CORS(files_bp,
allow_headers=['Content-Type'],
methods=['GET', 'PUT', 'POST', 'DELETE', 'OPTIONS', 'PATCH']
)
app.register_blueprint(files_bp)
# create app
app = create_app()

View File

@@ -26,6 +26,7 @@ DEFAULTS = {
'SERVICE_API_URL': 'https://api.dify.ai',
'APP_WEB_URL': 'https://udify.app',
'APP_API_URL': 'https://udify.app',
'FILES_URL': '',
'STORAGE_TYPE': 'local',
'STORAGE_LOCAL_PATH': 'storage',
'CHECK_UPDATE_URL': 'https://updates.dify.ai',
@@ -57,7 +58,9 @@ DEFAULTS = {
'CLEAN_DAY_SETTING': 30,
'UPLOAD_FILE_SIZE_LIMIT': 15,
'UPLOAD_FILE_BATCH_LIMIT': 5,
'OUTPUT_MODERATION_BUFFER_SIZE': 300
'UPLOAD_IMAGE_FILE_SIZE_LIMIT': 10,
'OUTPUT_MODERATION_BUFFER_SIZE': 300,
'MULTIMODAL_SEND_IMAGE_FORMAT': 'base64'
}
@@ -84,86 +87,65 @@ class Config:
"""Application configuration class."""
def __init__(self):
# app settings
self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL')
self.CURRENT_VERSION = "0.3.29"
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.3.30"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
self.TESTING = False
self.LOG_LEVEL = get_env('LOG_LEVEL')
# The backend URL prefix of the console API.
# used to concatenate the login authorization callback or notion integration callback.
self.CONSOLE_API_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_API_URL')
# The front-end URL prefix of the console web.
# used to concatenate some front-end addresses and for CORS configuration use.
self.CONSOLE_WEB_URL = get_env('CONSOLE_URL') if get_env('CONSOLE_URL') else get_env('CONSOLE_WEB_URL')
# WebApp API backend Url prefix.
# used to declare the back-end URL for the front-end API.
self.APP_API_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_API_URL')
# WebApp Url prefix.
# used to display WebAPP API Base Url to the front-end.
self.APP_WEB_URL = get_env('APP_URL') if get_env('APP_URL') else get_env('APP_WEB_URL')
# Service API Url prefix.
# used to display Service API Base Url to the front-end.
self.SERVICE_API_URL = get_env('API_URL') if get_env('API_URL') else get_env('SERVICE_API_URL')
# File preview or download Url prefix.
# used to display File preview or download Url to the front-end or as Multi-model inputs;
# Url is signed and has expiration time.
self.FILES_URL = get_env('FILES_URL') if get_env('FILES_URL') else self.CONSOLE_API_URL
# Fallback Url prefix.
# Will be deprecated in the future.
self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL')
# Your App secret key will be used for securely signing the session cookie
# Make sure you are changing this key for your deployment with a strong key.
# You can generate a strong key using `openssl rand -base64 42`.
# Alternatively you can set it with `SECRET_KEY` environment variable.
self.SECRET_KEY = get_env('SECRET_KEY')
# redis settings
self.REDIS_HOST = get_env('REDIS_HOST')
self.REDIS_PORT = get_env('REDIS_PORT')
self.REDIS_USERNAME = get_env('REDIS_USERNAME')
self.REDIS_PASSWORD = get_env('REDIS_PASSWORD')
self.REDIS_DB = get_env('REDIS_DB')
self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
# storage settings
self.STORAGE_TYPE = get_env('STORAGE_TYPE')
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
self.S3_ENDPOINT = get_env('S3_ENDPOINT')
self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME')
self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY')
self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
self.S3_REGION = get_env('S3_REGION')
# vector store settings, only support weaviate, qdrant
self.VECTOR_STORE = get_env('VECTOR_STORE')
# weaviate settings
self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY')
self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED')
self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE'))
# qdrant settings
self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
# milvus setting
self.MILVUS_HOST = get_env('MILVUS_HOST')
self.MILVUS_PORT = get_env('MILVUS_PORT')
self.MILVUS_USER = get_env('MILVUS_USER')
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
# cors settings
self.CONSOLE_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
'CONSOLE_CORS_ALLOW_ORIGINS', self.CONSOLE_WEB_URL)
self.WEB_API_CORS_ALLOW_ORIGINS = get_cors_allow_origins(
'WEB_API_CORS_ALLOW_ORIGINS', '*')
# mail settings
self.MAIL_TYPE = get_env('MAIL_TYPE')
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
# sentry settings
self.SENTRY_DSN = get_env('SENTRY_DSN')
self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE'))
self.SENTRY_PROFILES_SAMPLE_RATE = float(get_env('SENTRY_PROFILES_SAMPLE_RATE'))
# check update url
self.CHECK_UPDATE_URL = get_env('CHECK_UPDATE_URL')
# database settings
# ------------------------
# Database Configurations.
# ------------------------
db_credentials = {
key: get_env(key) for key in
['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT', 'DB_DATABASE']
@@ -177,14 +159,102 @@ class Config:
self.SQLALCHEMY_ECHO = get_bool_env('SQLALCHEMY_ECHO')
# celery settings
# ------------------------
# Redis Configurations.
# ------------------------
self.REDIS_HOST = get_env('REDIS_HOST')
self.REDIS_PORT = get_env('REDIS_PORT')
self.REDIS_USERNAME = get_env('REDIS_USERNAME')
self.REDIS_PASSWORD = get_env('REDIS_PASSWORD')
self.REDIS_DB = get_env('REDIS_DB')
self.REDIS_USE_SSL = get_bool_env('REDIS_USE_SSL')
# ------------------------
# Celery worker Configurations.
# ------------------------
self.CELERY_BROKER_URL = get_env('CELERY_BROKER_URL')
self.CELERY_BACKEND = get_env('CELERY_BACKEND')
self.CELERY_RESULT_BACKEND = 'db+{}'.format(self.SQLALCHEMY_DATABASE_URI) \
if self.CELERY_BACKEND == 'database' else self.CELERY_BROKER_URL
self.BROKER_USE_SSL = self.CELERY_BROKER_URL.startswith('rediss://')
# hosted provider credentials
# ------------------------
# File Storage Configurations.
# ------------------------
self.STORAGE_TYPE = get_env('STORAGE_TYPE')
self.STORAGE_LOCAL_PATH = get_env('STORAGE_LOCAL_PATH')
self.S3_ENDPOINT = get_env('S3_ENDPOINT')
self.S3_BUCKET_NAME = get_env('S3_BUCKET_NAME')
self.S3_ACCESS_KEY = get_env('S3_ACCESS_KEY')
self.S3_SECRET_KEY = get_env('S3_SECRET_KEY')
self.S3_REGION = get_env('S3_REGION')
# ------------------------
# Vector Store Configurations.
# Currently, only support: qdrant, milvus, zilliz, weaviate
# ------------------------
self.VECTOR_STORE = get_env('VECTOR_STORE')
# qdrant settings
self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
# milvus / zilliz setting
self.MILVUS_HOST = get_env('MILVUS_HOST')
self.MILVUS_PORT = get_env('MILVUS_PORT')
self.MILVUS_USER = get_env('MILVUS_USER')
self.MILVUS_PASSWORD = get_env('MILVUS_PASSWORD')
self.MILVUS_SECURE = get_env('MILVUS_SECURE')
# weaviate settings
self.WEAVIATE_ENDPOINT = get_env('WEAVIATE_ENDPOINT')
self.WEAVIATE_API_KEY = get_env('WEAVIATE_API_KEY')
self.WEAVIATE_GRPC_ENABLED = get_bool_env('WEAVIATE_GRPC_ENABLED')
self.WEAVIATE_BATCH_SIZE = int(get_env('WEAVIATE_BATCH_SIZE'))
# ------------------------
# Mail Configurations.
# ------------------------
self.MAIL_TYPE = get_env('MAIL_TYPE')
self.MAIL_DEFAULT_SEND_FROM = get_env('MAIL_DEFAULT_SEND_FROM')
self.RESEND_API_KEY = get_env('RESEND_API_KEY')
# ------------------------
# Sentry Configurations.
# ------------------------
self.SENTRY_DSN = get_env('SENTRY_DSN')
self.SENTRY_TRACES_SAMPLE_RATE = float(get_env('SENTRY_TRACES_SAMPLE_RATE'))
self.SENTRY_PROFILES_SAMPLE_RATE = float(get_env('SENTRY_PROFILES_SAMPLE_RATE'))
# ------------------------
# Business Configurations.
# ------------------------
# multi model send image format, support base64, url, default is base64
self.MULTIMODAL_SEND_IMAGE_FORMAT = get_env('MULTIMODAL_SEND_IMAGE_FORMAT')
# Dataset Configurations.
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
# File upload Configurations.
self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
self.UPLOAD_IMAGE_FILE_SIZE_LIMIT = int(get_env('UPLOAD_IMAGE_FILE_SIZE_LIMIT'))
# Moderation in app Configurations.
self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE'))
# Notion integration setting
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
# ------------------------
# Platform Configurations.
# ------------------------
self.HOSTED_OPENAI_ENABLED = get_bool_env('HOSTED_OPENAI_ENABLED')
self.HOSTED_OPENAI_API_KEY = get_env('HOSTED_OPENAI_API_KEY')
self.HOSTED_OPENAI_API_BASE = get_env('HOSTED_OPENAI_API_BASE')
@@ -212,26 +282,6 @@ class Config:
self.HOSTED_MODERATION_ENABLED = get_bool_env('HOSTED_MODERATION_ENABLED')
self.HOSTED_MODERATION_PROVIDERS = get_env('HOSTED_MODERATION_PROVIDERS')
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')
# notion import setting
self.NOTION_CLIENT_ID = get_env('NOTION_CLIENT_ID')
self.NOTION_CLIENT_SECRET = get_env('NOTION_CLIENT_SECRET')
self.NOTION_INTEGRATION_TYPE = get_env('NOTION_INTEGRATION_TYPE')
self.NOTION_INTERNAL_SECRET = get_env('NOTION_INTERNAL_SECRET')
self.NOTION_INTEGRATION_TOKEN = get_env('NOTION_INTEGRATION_TOKEN')
self.TENANT_DOCUMENT_COUNT = get_env('TENANT_DOCUMENT_COUNT')
self.CLEAN_DAY_SETTING = get_env('CLEAN_DAY_SETTING')
# uploading settings
self.UPLOAD_FILE_SIZE_LIMIT = int(get_env('UPLOAD_FILE_SIZE_LIMIT'))
self.UPLOAD_FILE_BATCH_LIMIT = int(get_env('UPLOAD_FILE_BATCH_LIMIT'))
# moderation settings
self.OUTPUT_MODERATION_BUFFER_SIZE = int(get_env('OUTPUT_MODERATION_BUFFER_SIZE'))
class CloudEditionConfig(Config):
@@ -246,18 +296,5 @@ class CloudEditionConfig(Config):
self.GOOGLE_CLIENT_SECRET = get_env('GOOGLE_CLIENT_SECRET')
self.OAUTH_REDIRECT_PATH = get_env('OAUTH_REDIRECT_PATH')
class TestConfig(Config):
def __init__(self):
super().__init__()
self.EDITION = "SELF_HOSTED"
self.TESTING = True
db_credentials = {
key: get_env(key) for key in ['DB_USERNAME', 'DB_PASSWORD', 'DB_HOST', 'DB_PORT']
}
# use a different database for testing: dify_test
self.SQLALCHEMY_DATABASE_URI = f"postgresql://{db_credentials['DB_USERNAME']}:{db_credentials['DB_PASSWORD']}@{db_credentials['DB_HOST']}:{db_credentials['DB_PORT']}/dify_test"
self.STRIPE_API_KEY = get_env('STRIPE_API_KEY')
self.STRIPE_WEBHOOK_SECRET = get_env('STRIPE_WEBHOOK_SECRET')

View File

@@ -40,12 +40,14 @@ class CompletionMessageApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
args = parser.parse_args()
streaming = args['response_mode'] != 'blocking'
args['auto_generate_name'] = False
account = flask_login.current_user
@@ -113,6 +115,7 @@ class ChatMessageApi(Resource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('model_config', type=dict, required=True, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
@@ -120,6 +123,7 @@ class ChatMessageApi(Resource):
args = parser.parse_args()
streaming = args['response_mode'] != 'blocking'
args['auto_generate_name'] = False
account = flask_login.current_user

View File

@@ -108,7 +108,7 @@ class CompletionConversationDetailApi(Resource):
conversation_id = str(conversation_id)
return _get_conversation(app_id, conversation_id, 'completion')
@setup_required
@login_required
@account_initialization_required
@@ -230,7 +230,7 @@ class ChatConversationDetailApi(Resource):
conversation_id = str(conversation_id)
return _get_conversation(app_id, conversation_id, 'chat')
@setup_required
@login_required
@account_initialization_required
@@ -253,8 +253,6 @@ class ChatConversationDetailApi(Resource):
return {'result': 'success'}, 204
api.add_resource(CompletionConversationApi, '/apps/<uuid:app_id>/completion-conversations')
api.add_resource(CompletionConversationDetailApi, '/apps/<uuid:app_id>/completion-conversations/<uuid:conversation_id>')
api.add_resource(ChatConversationApi, '/apps/<uuid:app_id>/chat-conversations')

View File

@@ -1,7 +1,6 @@
import datetime
import json
from cachetools import TTLCache
from flask import request
from flask_login import current_user
from libs.login import login_required
@@ -20,8 +19,6 @@ from models.source import DataSourceBinding
from services.dataset_service import DatasetService, DocumentService
from tasks.document_indexing_sync_task import document_indexing_sync_task
cache = TTLCache(maxsize=None, ttl=30)
class DataSourceApi(Resource):

View File

@@ -1,5 +1,5 @@
from cachetools import TTLCache
from flask import request, current_app
from flask_login import current_user
import services
from libs.login import login_required
@@ -15,9 +15,6 @@ from fields.file_fields import upload_config_fields, file_fields
from services.file_service import FileService
cache = TTLCache(maxsize=None, ttl=30)
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
PREVIEW_WORDS_LIMIT = 3000
@@ -30,9 +27,11 @@ class FileApi(Resource):
def get(self):
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT")
batch_count_limit = current_app.config.get("UPLOAD_FILE_BATCH_LIMIT")
image_file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT")
return {
'file_size_limit': file_size_limit,
'batch_count_limit': batch_count_limit
'batch_count_limit': batch_count_limit,
'image_file_size_limit': image_file_size_limit
}, 200
@setup_required
@@ -51,7 +50,7 @@ class FileApi(Resource):
if len(request.files) > 1:
raise TooManyFilesError()
try:
upload_file = FileService.upload_file(file)
upload_file = FileService.upload_file(file, current_user)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:

View File

@@ -1,6 +1,7 @@
# -*- coding:utf-8 -*-
import json
import logging
from datetime import datetime
from typing import Generator, Union
from flask import Response, stream_with_context
@@ -17,6 +18,7 @@ from controllers.console.explore.wraps import InstalledAppResource
from core.conversation_message_task import PubHandler
from core.model_providers.error import LLMBadRequestError, LLMAPIUnavailableError, LLMAuthorizationError, LLMAPIConnectionError, \
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from extensions.ext_database import db
from libs.helper import uuid_value
from services.completion_service import CompletionService
@@ -32,11 +34,16 @@ class CompletionApi(InstalledAppResource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
args['auto_generate_name'] = False
installed_app.last_used_at = datetime.utcnow()
db.session.commit()
try:
response = CompletionService.completion(
@@ -91,12 +98,17 @@ class ChatApi(InstalledAppResource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='explore_app', location='json')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
args['auto_generate_name'] = False
installed_app.last_used_at = datetime.utcnow()
db.session.commit()
try:
response = CompletionService.completion(

View File

@@ -38,7 +38,8 @@ class ConversationListApi(InstalledAppResource):
user=current_user,
last_id=args['last_id'],
limit=args['limit'],
pinned=pinned
pinned=pinned,
exclude_debug_conversation=True
)
except LastConversationNotExistsError:
raise NotFound("Last Conversation Not Exists.")
@@ -71,11 +72,18 @@ class ConversationRenameApi(InstalledAppResource):
conversation_id = str(c_id)
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
args = parser.parse_args()
try:
return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
return ConversationService.rename(
app_model,
conversation_id,
current_user,
args['name'],
args['auto_generate']
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@@ -39,8 +39,9 @@ class InstalledAppsListApi(Resource):
}
for installed_app in installed_apps
]
installed_apps.sort(key=lambda app: (-app['is_pinned'], app['last_used_at']
if app['last_used_at'] is not None else datetime.min))
installed_apps.sort(key=lambda app: (-app['is_pinned'],
app['last_used_at'] is None,
-app['last_used_at'].timestamp() if app['last_used_at'] is not None else 0))
return {'installed_apps': installed_apps}

View File

@@ -1,5 +1,6 @@
# -*- coding:utf-8 -*-
from flask_restful import marshal_with, fields
from flask import current_app
from controllers.console import api
from controllers.console.explore.wraps import InstalledAppResource
@@ -19,6 +20,10 @@ class AppParameterApi(InstalledAppResource):
'options': fields.List(fields.String)
}
system_parameters_fields = {
'image_file_size_limit': fields.String
}
parameters_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
@@ -27,7 +32,9 @@ class AppParameterApi(InstalledAppResource):
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw
'sensitive_word_avoidance': fields.Raw,
'file_upload': fields.Raw,
'system_parameters': fields.Nested(system_parameters_fields)
}
@marshal_with(parameters_fields)
@@ -44,7 +51,11 @@ class AppParameterApi(InstalledAppResource):
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
'file_upload': app_model_config.file_upload_dict,
'system_parameters': {
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
}
}

View File

@@ -9,6 +9,7 @@ from controllers.console.explore.wraps import InstalledAppResource
from libs.helper import uuid_value, TimestampField
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
from fields.conversation_fields import message_file_fields
feedback_fields = {
'rating': fields.String
@@ -19,6 +20,7 @@ message_fields = {
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'created_at': TimestampField
}

View File

@@ -25,6 +25,7 @@ class UniversalChatApi(UniversalChatResource):
parser = reqparse.RequestParser()
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('provider', type=str, required=True, location='json')
parser.add_argument('model', type=str, required=True, location='json')
@@ -60,6 +61,8 @@ class UniversalChatApi(UniversalChatResource):
del args['model']
del args['tools']
args['auto_generate_name'] = False
try:
response = CompletionService.completion(
app_model=app_model,

View File

@@ -65,11 +65,18 @@ class UniversalChatConversationRenameApi(UniversalChatResource):
conversation_id = str(c_id)
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
args = parser.parse_args()
try:
return ConversationService.rename(app_model, conversation_id, current_user, args['name'])
return ConversationService.rename(
app_model,
conversation_id,
current_user,
args['name'],
args['auto_generate']
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@@ -0,0 +1,10 @@
# -*- coding:utf-8 -*-
from flask import Blueprint
from libs.external_api import ExternalApi
bp = Blueprint('files', __name__)
api = ExternalApi(bp)
from . import image_preview

View File

@@ -0,0 +1,40 @@
from flask import request, Response
from flask_restful import Resource
import services
from controllers.files import api
from libs.exception import BaseHTTPException
from services.file_service import FileService
class ImagePreviewApi(Resource):
def get(self, file_id):
file_id = str(file_id)
timestamp = request.args.get('timestamp')
nonce = request.args.get('nonce')
sign = request.args.get('sign')
if not timestamp or not nonce or not sign:
return {'content': 'Invalid request.'}, 400
try:
generator, mimetype = FileService.get_image_preview(
file_id,
timestamp,
nonce,
sign
)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return Response(generator, mimetype=mimetype)
api.add_resource(ImagePreviewApi, '/files/<uuid:file_id>/image-preview')
class UnsupportedFileTypeError(BaseHTTPException):
error_code = 'unsupported_file_type'
description = "File type not allowed."
code = 415

View File

@@ -7,6 +7,6 @@ bp = Blueprint('service_api', __name__, url_prefix='/v1')
api = ExternalApi(bp)
from .app import completion, app, conversation, message, audio
from .app import completion, app, conversation, message, audio, file
from .dataset import document, segment, dataset

View File

@@ -1,5 +1,6 @@
# -*- coding:utf-8 -*-
from flask_restful import fields, marshal_with
from flask import current_app
from controllers.service_api import api
from controllers.service_api.wraps import AppApiResource
@@ -20,6 +21,10 @@ class AppParameterApi(AppApiResource):
'options': fields.List(fields.String)
}
system_parameters_fields = {
'image_file_size_limit': fields.String
}
parameters_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
@@ -28,7 +33,9 @@ class AppParameterApi(AppApiResource):
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw
'sensitive_word_avoidance': fields.Raw,
'file_upload': fields.Raw,
'system_parameters': fields.Nested(system_parameters_fields)
}
@marshal_with(parameters_fields)
@@ -44,7 +51,11 @@ class AppParameterApi(AppApiResource):
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
'file_upload': app_model_config.file_upload_dict,
'system_parameters': {
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
}
}

View File

@@ -28,6 +28,7 @@ class CompletionApi(AppApiResource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
@@ -39,13 +40,15 @@ class CompletionApi(AppApiResource):
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
args['auto_generate_name'] = False
try:
response = CompletionService.completion(
app_model=app_model,
user=end_user,
args=args,
from_source='api',
streaming=streaming
streaming=streaming,
)
return compact_response(response)
@@ -90,10 +93,12 @@ class ChatApi(AppApiResource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='dev', location='json')
parser.add_argument('auto_generate_name', type=bool, required=False, default='True', location='json')
args = parser.parse_args()

View File

@@ -65,15 +65,22 @@ class ConversationRenameApi(AppApiResource):
conversation_id = str(c_id)
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('user', type=str, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
args = parser.parse_args()
if end_user is None and args['user'] is not None:
end_user = create_or_update_end_user_for_user_id(app_model, args['user'])
try:
return ConversationService.rename(app_model, conversation_id, end_user, args['name'])
return ConversationService.rename(
app_model,
conversation_id,
end_user,
args['name'],
args['auto_generate']
)
except services.errors.conversation.ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@@ -75,3 +75,26 @@ class ProviderNotSupportSpeechToTextError(BaseHTTPException):
description = "Provider not support speech to text."
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = 'no_file_uploaded'
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = 'too_many_files'
description = "Only one file is allowed."
code = 400
class FileTooLargeError(BaseHTTPException):
error_code = 'file_too_large'
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = 'unsupported_file_type'
description = "File type not allowed."
code = 415

View File

@@ -0,0 +1,42 @@
from flask import request
from flask_restful import marshal_with
from controllers.service_api import api
from controllers.service_api.wraps import AppApiResource
from controllers.service_api.app import create_or_update_end_user_for_user_id
from controllers.service_api.app.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
UnsupportedFileTypeError
import services
from services.file_service import FileService
from fields.file_fields import file_fields
class FileApi(AppApiResource):
@marshal_with(file_fields)
def post(self, app_model, end_user):
file = request.files['file']
user_args = request.form.get('user')
if end_user is None and user_args is not None:
end_user = create_or_update_end_user_for_user_id(app_model, user_args)
# check file
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
try:
upload_file = FileService.upload_file(file, end_user)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return upload_file, 201
api.add_resource(FileApi, '/files/upload')

View File

@@ -12,7 +12,7 @@ from libs.helper import TimestampField, uuid_value
from services.message_service import MessageService
from extensions.ext_database import db
from models.model import Message, EndUser
from fields.conversation_fields import message_file_fields
class MessageListApi(AppApiResource):
feedback_fields = {
@@ -43,6 +43,7 @@ class MessageListApi(AppApiResource):
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField

View File

@@ -2,6 +2,7 @@ import json
from flask import request
from flask_restful import reqparse, marshal
from flask_login import current_user
from sqlalchemy import desc
from werkzeug.exceptions import NotFound
@@ -173,7 +174,7 @@ class DocumentAddByFileApi(DatasetApiResource):
if len(request.files) > 1:
raise TooManyFilesError()
upload_file = FileService.upload_file(file)
upload_file = FileService.upload_file(file, current_user)
data_source = {
'type': 'upload_file',
'info_list': {
@@ -235,7 +236,7 @@ class DocumentUpdateByFileApi(DatasetApiResource):
if len(request.files) > 1:
raise TooManyFilesError()
upload_file = FileService.upload_file(file)
upload_file = FileService.upload_file(file, current_user)
data_source = {
'type': 'upload_file',
'info_list': {

View File

@@ -7,4 +7,4 @@ bp = Blueprint('web', __name__, url_prefix='/api')
api = ExternalApi(bp)
from . import completion, app, conversation, message, site, saved_message, audio, passport
from . import completion, app, conversation, message, site, saved_message, audio, passport, file

View File

@@ -1,5 +1,6 @@
# -*- coding:utf-8 -*-
from flask_restful import marshal_with, fields
from flask import current_app
from controllers.web import api
from controllers.web.wraps import WebApiResource
@@ -19,6 +20,10 @@ class AppParameterApi(WebApiResource):
'options': fields.List(fields.String)
}
system_parameters_fields = {
'image_file_size_limit': fields.String
}
parameters_fields = {
'opening_statement': fields.String,
'suggested_questions': fields.Raw,
@@ -27,7 +32,9 @@ class AppParameterApi(WebApiResource):
'retriever_resource': fields.Raw,
'more_like_this': fields.Raw,
'user_input_form': fields.Raw,
'sensitive_word_avoidance': fields.Raw
'sensitive_word_avoidance': fields.Raw,
'file_upload': fields.Raw,
'system_parameters': fields.Nested(system_parameters_fields)
}
@marshal_with(parameters_fields)
@@ -43,7 +50,11 @@ class AppParameterApi(WebApiResource):
'retriever_resource': app_model_config.retriever_resource_dict,
'more_like_this': app_model_config.more_like_this_dict,
'user_input_form': app_model_config.user_input_form_list,
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict
'sensitive_word_avoidance': app_model_config.sensitive_word_avoidance_dict,
'file_upload': app_model_config.file_upload_dict,
'system_parameters': {
'image_file_size_limit': current_app.config.get('UPLOAD_IMAGE_FILE_SIZE_LIMIT')
}
}

View File

@@ -30,12 +30,14 @@ class CompletionApi(WebApiResource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, location='json', default='')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
args['auto_generate_name'] = False
try:
response = CompletionService.completion(
@@ -88,6 +90,7 @@ class ChatApi(WebApiResource):
parser = reqparse.RequestParser()
parser.add_argument('inputs', type=dict, required=True, location='json')
parser.add_argument('query', type=str, required=True, location='json')
parser.add_argument('files', type=list, required=False, location='json')
parser.add_argument('response_mode', type=str, choices=['blocking', 'streaming'], location='json')
parser.add_argument('conversation_id', type=uuid_value, location='json')
parser.add_argument('retriever_from', type=str, required=False, default='web_app', location='json')
@@ -95,6 +98,7 @@ class ChatApi(WebApiResource):
args = parser.parse_args()
streaming = args['response_mode'] == 'streaming'
args['auto_generate_name'] = False
try:
response = CompletionService.completion(

View File

@@ -67,11 +67,18 @@ class ConversationRenameApi(WebApiResource):
conversation_id = str(c_id)
parser = reqparse.RequestParser()
parser.add_argument('name', type=str, required=True, location='json')
parser.add_argument('name', type=str, required=False, location='json')
parser.add_argument('auto_generate', type=bool, required=False, default='False', location='json')
args = parser.parse_args()
try:
return ConversationService.rename(app_model, conversation_id, end_user, args['name'])
return ConversationService.rename(
app_model,
conversation_id,
end_user,
args['name'],
args['auto_generate']
)
except ConversationNotExistsError:
raise NotFound("Conversation Not Exists.")

View File

@@ -85,4 +85,28 @@ class UnsupportedAudioTypeError(BaseHTTPException):
class ProviderNotSupportSpeechToTextError(BaseHTTPException):
error_code = 'provider_not_support_speech_to_text'
description = "Provider not support speech to text."
code = 400
code = 400
class NoFileUploadedError(BaseHTTPException):
error_code = 'no_file_uploaded'
description = "Please upload your file."
code = 400
class TooManyFilesError(BaseHTTPException):
error_code = 'too_many_files'
description = "Only one file is allowed."
code = 400
class FileTooLargeError(BaseHTTPException):
error_code = 'file_too_large'
description = "File size exceeded. {message}"
code = 413
class UnsupportedFileTypeError(BaseHTTPException):
error_code = 'unsupported_file_type'
description = "File type not allowed."
code = 415

View File

@@ -0,0 +1,36 @@
from flask import request
from flask_restful import marshal_with
from controllers.web import api
from controllers.web.wraps import WebApiResource
from controllers.web.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
UnsupportedFileTypeError
import services
from services.file_service import FileService
from fields.file_fields import file_fields
class FileApi(WebApiResource):
@marshal_with(file_fields)
def post(self, app_model, end_user):
# get file from request
file = request.files['file']
# check file
if 'file' not in request.files:
raise NoFileUploadedError()
if len(request.files) > 1:
raise TooManyFilesError()
try:
upload_file = FileService.upload_file(file, end_user)
except services.errors.file.FileTooLargeError as file_too_large_error:
raise FileTooLargeError(file_too_large_error.description)
except services.errors.file.UnsupportedFileTypeError:
raise UnsupportedFileTypeError()
return upload_file, 201
api.add_resource(FileApi, '/files/upload')

View File

@@ -22,6 +22,7 @@ from services.errors.app import MoreLikeThisDisabledError
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService
from fields.conversation_fields import message_file_fields
class MessageListApi(WebApiResource):
@@ -54,6 +55,7 @@ class MessageListApi(WebApiResource):
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField

View File

@@ -8,6 +8,8 @@ from controllers.web.wraps import WebApiResource
from libs.helper import uuid_value, TimestampField
from services.errors.message import MessageNotExistsError
from services.saved_message_service import SavedMessageService
from fields.conversation_fields import message_file_fields
feedback_fields = {
'rating': fields.String
@@ -18,6 +20,7 @@ message_fields = {
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'created_at': TimestampField
}

View File

@@ -11,7 +11,8 @@ from pydantic import BaseModel
from core.callback_handler.entity.llm_message import LLMMessage
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage
from core.model_providers.models.entity.message import to_prompt_messages, PromptMessage, LCHumanMessageWithFiles, \
ImagePromptMessageFile
from core.model_providers.models.llm.base import BaseLLM
from core.moderation.base import ModerationOutputsResult, ModerationAction
from core.moderation.factory import ModerationFactory
@@ -72,7 +73,12 @@ class LLMCallbackHandler(BaseCallbackHandler):
real_prompts.append({
"role": role,
"text": message.content
"text": message.content,
"files": [{
"type": file.type.value,
"data": file.data[:10] + '...[TRUNCATED]...' + file.data[-10:],
"detail": file.detail.value if isinstance(file, ImagePromptMessageFile) else None,
} for file in (message.files if isinstance(message, LCHumanMessageWithFiles) else [])]
})
self.llm_message.prompt = real_prompts

View File

@@ -13,11 +13,12 @@ from core.callback_handler.llm_callback_handler import LLMCallbackHandler
from core.conversation_message_task import ConversationMessageTask, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.external_data_tool.factory import ExternalDataToolFactory
from core.file.file_obj import FileObj
from core.model_providers.error import LLMBadRequestError
from core.memory.read_only_conversation_token_db_buffer_shared_memory import \
ReadOnlyConversationTokenDBBufferSharedMemory
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.message import PromptMessage, PromptMessageFile
from core.model_providers.models.llm.base import BaseLLM
from core.orchestrator_rule_parser import OrchestratorRuleParser
from core.prompt.prompt_template import PromptTemplateParser
@@ -30,8 +31,9 @@ from core.moderation.factory import ModerationFactory
class Completion:
@classmethod
def generate(cls, task_id: str, app: App, app_model_config: AppModelConfig, query: str, inputs: dict,
user: Union[Account, EndUser], conversation: Optional[Conversation], streaming: bool,
is_override: bool = False, retriever_from: str = 'dev'):
files: List[FileObj], user: Union[Account, EndUser], conversation: Optional[Conversation],
streaming: bool, is_override: bool = False, retriever_from: str = 'dev',
auto_generate_name: bool = True):
"""
errors: ProviderTokenNotInitError
"""
@@ -64,16 +66,21 @@ class Completion:
is_override=is_override,
inputs=inputs,
query=query,
files=files,
streaming=streaming,
model_instance=final_model_instance
model_instance=final_model_instance,
auto_generate_name=auto_generate_name
)
prompt_message_files = [file.prompt_message_file for file in files]
rest_tokens_for_context_and_memory = cls.get_validate_rest_tokens(
mode=app.mode,
model_instance=final_model_instance,
app_model_config=app_model_config,
query=query,
inputs=inputs
inputs=inputs,
files=prompt_message_files
)
# init orchestrator rule parser
@@ -95,6 +102,7 @@ class Completion:
app_model_config=app_model_config,
query=query,
inputs=inputs,
files=prompt_message_files,
agent_execute_result=None,
conversation_message_task=conversation_message_task,
memory=memory,
@@ -146,6 +154,7 @@ class Completion:
app_model_config=app_model_config,
query=query,
inputs=inputs,
files=prompt_message_files,
agent_execute_result=agent_execute_result,
conversation_message_task=conversation_message_task,
memory=memory,
@@ -205,21 +214,20 @@ class Completion:
results = {}
with ThreadPoolExecutor() as executor:
futures = {}
for tools in grouped_tools.values():
# Only query the first tool in each group
first_tool = tools[0]
for tool in external_data_tools:
if not tool.get("enabled"):
continue
future = executor.submit(
cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, first_tool,
cls.query_external_data_tool, current_app._get_current_object(), tenant_id, app_id, tool,
inputs, query
)
for tool in tools:
futures[future] = tool
futures[future] = tool
for future in concurrent.futures.as_completed(futures):
tool_key, result = future.result()
if tool_key in grouped_tools:
for tool in grouped_tools[tool_key]:
results[tool['variable']] = result
tool_variable, result = future.result()
results[tool_variable] = result
inputs.update(results)
return inputs
@@ -246,9 +254,7 @@ class Completion:
query=query
)
tool_key = (external_data_tool.get("type"), json.dumps(external_data_tool.get("config"), sort_keys=True))
return tool_key, result
return tool_variable, result
@classmethod
def get_query_for_agent(cls, app: App, app_model_config: AppModelConfig, query: str, inputs: dict) -> str:
@@ -260,6 +266,7 @@ class Completion:
@classmethod
def run_final_llm(cls, model_instance: BaseLLM, mode: str, app_model_config: AppModelConfig, query: str,
inputs: dict,
files: List[PromptMessageFile],
agent_execute_result: Optional[AgentExecuteResult],
conversation_message_task: ConversationMessageTask,
memory: Optional[ReadOnlyConversationTokenDBBufferSharedMemory],
@@ -269,10 +276,11 @@ class Completion:
# get llm prompt
if app_model_config.prompt_type == 'simple':
prompt_messages, stop_words = prompt_transform.get_prompt(
mode=mode,
app_mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
files=files,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory,
model_instance=model_instance
@@ -283,6 +291,7 @@ class Completion:
app_model_config=app_model_config,
inputs=inputs,
query=query,
files=files,
context=agent_execute_result.output if agent_execute_result else None,
memory=memory,
model_instance=model_instance
@@ -340,7 +349,7 @@ class Completion:
@classmethod
def get_validate_rest_tokens(cls, mode: str, model_instance: BaseLLM, app_model_config: AppModelConfig,
query: str, inputs: dict) -> int:
query: str, inputs: dict, files: List[PromptMessageFile]) -> int:
model_limited_tokens = model_instance.model_rules.max_tokens.max
max_tokens = model_instance.get_model_kwargs().max_tokens
@@ -351,15 +360,15 @@ class Completion:
max_tokens = 0
prompt_transform = PromptTransform()
prompt_messages = []
# get prompt without memory and context
if app_model_config.prompt_type == 'simple':
prompt_messages, _ = prompt_transform.get_prompt(
mode=mode,
app_mode=mode,
pre_prompt=app_model_config.pre_prompt,
inputs=inputs,
query=query,
files=files,
context=None,
memory=None,
model_instance=model_instance
@@ -370,6 +379,7 @@ class Completion:
app_model_config=app_model_config,
inputs=inputs,
query=query,
files=files,
context=None,
memory=None,
model_instance=model_instance

View File

@@ -6,8 +6,9 @@ from core.callback_handler.entity.agent_loop import AgentLoop
from core.callback_handler.entity.dataset_query import DatasetQueryObj
from core.callback_handler.entity.llm_message import LLMMessage
from core.callback_handler.entity.chain_result import ChainResult
from core.file.file_obj import FileObj
from core.model_providers.model_factory import ModelFactory
from core.model_providers.models.entity.message import to_prompt_messages, MessageType
from core.model_providers.models.entity.message import to_prompt_messages, MessageType, PromptMessageFile
from core.model_providers.models.llm.base import BaseLLM
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import PromptTemplateParser
@@ -16,13 +17,14 @@ from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.dataset import DatasetQuery
from models.model import AppModelConfig, Conversation, Account, Message, EndUser, App, MessageAgentThought, \
MessageChain, DatasetRetrieverResource
MessageChain, DatasetRetrieverResource, MessageFile
class ConversationMessageTask:
def __init__(self, task_id: str, app: App, app_model_config: AppModelConfig, user: Account,
inputs: dict, query: str, streaming: bool, model_instance: BaseLLM,
conversation: Optional[Conversation] = None, is_override: bool = False):
inputs: dict, query: str, files: List[FileObj], streaming: bool,
model_instance: BaseLLM, conversation: Optional[Conversation] = None, is_override: bool = False,
auto_generate_name: bool = True):
self.start_at = time.perf_counter()
self.task_id = task_id
@@ -35,6 +37,7 @@ class ConversationMessageTask:
self.user = user
self.inputs = inputs
self.query = query
self.files = files
self.streaming = streaming
self.conversation = conversation
@@ -45,6 +48,7 @@ class ConversationMessageTask:
self.message = None
self.retriever_resource = None
self.auto_generate_name = auto_generate_name
self.model_dict = self.app_model_config.model_dict
self.provider_name = self.model_dict.get('provider')
@@ -100,7 +104,7 @@ class ConversationMessageTask:
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
mode=self.mode,
name='',
name='New conversation',
inputs=self.inputs,
introduction=introduction,
system_instruction=system_instruction,
@@ -142,6 +146,19 @@ class ConversationMessageTask:
db.session.add(self.message)
db.session.commit()
for file in self.files:
message_file = MessageFile(
message_id=self.message.id,
type=file.type.value,
transfer_method=file.transfer_method.value,
url=file.url,
upload_file_id=file.upload_file_id,
created_by_role=('account' if isinstance(self.user, Account) else 'end_user'),
created_by=self.user.id
)
db.session.add(message_file)
db.session.commit()
def append_message_text(self, text: str):
if text is not None:
self._pub_handler.pub_text(text)
@@ -176,7 +193,8 @@ class ConversationMessageTask:
message_was_created.send(
self.message,
conversation=self.conversation,
is_first_message=self.is_new_conversation
is_first_message=self.is_new_conversation,
auto_generate_name=self.auto_generate_name
)
if not by_stopped:

View File

79
api/core/file/file_obj.py Normal file
View File

@@ -0,0 +1,79 @@
import enum
from typing import Optional
from pydantic import BaseModel
from core.file.upload_file_parser import UploadFileParser
from core.model_providers.models.entity.message import PromptMessageFile, ImagePromptMessageFile
from extensions.ext_database import db
from models.model import UploadFile
class FileType(enum.Enum):
IMAGE = 'image'
@staticmethod
def value_of(value):
for member in FileType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileTransferMethod(enum.Enum):
REMOTE_URL = 'remote_url'
LOCAL_FILE = 'local_file'
@staticmethod
def value_of(value):
for member in FileTransferMethod:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class FileObj(BaseModel):
id: Optional[str]
tenant_id: str
type: FileType
transfer_method: FileTransferMethod
url: Optional[str]
upload_file_id: Optional[str]
file_config: dict
@property
def data(self) -> Optional[str]:
return self._get_data()
@property
def preview_url(self) -> Optional[str]:
return self._get_data(force_url=True)
@property
def prompt_message_file(self) -> PromptMessageFile:
if self.type == FileType.IMAGE:
image_config = self.file_config.get('image')
return ImagePromptMessageFile(
data=self.data,
detail=ImagePromptMessageFile.DETAIL.HIGH
if image_config.get("detail") == "high" else ImagePromptMessageFile.DETAIL.LOW
)
def _get_data(self, force_url: bool = False) -> Optional[str]:
if self.type == FileType.IMAGE:
if self.transfer_method == FileTransferMethod.REMOTE_URL:
return self.url
elif self.transfer_method == FileTransferMethod.LOCAL_FILE:
upload_file = (db.session.query(UploadFile)
.filter(
UploadFile.id == self.upload_file_id,
UploadFile.tenant_id == self.tenant_id
).first())
return UploadFileParser.get_image_data(
upload_file=upload_file,
force_url=force_url
)
return None

View File

@@ -0,0 +1,180 @@
from typing import List, Union, Optional, Dict
import requests
from core.file.file_obj import FileObj, FileType, FileTransferMethod
from core.file.upload_file_parser import SUPPORT_EXTENSIONS
from extensions.ext_database import db
from models.account import Account
from models.model import MessageFile, EndUser, AppModelConfig, UploadFile
class MessageFileParser:
def __init__(self, tenant_id: str, app_id: str) -> None:
self.tenant_id = tenant_id
self.app_id = app_id
def validate_and_transform_files_arg(self, files: List[dict], app_model_config: AppModelConfig,
user: Union[Account, EndUser]) -> List[FileObj]:
"""
validate and transform files arg
:param files:
:param app_model_config:
:param user:
:return:
"""
file_upload_config = app_model_config.file_upload_dict
for file in files:
if not isinstance(file, dict):
raise ValueError('Invalid file format, must be dict')
if not file.get('type'):
raise ValueError('Missing file type')
FileType.value_of(file.get('type'))
if not file.get('transfer_method'):
raise ValueError('Missing file transfer method')
FileTransferMethod.value_of(file.get('transfer_method'))
if file.get('transfer_method') == FileTransferMethod.REMOTE_URL.value:
if not file.get('url'):
raise ValueError('Missing file url')
if not file.get('url').startswith('http'):
raise ValueError('Invalid file url')
if file.get('transfer_method') == FileTransferMethod.LOCAL_FILE.value and not file.get('upload_file_id'):
raise ValueError('Missing file upload_file_id')
# transform files to file objs
type_file_objs = self._to_file_objs(files, file_upload_config)
# validate files
new_files = []
for file_type, file_objs in type_file_objs.items():
if file_type == FileType.IMAGE:
# parse and validate files
image_config = file_upload_config.get('image')
# check if image file feature is enabled
if not image_config['enabled']:
continue
# Validate number of files
if len(files) > image_config['number_limits']:
raise ValueError(f"Number of image files exceeds the maximum limit {image_config['number_limits']}")
for file_obj in file_objs:
# Validate transfer method
if file_obj.transfer_method.value not in image_config['transfer_methods']:
raise ValueError(f'Invalid transfer method: {file_obj.transfer_method.value}')
# Validate file type
if file_obj.type != FileType.IMAGE:
raise ValueError(f'Invalid file type: {file_obj.type}')
if file_obj.transfer_method == FileTransferMethod.REMOTE_URL:
# check remote url valid and is image
result, error = self._check_image_remote_url(file_obj.url)
if result is False:
raise ValueError(error)
elif file_obj.transfer_method == FileTransferMethod.LOCAL_FILE:
# get upload file from upload_file_id
upload_file = (db.session.query(UploadFile)
.filter(
UploadFile.id == file_obj.upload_file_id,
UploadFile.tenant_id == self.tenant_id,
UploadFile.created_by == user.id,
UploadFile.created_by_role == ('account' if isinstance(user, Account) else 'end_user'),
UploadFile.extension.in_(SUPPORT_EXTENSIONS)
).first())
# check upload file is belong to tenant and user
if not upload_file:
raise ValueError('Invalid upload file')
new_files.append(file_obj)
# return all file objs
return new_files
def transform_message_files(self, files: List[MessageFile], app_model_config: Optional[AppModelConfig]) -> List[FileObj]:
"""
transform message files
:param files:
:param app_model_config:
:return:
"""
# transform files to file objs
type_file_objs = self._to_file_objs(files, app_model_config.file_upload_dict)
# return all file objs
return [file_obj for file_objs in type_file_objs.values() for file_obj in file_objs]
def _to_file_objs(self, files: List[Union[Dict, MessageFile]],
file_upload_config: dict) -> Dict[FileType, List[FileObj]]:
"""
transform files to file objs
:param files:
:param file_upload_config:
:return:
"""
type_file_objs: Dict[FileType, List[FileObj]] = {
# Currently only support image
FileType.IMAGE: []
}
if not files:
return type_file_objs
# group by file type and convert file args or message files to FileObj
for file in files:
file_obj = self._to_file_obj(file, file_upload_config)
if file_obj.type not in type_file_objs:
continue
type_file_objs[file_obj.type].append(file_obj)
return type_file_objs
def _to_file_obj(self, file: Union[dict, MessageFile], file_upload_config: dict) -> FileObj:
"""
transform file to file obj
:param file:
:return:
"""
if isinstance(file, dict):
transfer_method = FileTransferMethod.value_of(file.get('transfer_method'))
return FileObj(
tenant_id=self.tenant_id,
type=FileType.value_of(file.get('type')),
transfer_method=transfer_method,
url=file.get('url') if transfer_method == FileTransferMethod.REMOTE_URL else None,
upload_file_id=file.get('upload_file_id') if transfer_method == FileTransferMethod.LOCAL_FILE else None,
file_config=file_upload_config
)
else:
return FileObj(
id=file.id,
tenant_id=self.tenant_id,
type=FileType.value_of(file.type),
transfer_method=FileTransferMethod.value_of(file.transfer_method),
url=file.url,
upload_file_id=file.upload_file_id or None,
file_config=file_upload_config
)
def _check_image_remote_url(self, url):
try:
headers = {
"User-Agent": "Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124 Safari/537.36"
}
response = requests.head(url, headers=headers, allow_redirects=True)
if response.status_code == 200:
return True, ""
else:
return False, "URL does not exist."
except requests.RequestException as e:
return False, f"Error checking URL: {e}"

View File

@@ -0,0 +1,79 @@
import base64
import hashlib
import hmac
import logging
import os
import time
from typing import Optional
from flask import current_app
from extensions.ext_storage import storage
SUPPORT_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif']
class UploadFileParser:
@classmethod
def get_image_data(cls, upload_file, force_url: bool = False) -> Optional[str]:
if not upload_file:
return None
if upload_file.extension not in SUPPORT_EXTENSIONS:
return None
if current_app.config['MULTIMODAL_SEND_IMAGE_FORMAT'] == 'url' or force_url:
return cls.get_signed_temp_image_url(upload_file)
else:
# get image file base64
try:
data = storage.load(upload_file.key)
except FileNotFoundError:
logging.error(f'File not found: {upload_file.key}')
return None
encoded_string = base64.b64encode(data).decode('utf-8')
return f'data:{upload_file.mime_type};base64,{encoded_string}'
@classmethod
def get_signed_temp_image_url(cls, upload_file) -> str:
"""
get signed url from upload file
:param upload_file: UploadFile object
:return:
"""
base_url = current_app.config.get('FILES_URL')
image_preview_url = f'{base_url}/files/{upload_file.id}/image-preview'
timestamp = str(int(time.time()))
nonce = os.urandom(16).hex()
data_to_sign = f"image-preview|{upload_file.id}|{timestamp}|{nonce}"
secret_key = current_app.config['SECRET_KEY'].encode()
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
encoded_sign = base64.urlsafe_b64encode(sign).decode()
return f"{image_preview_url}?timestamp={timestamp}&nonce={nonce}&sign={encoded_sign}"
@classmethod
def verify_image_file_signature(cls, upload_file_id: str, timestamp: str, nonce: str, sign: str) -> bool:
"""
verify signature
:param upload_file_id: file id
:param timestamp: timestamp
:param nonce: nonce
:param sign: signature
:return:
"""
data_to_sign = f"image-preview|{upload_file_id}|{timestamp}|{nonce}"
secret_key = current_app.config['SECRET_KEY'].encode()
recalculated_sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
recalculated_encoded_sign = base64.urlsafe_b64encode(recalculated_sign).decode()
# verify signature
if sign != recalculated_encoded_sign:
return False
current_time = int(time.time())
return current_time - int(timestamp) <= 300 # expired after 5 minutes

View File

@@ -16,7 +16,7 @@ from core.prompt.prompts import CONVERSATION_TITLE_PROMPT, GENERATOR_QA_PROMPT
class LLMGenerator:
@classmethod
def generate_conversation_name(cls, tenant_id: str, query, answer):
def generate_conversation_name(cls, tenant_id: str, query):
prompt = CONVERSATION_TITLE_PROMPT
if len(query) > 2000:
@@ -40,8 +40,12 @@ class LLMGenerator:
result_dict = json.loads(answer)
answer = result_dict['Your Output']
name = answer.strip()
return answer.strip()
if len(name) > 75:
name = name[:75] + '...'
return name
@classmethod
def generate_suggested_questions_after_answer(cls, tenant_id: str, histories: str):

View File

@@ -89,22 +89,6 @@ class IndexingRunner:
dataset_document.stopped_at = datetime.datetime.utcnow()
db.session.commit()
def format_split_text(self, text):
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
matches = re.findall(regex, text, re.MULTILINE)
result = []
for match in matches:
q = match[0]
a = match[1]
if q and a:
result.append({
"question": q,
"answer": re.sub(r"\n\s*", "\n", a.strip())
})
return result
def run_in_splitting_status(self, dataset_document: DatasetDocument):
"""Run the indexing process when the index_status is splitting."""
try:
@@ -647,21 +631,16 @@ class IndexingRunner:
return text
def format_split_text(self, text):
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)" # 匹配Q和A的正则表达式
matches = re.findall(regex, text, re.MULTILINE) # 获取所有匹配到的结果
regex = r"Q\d+:\s*(.*?)\s*A\d+:\s*([\s\S]*?)(?=Q|$)"
matches = re.findall(regex, text, re.MULTILINE)
result = [] # 存储最终的结果
for match in matches:
q = match[0]
a = match[1]
if q and a:
# 如果Q和A都存在就将其添加到结果中
result.append({
"question": q,
"answer": re.sub(r"\n\s*", "\n", a.strip())
})
return result
return [
{
"question": q,
"answer": re.sub(r"\n\s*", "\n", a.strip())
}
for q, a in matches if q and a
]
def _build_index(self, dataset: Dataset, dataset_document: DatasetDocument, documents: List[Document]) -> None:
"""

View File

@@ -3,6 +3,7 @@ from typing import Any, List, Dict
from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import get_buffer_string, BaseMessage
from core.file.message_file_parser import MessageFileParser
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_lc_messages
from core.model_providers.models.llm.base import BaseLLM
from extensions.ext_database import db
@@ -21,6 +22,8 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
@property
def buffer(self) -> List[BaseMessage]:
"""String buffer of memory."""
app_model = self.conversation.app
# fetch limited messages desc, and return reversed
messages = db.session.query(Message).filter(
Message.conversation_id == self.conversation.id,
@@ -28,10 +31,25 @@ class ReadOnlyConversationTokenDBBufferSharedMemory(BaseChatMemory):
).order_by(Message.created_at.desc()).limit(self.message_limit).all()
messages = list(reversed(messages))
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=self.conversation.app_id)
chat_messages: List[PromptMessage] = []
for message in messages:
chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
files = message.message_files
if files:
file_objs = message_file_parser.transform_message_files(
files, message.app_model_config
)
prompt_message_files = [file_obj.prompt_message_file for file_obj in file_objs]
chat_messages.append(PromptMessage(
content=message.query,
type=MessageType.USER,
files=prompt_message_files
))
else:
chat_messages.append(PromptMessage(content=message.query, type=MessageType.USER))
chat_messages.append(PromptMessage(content=message.answer, type=MessageType.ASSISTANT))
if not chat_messages:

View File

@@ -1,4 +1,5 @@
import enum
from typing import Any, cast, Union, List, Dict
from langchain.schema import HumanMessage, AIMessage, SystemMessage, BaseMessage, FunctionMessage
from pydantic import BaseModel
@@ -18,17 +19,53 @@ class MessageType(enum.Enum):
SYSTEM = 'system'
class PromptMessageFileType(enum.Enum):
IMAGE = 'image'
@staticmethod
def value_of(value):
for member in PromptMessageFileType:
if member.value == value:
return member
raise ValueError(f"No matching enum found for value '{value}'")
class PromptMessageFile(BaseModel):
type: PromptMessageFileType
data: Any
class ImagePromptMessageFile(PromptMessageFile):
class DETAIL(enum.Enum):
LOW = 'low'
HIGH = 'high'
type: PromptMessageFileType = PromptMessageFileType.IMAGE
detail: DETAIL = DETAIL.LOW
class PromptMessage(BaseModel):
type: MessageType = MessageType.USER
content: str = ''
files: list[PromptMessageFile] = []
function_call: dict = None
class LCHumanMessageWithFiles(HumanMessage):
# content: Union[str, List[Union[str, Dict]]]
content: str
files: list[PromptMessageFile]
def to_lc_messages(messages: list[PromptMessage]):
lc_messages = []
for message in messages:
if message.type == MessageType.USER:
lc_messages.append(HumanMessage(content=message.content))
if not message.files:
lc_messages.append(HumanMessage(content=message.content))
else:
lc_messages.append(LCHumanMessageWithFiles(content=message.content, files=message.files))
elif message.type == MessageType.ASSISTANT:
additional_kwargs = {}
if message.function_call:
@@ -44,7 +81,14 @@ def to_prompt_messages(messages: list[BaseMessage]):
prompt_messages = []
for message in messages:
if isinstance(message, HumanMessage):
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
if isinstance(message, LCHumanMessageWithFiles):
prompt_messages.append(PromptMessage(
content=message.content,
type=MessageType.USER,
files=message.files
))
else:
prompt_messages.append(PromptMessage(content=message.content, type=MessageType.USER))
elif isinstance(message, AIMessage):
message_kwargs = {
'content': message.content,

View File

@@ -8,3 +8,4 @@ class ProviderQuotaUnit(Enum):
class ModelFeature(Enum):
AGENT_THOUGHT = 'agent_thought'
VISION = 'vision'

View File

@@ -19,6 +19,13 @@ from core.model_providers.models.entity.model_params import ModelMode, ModelKwar
AZURE_OPENAI_API_VERSION = '2023-07-01-preview'
FUNCTION_CALL_MODELS = [
'gpt-4',
'gpt-4-32k',
'gpt-35-turbo',
'gpt-35-turbo-16k'
]
class AzureOpenAIModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
name: str,
@@ -157,3 +164,7 @@ class AzureOpenAIModel(BaseLLM):
@property
def support_streaming(self):
return True
@property
def support_function_call(self):
return self.base_model_name in FUNCTION_CALL_MODELS

View File

@@ -310,6 +310,10 @@ class BaseLLM(BaseProviderModel):
def support_streaming(self):
return False
@property
def support_function_call(self):
return False
def _get_prompt_from_messages(self, messages: List[PromptMessage],
model_mode: Optional[ModelMode] = None) -> Union[str , List[BaseMessage]]:
if not model_mode:

View File

@@ -1,11 +1,9 @@
import decimal
import logging
from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from openai import api_requestor
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
@@ -23,21 +21,36 @@ COMPLETION_MODELS = [
]
CHAT_MODELS = [
'gpt-4-1106-preview', # 128,000 tokens
'gpt-4-vision-preview', # 128,000 tokens
'gpt-4', # 8,192 tokens
'gpt-4-32k', # 32,768 tokens
'gpt-3.5-turbo-1106', # 16,384 tokens
'gpt-3.5-turbo', # 4,096 tokens
'gpt-3.5-turbo-16k', # 16,384 tokens
]
MODEL_MAX_TOKENS = {
'gpt-4-1106-preview': 128000,
'gpt-4-vision-preview': 128000,
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo-1106': 16384,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-instruct': 4097,
'gpt-3.5-turbo-16k': 16384,
'text-davinci-003': 4097,
}
FUNCTION_CALL_MODELS = [
'gpt-4-1106-preview',
'gpt-4',
'gpt-4-32k',
'gpt-3.5-turbo-1106',
'gpt-3.5-turbo',
'gpt-3.5-turbo-16k'
]
class OpenAIModel(BaseLLM):
def __init__(self, model_provider: BaseModelProvider,
@@ -50,7 +63,6 @@ class OpenAIModel(BaseLLM):
else:
self.model_mode = ModelMode.CHAT
# TODO load price config from configs(db)
super().__init__(model_provider, name, model_kwargs, streaming, callbacks)
def _init_client(self) -> Any:
@@ -100,7 +112,7 @@ class OpenAIModel(BaseLLM):
:param callbacks:
:return:
"""
if self.name == 'gpt-4' \
if self.name.startswith('gpt-4') \
and self.model_provider.provider.provider_type == ProviderType.SYSTEM.value \
and self.model_provider.provider.quota_type == ProviderQuotaType.TRIAL.value:
raise ModelCurrentlyNotSupportError("Dify Hosted OpenAI GPT-4 currently not support.")
@@ -175,6 +187,10 @@ class OpenAIModel(BaseLLM):
def support_streaming(self):
return True
@property
def support_function_call(self):
return self.name in FUNCTION_CALL_MODELS
# def is_model_valid_or_raise(self):
# """
# check is a valid model.

View File

@@ -41,9 +41,17 @@ class OpenAIProvider(BaseModelProvider):
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-3.5-turbo-1106',
'name': 'gpt-3.5-turbo-1106',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-3.5-turbo-instruct',
'name': 'GPT-3.5-Turbo-Instruct',
'name': 'gpt-3.5-turbo-instruct',
'mode': ModelMode.COMPLETION.value,
},
{
@@ -62,6 +70,22 @@ class OpenAIProvider(BaseModelProvider):
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-4-1106-preview',
'name': 'gpt-4-1106-preview',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.AGENT_THOUGHT.value
]
},
{
'id': 'gpt-4-vision-preview',
'name': 'gpt-4-vision-preview',
'mode': ModelMode.CHAT.value,
'features': [
ModelFeature.VISION.value
]
},
{
'id': 'gpt-4-32k',
'name': 'gpt-4-32k',
@@ -79,7 +103,7 @@ class OpenAIProvider(BaseModelProvider):
if self.provider.provider_type == ProviderType.SYSTEM.value \
and self.provider.quota_type == ProviderQuotaType.TRIAL.value:
models = [item for item in models if item['id'] not in ['gpt-4', 'gpt-4-32k']]
models = [item for item in models if not item['id'].startswith('gpt-4')]
return models
elif model_type == ModelType.EMBEDDINGS:
@@ -141,8 +165,11 @@ class OpenAIProvider(BaseModelProvider):
:return:
"""
model_max_tokens = {
'gpt-4-1106-preview': 128000,
'gpt-4-vision-preview': 128000,
'gpt-4': 8192,
'gpt-4-32k': 32768,
'gpt-3.5-turbo-1106': 16384,
'gpt-3.5-turbo': 4096,
'gpt-3.5-turbo-instruct': 4097,
'gpt-3.5-turbo-16k': 16384,

View File

@@ -24,12 +24,30 @@
"unit": "0.001",
"currency": "USD"
},
"gpt-4-1106-preview": {
"prompt": "0.01",
"completion": "0.03",
"unit": "0.001",
"currency": "USD"
},
"gpt-4-vision-preview": {
"prompt": "0.01",
"completion": "0.03",
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo": {
"prompt": "0.0015",
"completion": "0.002",
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo-1106": {
"prompt": "0.0010",
"completion": "0.002",
"unit": "0.001",
"currency": "USD"
},
"gpt-3.5-turbo-instruct": {
"prompt": "0.0015",
"completion": "0.002",

View File

@@ -73,8 +73,7 @@ class OrchestratorRuleParser:
planning_strategy = PlanningStrategy(agent_mode_config.get('strategy', 'router'))
# only OpenAI chat model (include Azure) support function call, use ReACT instead
if agent_model_instance.model_mode != ModelMode.CHAT \
or agent_model_instance.model_provider.provider_name not in ['openai', 'azure_openai']:
if not agent_model_instance.support_function_call:
if planning_strategy == PlanningStrategy.FUNCTION_CALL:
planning_strategy = PlanningStrategy.REACT
elif planning_strategy == PlanningStrategy.ROUTER:

View File

@@ -8,7 +8,7 @@ from langchain.memory.chat_memory import BaseChatMemory
from langchain.schema import BaseMessage
from core.model_providers.models.entity.model_params import ModelMode
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages
from core.model_providers.models.entity.message import PromptMessage, MessageType, to_prompt_messages, PromptMessageFile
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.llm.baichuan_model import BaichuanModel
from core.model_providers.models.llm.huggingface_hub_model import HuggingfaceHubModel
@@ -16,32 +16,57 @@ from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.models.llm.xinference_model import XinferenceModel
from core.prompt.prompt_builder import PromptBuilder
from core.prompt.prompt_template import PromptTemplateParser
from models.model import AppModelConfig
class AppMode(enum.Enum):
COMPLETION = 'completion'
CHAT = 'chat'
class PromptTransform:
def get_prompt(self, mode: str,
pre_prompt: str, inputs: dict,
def get_prompt(self,
app_mode: str,
pre_prompt: str,
inputs: dict,
query: str,
files: List[PromptMessageFile],
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> \
Tuple[List[PromptMessage], Optional[List[str]]]:
prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name(mode, model_instance))
prompt, stops = self._get_prompt_and_stop(prompt_rules, pre_prompt, inputs, query, context, memory, model_instance)
return [PromptMessage(content=prompt)], stops
def get_advanced_prompt(self,
app_mode: str,
app_model_config: str,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> List[PromptMessage]:
app_mode_enum = AppMode(app_mode)
model_mode_enum = model_instance.model_mode
prompt_rules = self._read_prompt_rules_from_file(self._prompt_file_name(app_mode, model_instance))
if app_mode_enum == AppMode.CHAT and model_mode_enum == ModelMode.CHAT:
stops = None
prompt_messages = self._get_simple_chat_app_chat_model_prompt_messages(prompt_rules, pre_prompt, inputs,
query, context, memory,
model_instance, files)
else:
stops = prompt_rules.get('stops')
if stops is not None and len(stops) == 0:
stops = None
prompt_messages = self._get_simple_others_prompt_messages(prompt_rules, pre_prompt, inputs, query, context,
memory,
model_instance, files)
return prompt_messages, stops
def get_advanced_prompt(self,
app_mode: str,
app_model_config: AppModelConfig,
inputs: dict,
query: str,
files: List[PromptMessageFile],
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> List[PromptMessage]:
model_mode = app_model_config.model_dict['mode']
app_mode_enum = AppMode(app_mode)
@@ -51,15 +76,20 @@ class PromptTransform:
if app_mode_enum == AppMode.CHAT:
if model_mode_enum == ModelMode.COMPLETION:
prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance)
prompt_messages = self._get_chat_app_completion_model_prompt_messages(app_model_config, inputs, query,
files, context, memory,
model_instance)
elif model_mode_enum == ModelMode.CHAT:
prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, context, memory, model_instance)
prompt_messages = self._get_chat_app_chat_model_prompt_messages(app_model_config, inputs, query, files,
context, memory, model_instance)
elif app_mode_enum == AppMode.COMPLETION:
if model_mode_enum == ModelMode.CHAT:
prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs, context)
prompt_messages = self._get_completion_app_chat_model_prompt_messages(app_model_config, inputs,
files, context)
elif model_mode_enum == ModelMode.COMPLETION:
prompt_messages = self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs, context)
prompt_messages = self._get_completion_app_completion_model_prompt_messages(app_model_config, inputs,
files, context)
return prompt_messages
def _get_history_messages_from_memory(self, memory: BaseChatMemory,
@@ -71,7 +101,7 @@ class PromptTransform:
return external_context[memory_key]
def _get_history_messages_list_from_memory(self, memory: BaseChatMemory,
max_token_limit: int) -> List[PromptMessage]:
max_token_limit: int) -> List[PromptMessage]:
"""Get memory messages."""
memory.max_token_limit = max_token_limit
memory.return_messages = True
@@ -79,7 +109,7 @@ class PromptTransform:
external_context = memory.load_memory_variables({})
memory.return_messages = False
return to_prompt_messages(external_context[memory_key])
def _prompt_file_name(self, mode: str, model_instance: BaseLLM) -> str:
# baichuan
if isinstance(model_instance, BaichuanModel):
@@ -94,13 +124,13 @@ class PromptTransform:
return 'common_completion'
else:
return 'common_chat'
def _prompt_file_name_for_baichuan(self, mode: str) -> str:
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
def _read_prompt_rules_from_file(self, prompt_name: str) -> dict:
# Get the absolute path of the subdirectory
prompt_path = os.path.join(
@@ -111,12 +141,53 @@ class PromptTransform:
# Open the JSON file and read its content
with open(json_file_path, 'r') as json_file:
return json.load(json_file)
def _get_prompt_and_stop(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> Tuple[str, Optional[list]]:
def _get_simple_chat_app_chat_model_prompt_messages(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM,
files: List[PromptMessageFile]) -> List[PromptMessage]:
prompt_messages = []
context_prompt_content = ''
if context and 'context_prompt' in prompt_rules:
prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
context_prompt_content = prompt_template.format(
{'context': context}
)
pre_prompt_content = ''
if pre_prompt:
prompt_template = PromptTemplateParser(template=pre_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
pre_prompt_content = prompt_template.format(
prompt_inputs
)
prompt = ''
for order in prompt_rules['system_prompt_orders']:
if order == 'context_prompt':
prompt += context_prompt_content
elif order == 'pre_prompt':
prompt += pre_prompt_content
prompt = re.sub(r'<\|.*?\|>', '', prompt)
prompt_messages.append(PromptMessage(type=MessageType.SYSTEM, content=prompt))
self._append_chat_histories(memory, prompt_messages, model_instance)
prompt_messages.append(PromptMessage(type=MessageType.USER, content=query, files=files))
return prompt_messages
def _get_simple_others_prompt_messages(self, prompt_rules: dict, pre_prompt: str, inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM,
files: List[PromptMessageFile]) -> List[PromptMessage]:
context_prompt_content = ''
if context and 'context_prompt' in prompt_rules:
prompt_template = PromptTemplateParser(template=prompt_rules['context_prompt'])
@@ -175,16 +246,12 @@ class PromptTransform:
prompt = re.sub(r'<\|.*?\|>', '', prompt)
stops = prompt_rules.get('stops')
if stops is not None and len(stops) == 0:
stops = None
return [PromptMessage(content=prompt, files=files)]
return prompt, stops
def _set_context_variable(self, context: str, prompt_template: PromptTemplateParser, prompt_inputs: dict) -> None:
if '#context#' in prompt_template.variable_keys:
if context:
prompt_inputs['#context#'] = context
prompt_inputs['#context#'] = context
else:
prompt_inputs['#context#'] = ''
@@ -195,17 +262,18 @@ class PromptTransform:
else:
prompt_inputs['#query#'] = ''
def _set_histories_variable(self, memory: BaseChatMemory, raw_prompt: str, conversation_histories_role: dict,
prompt_template: PromptTemplateParser, prompt_inputs: dict, model_instance: BaseLLM) -> None:
def _set_histories_variable(self, memory: BaseChatMemory, raw_prompt: str, conversation_histories_role: dict,
prompt_template: PromptTemplateParser, prompt_inputs: dict,
model_instance: BaseLLM) -> None:
if '#histories#' in prompt_template.variable_keys:
if memory:
tmp_human_message = PromptBuilder.to_human_message(
prompt_content=raw_prompt,
inputs={ '#histories#': '', **prompt_inputs }
inputs={'#histories#': '', **prompt_inputs}
)
rest_tokens = self._calculate_rest_token(tmp_human_message, model_instance)
memory.human_prefix = conversation_histories_role['user_prefix']
memory.ai_prefix = conversation_histories_role['assistant_prefix']
histories = self._get_history_messages_from_memory(memory, rest_tokens)
@@ -213,7 +281,8 @@ class PromptTransform:
else:
prompt_inputs['#histories#'] = ''
def _append_chat_histories(self, memory: BaseChatMemory, prompt_messages: list[PromptMessage], model_instance: BaseLLM) -> None:
def _append_chat_histories(self, memory: BaseChatMemory, prompt_messages: list[PromptMessage],
model_instance: BaseLLM) -> None:
if memory:
rest_tokens = self._calculate_rest_token(prompt_messages, model_instance)
@@ -242,19 +311,19 @@ class PromptTransform:
return prompt
def _get_chat_app_completion_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> List[PromptMessage]:
app_model_config: AppModelConfig,
inputs: dict,
query: str,
files: List[PromptMessageFile],
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> List[PromptMessage]:
raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
conversation_histories_role = app_model_config.completion_prompt_config_dict['conversation_histories_role']
prompt_messages = []
prompt = ''
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
@@ -262,28 +331,29 @@ class PromptTransform:
self._set_query_variable(query, prompt_template, prompt_inputs)
self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs, model_instance)
self._set_histories_variable(memory, raw_prompt, conversation_histories_role, prompt_template, prompt_inputs,
model_instance)
prompt = self._format_prompt(prompt_template, prompt_inputs)
prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))
prompt_messages.append(PromptMessage(type=MessageType.USER, content=prompt, files=files))
return prompt_messages
def _get_chat_app_chat_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
query: str,
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> List[PromptMessage]:
app_model_config: AppModelConfig,
inputs: dict,
query: str,
files: List[PromptMessageFile],
context: Optional[str],
memory: Optional[BaseChatMemory],
model_instance: BaseLLM) -> List[PromptMessage]:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
prompt_messages = []
for prompt_item in raw_prompt_list:
raw_prompt = prompt_item['text']
prompt = ''
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
@@ -292,23 +362,23 @@ class PromptTransform:
prompt = self._format_prompt(prompt_template, prompt_inputs)
prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
prompt_messages.append(PromptMessage(type=MessageType(prompt_item['role']), content=prompt))
self._append_chat_histories(memory, prompt_messages, model_instance)
prompt_messages.append(PromptMessage(type = MessageType.USER ,content=query))
prompt_messages.append(PromptMessage(type=MessageType.USER, content=query, files=files))
return prompt_messages
def _get_completion_app_completion_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
context: Optional[str]) -> List[PromptMessage]:
app_model_config: AppModelConfig,
inputs: dict,
files: List[PromptMessageFile],
context: Optional[str]) -> List[PromptMessage]:
raw_prompt = app_model_config.completion_prompt_config_dict['prompt']['text']
prompt_messages = []
prompt = ''
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
@@ -316,21 +386,21 @@ class PromptTransform:
prompt = self._format_prompt(prompt_template, prompt_inputs)
prompt_messages.append(PromptMessage(type = MessageType(MessageType.USER) ,content=prompt))
prompt_messages.append(PromptMessage(type=MessageType(MessageType.USER), content=prompt, files=files))
return prompt_messages
def _get_completion_app_chat_model_prompt_messages(self,
app_model_config: str,
inputs: dict,
context: Optional[str]) -> List[PromptMessage]:
app_model_config: AppModelConfig,
inputs: dict,
files: List[PromptMessageFile],
context: Optional[str]) -> List[PromptMessage]:
raw_prompt_list = app_model_config.chat_prompt_config_dict['prompt']
prompt_messages = []
for prompt_item in raw_prompt_list:
raw_prompt = prompt_item['text']
prompt = ''
prompt_template = PromptTemplateParser(template=raw_prompt)
prompt_inputs = {k: inputs[k] for k in prompt_template.variable_keys if k in inputs}
@@ -339,6 +409,11 @@ class PromptTransform:
prompt = self._format_prompt(prompt_template, prompt_inputs)
prompt_messages.append(PromptMessage(type = MessageType(prompt_item['role']) ,content=prompt))
return prompt_messages
prompt_messages.append(PromptMessage(type=MessageType(prompt_item['role']), content=prompt))
for prompt_message in prompt_messages[::-1]:
if prompt_message.type == MessageType.USER:
prompt_message.files = files
break
return prompt_messages

View File

@@ -1,10 +1,13 @@
import os
from typing import Dict, Any, Optional, Union, Tuple
from typing import Dict, Any, Optional, Union, Tuple, List, cast
from langchain.chat_models import ChatOpenAI
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage, FunctionMessage
from pydantic import root_validator
from core.model_providers.models.entity.message import LCHumanMessageWithFiles, PromptMessageFileType, ImagePromptMessageFile
class EnhanceChatOpenAI(ChatOpenAI):
request_timeout: Optional[Union[float, Tuple[float, float]]] = (5.0, 300.0)
@@ -48,3 +51,102 @@ class EnhanceChatOpenAI(ChatOpenAI):
"api_key": self.openai_api_key,
"organization": self.openai_organization if self.openai_organization else None,
}
def _create_message_dicts(
self, messages: List[BaseMessage], stop: Optional[List[str]]
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = self._client_params
if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
message_dicts = [self._convert_message_to_dict(m) for m in messages]
return message_dicts, params
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Calculate num tokens for gpt-3.5-turbo and gpt-4 with tiktoken package.
Official documentation: https://github.com/openai/openai-cookbook/blob/
main/examples/How_to_format_inputs_to_ChatGPT_models.ipynb"""
model, encoding = self._get_encoding_model()
if model.startswith("gpt-3.5-turbo-0301"):
# every message follows <im_start>{role/name}\n{content}<im_end>\n
tokens_per_message = 4
# if there's a name, the role is omitted
tokens_per_name = -1
elif model.startswith("gpt-3.5-turbo") or model.startswith("gpt-4"):
tokens_per_message = 3
tokens_per_name = 1
else:
raise NotImplementedError(
f"get_num_tokens_from_messages() is not presently implemented "
f"for model {model}."
"See https://github.com/openai/openai-python/blob/main/chatml.md for "
"information on how messages are converted to tokens."
)
num_tokens = 0
messages_dict = [self._convert_message_to_dict(m) for m in messages]
for message in messages_dict:
num_tokens += tokens_per_message
for key, value in message.items():
# Cast str(value) in case the message value is not a string
# This occurs with function messages
# TODO: The current token calculation method for the image type is not implemented,
# which need to download the image and then get the resolution for calculation,
# and will increase the request delay
if isinstance(value, list):
text = ''
for item in value:
if isinstance(item, dict) and item['type'] == 'text':
text += item['text']
value = text
num_tokens += len(encoding.encode(str(value)))
if key == "name":
num_tokens += tokens_per_name
# every reply is primed with <im_start>assistant
num_tokens += 3
return num_tokens
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, LCHumanMessageWithFiles):
content = [
{
"type": "text",
"text": message.content
}
]
for file in message.files:
if file.type == PromptMessageFileType.IMAGE:
file = cast(ImagePromptMessageFile, file)
content.append({
"type": "image_url",
"image_url": {
"url": file.data,
"detail": file.detail.value
}
})
message_dict = {"role": "user", "content": content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
if "function_call" in message.additional_kwargs:
message_dict["function_call"] = message.additional_kwargs["function_call"]
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {
"role": "function",
"content": message.content,
"name": message.name,
}
else:
raise ValueError(f"Got unknown type {message}")
if "name" in message.additional_kwargs:
message_dict["name"] = message.additional_kwargs["name"]
return message_dict

View File

@@ -17,8 +17,12 @@ import websocket
class SparkLLMClient:
def __init__(self, model_name: str, app_id: str, api_key: str, api_secret: str, api_domain: Optional[str] = None):
domain = 'spark-api.xf-yun.com' if not api_domain else api_domain
domain = 'spark-api.xf-yun.com'
endpoint = 'chat'
if api_domain:
domain = api_domain
if model_name == 'spark-v3':
endpoint = 'multimodal'
model_api_configs = {
'spark': {
@@ -38,7 +42,7 @@ class SparkLLMClient:
api_version = model_api_configs[model_name]['version']
self.chat_domain = model_api_configs[model_name]['chat_domain']
self.api_base = f"wss://{domain}/{api_version}/chat"
self.api_base = f"wss://{domain}/{api_version}/{endpoint}"
self.app_id = app_id
self.ws_url = self.create_url(
urlparse(self.api_base).netloc,

View File

@@ -1,5 +1,3 @@
import logging
from core.generator.llm_generator import LLMGenerator
from events.message_event import message_was_created
from extensions.ext_database import db
@@ -10,8 +8,9 @@ def handle(sender, **kwargs):
message = sender
conversation = kwargs.get('conversation')
is_first_message = kwargs.get('is_first_message')
auto_generate_name = kwargs.get('auto_generate_name', True)
if is_first_message:
if auto_generate_name and is_first_message:
if conversation.mode == 'chat':
app_model = conversation.app
if not app_model:
@@ -19,14 +18,9 @@ def handle(sender, **kwargs):
# generate conversation name
try:
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query, message.answer)
if len(name) > 75:
name = name[:75] + '...'
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query)
conversation.name = name
except:
conversation.name = 'New conversation'
pass
db.session.add(conversation)
db.session.commit()

View File

@@ -1,6 +1,7 @@
import os
import shutil
from contextlib import closing
from typing import Union, Generator
import boto3
from botocore.exceptions import ClientError
@@ -45,7 +46,13 @@ class Storage:
with open(os.path.join(os.getcwd(), filename), "wb") as f:
f.write(data)
def load(self, filename):
def load(self, filename: str, stream: bool = False) -> Union[bytes, Generator]:
if stream:
return self.load_stream(filename)
else:
return self.load_once(filename)
def load_once(self, filename: str) -> bytes:
if self.storage_type == 's3':
try:
with closing(self.client) as client:
@@ -69,6 +76,34 @@ class Storage:
return data
def load_stream(self, filename: str) -> Generator:
def generate(filename: str = filename) -> Generator:
if self.storage_type == 's3':
try:
with closing(self.client) as client:
response = client.get_object(Bucket=self.bucket_name, Key=filename)
for chunk in response['Body'].iter_chunks():
yield chunk
except ClientError as ex:
if ex.response['Error']['Code'] == 'NoSuchKey':
raise FileNotFoundError("File not found")
else:
raise
else:
if not self.folder or self.folder.endswith('/'):
filename = self.folder + filename
else:
filename = self.folder + '/' + filename
if not os.path.exists(filename):
raise FileNotFoundError("File not found")
with open(filename, "rb") as f:
while chunk := f.read(4096): # Read in chunks of 4KB
yield chunk
return generate()
def download(self, filename, target_filepath):
if self.storage_type == 's3':
with closing(self.client) as client:

View File

@@ -5,7 +5,13 @@ from libs.helper import TimestampField
class HiddenAPIKey(fields.Raw):
def output(self, key, obj):
return obj.api_key[:3] + '***' + obj.api_key[-3:]
api_key = obj.api_key
# If the length of the api_key is less than 8 characters, show the first and last characters
if len(api_key) <= 8:
return api_key[0] + '******' + api_key[-1]
# If the api_key is greater than 8 characters, show the first three and the last three characters
else:
return api_key[:3] + '******' + api_key[-3:]
api_based_extension_fields = {

View File

@@ -32,7 +32,8 @@ model_config_fields = {
'prompt_type': fields.String,
'chat_prompt_config': fields.Raw(attribute='chat_prompt_config_dict'),
'completion_prompt_config': fields.Raw(attribute='completion_prompt_config_dict'),
'dataset_configs': fields.Raw(attribute='dataset_configs_dict')
'dataset_configs': fields.Raw(attribute='dataset_configs_dict'),
'file_upload': fields.Raw(attribute='file_upload_dict'),
}
app_detail_fields = {
@@ -140,4 +141,4 @@ app_site_fields = {
'privacy_policy': fields.String,
'customize_token_strategy': fields.String,
'prompt_public': fields.Boolean
}
}

View File

@@ -28,6 +28,12 @@ annotation_fields = {
'created_at': TimestampField
}
message_file_fields = {
'id': fields.String,
'type': fields.String,
'url': fields.String,
}
message_detail_fields = {
'id': fields.String,
'conversation_id': fields.String,
@@ -43,7 +49,8 @@ message_detail_fields = {
'from_account_id': fields.String,
'feedbacks': fields.List(fields.Nested(feedback_fields)),
'annotation': fields.Nested(annotation_fields, allow_null=True),
'created_at': TimestampField
'created_at': TimestampField,
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
}
feedback_stat_fields = {
@@ -111,11 +118,6 @@ conversation_message_detail_fields = {
'message': fields.Nested(message_detail_fields, attribute='first_message'),
}
simple_model_config_fields = {
'model': fields.Raw(attribute='model_dict'),
'pre_prompt': fields.String,
}
conversation_with_summary_fields = {
'id': fields.String,
'status': fields.String,
@@ -180,4 +182,4 @@ conversation_with_model_config_infinite_scroll_pagination_fields = {
'limit': fields.Integer,
'has_more': fields.Boolean,
'data': fields.List(fields.Nested(conversation_with_model_config_fields))
}
}

View File

@@ -4,7 +4,8 @@ from libs.helper import TimestampField
upload_config_fields = {
'file_size_limit': fields.Integer,
'batch_count_limit': fields.Integer
'batch_count_limit': fields.Integer,
'image_file_size_limit': fields.Integer,
}
file_fields = {

View File

@@ -1,6 +1,7 @@
from flask_restful import fields
from libs.helper import TimestampField
from fields.conversation_fields import message_file_fields
feedback_fields = {
'rating': fields.String
@@ -31,6 +32,7 @@ message_fields = {
'inputs': fields.Raw,
'query': fields.String,
'answer': fields.String,
'message_files': fields.List(fields.Nested(message_file_fields), attribute='files'),
'feedback': fields.Nested(feedback_fields, attribute='user_feedback', allow_null=True),
'retriever_resources': fields.List(fields.Nested(retriever_resource_fields)),
'created_at': TimestampField

View File

@@ -0,0 +1,59 @@
"""add gpt4v supports
Revision ID: 8fe468ba0ca5
Revises: a9836e3baeee
Create Date: 2023-11-09 11:39:00.006432
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '8fe468ba0ca5'
down_revision = 'a9836e3baeee'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('message_files',
sa.Column('id', postgresql.UUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('message_id', postgresql.UUID(), nullable=False),
sa.Column('type', sa.String(length=255), nullable=False),
sa.Column('transfer_method', sa.String(length=255), nullable=False),
sa.Column('url', sa.Text(), nullable=True),
sa.Column('upload_file_id', postgresql.UUID(), nullable=True),
sa.Column('created_by_role', sa.String(length=255), nullable=False),
sa.Column('created_by', postgresql.UUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP(0)'), nullable=False),
sa.PrimaryKeyConstraint('id', name='message_file_pkey')
)
with op.batch_alter_table('message_files', schema=None) as batch_op:
batch_op.create_index('message_file_created_by_idx', ['created_by'], unique=False)
batch_op.create_index('message_file_message_idx', ['message_id'], unique=False)
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.add_column(sa.Column('file_upload', sa.Text(), nullable=True))
with op.batch_alter_table('upload_files', schema=None) as batch_op:
batch_op.add_column(sa.Column('created_by_role', sa.String(length=255), server_default=sa.text("'account'::character varying"), nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
with op.batch_alter_table('upload_files', schema=None) as batch_op:
batch_op.drop_column('created_by_role')
with op.batch_alter_table('app_model_configs', schema=None) as batch_op:
batch_op.drop_column('file_upload')
with op.batch_alter_table('message_files', schema=None) as batch_op:
batch_op.drop_index('message_file_message_idx')
batch_op.drop_index('message_file_created_by_idx')
op.drop_table('message_files')
# ### end Alembic commands ###

View File

@@ -1,10 +1,10 @@
import json
from json import JSONDecodeError
from flask import current_app, request
from flask_login import UserMixin
from sqlalchemy.dialects.postgresql import UUID
from core.file.upload_file_parser import UploadFileParser
from libs.helper import generate_string
from extensions.ext_database import db
from .account import Account, Tenant
@@ -98,6 +98,7 @@ class AppModelConfig(db.Model):
completion_prompt_config = db.Column(db.Text)
dataset_configs = db.Column(db.Text)
external_data_tools = db.Column(db.Text)
file_upload = db.Column(db.Text)
@property
def app(self):
@@ -161,6 +162,10 @@ class AppModelConfig(db.Model):
def dataset_configs_dict(self) -> dict:
return json.loads(self.dataset_configs) if self.dataset_configs else {"top_k": 2, "score_threshold": {"enable": False}}
@property
def file_upload_dict(self) -> dict:
return json.loads(self.file_upload) if self.file_upload else {"image": {"enabled": False, "number_limits": 3, "detail": "high", "transfer_methods": ["remote_url", "local_file"]}}
def to_dict(self) -> dict:
return {
"provider": "",
@@ -182,7 +187,8 @@ class AppModelConfig(db.Model):
"prompt_type": self.prompt_type,
"chat_prompt_config": self.chat_prompt_config_dict,
"completion_prompt_config": self.completion_prompt_config_dict,
"dataset_configs": self.dataset_configs_dict
"dataset_configs": self.dataset_configs_dict,
"file_upload": self.file_upload_dict
}
def from_model_config_dict(self, model_config: dict):
@@ -197,7 +203,8 @@ class AppModelConfig(db.Model):
self.more_like_this = json.dumps(model_config['more_like_this'])
self.sensitive_word_avoidance = json.dumps(model_config['sensitive_word_avoidance']) \
if model_config.get('sensitive_word_avoidance') else None
self.external_data_tools = json.dumps(model_config['external_data_tools'])
self.external_data_tools = json.dumps(model_config['external_data_tools']) \
if model_config.get('external_data_tools') else None
self.model = json.dumps(model_config['model'])
self.user_input_form = json.dumps(model_config['user_input_form'])
self.dataset_query_variable = model_config.get('dataset_query_variable')
@@ -212,6 +219,8 @@ class AppModelConfig(db.Model):
if model_config.get('completion_prompt_config') else None
self.dataset_configs = json.dumps(model_config.get('dataset_configs')) \
if model_config.get('dataset_configs') else None
self.file_upload = json.dumps(model_config.get('file_upload')) \
if model_config.get('file_upload') else None
return self
def copy(self):
@@ -237,7 +246,8 @@ class AppModelConfig(db.Model):
prompt_type=self.prompt_type,
chat_prompt_config=self.chat_prompt_config,
completion_prompt_config=self.completion_prompt_config,
dataset_configs=self.dataset_configs
dataset_configs=self.dataset_configs,
file_upload=self.file_upload
)
return new_app_model_config
@@ -511,6 +521,37 @@ class Message(db.Model):
return db.session.query(DatasetRetrieverResource).filter(DatasetRetrieverResource.message_id == self.id) \
.order_by(DatasetRetrieverResource.position.asc()).all()
@property
def message_files(self):
return db.session.query(MessageFile).filter(MessageFile.message_id == self.id).all()
@property
def files(self):
message_files = self.message_files
files = []
for message_file in message_files:
url = message_file.url
if message_file.type == 'image':
if message_file.transfer_method == 'local_file':
upload_file = (db.session.query(UploadFile)
.filter(
UploadFile.id == message_file.upload_file_id
).first())
url = UploadFileParser.get_image_data(
upload_file=upload_file,
force_url=True
)
files.append({
'id': message_file.id,
'type': message_file.type,
'url': url
})
return files
class MessageFeedback(db.Model):
__tablename__ = 'message_feedbacks'
@@ -539,6 +580,25 @@ class MessageFeedback(db.Model):
return account
class MessageFile(db.Model):
__tablename__ = 'message_files'
__table_args__ = (
db.PrimaryKeyConstraint('id', name='message_file_pkey'),
db.Index('message_file_message_idx', 'message_id'),
db.Index('message_file_created_by_idx', 'created_by')
)
id = db.Column(UUID, server_default=db.text('uuid_generate_v4()'))
message_id = db.Column(UUID, nullable=False)
type = db.Column(db.String(255), nullable=False)
transfer_method = db.Column(db.String(255), nullable=False)
url = db.Column(db.Text, nullable=True)
upload_file_id = db.Column(UUID, nullable=True)
created_by_role = db.Column(db.String(255), nullable=False)
created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
class MessageAnnotation(db.Model):
__tablename__ = 'message_annotations'
__table_args__ = (
@@ -682,6 +742,7 @@ class UploadFile(db.Model):
size = db.Column(db.Integer, nullable=False)
extension = db.Column(db.String(255), nullable=False)
mime_type = db.Column(db.String(255), nullable=True)
created_by_role = db.Column(db.String(255), nullable=False, server_default=db.text("'account'::character varying"))
created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.text('CURRENT_TIMESTAMP(0)'))
used = db.Column(db.Boolean, nullable=False, server_default=db.text('false'))
@@ -782,4 +843,3 @@ class DatasetRetrieverResource(db.Model):
retriever_from = db.Column(db.Text, nullable=False)
created_by = db.Column(UUID, nullable=False)
created_at = db.Column(db.DateTime, nullable=False, server_default=db.func.current_timestamp())

View File

@@ -315,6 +315,9 @@ class AppModelConfigService:
# moderation validation
cls.is_moderation_valid(tenant_id, config)
# file upload validation
cls.is_file_upload_valid(config)
# Filter out extra parameters
filtered_config = {
"opening_statement": config["opening_statement"],
@@ -338,7 +341,8 @@ class AppModelConfigService:
"prompt_type": config["prompt_type"],
"chat_prompt_config": config["chat_prompt_config"],
"completion_prompt_config": config["completion_prompt_config"],
"dataset_configs": config["dataset_configs"]
"dataset_configs": config["dataset_configs"],
"file_upload": config["file_upload"]
}
return filtered_config
@@ -371,6 +375,34 @@ class AppModelConfigService:
config=config
)
@classmethod
def is_file_upload_valid(cls, config: dict):
if 'file_upload' not in config or not config["file_upload"]:
config["file_upload"] = {}
if not isinstance(config["file_upload"], dict):
raise ValueError("file_upload must be of dict type")
# check image config
if 'image' not in config["file_upload"] or not config["file_upload"]["image"]:
config["file_upload"]["image"] = {"enabled": False}
if config['file_upload']['image']['enabled']:
number_limits = config['file_upload']['image']['number_limits']
if number_limits < 1 or number_limits > 6:
raise ValueError("number_limits must be in [1, 6]")
detail = config['file_upload']['image']['detail']
if detail not in ['high', 'low']:
raise ValueError("detail must be in ['high', 'low']")
transfer_methods = config['file_upload']['image']['transfer_methods']
if not isinstance(transfer_methods, list):
raise ValueError("transfer_methods must be of list type")
for method in transfer_methods:
if method not in ['remote_url', 'local_file']:
raise ValueError("transfer_methods must be in ['remote_url', 'local_file']")
@classmethod
def is_external_data_tools_valid(cls, tenant_id: str, config: dict):
if 'external_data_tools' not in config or not config["external_data_tools"]:

View File

@@ -3,7 +3,7 @@ import logging
import threading
import time
import uuid
from typing import Generator, Union, Any, Optional
from typing import Generator, Union, Any, Optional, List
from flask import current_app, Flask
from redis.client import PubSub
@@ -12,9 +12,11 @@ from sqlalchemy import and_
from core.completion import Completion
from core.conversation_message_task import PubHandler, ConversationTaskStoppedException, \
ConversationTaskInterruptException
from core.file.message_file_parser import MessageFileParser
from core.model_providers.error import LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError, \
LLMRateLimitError, \
LLMAuthorizationError, ProviderTokenNotInitError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.model_providers.models.entity.message import PromptMessageFile
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.model import Conversation, AppModelConfig, App, Account, EndUser, Message
@@ -35,6 +37,9 @@ class CompletionService:
# is streaming mode
inputs = args['inputs']
query = args['query']
files = args['files'] if 'files' in args and args['files'] else []
auto_generate_name = args['auto_generate_name'] \
if 'auto_generate_name' in args else True
if app_model.mode != 'completion' and not query:
raise ValueError('query is required')
@@ -132,6 +137,14 @@ class CompletionService:
# clean input by app_model_config form rules
inputs = cls.get_cleaned_inputs(inputs, app_model_config)
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_objs = message_file_parser.validate_and_transform_files_arg(
files,
app_model_config,
user
)
generate_task_id = str(uuid.uuid4())
pubsub = redis_client.pubsub()
@@ -146,17 +159,20 @@ class CompletionService:
'app_model_config': app_model_config.copy(),
'query': query,
'inputs': inputs,
'files': file_objs,
'detached_user': user,
'detached_conversation': conversation,
'streaming': streaming,
'is_model_config_override': is_model_config_override,
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev'
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev',
'auto_generate_name': auto_generate_name
})
generate_worker_thread.start()
# wait for 10 minutes to close the thread
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user,
generate_task_id)
return cls.compact_response(pubsub, streaming)
@@ -172,10 +188,12 @@ class CompletionService:
return user
@classmethod
def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, app_model_config: AppModelConfig,
query: str, inputs: dict, detached_user: Union[Account, EndUser],
def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
app_model_config: AppModelConfig,
query: str, inputs: dict, files: List[PromptMessageFile],
detached_user: Union[Account, EndUser],
detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
retriever_from: str = 'dev'):
retriever_from: str = 'dev', auto_generate_name: bool = True):
with flask_app.app_context():
# fixed the state of the model object when it detached from the original session
user = db.session.merge(detached_user)
@@ -195,10 +213,12 @@ class CompletionService:
query=query,
inputs=inputs,
user=user,
files=files,
conversation=conversation,
streaming=streaming,
is_override=is_model_config_override,
retriever_from=retriever_from
retriever_from=retriever_from,
auto_generate_name=auto_generate_name
)
except (ConversationTaskInterruptException, ConversationTaskStoppedException):
pass
@@ -215,7 +235,8 @@ class CompletionService:
db.session.commit()
@classmethod
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user, generate_task_id) -> threading.Thread:
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user,
generate_task_id) -> threading.Thread:
# wait for 10 minutes to close the thread
timeout = 600
@@ -274,6 +295,12 @@ class CompletionService:
model_dict['completion_params'] = completion_params
app_model_config.model = json.dumps(model_dict)
# parse files
message_file_parser = MessageFileParser(tenant_id=app_model.tenant_id, app_id=app_model.id)
file_objs = message_file_parser.transform_message_files(
message.files, app_model_config
)
generate_task_id = str(uuid.uuid4())
pubsub = redis_client.pubsub()
@@ -288,11 +315,13 @@ class CompletionService:
'app_model_config': app_model_config.copy(),
'query': message.query,
'inputs': message.inputs,
'files': file_objs,
'detached_user': user,
'detached_conversation': None,
'streaming': streaming,
'is_model_config_override': True,
'retriever_from': retriever_from
'retriever_from': retriever_from,
'auto_generate_name': False
})
generate_worker_thread.start()
@@ -388,7 +417,8 @@ class CompletionService:
if event == 'message':
yield "data: " + json.dumps(cls.get_message_response_data(result.get('data'))) + "\n\n"
elif event == 'message_replace':
yield "data: " + json.dumps(cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
yield "data: " + json.dumps(
cls.get_message_replace_response_data(result.get('data'))) + "\n\n"
elif event == 'chain':
yield "data: " + json.dumps(cls.get_chain_response_data(result.get('data'))) + "\n\n"
elif event == 'agent_thought':

View File

@@ -1,17 +1,20 @@
from typing import Union, Optional
from core.generator.llm_generator import LLMGenerator
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db
from models.account import Account
from models.model import Conversation, App, EndUser
from models.model import Conversation, App, EndUser, Message
from services.errors.conversation import ConversationNotExistsError, LastConversationNotExistsError
from services.errors.message import MessageNotExistsError
class ConversationService:
@classmethod
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]],
last_id: Optional[str], limit: int,
include_ids: Optional[list] = None, exclude_ids: Optional[list] = None) -> InfiniteScrollPagination:
include_ids: Optional[list] = None, exclude_ids: Optional[list] = None,
exclude_debug_conversation: bool = False) -> InfiniteScrollPagination:
if not user:
return InfiniteScrollPagination(data=[], limit=limit, has_more=False)
@@ -29,6 +32,9 @@ class ConversationService:
if exclude_ids is not None:
base_query = base_query.filter(~Conversation.id.in_(exclude_ids))
if exclude_debug_conversation:
base_query = base_query.filter(Conversation.override_model_configs == None)
if last_id:
last_conversation = base_query.filter(
Conversation.id == last_id,
@@ -63,10 +69,36 @@ class ConversationService:
@classmethod
def rename(cls, app_model: App, conversation_id: str,
user: Optional[Union[Account | EndUser]], name: str):
user: Optional[Union[Account | EndUser]], name: str, auto_generate: bool):
conversation = cls.get_conversation(app_model, conversation_id, user)
conversation.name = name
if auto_generate:
return cls.auto_generate_name(app_model, conversation)
else:
conversation.name = name
db.session.commit()
return conversation
@classmethod
def auto_generate_name(cls, app_model: App, conversation: Conversation):
# get conversation first message
message = db.session.query(Message) \
.filter(
Message.app_id == app_model.id,
Message.conversation_id == conversation.id
).order_by(Message.created_at.asc()).first()
if not message:
raise MessageNotExistsError()
# generate conversation name
try:
name = LLMGenerator.generate_conversation_name(app_model.tenant_id, message.query)
conversation.name = name
except:
pass
db.session.commit()
return conversation

View File

@@ -1,46 +1,62 @@
import datetime
import hashlib
import time
import uuid
from typing import Generator, Tuple, Union
from cachetools import TTLCache
from flask import request, current_app
from flask import current_app
from flask_login import current_user
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
from core.data_loader.file_extractor import FileExtractor
from core.file.upload_file_parser import UploadFileParser
from extensions.ext_storage import storage
from extensions.ext_database import db
from models.model import UploadFile
from models.account import Account
from models.model import UploadFile, EndUser
from services.errors.file import FileTooLargeError, UnsupportedFileTypeError
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv']
ALLOWED_EXTENSIONS = ['txt', 'markdown', 'md', 'pdf', 'html', 'htm', 'xlsx', 'docx', 'csv',
'jpg', 'jpeg', 'png', 'webp', 'gif']
IMAGE_EXTENSIONS = ['jpg', 'jpeg', 'png', 'webp', 'gif']
PREVIEW_WORDS_LIMIT = 3000
cache = TTLCache(maxsize=None, ttl=30)
class FileService:
@staticmethod
def upload_file(file: FileStorage) -> UploadFile:
def upload_file(file: FileStorage, user: Union[Account, EndUser], only_image: bool = False) -> UploadFile:
extension = file.filename.split('.')[-1]
if extension.lower() not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError()
elif only_image and extension.lower() not in IMAGE_EXTENSIONS:
raise UnsupportedFileTypeError()
# read file content
file_content = file.read()
# get file size
file_size = len(file_content)
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
if extension.lower() in IMAGE_EXTENSIONS:
file_size_limit = current_app.config.get("UPLOAD_IMAGE_FILE_SIZE_LIMIT") * 1024 * 1024
else:
file_size_limit = current_app.config.get("UPLOAD_FILE_SIZE_LIMIT") * 1024 * 1024
if file_size > file_size_limit:
message = f'File size exceeded. {file_size} > {file_size_limit}'
raise FileTooLargeError(message)
extension = file.filename.split('.')[-1]
if extension.lower() not in ALLOWED_EXTENSIONS:
raise UnsupportedFileTypeError()
# user uuid as file name
file_uuid = str(uuid.uuid4())
file_key = 'upload_files/' + current_user.current_tenant_id + '/' + file_uuid + '.' + extension
if isinstance(user, Account):
current_tenant_id = user.current_tenant_id
else:
# end_user
current_tenant_id = user.tenant_id
file_key = 'upload_files/' + current_tenant_id + '/' + file_uuid + '.' + extension
# save file to storage
storage.save(file_key, file_content)
@@ -48,14 +64,15 @@ class FileService:
# save file to db
config = current_app.config
upload_file = UploadFile(
tenant_id=current_user.current_tenant_id,
tenant_id=current_tenant_id,
storage_type=config['STORAGE_TYPE'],
key=file_key,
name=file.filename,
size=file_size,
extension=extension,
mime_type=file.mimetype,
created_by=current_user.id,
created_by_role=('account' if isinstance(user, Account) else 'end_user'),
created_by=user.id,
created_at=datetime.datetime.utcnow(),
used=False,
hash=hashlib.sha3_256(file_content).hexdigest()
@@ -99,12 +116,6 @@ class FileService:
@staticmethod
def get_file_preview(file_id: str) -> str:
# get file storage key
key = file_id + request.path
cached_response = cache.get(key)
if cached_response and time.time() - cached_response['timestamp'] < cache.ttl:
return cached_response['response']
upload_file = db.session.query(UploadFile) \
.filter(UploadFile.id == file_id) \
.first()
@@ -121,3 +132,25 @@ class FileService:
text = text[0:PREVIEW_WORDS_LIMIT] if text else ''
return text
@staticmethod
def get_image_preview(file_id: str, timestamp: str, nonce: str, sign: str) -> Tuple[Generator, str]:
result = UploadFileParser.verify_image_file_signature(file_id, timestamp, nonce, sign)
if not result:
raise NotFound("File not found or signature is invalid")
upload_file = db.session.query(UploadFile) \
.filter(UploadFile.id == file_id) \
.first()
if not upload_file:
raise NotFound("File not found or signature is invalid")
# extract text from file
extension = upload_file.extension
if extension.lower() not in IMAGE_EXTENSIONS:
raise UnsupportedFileTypeError()
generator = storage.load(upload_file.key, stream=True)
return generator, upload_file.mime_type

View File

@@ -11,7 +11,8 @@ from services.conversation_service import ConversationService
class WebConversationService:
@classmethod
def pagination_by_last_id(cls, app_model: App, user: Optional[Union[Account | EndUser]],
last_id: Optional[str], limit: int, pinned: Optional[bool] = None) -> InfiniteScrollPagination:
last_id: Optional[str], limit: int, pinned: Optional[bool] = None,
exclude_debug_conversation: bool = False) -> InfiniteScrollPagination:
include_ids = None
exclude_ids = None
if pinned is not None:
@@ -32,7 +33,8 @@ class WebConversationService:
last_id=last_id,
limit=limit,
include_ids=include_ids,
exclude_ids=exclude_ids
exclude_ids=exclude_ids,
exclude_debug_conversation=exclude_debug_conversation
)
@classmethod

View File

@@ -5,7 +5,7 @@ from unittest.mock import patch
from langchain.schema import Generation, ChatGeneration, AIMessage
from core.model_providers.providers.openai_provider import OpenAIProvider
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.message import PromptMessage, MessageType, ImageMessageFile
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.openai_model import OpenAIModel
from models.provider import Provider, ProviderType
@@ -57,6 +57,18 @@ def test_chat_get_num_tokens(mock_decrypt):
assert rst == 22
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_vision_chat_get_num_tokens(mock_decrypt):
openai_model = get_mock_openai_model('gpt-4-vision-preview')
messages = [
PromptMessage(content='Whats in first image?', files=[
ImageMessageFile(
data='https://upload.wikimedia.org/wikipedia/commons/0/00/1890s_Carlisle_Boarding_School_Graduates_PA.jpg')
])
]
rst = openai_model.get_num_tokens(messages)
assert rst == 77
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
@@ -80,4 +92,20 @@ def test_chat_run(mock_decrypt, mocker):
messages,
stop=['\nHuman:'],
)
assert (len(rst.content) > 0)
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_vision_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
openai_model = get_mock_openai_model('gpt-4-vision-preview')
messages = [
PromptMessage(content='Whats in first image?', files=[
ImageMessageFile(data='https://upload.wikimedia.org/wikipedia/commons/0/00/1890s_Carlisle_Boarding_School_Graduates_PA.jpg')
])
]
rst = openai_model.run(
messages,
)
assert len(rst.content) > 0

View File

@@ -2,7 +2,7 @@ version: '3.1'
services:
# API service
api:
image: langgenius/dify-api:0.3.29
image: langgenius/dify-api:0.3.30
restart: always
environment:
# Startup mode, 'api' starts the API server.
@@ -19,18 +19,22 @@ services:
# different from api or web app domain.
# example: http://cloud.dify.ai
CONSOLE_API_URL: ''
# The URL for Service API endpoints, refers to the base URL of the current API service if api domain is
# The URL prefix for Service API endpoints, refers to the base URL of the current API service if api domain is
# different from console domain.
# example: http://api.dify.ai
SERVICE_API_URL: ''
# The URL for Web APP api server, refers to the Web App base URL of WEB service if web app domain is different from
# The URL prefix for Web APP api server, refers to the Web App base URL of WEB service if web app domain is different from
# console or api domain.
# example: http://udify.app
APP_API_URL: ''
# The URL for Web APP frontend, refers to the Web App base URL of WEB service if web app domain is different from
# The URL prefix for Web APP frontend, refers to the Web App base URL of WEB service if web app domain is different from
# console or api domain.
# example: http://udify.app
APP_WEB_URL: ''
# File preview or download Url prefix.
# used to display File preview or download Url to the front-end or as Multi-model inputs;
# Url is signed and has expiration time.
FILES_URL: ''
# When enabled, migrations will be executed prior to application startup and the application will start after the migrations have completed.
MIGRATION_ENABLED: 'true'
# The configurations of postgres database connection.
@@ -124,7 +128,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.3.29
image: langgenius/dify-api:0.3.30
restart: always
environment:
# Startup mode, 'worker' starts the Celery worker for processing the queue.
@@ -192,7 +196,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.3.29
image: langgenius/dify-web:0.3.30
restart: always
environment:
EDITION: SELF_HOSTED

View File

@@ -17,6 +17,11 @@ server {
include proxy.conf;
}
location /files {
proxy_pass http://api:5001;
include proxy.conf;
}
location / {
proxy_pass http://web:3000;
include proxy.conf;

View File

@@ -20,16 +20,14 @@
- options.js 插件配置JS脚本
### 插件导入完成后,后续配置无差异
- 初始化设置Dify 应用配置,分别输入Dify根域名和应用TokenToken可以在Dify应用嵌入中获取,如图:
- 创建Dify应用配置在应用概览中点击嵌入切换到安装Chrome浏览器扩展视图点击copy按钮获取ChatBot Url,如图:
![img-2.png](images/img-2.png)
![img-3.png](images/img-3.png)
- 点击保存,确认提示配置成功即可
![img-4.png](images/img-4.png)
![img-3.png](images/img-3.png)
- 保险起见重启浏览器确保所有分页刷新成功
- Chrome打开任意页面均可正常加载DIfy机器人浮动栏后续如需更换机器人只需要变更Token即可
- Chrome打开任意页面均可正常加载DIfy机器人浮动栏后续如需更换机器人只需要变更ChatBot Url即可
![img-5.png](images/img-5.png)
![img-4.png](images/img-4.png)

View File

@@ -0,0 +1,6 @@
## Chrome Dify ChatBot插件
1、初始化设置Dify 应用配置分别输入Dify根域名和应用TokenToken可以在Dify应用嵌入中获取
2、点击保存确认提示配置成功即可
3、保险起见重启浏览器确保所有分页刷新成功
4、Chrome打开任意页面均可正常加载DIfy机器人浮动栏后续如需更换机器人只需要变更Token即可

View File

@@ -1,8 +1,7 @@
var storage = chrome.storage.sync;
chrome.storage.sync.get(['baseUrl', 'token'], function(result) {
const storage = chrome.storage.sync;
chrome.storage.sync.get(['chatbotUrl'], function(result) {
window.difyChatbotConfig = {
baseUrl: result.baseUrl,
token: result.token
chatbotUrl: result.chatbotUrl,
};
});
@@ -10,11 +9,10 @@ document.body.onload = embedChatbot;
async function embedChatbot() {
const difyChatbotConfig = window.difyChatbotConfig;
if (!difyChatbotConfig || !difyChatbotConfig.token) {
console.warn('difyChatbotConfig is empty or token is not provided');
if (!difyChatbotConfig) {
console.warn('Dify Chatbot Url is empty or is not provided');
return;
}
const baseUrl = difyChatbotConfig.baseUrl
const openIcon = `<svg
id="openIcon"
width="24"
@@ -53,7 +51,7 @@ async function embedChatbot() {
iframe.allow = "fullscreen;microphone"
iframe.title = "dify chatbot bubble window"
iframe.id = 'dify-chatbot-bubble-window'
iframe.src = `${baseUrl}/chat/${difyChatbotConfig.token}`
iframe.src = difyChatbotConfig.chatbotUrl
iframe.style.cssText = 'border: none; position: fixed; flex-direction: column; justify-content: space-between; box-shadow: rgba(150, 150, 150, 0.2) 0px 10px 30px 0px, rgba(150, 150, 150, 0.2) 0px 0px 0px 1px; bottom: 6.7rem; right: 1rem; width: 30rem; height: 48rem; border-radius: 0.75rem; display: flex; z-index: 2147483647; overflow: hidden; left: unset; background-color: #F3F4F6;'
document.body.appendChild(iframe);
}

BIN
third-party/chrome plug-in/favicon.png vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 2.8 KiB

Binary file not shown.

After

Width:  |  Height:  |  Size: 15 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 73 KiB

After

Width:  |  Height:  |  Size: 124 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 96 KiB

After

Width:  |  Height:  |  Size: 85 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 55 KiB

After

Width:  |  Height:  |  Size: 85 KiB

Binary file not shown.

Before

Width:  |  Height:  |  Size: 85 KiB

View File

@@ -1,7 +1,7 @@
{
"manifest_version": 3,
"name": "Dify Chatbot",
"version": "1.3",
"version": "1.5",
"description": "This is a chrome extension to inject a dify chatbot on any pages",
"content_scripts": [
{
@@ -17,7 +17,7 @@
"32": "images/32.png",
"48": "images/48.png",
"128": "images/128.png"
}
},
"icons": {

19
third-party/chrome plug-in/options.css vendored Normal file
View File

@@ -0,0 +1,19 @@
body {
background-color: #f2f2f2;
font-family: Arial, sans-serif;
}
h2 {
color: #333;
}
label {
display: block;
margin-top: 10px;
margin-bottom: 10px;
}
input[type="text"] {
width: 280px;
padding: 6px;
}

View File

@@ -4,32 +4,21 @@
<head>
<title>Dify Chatbot Extension</title>
<meta name="viewport" content="width=device-width, initial-scale=1.0">
<link href="https://cdn.jsdelivr.net/npm/tailwindcss@2.2.19/dist/tailwind.min.css" rel="stylesheet">
<link href="./tailwind.css" rel="stylesheet">
</head>
<body class="bg-gray-100 py-8 px-4 w-96">
<body class="bg-gray-100 py-4 px-4 w-128">
<div class="max-w-md mx-auto bg-white shadow-md rounded-lg p-4">
<h2 class="text-2xl font-semibold mb-4">Dify Chatbot Extension</h2>
<form>
<div class="mb-4 flex items-center">
<div class="w-1/4">
<label for="base-url" class="block font-semibold text-gray-700">Base URL</label>
<label for="chatbot-url" class="block font-semibold text-gray-700">ChatBot URL</label>
</div>
<div class="w-3/4">
<input type="text" id="base-url" name="base-url" value=""
<input type="text" id="chatbot-url" name="base-url" value=""
class="w-full border border-gray-300 rounded px-3 py-2 focus:outline-none focus:border-blue-400"
placeholder="https://udify.app">
</div>
</div>
<div class="mb-4 flex items-center">
<div class="w-1/4">
<label for="token" class="block font-semibold text-gray-700">Token</label>
</div>
<div class="w-3/4">
<input type="text" id="token" name="token" value=""
class="w-full border border-gray-300 rounded px-3 py-2 focus:outline-none focus:border-blue-400"
placeholder="Application Embedded Token">
placeholder="https://udify.app/chatbot/7CQBa5yyvYLSkZtx">
</div>
</div>

View File

@@ -1,39 +1,28 @@
document.getElementById('save-button').addEventListener('click', function (e) {
e.preventDefault();
var baseUrl = document.getElementById('base-url').value;
var token = document.getElementById('token').value;
var errorTip = document.getElementById('error-tip');
const chatbotUrl = document.getElementById('chatbot-url').value;
const errorTip = document.getElementById('error-tip');
if (baseUrl.trim() === "" || token.trim() === "") {
if (baseUrl.trim() === "") {
errorTip.textContent = "Base URL cannot be empty.";
} else {
errorTip.textContent = "Token cannot be empty.";
}
if (chatbotUrl.trim() === "") {
errorTip.textContent = "Dify ChatBot URL cannot be empty.";
} else {
errorTip.textContent = "";
chrome.storage.sync.set({
'baseUrl': baseUrl,
'token': token
'chatbotUrl': chatbotUrl,
}, function () {
alert('Save Success!');
});
}
});
// Load parameters from chrome.storage when the page loads
chrome.storage.sync.get(['baseUrl', 'token'], function (result) {
const baseUrlInput = document.getElementById('base-url');
const tokenInput = document.getElementById('token');
chrome.storage.sync.get(['chatbotUrl'], function (result) {
const chatbotUrlInput = document.getElementById('chatbot-url');
if (result.baseUrl) {
baseUrlInput.value = result.baseUrl;
if (result.chatbotUrl) {
chatbotUrlInput.value = result.chatbotUrl;
}
if (result.token) {
tokenInput.value = result.token;
}
});

176015
third-party/chrome plug-in/tailwind.css vendored Normal file

File diff suppressed because it is too large Load Diff

View File

@@ -1,6 +1,7 @@
'use client'
import type { FC, ReactNode } from 'react'
import React, { useEffect, useLayoutEffect, useRef, useState } from 'react'
import Textarea from 'rc-textarea'
import { useContext } from 'use-context-selector'
import cn from 'classnames'
import Recorder from 'js-audio-recorder'
@@ -10,9 +11,8 @@ import type { DisplayScene, FeedbackFunc, IChatItem, SubmitAnnotationFunc } from
import { TryToAskIcon, stopIcon } from './icon-component'
import Answer from './answer'
import Question from './question'
import Tooltip from '@/app/components/base/tooltip'
import TooltipPlus from '@/app/components/base/tooltip-plus'
import { ToastContext } from '@/app/components/base/toast'
import AutoHeightTextarea from '@/app/components/base/auto-height-textarea'
import Button from '@/app/components/base/button'
import useBreakpoints, { MediaType } from '@/hooks/use-breakpoints'
import VoiceInput from '@/app/components/base/voice-input'
@@ -20,6 +20,10 @@ import { Microphone01 } from '@/app/components/base/icons/src/vender/line/mediaA
import { Microphone01 as Microphone01Solid } from '@/app/components/base/icons/src/vender/solid/mediaAndDevices'
import { XCircle } from '@/app/components/base/icons/src/vender/solid/general'
import type { DataSet } from '@/models/datasets'
import ChatImageUploader from '@/app/components/base/image-uploader/chat-image-uploader'
import ImageList from '@/app/components/base/image-uploader/image-list'
import { TransferMethod, type VisionFile, type VisionSettings } from '@/types/app'
import { useImageFiles } from '@/app/components/base/image-uploader/hooks'
export type IChatProps = {
configElem?: React.ReactNode
@@ -37,7 +41,7 @@ export type IChatProps = {
onFeedback?: FeedbackFunc
onSubmitAnnotation?: SubmitAnnotationFunc
checkCanSend?: () => boolean
onSend?: (message: string) => void
onSend?: (message: string, files: VisionFile[]) => void
displayScene?: DisplayScene
useCurrentUserAvatar?: boolean
isResponsing?: boolean
@@ -54,6 +58,7 @@ export type IChatProps = {
dataSets?: DataSet[]
isShowCitationHitInfo?: boolean
isShowPromptLog?: boolean
visionConfig?: VisionSettings
}
const Chat: FC<IChatProps> = ({
@@ -83,9 +88,19 @@ const Chat: FC<IChatProps> = ({
dataSets,
isShowCitationHitInfo,
isShowPromptLog,
visionConfig,
}) => {
const { t } = useTranslation()
const { notify } = useContext(ToastContext)
const {
files,
onUpload,
onRemove,
onReUpload,
onImageLinkLoadError,
onImageLinkLoadSuccess,
onClear,
} = useImageFiles()
const isUseInputMethod = useRef(false)
const [query, setQuery] = React.useState('')
@@ -114,9 +129,18 @@ const Chat: FC<IChatProps> = ({
const handleSend = () => {
if (!valid() || (checkCanSend && !checkCanSend()))
return
onSend(query)
if (!isResponsing)
setQuery('')
onSend(query, files.filter(file => file.progress !== -1).map(fileItem => ({
type: 'image',
transfer_method: fileItem.type,
url: fileItem.url,
upload_file_id: fileItem.fileId,
})))
if (!files.find(item => item.type === TransferMethod.local_file && !item.fileId)) {
if (files.length)
onClear()
if (!isResponsing)
setQuery('')
}
}
const handleKeyUp = (e: React.KeyboardEvent<HTMLTextAreaElement>) => {
@@ -198,6 +222,8 @@ const Chat: FC<IChatProps> = ({
item={item}
isShowPromptLog={isShowPromptLog}
isResponsing={isResponsing}
// ['https://placekitten.com/360/360', 'https://placekitten.com/360/640']
imgSrcs={(item.message_files && item.message_files?.length > 0) ? item.message_files.map(item => item.url) : []}
/>
)
})}
@@ -246,18 +272,42 @@ const Chat: FC<IChatProps> = ({
</div>
</div>)
}
<div className="relative">
<AutoHeightTextarea
<div className='p-[5.5px] max-h-[150px] bg-white border-[1.5px] border-gray-200 rounded-xl overflow-y-auto'>
{
visionConfig?.enabled && (
<>
<div className='absolute bottom-2 left-2 flex items-center'>
<ChatImageUploader
settings={visionConfig}
onUpload={onUpload}
disabled={files.length >= visionConfig.number_limits}
/>
<div className='mx-1 w-[1px] h-4 bg-black/5' />
</div>
<div className='pl-[52px]'>
<ImageList
list={files}
onRemove={onRemove}
onReUpload={onReUpload}
onImageLinkLoadSuccess={onImageLinkLoadSuccess}
onImageLinkLoadError={onImageLinkLoadError}
/>
</div>
</>
)
}
<Textarea
className={`
block w-full px-2 pr-[118px] py-[7px] leading-5 max-h-none text-sm text-gray-700 outline-none appearance-none resize-none
${visionConfig?.enabled && 'pl-12'}
`}
value={query}
onChange={handleContentChange}
onKeyUp={handleKeyUp}
onKeyDown={handleKeyDown}
minHeight={48}
autoFocus
controlFocus={controlFocus}
className={`${cn(s.textArea)} resize-none block w-full pl-3 bg-gray-50 border border-gray-200 rounded-md focus:outline-none sm:text-sm text-gray-700`}
autoSize
/>
<div className="absolute top-0 right-2 flex items-center h-[48px]">
<div className="absolute bottom-2 right-2 flex items-center h-8">
<div className={`${s.count} mr-4 h-5 leading-5 text-sm bg-gray-50 text-gray-500`}>{query.trim().length}</div>
{
query
@@ -282,9 +332,8 @@ const Chat: FC<IChatProps> = ({
{isMobile
? sendBtn
: (
<Tooltip
selector='send-tip'
htmlContent={
<TooltipPlus
popupContent={
<div>
<div>{t('common.operation.send')} Enter</div>
<div>{t('common.operation.lineBreak')} Shift Enter</div>
@@ -292,7 +341,7 @@ const Chat: FC<IChatProps> = ({
}
>
{sendBtn}
</Tooltip>
</TooltipPlus>
)}
</div>
{

View File

@@ -8,14 +8,16 @@ import Log from '../log'
import MoreInfo from '../more-info'
import AppContext from '@/context/app-context'
import { Markdown } from '@/app/components/base/markdown'
import ImageGallery from '@/app/components/base/image-gallery'
type IQuestionProps = Pick<IChatItem, 'id' | 'content' | 'more' | 'useCurrentUserAvatar'> & {
isShowPromptLog?: boolean
item: IChatItem
imgSrcs?: string[]
isResponsing?: boolean
}
const Question: FC<IQuestionProps> = ({ id, content, more, useCurrentUserAvatar, isShowPromptLog, item, isResponsing }) => {
const Question: FC<IQuestionProps> = ({ id, content, imgSrcs, more, useCurrentUserAvatar, isShowPromptLog, item, isResponsing }) => {
const { userProfile } = useContext(AppContext)
const userName = userProfile?.name
const ref = useRef(null)
@@ -23,6 +25,7 @@ const Question: FC<IQuestionProps> = ({ id, content, more, useCurrentUserAvatar,
return (
<div className={`flex items-start justify-end ${isShowPromptLog && 'first-of-type:pt-[14px]'}`} key={id} ref={ref}>
<div className={s.questionWrapWrap}>
<div className={`${s.question} group relative text-sm text-gray-900`}>
{
isShowPromptLog && !isResponsing && (
@@ -32,6 +35,9 @@ const Question: FC<IQuestionProps> = ({ id, content, more, useCurrentUserAvatar,
<div
className={'mr-2 py-3 px-4 bg-blue-500 rounded-tl-2xl rounded-b-2xl'}
>
{imgSrcs && imgSrcs.length > 0 && (
<ImageGallery srcs={imgSrcs} />
)}
<Markdown content={content} />
</div>
</div>

View File

@@ -1,5 +1,5 @@
import type { Annotation, MessageRating } from '@/models/log'
import type { VisionFile } from '@/types/app'
export type MessageMore = {
time: string
tokens: number
@@ -67,6 +67,7 @@ export type IChatItem = {
useCurrentUserAvatar?: boolean
isOpeningStatement?: boolean
log?: { role: string; text: string }[]
message_files?: VisionFile[]
}
export type MessageEnd = {

View File

@@ -33,7 +33,7 @@ export type IConfigModelProps = {
mode: string
modelId: string
provider: ProviderEnum
setModel: (model: { id: string; provider: ProviderEnum; mode: ModelModeType }) => void
setModel: (model: { id: string; provider: ProviderEnum; mode: ModelModeType; features: string[] }) => void
completionParams: CompletionParams
onCompletionParamsChange: (newParams: CompletionParams) => void
disabled: boolean
@@ -121,7 +121,7 @@ const ConfigModel: FC<IConfigModelProps> = ({
return adjustedValue
}
const handleSelectModel = ({ id, provider: nextProvider, mode }: { id: string; provider: ProviderEnum; mode: ModelModeType }) => {
const handleSelectModel = ({ id, provider: nextProvider, mode, features }: { id: string; provider: ProviderEnum; mode: ModelModeType; features: string[] }) => {
return async () => {
const prevParamsRule = getAllParams()[provider]?.[modelId]
@@ -129,6 +129,7 @@ const ConfigModel: FC<IConfigModelProps> = ({
id,
provider: nextProvider || ProviderEnum.openai,
mode,
features,
})
await ensureModelParamLoaded(nextProvider, id)
@@ -320,6 +321,7 @@ const ConfigModel: FC<IConfigModelProps> = ({
id: model.model_name,
provider: model.model_provider.provider_name as ProviderEnum,
mode: model.model_mode,
features: model.features,
})()
}}
/>

View File

@@ -1,7 +1,6 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import { useTranslation } from 'react-i18next'
export type IModelNameProps = {
modelId: string
@@ -16,19 +15,11 @@ export const supportI18nModelName = [
]
const ModelName: FC<IModelNameProps> = ({
modelId,
modelDisplayName,
}) => {
const { t } = useTranslation()
let name = modelId
if (supportI18nModelName.includes(modelId))
name = t(`common.modelName.${modelId}`)
else if (modelDisplayName)
name = modelDisplayName
return (
<span title={name}>
{name}
<span title={modelDisplayName}>
{modelDisplayName}
</span>
)
}

View File

@@ -169,7 +169,7 @@ const ConfigVar: FC<IConfigVarProps> = ({ promptVariables, readonly, onPromptVar
}
title={
<div className='flex items-center'>
<div className='ml-1 mr-1'>{t('appDebug.variableTitle')}</div>
<div className='mr-1'>{t('appDebug.variableTitle')}</div>
{!readonly && (
<Tooltip htmlContent={<div className='w-[180px]'>
{t('appDebug.variableTip')}

View File

@@ -0,0 +1,60 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import { useTranslation } from 'react-i18next'
import { useContext } from 'use-context-selector'
import Panel from '../base/feature-panel'
import ParamConfig from './param-config'
import { HelpCircle } from '@/app/components/base/icons/src/vender/line/general'
import Tooltip from '@/app/components/base/tooltip'
import Switch from '@/app/components/base/switch'
import { Eye } from '@/app/components/base/icons/src/vender/solid/general'
import ConfigContext from '@/context/debug-configuration'
const ConfigVision: FC = () => {
const { t } = useTranslation()
const {
isShowVisionConfig,
visionConfig,
setVisionConfig,
} = useContext(ConfigContext)
if (!isShowVisionConfig)
return null
return (<>
<Panel
className="mt-4"
headerIcon={
<Eye className='w-4 h-4 text-[#6938EF]'/>
}
title={
<div className='flex items-center'>
<div className='mr-1'>{t('appDebug.vision.name')}</div>
<Tooltip htmlContent={<div className='w-[180px]' >
{t('appDebug.vision.description')}
</div>} selector='config-vision-tooltip'>
<HelpCircle className='w-[14px] h-[14px] text-gray-400' />
</Tooltip>
</div>
}
headerRight={
<div className='flex items-center'>
<ParamConfig />
<div className='ml-4 mr-3 w-[1px] h-3.5 bg-gray-200'></div>
<Switch
defaultValue={visionConfig.enabled}
onChange={value => setVisionConfig({
...visionConfig,
enabled: value,
})}
size='md'
/>
</div>
}
noBodySpacing
/>
</>
)
}
export default React.memo(ConfigVision)

View File

@@ -0,0 +1,132 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import { useContext } from 'use-context-selector'
import { useTranslation } from 'react-i18next'
import RadioGroup from './radio-group'
import ConfigContext from '@/context/debug-configuration'
import { Resolution, TransferMethod } from '@/types/app'
import ParamItem from '@/app/components/base/param-item'
import Tooltip from '@/app/components/base/tooltip'
import { HelpCircle } from '@/app/components/base/icons/src/vender/line/general'
const MIN = 1
const MAX = 6
const ParamConfigContent: FC = () => {
const { t } = useTranslation()
const {
visionConfig,
setVisionConfig,
} = useContext(ConfigContext)
const transferMethod = (() => {
if (!visionConfig.transfer_methods || visionConfig.transfer_methods.length === 2)
return TransferMethod.all
return visionConfig.transfer_methods[0]
})()
return (
<div>
<div>
<div className='leading-6 text-base font-semibold text-gray-800'>{t('appDebug.vision.visionSettings.title')}</div>
<div className='pt-3 space-y-6'>
<div>
<div className='mb-2 flex items-center space-x-1'>
<div className='leading-[18px] text-[13px] font-semibold text-gray-800'>{t('appDebug.vision.visionSettings.resolution')}</div>
<Tooltip htmlContent={<div className='w-[180px]' >
{t('appDebug.vision.visionSettings.resolutionTooltip').split('\n').map(item => (
<div key={item}>{item}</div>
))}
</div>} selector='config-resolution-tooltip'>
<HelpCircle className='w-[14px] h-[14px] text-gray-400' />
</Tooltip>
</div>
<RadioGroup
className='space-x-3'
options={[
{
label: t('appDebug.vision.visionSettings.high'),
value: Resolution.high,
},
{
label: t('appDebug.vision.visionSettings.low'),
value: Resolution.low,
},
]}
value={visionConfig.detail}
onChange={(value: Resolution) => {
setVisionConfig({
...visionConfig,
detail: value,
})
}}
/>
</div>
<div>
<div className='mb-2 leading-[18px] text-[13px] font-semibold text-gray-800'>{t('appDebug.vision.visionSettings.uploadMethod')}</div>
<RadioGroup
className='space-x-3'
options={[
{
label: t('appDebug.vision.visionSettings.both'),
value: TransferMethod.all,
},
{
label: t('appDebug.vision.visionSettings.localUpload'),
value: TransferMethod.local_file,
},
{
label: t('appDebug.vision.visionSettings.url'),
value: TransferMethod.remote_url,
},
]}
value={transferMethod}
onChange={(value: TransferMethod) => {
if (value === TransferMethod.all) {
setVisionConfig({
...visionConfig,
transfer_methods: [TransferMethod.remote_url, TransferMethod.local_file],
})
return
}
setVisionConfig({
...visionConfig,
transfer_methods: [value],
})
}}
/>
</div>
<div>
<ParamItem
id='upload_limit'
className=''
name={t('appDebug.vision.visionSettings.uploadLimit')}
noTooltip
{...{
default: 2,
step: 1,
min: MIN,
max: MAX,
}}
value={visionConfig.number_limits}
enable={true}
onChange={(_key: string, value: number) => {
if (!value)
return
setVisionConfig({
...visionConfig,
number_limits: value,
})
}}
/>
</div>
</div>
</div>
</div>
)
}
export default React.memo(ParamConfigContent)

View File

@@ -0,0 +1,41 @@
'use client'
import type { FC } from 'react'
import { memo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import cn from 'classnames'
import ParamConfigContent from './param-config-content'
import { Settings01 } from '@/app/components/base/icons/src/vender/line/general'
import {
PortalToFollowElem,
PortalToFollowElemContent,
PortalToFollowElemTrigger,
} from '@/app/components/base/portal-to-follow-elem'
const ParamsConfig: FC = () => {
const { t } = useTranslation()
const [open, setOpen] = useState(false)
return (
<PortalToFollowElem
open={open}
onOpenChange={setOpen}
placement='bottom-end'
offset={{
mainAxis: 4,
}}
>
<PortalToFollowElemTrigger onClick={() => setOpen(v => !v)}>
<div className={cn('flex items-center rounded-md h-7 px-3 space-x-1 text-gray-700 cursor-pointer hover:bg-gray-200', open && 'bg-gray-200')}>
<Settings01 className='w-3.5 h-3.5 ' />
<div className='ml-1 leading-[18px] text-xs font-medium '>{t('appDebug.vision.settings')}</div>
</div>
</PortalToFollowElemTrigger>
<PortalToFollowElemContent style={{ zIndex: 50 }}>
<div className='w-[412px] p-4 bg-white rounded-lg border-[0.5px] border-gray-200 shadow-lg space-y-3'>
<ParamConfigContent />
</div>
</PortalToFollowElemContent>
</PortalToFollowElem>
)
}
export default memo(ParamsConfig)

View File

@@ -0,0 +1,40 @@
'use client'
import type { FC } from 'react'
import React from 'react'
import cn from 'classnames'
import s from './style.module.css'
type OPTION = {
label: string
value: any
}
type Props = {
className?: string
options: OPTION[]
value: any
onChange: (value: any) => void
}
const RadioGroup: FC<Props> = ({
className = '',
options,
value,
onChange,
}) => {
return (
<div className={cn(className, 'flex')}>
{options.map(item => (
<div
key={item.value}
className={cn(s.item, item.value === value && s.checked)}
onClick={() => onChange(item.value)}
>
<div className={s.radio}></div>
<div className='text-[13px] font-medium text-gray-900'>{item.label}</div>
</div>
))}
</div>
)
}
export default React.memo(RadioGroup)

View File

@@ -0,0 +1,24 @@
.item {
@apply grow flex items-center h-8 px-2.5 rounded-lg bg-gray-25 border border-gray-100 cursor-pointer space-x-2;
}
.item:hover {
background-color: #ffffff;
border-color: #B2CCFF;
box-shadow: 0px 12px 16px -4px rgba(16, 24, 40, 0.08), 0px 4px 6px -2px rgba(16, 24, 40, 0.03);
}
.item.checked {
background-color: #ffffff;
border-color: #528BFF;
box-shadow: 0px 1px 2px 0px rgba(16, 24, 40, 0.06), 0px 1px 3px 0px rgba(16, 24, 40, 0.10);
}
.radio {
@apply w-4 h-4 border-[2px] border-gray-200 rounded-full;
}
.item.checked .radio {
border-width: 5px;
border-color: #155eef;
}

View File

@@ -10,6 +10,7 @@ import ChatGroup from '../features/chat-group'
import ExperienceEnchanceGroup from '../features/experience-enchance-group'
import Toolbox from '../toolbox'
import HistoryPanel from '../config-prompt/conversation-histroy/history-panel'
import ConfigVision from '../config-vision'
import AddFeatureBtn from './feature/add-feature-btn'
import ChooseFeature from './feature/choose-feature'
import useFeature from './feature/use-feature'
@@ -193,6 +194,8 @@ const Config: FC = () => {
<Tools />
<ConfigVision />
{/* Chat History */}
{isAdvancedMode && isChatApp && modelModeType === ModelModeType.completion && (
<HistoryPanel

View File

@@ -81,7 +81,7 @@ const DatasetConfig: FC = () => {
>
{hasData
? (
<div className='flex flex-wrap mt-1 px-3 justify-between'>
<div className='flex flex-wrap mt-1 px-3 pb-3 justify-between'>
{dataSet.map(item => (
<CardItem
key={item.id}

View File

@@ -1,5 +1,6 @@
'use client'
import type { FC } from 'react'
import useSWR from 'swr'
import { useTranslation } from 'react-i18next'
import React, { useEffect, useRef, useState } from 'react'
import cn from 'classnames'
@@ -11,7 +12,7 @@ import HasNotSetAPIKEY from '../base/warning-mask/has-not-set-api'
import FormattingChanged from '../base/warning-mask/formatting-changed'
import GroupName from '../base/group-name'
import CannotQueryDataset from '../base/warning-mask/cannot-query-dataset'
import { AppType, ModelModeType } from '@/types/app'
import { AppType, ModelModeType, TransferMethod } from '@/types/app'
import PromptValuePanel, { replaceStringWithValues } from '@/app/components/app/configuration/prompt-value-panel'
import type { IChatItem } from '@/app/components/app/chat/type'
import Chat from '@/app/components/app/chat'
@@ -19,12 +20,13 @@ import ConfigContext from '@/context/debug-configuration'
import { ToastContext } from '@/app/components/base/toast'
import { fetchConvesationMessages, fetchSuggestedQuestions, sendChatMessage, sendCompletionMessage, stopChatMessageResponding } from '@/service/debug'
import Button from '@/app/components/base/button'
import type { ModelConfig as BackendModelConfig } from '@/types/app'
import type { ModelConfig as BackendModelConfig, VisionFile } from '@/types/app'
import { promptVariablesToUserInputsForm } from '@/utils/model-config'
import TextGeneration from '@/app/components/app/text-generate/item'
import { IS_CE_EDITION } from '@/config'
import { useProviderContext } from '@/context/provider-context'
import type { Inputs } from '@/models/debug'
import { fetchFileUploadConfig } from '@/service/common'
type IDebug = {
hasSetAPIKEY: boolean
@@ -64,10 +66,12 @@ const Debug: FC<IDebug> = ({
hasSetContextVar,
datasetConfigs,
externalDataToolsConfig,
visionConfig,
} = useContext(ConfigContext)
const { speech2textDefaultModel } = useProviderContext()
const [chatList, setChatList, getChatList] = useGetState<IChatItem[]>([])
const chatListDomRef = useRef<HTMLDivElement>(null)
const { data: fileUploadConfigResponse } = useSWR({ url: '/files/upload' }, fetchFileUploadConfig)
useEffect(() => {
// scroll to bottom
if (chatListDomRef.current)
@@ -161,17 +165,28 @@ const Debug: FC<IDebug> = ({
logError(t('appDebug.errorMessage.valueOfVarRequired', { key: hasEmptyInput }))
return false
}
// eslint-disable-next-line @typescript-eslint/no-use-before-define
if (completionFiles.find(item => item.transfer_method === TransferMethod.local_file && !item.upload_file_id)) {
notify({ type: 'info', message: t('appDebug.errorMessage.waitForImgUpload') })
return false
}
return !hasEmptyInput
}
const doShowSuggestion = isShowSuggestion && !isResponsing
const [suggestQuestions, setSuggestQuestions] = useState<string[]>([])
const onSend = async (message: string) => {
const onSend = async (message: string, files?: VisionFile[]) => {
if (isResponsing) {
notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') })
return false
}
if (files?.find(item => item.transfer_method === TransferMethod.local_file && !item.upload_file_id)) {
notify({ type: 'info', message: t('appDebug.errorMessage.waitForImgUpload') })
return false
}
const postDatasets = dataSets.map(({ id }) => ({
dataset: {
enabled: true,
@@ -207,6 +222,9 @@ const Debug: FC<IDebug> = ({
completion_params: completionParams as any,
},
dataset_configs: datasetConfigs,
file_upload: {
image: visionConfig,
},
}
if (isAdvancedMode) {
@@ -214,19 +232,32 @@ const Debug: FC<IDebug> = ({
postModelConfig.completion_prompt_config = completionPromptConfig
}
const data = {
const data: Record<string, any> = {
conversation_id: conversationId,
inputs,
query: message,
model_config: postModelConfig,
}
if (visionConfig.enabled && files && files?.length > 0) {
data.files = files.map((item) => {
if (item.transfer_method === TransferMethod.local_file) {
return {
...item,
url: '',
}
}
return item
})
}
// qustion
const questionId = `question-${Date.now()}`
const questionItem = {
id: questionId,
content: message,
isAnswer: false,
message_files: files,
}
const placeholderAnswerId = `answer-placeholder-${Date.now()}`
@@ -347,6 +378,7 @@ const Debug: FC<IDebug> = ({
const [completionRes, setCompletionRes] = useState('')
const [messageId, setMessageId] = useState<string | null>(null)
const [completionFiles, setCompletionFiles] = useState<VisionFile[]>([])
const sendTextCompletion = async () => {
if (isResponsing) {
notify({ type: 'info', message: t('appDebug.errorMessage.waitForResponse') })
@@ -394,6 +426,9 @@ const Debug: FC<IDebug> = ({
completion_params: completionParams as any,
},
dataset_configs: datasetConfigs,
file_upload: {
image: visionConfig,
},
}
if (isAdvancedMode) {
@@ -401,11 +436,23 @@ const Debug: FC<IDebug> = ({
postModelConfig.completion_prompt_config = completionPromptConfig
}
const data = {
const data: Record<string, any> = {
inputs,
model_config: postModelConfig,
}
if (visionConfig.enabled && completionFiles && completionFiles?.length > 0) {
data.files = completionFiles.map((item) => {
if (item.transfer_method === TransferMethod.local_file) {
return {
...item,
url: '',
}
}
return item
})
}
setCompletionRes('')
setMessageId('')
let res: string[] = []
@@ -448,6 +495,11 @@ const Debug: FC<IDebug> = ({
appType={mode as AppType}
onSend={sendTextCompletion}
inputs={inputs}
visionConfig={{
...visionConfig,
image_file_size_limit: fileUploadConfigResponse?.image_file_size_limit,
}}
onVisionFilesChange={setCompletionFiles}
/>
</div>
<div className="flex flex-col grow">
@@ -475,6 +527,10 @@ const Debug: FC<IDebug> = ({
isShowCitation={citationConfig.enabled}
isShowCitationHitInfo
isShowPromptLog
visionConfig={{
...visionConfig,
image_file_size_limit: fileUploadConfigResponse?.image_file_size_limit,
}}
/>
</div>
</div>

View File

@@ -25,19 +25,19 @@ import type {
} from '@/models/debug'
import type { ExternalDataTool } from '@/models/common'
import type { DataSet } from '@/models/datasets'
import type { ModelConfig as BackendModelConfig } from '@/types/app'
import type { ModelConfig as BackendModelConfig, VisionSettings } from '@/types/app'
import ConfigContext from '@/context/debug-configuration'
import ConfigModel from '@/app/components/app/configuration/config-model'
import Config from '@/app/components/app/configuration/config'
import Debug from '@/app/components/app/configuration/debug'
import Confirm from '@/app/components/base/confirm'
import { ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
import { ModelFeature, ProviderEnum } from '@/app/components/header/account-setting/model-page/declarations'
import { ToastContext } from '@/app/components/base/toast'
import { fetchAppDetail, updateAppModelConfig } from '@/service/apps'
import { promptVariablesToUserInputsForm, userInputsFormToPromptVariables } from '@/utils/model-config'
import { fetchDatasets } from '@/service/datasets'
import { useProviderContext } from '@/context/provider-context'
import { AppType, ModelModeType } from '@/types/app'
import { AppType, ModelModeType, Resolution, TransferMethod } from '@/types/app'
import { FlipBackward } from '@/app/components/base/icons/src/vender/line/arrows'
import { PromptMode } from '@/models/debug'
import { DEFAULT_CHAT_PROMPT_CONFIG, DEFAULT_COMPLETION_PROMPT_CONFIG } from '@/config'
@@ -198,6 +198,7 @@ const Configuration: FC = () => {
}
const { textGenerationModelList } = useProviderContext()
const currModel = textGenerationModelList.find(item => item.model_name === modelConfig.model_id)
const hasSetCustomAPIKEY = !!textGenerationModelList?.find(({ model_provider: provider }) => {
if (provider.provider_type === 'system' && provider.quota_type === 'paid')
return true
@@ -271,7 +272,8 @@ const Configuration: FC = () => {
id: modelId,
provider,
mode: modeMode,
}: { id: string; provider: ProviderEnum; mode: ModelModeType }) => {
features,
}: { id: string; provider: ProviderEnum; mode: ModelModeType; features: string[] }) => {
if (isAdvancedMode) {
const appMode = mode
@@ -297,10 +299,31 @@ const Configuration: FC = () => {
})
setModelConfig(newModelConfig)
const supportVision = features && features.includes(ModelFeature.vision)
// eslint-disable-next-line @typescript-eslint/no-use-before-define
setVisionConfig({
// eslint-disable-next-line @typescript-eslint/no-use-before-define
...visionConfig,
enabled: supportVision,
}, true)
}
const isShowVisionConfig = !!currModel?.features.includes(ModelFeature.vision)
const [visionConfig, doSetVisionConfig] = useState({
enabled: false,
number_limits: 2,
detail: Resolution.low,
transfer_methods: [TransferMethod.local_file],
})
const setVisionConfig = (config: VisionSettings, notNoticeFormattingChanged?: boolean) => {
doSetVisionConfig(config)
if (!notNoticeFormattingChanged)
setFormattingChanged(true)
}
useEffect(() => {
fetchAppDetail({ url: '/apps', id: appId }).then(async (res) => {
fetchAppDetail({ url: '/apps', id: appId }).then(async (res: any) => {
setMode(res.mode)
const modelConfig = res.model_config
const promptMode = modelConfig.prompt_type === PromptMode.advanced ? PromptMode.advanced : PromptMode.simple
@@ -362,6 +385,10 @@ const Configuration: FC = () => {
},
completionParams: model.completion_params,
}
if (modelConfig.file_upload)
setVisionConfig(modelConfig.file_upload.image, true)
syncToPublishedConfig(config)
setPublishedConfig(config)
setDatasetConfigs(modelConfig.dataset_configs)
@@ -459,6 +486,9 @@ const Configuration: FC = () => {
completion_params: completionParams as any,
},
dataset_configs: datasetConfigs,
file_upload: {
image: visionConfig,
},
}
if (isAdvancedMode) {
@@ -557,6 +587,9 @@ const Configuration: FC = () => {
datasetConfigs,
setDatasetConfigs,
hasSetContextVar,
isShowVisionConfig,
visionConfig,
setVisionConfig,
}}
>
<>

View File

@@ -14,17 +14,23 @@ import { DEFAULT_VALUE_MAX_LEN } from '@/config'
import Button from '@/app/components/base/button'
import { ChevronDown, ChevronRight } from '@/app/components/base/icons/src/vender/line/arrows'
import Tooltip from '@/app/components/base/tooltip-plus'
import TextGenerationImageUploader from '@/app/components/base/image-uploader/text-generation-image-uploader'
import type { VisionFile, VisionSettings } from '@/types/app'
export type IPromptValuePanelProps = {
appType: AppType
onSend?: () => void
inputs: Inputs
visionConfig: VisionSettings
onVisionFilesChange: (files: VisionFile[]) => void
}
const PromptValuePanel: FC<IPromptValuePanelProps> = ({
appType,
onSend,
inputs,
visionConfig,
onVisionFilesChange,
}) => {
const { t } = useTranslation()
const { modelModeType, modelConfig, setInputs, mode, isAdvancedMode, completionPromptConfig, chatPromptConfig } = useContext(ConfigContext)
@@ -152,6 +158,24 @@ const PromptValuePanel: FC<IPromptValuePanelProps> = ({
<div className='text-xs text-gray-500'>{t('appDebug.inputs.noVar')}</div>
)
}
{
appType === AppType.completion && visionConfig?.enabled && (
<div className="mt-3 xl:flex justify-between">
<div className="mr-1 py-2 shrink-0 w-[120px] text-sm text-gray-900">Image Upload</div>
<div className='grow'>
<TextGenerationImageUploader
settings={visionConfig}
onFilesChange={files => onVisionFilesChange(files.filter(file => file.progress !== -1).map(fileItem => ({
type: 'image',
transfer_method: fileItem.type,
url: fileItem.url,
upload_file_id: fileItem.fileId,
})))}
/>
</div>
</div>
)
}
</>
)
}

Some files were not shown because too many files have changed in this diff Show More