Compare commits

...

15 Commits

Author SHA1 Message Date
takatost
8bf892b306 feat: bump version to 0.3.25 (#1300) 2023-10-10 13:03:49 +08:00
takatost
8480b0197b fix: prompt for baichuan text generation models (#1299) 2023-10-10 13:01:18 +08:00
zxhlyh
df07fb5951 feat: provider add baichuan (#1298) 2023-10-09 23:10:43 -05:00
takatost
4ab4bcc074 feat: support openllm embedding (#1293) 2023-10-09 23:09:35 -05:00
takatost
1d4f019de4 feat: add baichuan llm support (#1294)
Co-authored-by: zxhlyh <jasonapring2015@outlook.com>
2023-10-09 23:09:26 -05:00
takatost
677aacc8e3 feat: upgrade xinference client to 0.5.2 (#1292) 2023-10-09 08:12:58 -05:00
takatost
fda937175d feat: qdrant support in docker compose (#1286) 2023-10-08 12:04:04 -05:00
takatost
024250803a feat: move login_required wrapper outside (#1281) 2023-10-08 05:21:32 -05:00
Charlie.Wei
b711ce33b7 Application share qrcode (#1277)
Co-authored-by: luowei <glpat-EjySCyNjWiLqAED-YmwM>
Co-authored-by: crazywoola <427733928@qq.com>
2023-10-08 09:34:49 +08:00
Rhon Joe
52bec63275 chore(web): strong type (#1259) 2023-10-07 04:42:16 -05:00
qiuqiua
657fa80f4d fix devcontainer issue (#1273) 2023-10-07 10:34:25 +08:00
takatost
373e90ee6d fix: detached model in completion thread (#1269) 2023-10-02 22:27:25 +08:00
takatost
41d4c5b424 fix: count down thread in completion db not commit (#1267) 2023-10-02 10:19:26 +08:00
takatost
86a9dea428 fix: db not commit when streaming output (#1266) 2023-10-01 16:41:52 +08:00
takatost
8606d80c66 fix: request timeout when openai completion (#1265) 2023-10-01 16:00:23 +08:00
95 changed files with 1714 additions and 182 deletions

View File

@@ -1,11 +1,8 @@
FROM mcr.microsoft.com/devcontainers/anaconda:0-3
FROM mcr.microsoft.com/devcontainers/python:3.10
COPY . .
# Copy environment.yml (if found) to a temp location so we update the environment. Also
# copy "noop.txt" so the COPY instruction does not fail if no environment.yml exists.
COPY environment.yml* .devcontainer/noop.txt /tmp/conda-tmp/
RUN if [ -f "/tmp/conda-tmp/environment.yml" ]; then umask 0002 && /opt/conda/bin/conda env update -n base -f /tmp/conda-tmp/environment.yml; fi \
&& rm -rf /tmp/conda-tmp
# [Optional] Uncomment this section to install additional OS packages.
# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
# && apt-get -y install --no-install-recommends <your-package-list-here>
# && apt-get -y install --no-install-recommends <your-package-list-here>

View File

@@ -1,13 +1,12 @@
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
// README at: https://github.com/devcontainers/templates/tree/main/src/anaconda
{
"name": "Anaconda (Python 3)",
"name": "Python 3.10",
"build": {
"context": "..",
"dockerfile": "Dockerfile"
},
"features": {
"ghcr.io/dhoeric/features/act:1": {},
"ghcr.io/devcontainers/features/node:1": {
"nodeGypDependencies": true,
"version": "lts"

1
.gitignore vendored
View File

@@ -144,6 +144,7 @@ docker/volumes/app/storage/*
docker/volumes/db/data/*
docker/volumes/redis/data/*
docker/volumes/weaviate/*
docker/volumes/qdrant/*
sdks/python-client/build
sdks/python-client/dist

View File

@@ -59,9 +59,9 @@ WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
WEAVIATE_GRPC_ENABLED=false
WEAVIATE_BATCH_SIZE=100
# Qdrant configuration, use `path:` prefix for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
QDRANT_URL=path:storage/qdrant
QDRANT_API_KEY=your-qdrant-api-key
# Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
QDRANT_URL=http://localhost:6333
QDRANT_API_KEY=difyai123456
# Mail configuration, support: resend
MAIL_TYPE=

View File

@@ -92,7 +92,7 @@ class Config:
self.CONSOLE_URL = get_env('CONSOLE_URL')
self.API_URL = get_env('API_URL')
self.APP_URL = get_env('APP_URL')
self.CURRENT_VERSION = "0.3.24"
self.CURRENT_VERSION = "0.3.25"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')

View File

@@ -1,5 +1,5 @@
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
import flask_restful
from flask_restful import Resource, fields, marshal_with
from werkzeug.exceptions import Forbidden

View File

@@ -3,10 +3,9 @@ import json
import logging
from datetime import datetime
import flask
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal_with, abort, inputs
from werkzeug.exceptions import Forbidden
from constants.model_template import model_templates, demo_model_templates
@@ -17,11 +16,9 @@ from controllers.console.wraps import account_initialization_required
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
from core.model_providers.model_factory import ModelFactory
from core.model_providers.model_provider_factory import ModelProviderFactory
from core.model_providers.models.entity.model_params import ModelType
from events.app_event import app_was_created, app_was_deleted
from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
app_detail_fields_with_site
from libs.helper import TimestampField
from extensions.ext_database import db
from models.model import App, AppModelConfig, Site
from services.app_model_config_service import AppModelConfigService

View File

@@ -2,8 +2,8 @@
import logging
from flask import request
from core.login.login import login_required
from werkzeug.exceptions import InternalServerError, NotFound
from libs.login import login_required
from werkzeug.exceptions import InternalServerError
import services
from controllers.console import api

View File

@@ -5,7 +5,7 @@ from typing import Generator, Union
import flask_login
from flask import Response, stream_with_context
from core.login.login import login_required
from libs.login import login_required
from werkzeug.exceptions import InternalServerError, NotFound
import services

View File

@@ -2,8 +2,8 @@ from datetime import datetime
import pytz
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal_with
from flask_restful.inputs import int_range
from sqlalchemy import or_, func
from sqlalchemy.orm import joinedload
@@ -15,7 +15,7 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from fields.conversation_fields import conversation_pagination_fields, conversation_detail_fields, \
conversation_message_detail_fields, conversation_with_summary_pagination_fields
from libs.helper import TimestampField, datetime_string, uuid_value
from libs.helper import datetime_string
from extensions.ext_database import db
from models.model import Message, MessageAnnotation, Conversation

View File

@@ -1,5 +1,5 @@
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, reqparse
from controllers.console import api

View File

@@ -16,9 +16,9 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
from core.login.login import login_required
from libs.login import login_required
from fields.conversation_fields import message_detail_fields
from libs.helper import uuid_value, TimestampField
from libs.helper import uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from extensions.ext_database import db
from models.model import MessageAnnotation, Conversation, Message, MessageFeedback

View File

@@ -1,5 +1,4 @@
# -*- coding:utf-8 -*-
import json
from flask import request
from flask_restful import Resource
@@ -9,7 +8,7 @@ from controllers.console import api
from controllers.console.app import _get_app
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.login.login import login_required
from libs.login import login_required
from events.app_event import app_model_config_was_updated
from extensions.ext_database import db
from models.model import AppModelConfig

View File

@@ -1,7 +1,7 @@
# -*- coding:utf-8 -*-
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal_with
from werkzeug.exceptions import NotFound, Forbidden
from controllers.console import api

View File

@@ -5,7 +5,7 @@ from datetime import datetime
import pytz
from flask import jsonify
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, reqparse
from controllers.console import api

View File

@@ -1,16 +1,13 @@
import logging
from datetime import datetime
from typing import Optional
import flask_login
import requests
from flask import request, redirect, current_app, session
from flask import request, redirect, current_app
from flask_login import current_user
from flask_restful import Resource
from werkzeug.exceptions import Forbidden
from core.login.login import login_required
from libs.login import login_required
from libs.oauth_data_source import NotionOAuth
from controllers.console import api
from ..setup import setup_required

View File

@@ -2,10 +2,10 @@ import datetime
import json
from cachetools import TTLCache
from flask import request, current_app
from flask import request
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, marshal_with, fields, reqparse, marshal
from libs.login import login_required
from flask_restful import Resource, marshal_with, reqparse
from werkzeug.exceptions import NotFound
from controllers.console import api
@@ -15,7 +15,6 @@ from core.data_loader.loader.notion import NotionLoader
from core.indexing_runner import IndexingRunner
from extensions.ext_database import db
from fields.data_source_fields import integrate_notion_info_list_fields, integrate_list_fields
from libs.helper import TimestampField
from models.dataset import Document
from models.source import DataSourceBinding
from services.dataset_service import DatasetService, DocumentService

View File

@@ -4,8 +4,8 @@ from flask import request, current_app
from flask_login import current_user
from controllers.console.apikey import api_key_list, api_key_fields
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal, marshal_with
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal, marshal_with
from werkzeug.exceptions import NotFound, Forbidden
import services
from controllers.console import api

View File

@@ -1,11 +1,10 @@
# -*- coding:utf-8 -*-
import random
from datetime import datetime
from typing import List
from flask import request, current_app
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
from sqlalchemy import desc, asc
from werkzeug.exceptions import NotFound, Forbidden
@@ -25,7 +24,6 @@ from core.model_providers.model_factory import ModelFactory
from extensions.ext_redis import redis_client
from fields.document_fields import document_with_segments_fields, document_fields, \
dataset_and_document_fields, document_status_fields
from libs.helper import TimestampField
from extensions.ext_database import db
from models.dataset import DatasetProcessRule, Dataset
from models.dataset import Document, DocumentSegment

View File

@@ -14,13 +14,12 @@ from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
from core.model_providers.model_factory import ModelFactory
from core.login.login import login_required
from libs.login import login_required
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from fields.segment_fields import segment_fields
from models.dataset import DocumentSegment
from libs.helper import TimestampField
from services.dataset_service import DatasetService, DocumentService, SegmentService
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
from tasks.disable_segment_from_index_task import disable_segment_from_index_task

View File

@@ -2,8 +2,8 @@ from cachetools import TTLCache
from flask import request, current_app
import services
from core.login.login import login_required
from flask_restful import Resource, marshal_with, fields
from libs.login import login_required
from flask_restful import Resource, marshal_with
from controllers.console import api
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \

View File

@@ -1,7 +1,7 @@
import logging
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden

View File

@@ -2,8 +2,8 @@
from datetime import datetime
from flask_login import current_user
from core.login.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with, inputs
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal_with, inputs
from sqlalchemy import and_
from werkzeug.exceptions import NotFound, Forbidden, BadRequest
@@ -12,7 +12,6 @@ from controllers.console.explore.wraps import InstalledAppResource
from controllers.console.wraps import account_initialization_required
from extensions.ext_database import db
from fields.installed_app_fields import installed_app_list_fields
from libs.helper import TimestampField
from models.model import App, InstalledApp, RecommendedApp
from services.account_service import TenantService

View File

@@ -1,6 +1,6 @@
# -*- coding:utf-8 -*-
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, fields, marshal_with
from sqlalchemy import and_

View File

@@ -1,5 +1,5 @@
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource
from functools import wraps

View File

@@ -2,7 +2,7 @@ import json
from functools import wraps
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required

View File

@@ -4,7 +4,7 @@ from datetime import datetime
import pytz
from flask import current_app, request
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, reqparse, fields, marshal_with
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError

View File

@@ -1,7 +1,7 @@
# -*- coding:utf-8 -*-
from flask import current_app
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal
import services

View File

@@ -1,5 +1,5 @@
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden

View File

@@ -1,5 +1,5 @@
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, reqparse
from controllers.console import api

View File

@@ -1,6 +1,6 @@
# -*- coding:utf-8 -*-
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, reqparse
from werkzeug.exceptions import Forbidden

View File

@@ -1,7 +1,7 @@
import json
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, abort, reqparse
from werkzeug.exceptions import Forbidden

View File

@@ -3,9 +3,8 @@ import logging
from flask import request
from flask_login import current_user
from core.login.login import login_required
from libs.login import login_required
from flask_restful import Resource, fields, marshal_with, reqparse, marshal, inputs
from flask_restful.inputs import int_range
from controllers.console import api
from controllers.console.admin import admin_required

View File

@@ -4,12 +4,9 @@ import services.dataset_service
from controllers.service_api import api
from controllers.service_api.dataset.error import DatasetNameDuplicateError
from controllers.service_api.wraps import DatasetApiResource
from core.login.login import current_user
from libs.login import current_user
from core.model_providers.models.entity.model_params import ModelType
from extensions.ext_database import db
from fields.dataset_fields import dataset_detail_fields
from models.account import Account, TenantAccountJoin
from models.dataset import Dataset
from services.dataset_service import DatasetService
from services.provider_service import ProviderService

View File

@@ -1,8 +1,6 @@
import datetime
import json
import uuid
from flask import current_app, request
from flask import request
from flask_restful import reqparse, marshal
from sqlalchemy import desc
from werkzeug.exceptions import NotFound
@@ -13,13 +11,11 @@ from controllers.service_api.app.error import ProviderNotInitializeError
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
NoFileUploadedError, TooManyFilesError
from controllers.service_api.wraps import DatasetApiResource
from core.login.login import current_user
from libs.login import current_user
from core.model_providers.error import ProviderTokenNotInitError
from extensions.ext_database import db
from extensions.ext_storage import storage
from fields.document_fields import document_fields, document_status_fields
from models.dataset import Dataset, Document, DocumentSegment
from models.model import UploadFile
from services.dataset_service import DocumentService
from services.file_service import FileService

View File

@@ -7,10 +7,9 @@ from flask_login import user_logged_in
from flask_restful import Resource
from werkzeug.exceptions import NotFound, Unauthorized
from core.login.login import _get_user
from libs.login import _get_user
from extensions.ext_database import db
from models.account import Tenant, TenantAccountJoin, Account
from models.dataset import Dataset
from models.model import ApiToken, App

View File

@@ -94,7 +94,7 @@ class ConversationMessageTask:
if not self.conversation:
self.is_new_conversation = True
self.conversation = Conversation(
app_id=self.app_model_config.app_id,
app_id=self.app.id,
app_model_config_id=self.app_model_config.id,
model_provider=self.provider_name,
model_id=self.model_name,
@@ -115,7 +115,7 @@ class ConversationMessageTask:
db.session.commit()
self.message = Message(
app_id=self.app_model_config.app_id,
app_id=self.app.id,
model_provider=self.provider_name,
model_id=self.model_name,
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,

View File

@@ -51,6 +51,9 @@ class ModelProviderFactory:
elif provider_name == 'chatglm':
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
return ChatGLMProvider
elif provider_name == 'baichuan':
from core.model_providers.providers.baichuan_provider import BaichuanProvider
return BaichuanProvider
elif provider_name == 'azure_openai':
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
return AzureOpenAIProvider

View File

@@ -0,0 +1,22 @@
from core.third_party.langchain.embeddings.openllm_embedding import OpenLLMEmbeddings
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
from core.model_providers.models.embedding.base import BaseEmbedding
class OpenLLMEmbedding(BaseEmbedding):
def __init__(self, model_provider: BaseModelProvider, name: str):
credentials = model_provider.get_model_credentials(
model_name=name,
model_type=self.type
)
client = OpenLLMEmbeddings(
server_url=credentials['server_url']
)
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"OpenLLM embedding: {str(ex)}")

View File

@@ -1,5 +1,4 @@
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
from replicate.exceptions import ModelError, ReplicateError
from core.model_providers.error import LLMBadRequestError
from core.model_providers.providers.base import BaseModelProvider
@@ -21,7 +20,4 @@ class XinferenceEmbedding(BaseEmbedding):
super().__init__(model_provider, client, name)
def handle_exceptions(self, ex: Exception) -> Exception:
if isinstance(ex, (ModelError, ReplicateError)):
return LLMBadRequestError(f"Xinference embedding: {str(ex)}")
else:
return ex
return LLMBadRequestError(f"Xinference embedding: {str(ex)}")

View File

@@ -0,0 +1,67 @@
from typing import List, Optional, Any
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from core.model_providers.error import LLMBadRequestError
from core.model_providers.models.llm.base import BaseLLM
from core.model_providers.models.entity.message import PromptMessage
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
class BaichuanModel(BaseLLM):
model_mode: ModelMode = ModelMode.CHAT
def _init_client(self) -> Any:
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
return BaichuanChatLLM(
streaming=self.streaming,
callbacks=self.callbacks,
**self.credentials,
**provider_model_kwargs
)
def _run(self, messages: List[PromptMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs) -> LLMResult:
"""
run predict by prompt messages and stop words.
:param messages:
:param stop:
:param callbacks:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return self._client.generate([prompts], stop, callbacks)
def prompt_file_name(self, mode: str) -> str:
if mode == 'completion':
return 'baichuan_completion'
else:
return 'baichuan_chat'
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
"""
get num tokens of prompt messages.
:param messages:
:return:
"""
prompts = self._get_prompt_from_messages(messages)
return max(self._client.get_num_tokens_from_messages(prompts), 0)
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
for k, v in provider_model_kwargs.items():
if hasattr(self.client, k):
setattr(self.client, k, v)
def handle_exceptions(self, ex: Exception) -> Exception:
return LLMBadRequestError(f"Baichuan: {str(ex)}")
@property
def support_streaming(self):
return True

View File

@@ -132,8 +132,6 @@ class BaseLLM(BaseProviderModel):
if self.deduct_quota:
self.model_provider.check_quota_over_limit()
db.session.commit()
if not callbacks:
callbacks = self.callbacks
else:

View File

@@ -5,6 +5,7 @@ from typing import List, Optional, Any
import openai
from langchain.callbacks.manager import Callbacks
from langchain.schema import LLMResult
from openai import api_requestor
from core.model_providers.providers.base import BaseModelProvider
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI

View File

@@ -0,0 +1,167 @@
import json
from json import JSONDecodeError
from typing import Type
from langchain.schema import HumanMessage
from core.helper import encrypter
from core.model_providers.models.base import BaseProviderModel
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
from core.model_providers.models.llm.baichuan_model import BaichuanModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
from models.provider import ProviderType
class BaichuanProvider(BaseModelProvider):
@property
def provider_name(self):
"""
Returns the name of a provider.
"""
return 'baichuan'
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
if model_type == ModelType.TEXT_GENERATION:
return [
{
'id': 'baichuan2-53b',
'name': 'Baichuan2-53B',
}
]
else:
return []
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
"""
Returns the model class.
:param model_type:
:return:
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = BaichuanModel
else:
raise NotImplementedError
return model_class
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
"""
get model parameter rules.
:param model_name:
:param model_type:
:return:
"""
return ModelKwargsRules(
temperature=KwargRule[float](min=0, max=1, default=0.3, precision=2),
top_p=KwargRule[float](min=0, max=0.99, default=0.85, precision=2),
presence_penalty=KwargRule[float](enabled=False),
frequency_penalty=KwargRule[float](enabled=False),
max_tokens=KwargRule[int](enabled=False),
)
@classmethod
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
"""
Validates the given credentials.
"""
if 'api_key' not in credentials:
raise CredentialsValidateFailedError('Baichuan api_key must be provided.')
if 'secret_key' not in credentials:
raise CredentialsValidateFailedError('Baichuan secret_key must be provided.')
try:
credential_kwargs = {
'api_key': credentials['api_key'],
'secret_key': credentials['secret_key'],
}
llm = BaichuanChatLLM(
temperature=0,
**credential_kwargs
)
llm([HumanMessage(content='ping')])
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@classmethod
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
credentials['secret_key'] = encrypter.encrypt_token(tenant_id, credentials['secret_key'])
return credentials
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
if self.provider.provider_type == ProviderType.CUSTOM.value:
try:
credentials = json.loads(self.provider.encrypted_config)
except JSONDecodeError:
credentials = {
'api_key': None,
'secret_key': None,
}
if credentials['api_key']:
credentials['api_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['api_key']
)
if obfuscated:
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
if credentials['secret_key']:
credentials['secret_key'] = encrypter.decrypt_token(
self.provider.tenant_id,
credentials['secret_key']
)
if obfuscated:
credentials['secret_key'] = encrypter.obfuscated_token(credentials['secret_key'])
return credentials
else:
return {}
def should_deduct_quota(self):
return True
@classmethod
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
"""
check model credentials valid.
:param model_name:
:param model_type:
:param credentials:
"""
return
@classmethod
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
credentials: dict) -> dict:
"""
encrypt model credentials for save.
:param tenant_id:
:param model_name:
:param model_type:
:param credentials:
:return:
"""
return {}
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
"""
get credentials for llm use.
:param model_name:
:param model_type:
:param obfuscated:
:return:
"""
return self.get_provider_credentials(obfuscated)

View File

@@ -2,11 +2,13 @@ import json
from typing import Type
from core.helper import encrypter
from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
from core.model_providers.models.llm.openllm_model import OpenLLMModel
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
from core.model_providers.models.base import BaseProviderModel
from core.third_party.langchain.embeddings.openllm_embedding import OpenLLMEmbeddings
from core.third_party.langchain.llms.openllm import OpenLLM
from models.provider import ProviderType
@@ -31,6 +33,8 @@ class OpenLLMProvider(BaseModelProvider):
"""
if model_type == ModelType.TEXT_GENERATION:
model_class = OpenLLMModel
elif model_type== ModelType.EMBEDDINGS:
model_class = OpenLLMEmbedding
else:
raise NotImplementedError
@@ -69,14 +73,21 @@ class OpenLLMProvider(BaseModelProvider):
'server_url': credentials['server_url']
}
llm = OpenLLM(
llm_kwargs={
'max_new_tokens': 10
},
**credential_kwargs
)
if model_type == ModelType.TEXT_GENERATION:
llm = OpenLLM(
llm_kwargs={
'max_new_tokens': 10
},
**credential_kwargs
)
llm("ping")
llm("ping")
elif model_type == ModelType.EMBEDDINGS:
embedding = OpenLLMEmbeddings(
**credential_kwargs
)
embedding.embed_query("ping")
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))

View File

@@ -7,10 +7,11 @@
"spark",
"wenxin",
"zhipuai",
"baichuan",
"chatglm",
"replicate",
"huggingface_hub",
"xinference",
"openllm",
"localai"
]
]

View File

@@ -0,0 +1,15 @@
{
"support_provider_types": [
"custom"
],
"system_config": null,
"model_flexibility": "fixed",
"price_config": {
"baichuan2-53b": {
"prompt": "0.01",
"completion": "0.01",
"unit": "0.001",
"currency": "RMB"
}
}
}

View File

@@ -0,0 +1,67 @@
"""Wrapper around OpenLLM embedding models."""
from typing import Any, List, Optional
import requests
from pydantic import BaseModel, Extra
from langchain.embeddings.base import Embeddings
class OpenLLMEmbeddings(BaseModel, Embeddings):
"""Wrapper around OpenLLM embedding models.
"""
client: Any #: :meta private:
server_url: Optional[str] = None
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Call out to OpenLLM's embedding endpoint.
Args:
texts: The list of texts to embed.
Returns:
List of embeddings, one for each text.
"""
embeddings = []
for text in texts:
result = self.invoke_embedding(text=text)
embeddings.append(result)
return [list(map(float, e)) for e in embeddings]
def invoke_embedding(self, text):
params = [
text
]
headers = {"Content-Type": "application/json"}
response = requests.post(
f'{self.server_url}/v1/embeddings',
headers=headers,
json=params
)
if not response.ok:
raise ValueError(f"OpenLLM HTTP {response.status_code} error: {response.text}")
json_response = response.json()
return json_response[0]["embeddings"][0]
def embed_query(self, text: str) -> List[float]:
"""Call out to OpenLLM's embedding endpoint.
Args:
text: The text to embed.
Returns:
Embeddings for the text.
"""
return self.embed_documents([text])[0]

View File

@@ -0,0 +1,315 @@
"""Wrapper around Baichuan APIs."""
from __future__ import annotations
import hashlib
import json
import logging
import time
from typing import (
Any,
Dict,
List,
Optional, Iterator,
)
import requests
from langchain.chat_models.base import BaseChatModel
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
from langchain.schema.messages import AIMessageChunk
from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration
from pydantic import Extra, root_validator, BaseModel
from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
)
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class BaichuanModelAPI(BaseModel):
api_key: str
secret_key: str
base_url: str = "https://api.baichuan-ai.com/v1"
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
def do_request(self, model: str, messages: list[dict], parameters: dict, **kwargs: Any):
stream = 'stream' in kwargs and kwargs['stream']
url = self.base_url + ("/stream/chat" if stream else "/chat")
data = {
"model": model,
"messages": messages,
"parameters": parameters
}
json_data = json.dumps(data)
time_stamp = int(time.time())
signature = self._calculate_md5(self.secret_key + json_data + str(time_stamp))
headers = {
"Content-Type": "application/json",
"Authorization": "Bearer " + self.api_key,
"X-BC-Request-Id": "your requestId",
"X-BC-Timestamp": str(time_stamp),
"X-BC-Signature": signature,
"X-BC-Sign-Algo": "MD5",
}
response = requests.post(url, data=json_data, headers=headers, stream=stream, timeout=(5, 60))
if not response.ok:
raise ValueError(f"HTTP {response.status_code} error: {response.text}")
if not stream:
json_response = response.json()
if json_response['code'] != 0:
raise ValueError(
f"API {json_response['code']}"
f" error: {json_response['msg']}"
)
return json_response
else:
return response
def _calculate_md5(self, input_string):
md5 = hashlib.md5()
md5.update(input_string.encode('utf-8'))
encrypted = md5.hexdigest()
return encrypted
class BaichuanChatLLM(BaseChatModel):
"""Wrapper around Baichuan large language models.
To use, you should pass the api_key as a named parameter to the constructor.
Example:
.. code-block:: python
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
model = BaichuanChatLLM(model="<model_name>", api_key="my-api-key", secret_key="my-secret-key")
"""
@property
def lc_secrets(self) -> Dict[str, str]:
return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"}
@property
def lc_serializable(self) -> bool:
return True
client: Any = None #: :meta private:
model: str = "Baichuan2-53B"
"""Model name to use."""
temperature: float = 0.3
"""A non-negative float that tunes the degree of randomness in generation."""
top_p: float = 0.85
"""Total probability mass of tokens to consider at each step."""
streaming: bool = False
"""Whether to stream the response or return it all at once."""
api_key: Optional[str] = None
secret_key: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
values["api_key"] = get_from_dict_or_env(
values, "api_key", "BAICHUAN_API_KEY"
)
values["secret_key"] = get_from_dict_or_env(
values, "secret_key", "BAICHUAN_SECRET_KEY"
)
values['client'] = BaichuanModelAPI(
api_key=values['api_key'],
secret_key=values['secret_key']
)
return values
@property
def _default_params(self) -> Dict[str, Any]:
"""Get the default parameters for calling OpenAI API."""
return {
"model": self.model,
"parameters": {
"temperature": self.temperature,
"top_p": self.top_p
}
}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Get the identifying parameters."""
return self._default_params
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "baichuan"
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "user", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage:
role = _dict["role"]
if role == "user":
return HumanMessage(content=_dict["content"])
elif role == "assistant":
return AIMessage(content=_dict["content"])
elif role == "system":
return SystemMessage(content=_dict["content"])
else:
return ChatMessage(content=_dict["content"], role=role)
def _create_message_dicts(
self, messages: List[BaseMessage]
) -> List[Dict[str, Any]]:
dict_messages = []
for m in messages:
message = self._convert_message_to_dict(m)
if dict_messages:
previous_message = dict_messages[-1]
if previous_message['role'] == message['role']:
dict_messages[-1]['content'] += f"\n{message['content']}"
else:
dict_messages.append(message)
else:
dict_messages.append(message)
return dict_messages
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
if self.streaming:
generation: Optional[ChatGenerationChunk] = None
llm_output: Optional[Dict] = None
for chunk in self._stream(
messages=messages, stop=stop, run_manager=run_manager, **kwargs
):
if generation is None:
generation = chunk
else:
generation += chunk
if chunk.generation_info is not None \
and 'token_usage' in chunk.generation_info:
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
assert generation is not None
return ChatResult(generations=[generation], llm_output=llm_output)
else:
message_dicts = self._create_message_dicts(messages)
params = self._default_params
params["messages"] = message_dicts
params.update(kwargs)
response = self.client.do_request(**params)
return self._create_chat_result(response)
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
message_dicts = self._create_message_dicts(messages)
params = self._default_params
params["messages"] = message_dicts
params.update(kwargs)
for event in self.client.do_request(stream=True, **params).iter_lines():
if event:
event = event.decode("utf-8")
meta = json.loads(event)
if meta['code'] != 0:
raise ValueError(
f"API {meta['code']}"
f" error: {meta['msg']}"
)
content = meta['data']['messages'][0]['content']
chunk_kwargs = {
'message': AIMessageChunk(content=content),
}
if 'usage' in meta:
token_usage = meta['usage']
overall_token_usage = {
'prompt_tokens': token_usage.get('prompt_tokens', 0),
'completion_tokens': token_usage.get('answer_tokens', 0),
'total_tokens': token_usage.get('total_tokens', 0)
}
chunk_kwargs['generation_info'] = {'token_usage': overall_token_usage}
yield ChatGenerationChunk(**chunk_kwargs)
if run_manager:
run_manager.on_llm_new_token(content)
def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
data = response["data"]
generations = []
for res in data["messages"]:
message = self._convert_dict_to_message(res)
gen = ChatGeneration(
message=message
)
generations.append(gen)
usage = response.get("usage")
token_usage = {
'prompt_tokens': usage.get('prompt_tokens', 0),
'completion_tokens': usage.get('answer_tokens', 0),
'total_tokens': usage.get('total_tokens', 0)
}
llm_output = {"token_usage": token_usage, "model_name": self.model}
return ChatResult(generations=generations, llm_output=llm_output)
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
"""Get the number of tokens in the messages.
Useful for checking if an input will fit in a model's context window.
Args:
messages: The message inputs to tokenize.
Returns:
The sum of the number of tokens across the messages.
"""
return sum([self.get_num_tokens(m.content) for m in messages])
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
token_usage: dict = {}
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
return {"token_usage": token_usage, "model_name": self.model}

View File

@@ -49,6 +49,7 @@ huggingface_hub~=0.16.4
transformers~=4.31.0
stripe~=5.5.0
pandas==1.5.3
xinference==0.4.2
xinference==0.5.2
safetensors==0.3.2
zhipuai==1.0.7
zhipuai==1.0.7
werkzeug==2.3.7

View File

@@ -3,7 +3,7 @@ import logging
import threading
import time
import uuid
from typing import Generator, Union, Any
from typing import Generator, Union, Any, Optional
from flask import current_app, Flask
from redis.client import PubSub
@@ -141,12 +141,12 @@ class CompletionService:
generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
'flask_app': current_app._get_current_object(),
'generate_task_id': generate_task_id,
'app_model': app_model,
'detached_app_model': app_model,
'app_model_config': app_model_config,
'query': query,
'inputs': inputs,
'user': user,
'conversation': conversation,
'detached_user': user,
'detached_conversation': conversation,
'streaming': streaming,
'is_model_config_override': is_model_config_override,
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev'
@@ -155,7 +155,7 @@ class CompletionService:
generate_worker_thread.start()
# wait for 10 minutes to close the thread
cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
return cls.compact_response(pubsub, streaming)
@@ -171,18 +171,22 @@ class CompletionService:
return user
@classmethod
def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, app_model_config: AppModelConfig,
query: str, inputs: dict, user: Union[Account, EndUser],
conversation: Conversation, streaming: bool, is_model_config_override: bool,
def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, app_model_config: AppModelConfig,
query: str, inputs: dict, detached_user: Union[Account, EndUser],
detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
retriever_from: str = 'dev'):
with flask_app.app_context():
# fixed the state of the model object when it detached from the original session
user = db.session.merge(detached_user)
app_model = db.session.merge(detached_app_model)
if detached_conversation:
conversation = db.session.merge(detached_conversation)
else:
conversation = None
try:
if conversation:
# fixed the state of the conversation object when it detached from the original session
conversation = db.session.query(Conversation).filter_by(id=conversation.id).first()
# run
Completion.generate(
task_id=generate_task_id,
app=app_model,
@@ -200,36 +204,38 @@ class CompletionService:
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
ModelCurrentlyNotSupportError) as e:
db.session.rollback()
PubHandler.pub_error(user, generate_task_id, e)
except LLMAuthorizationError:
db.session.rollback()
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
except Exception as e:
db.session.rollback()
logging.exception("Unknown Error in completion")
PubHandler.pub_error(user, generate_task_id, e)
finally:
db.session.commit()
@classmethod
def countdown_and_close(cls, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user, generate_task_id) -> threading.Thread:
# wait for 10 minutes to close the thread
timeout = 600
def close_pubsub():
sleep_iterations = 0
while sleep_iterations < timeout and worker_thread.is_alive():
if sleep_iterations > 0 and sleep_iterations % 10 == 0:
PubHandler.ping(user, generate_task_id)
with flask_app.app_context():
user = db.session.merge(detached_user)
time.sleep(1)
sleep_iterations += 1
sleep_iterations = 0
while sleep_iterations < timeout and worker_thread.is_alive():
if sleep_iterations > 0 and sleep_iterations % 10 == 0:
PubHandler.ping(user, generate_task_id)
if worker_thread.is_alive():
PubHandler.stop(user, generate_task_id)
try:
pubsub.close()
except:
pass
time.sleep(1)
sleep_iterations += 1
if worker_thread.is_alive():
PubHandler.stop(user, generate_task_id)
try:
pubsub.close()
except:
pass
countdown_thread = threading.Thread(target=close_pubsub)
countdown_thread.start()
@@ -279,25 +285,30 @@ class CompletionService:
generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={
'flask_app': current_app._get_current_object(),
'generate_task_id': generate_task_id,
'app_model': app_model,
'detached_app_model': app_model,
'app_model_config': app_model_config,
'message': message,
'detached_message': message,
'pre_prompt': pre_prompt,
'user': user,
'detached_user': user,
'streaming': streaming
})
generate_worker_thread.start()
cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
return cls.compact_response(pubsub, streaming)
@classmethod
def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App,
app_model_config: AppModelConfig, message: Message, pre_prompt: str,
user: Union[Account, EndUser], streaming: bool):
def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
app_model_config: AppModelConfig, detached_message: Message, pre_prompt: str,
detached_user: Union[Account, EndUser], streaming: bool):
with flask_app.app_context():
# fixed the state of the model object when it detached from the original session
user = db.session.merge(detached_user)
app_model = db.session.merge(detached_app_model)
message = db.session.merge(detached_message)
try:
# run
Completion.generate_more_like_this(
@@ -314,15 +325,14 @@ class CompletionService:
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
ModelCurrentlyNotSupportError) as e:
db.session.rollback()
PubHandler.pub_error(user, generate_task_id, e)
except LLMAuthorizationError:
db.session.rollback()
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
except Exception as e:
db.session.rollback()
logging.exception("Unknown Error in completion")
PubHandler.pub_error(user, generate_task_id, e)
finally:
db.session.commit()
@classmethod
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
@@ -386,6 +396,8 @@ class CompletionService:
logging.exception(e)
raise
finally:
db.session.commit()
try:
pubsub.unsubscribe(generate_channel)
except ConnectionError:
@@ -423,6 +435,8 @@ class CompletionService:
logging.exception(e)
raise
finally:
db.session.commit()
try:
pubsub.unsubscribe(generate_channel)
except ConnectionError:

View File

@@ -35,6 +35,10 @@ WENXIN_SECRET_KEY=
# ZhipuAI Credentials
ZHIPUAI_API_KEY=
# Baichuan Credentials
BAICHUAN_API_KEY=
BAICHUAN_SECRET_KEY=
# ChatGLM Credentials
CHATGLM_API_BASE=

View File

@@ -0,0 +1,63 @@
import json
import os
from unittest.mock import patch, MagicMock
from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding
from core.model_providers.models.entity.model_params import ModelType
from core.model_providers.providers.openllm_provider import OpenLLMProvider
from models.provider import Provider, ProviderType, ProviderModel
def get_mock_provider():
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='openllm',
provider_type=ProviderType.CUSTOM.value,
encrypted_config='',
is_valid=True,
)
def get_mock_embedding_model(mocker):
model_name = 'facebook/opt-125m'
server_url = os.environ['OPENLLM_SERVER_URL']
model_provider = OpenLLMProvider(provider=get_mock_provider())
mock_query = MagicMock()
mock_query.filter.return_value.first.return_value = ProviderModel(
provider_name='openllm',
model_name=model_name,
model_type=ModelType.EMBEDDINGS.value,
encrypted_config=json.dumps({
'server_url': server_url
}),
is_valid=True,
)
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
return OpenLLMEmbedding(
model_provider=model_provider,
name=model_name
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_documents(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_documents(['test', 'test1'])
assert isinstance(rst, list)
assert len(rst) == 2
assert len(rst[0]) > 0
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_embed_query(mock_decrypt, mocker):
embedding_model = get_mock_embedding_model(mocker)
rst = embedding_model.client.embed_query('test')
assert isinstance(rst, list)
assert len(rst) > 0

View File

@@ -0,0 +1,81 @@
import json
import os
from unittest.mock import patch
from core.model_providers.models.entity.message import PromptMessage, MessageType
from core.model_providers.models.entity.model_params import ModelKwargs
from core.model_providers.models.llm.baichuan_model import BaichuanModel
from core.model_providers.providers.baichuan_provider import BaichuanProvider
from models.provider import Provider, ProviderType
def get_mock_provider(valid_api_key, valid_secret_key):
return Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name='baichuan',
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps({
'api_key': valid_api_key,
'secret_key': valid_secret_key,
}),
is_valid=True,
)
def get_mock_model(model_name: str, streaming: bool = False):
model_kwargs = ModelKwargs(
temperature=0.01,
)
valid_api_key = os.environ['BAICHUAN_API_KEY']
valid_secret_key = os.environ['BAICHUAN_SECRET_KEY']
model_provider = BaichuanProvider(provider=get_mock_provider(valid_api_key, valid_secret_key))
return BaichuanModel(
model_provider=model_provider,
name=model_name,
model_kwargs=model_kwargs,
streaming=streaming
)
def decrypt_side_effect(tenant_id, encrypted_api_key):
return encrypted_api_key
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_get_num_tokens(mock_decrypt):
model = get_mock_model('baichuan2-53b')
rst = model.get_num_tokens([
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
])
assert rst > 0
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model('baichuan2-53b')
messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages,
)
assert len(rst.content) > 0
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_chat_stream_run(mock_decrypt, mocker):
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
model = get_mock_model('baichuan2-53b', streaming=True)
messages = [
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
]
rst = model.run(
messages
)
assert len(rst.content) > 0

View File

@@ -0,0 +1,97 @@
import pytest
from unittest.mock import patch
import json
from langchain.schema import ChatResult, ChatGeneration, AIMessage
from core.model_providers.providers.baichuan_provider import BaichuanProvider
from core.model_providers.providers.base import CredentialsValidateFailedError
from models.provider import ProviderType, Provider
PROVIDER_NAME = 'baichuan'
MODEL_PROVIDER_CLASS = BaichuanProvider
VALIDATE_CREDENTIAL = {
'api_key': 'valid_key',
'secret_key': 'valid_key',
}
def encrypt_side_effect(tenant_id, encrypt_key):
return f'encrypted_{encrypt_key}'
def decrypt_side_effect(tenant_id, encrypted_key):
return encrypted_key.replace('encrypted_', '')
def test_is_provider_credentials_valid_or_raise_valid(mocker):
mocker.patch('core.third_party.langchain.llms.baichuan_llm.BaichuanChatLLM._generate',
return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))]))
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
def test_is_provider_credentials_valid_or_raise_invalid():
# raise CredentialsValidateFailedError if api_key is not in credentials
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
credential = VALIDATE_CREDENTIAL.copy()
credential['api_key'] = 'invalid_key'
credential['secret_key'] = 'invalid_key'
# raise CredentialsValidateFailedError if api_key is invalid
with pytest.raises(CredentialsValidateFailedError):
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
def test_encrypt_credentials(mock_encrypt):
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}'
assert result['secret_key'] == f'encrypted_{VALIDATE_CREDENTIAL["secret_key"]}'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_custom(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials()
assert result['api_key'] == 'valid_key'
assert result['secret_key'] == 'valid_key'
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
def test_get_credentials_obfuscated(mock_decrypt):
encrypted_credential = VALIDATE_CREDENTIAL.copy()
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key']
provider = Provider(
id='provider_id',
tenant_id='tenant_id',
provider_name=PROVIDER_NAME,
provider_type=ProviderType.CUSTOM.value,
encrypted_config=json.dumps(encrypted_credential),
is_valid=True,
)
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
result = model_provider.get_provider_credentials(obfuscated=True)
middle_token = result['api_key'][6:-2]
secret_key_middle_token = result['secret_key'][6:-2]
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0)
assert len(secret_key_middle_token) == max(len(VALIDATE_CREDENTIAL['secret_key']) - 8, 0)
assert all(char == '*' for char in middle_token)
assert all(char == '*' for char in secret_key_middle_token)

View File

@@ -49,4 +49,18 @@ services:
AUTHORIZATION_ADMINLIST_ENABLED: 'true'
AUTHORIZATION_ADMINLIST_USERS: 'hello@dify.ai'
ports:
- "8080:8080"
- "8080:8080"
# Qdrant vector store.
# uncomment to use qdrant as vector store.
# (if uncommented, you need to comment out the weaviate service above,
# and set VECTOR_STORE to qdrant in the api & worker service.)
# qdrant:
# image: qdrant/qdrant:latest
# restart: always
# volumes:
# - ./volumes/qdrant:/qdrant/storage
# environment:
# QDRANT__API_KEY: 'difyai123456'
# ports:
# - "6333:6333"

View File

@@ -2,7 +2,7 @@ version: '3.1'
services:
# API service
api:
image: langgenius/dify-api:0.3.24
image: langgenius/dify-api:0.3.25
restart: always
environment:
# Startup mode, 'api' starts the API server.
@@ -85,9 +85,9 @@ services:
# The Weaviate API key.
WEAVIATE_API_KEY: WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
# The Qdrant endpoint URL. Only available when VECTOR_STORE is `qdrant`.
QDRANT_URL: 'https://your-qdrant-cluster-url.qdrant.tech/'
QDRANT_URL: http://qdrant:6333
# The Qdrant API key.
QDRANT_API_KEY: 'ak-difyai'
QDRANT_API_KEY: difyai123456
# Mail configuration, support: resend
MAIL_TYPE: ''
# default send from email address, if not specified
@@ -103,15 +103,17 @@ services:
depends_on:
- db
- redis
- weaviate
volumes:
# Mount the storage directory to the container, for storing user files.
- ./volumes/app/storage:/app/api/storage
# uncomment to expose dify-api port to host
# ports:
# - "5001:5001"
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.3.24
image: langgenius/dify-api:0.3.25
restart: always
environment:
# Startup mode, 'worker' starts the Celery worker for processing the queue.
@@ -143,10 +145,16 @@ services:
# The type of storage to use for storing user files. Supported values are `local` and `s3`, Default: `local`
STORAGE_TYPE: local
STORAGE_LOCAL_PATH: storage
# The Vector store configurations.
# The type of vector store to use. Supported values are `weaviate`, `qdrant`.
VECTOR_STORE: weaviate
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
WEAVIATE_ENDPOINT: http://weaviate:8080
# The Weaviate API key.
WEAVIATE_API_KEY: WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
# The Qdrant endpoint URL. Only available when VECTOR_STORE is `qdrant`.
QDRANT_URL: http://qdrant:6333
# The Qdrant API key.
QDRANT_API_KEY: difyai123456
# Mail configuration, support: resend
MAIL_TYPE: ''
# default send from email address, if not specified
@@ -156,14 +164,13 @@ services:
depends_on:
- db
- redis
- weaviate
volumes:
# Mount the storage directory to the container, for storing user files.
- ./volumes/app/storage:/app/api/storage
# Frontend web application.
web:
image: langgenius/dify-web:0.3.24
image: langgenius/dify-web:0.3.25
restart: always
environment:
EDITION: SELF_HOSTED
@@ -177,6 +184,9 @@ services:
APP_API_URL: ''
# The DSN for Sentry error reporting. If not set, Sentry error reporting will be disabled.
SENTRY_DSN: ''
# uncomment to expose dify-web port to host
# ports:
# - "3000:3000"
# The postgres database.
db:
@@ -211,6 +221,9 @@ services:
command: redis-server --requirepass difyai123456
healthcheck:
test: ["CMD", "redis-cli","ping"]
# uncomment to expose redis port to host
# ports:
# - "6379:6379"
# The Weaviate vector store.
weaviate:
@@ -232,6 +245,24 @@ services:
AUTHENTICATION_APIKEY_USERS: 'hello@dify.ai'
AUTHORIZATION_ADMINLIST_ENABLED: 'true'
AUTHORIZATION_ADMINLIST_USERS: 'hello@dify.ai'
# uncomment to expose weaviate port to host
# ports:
# - "8080:8080"
# Qdrant vector store.
# uncomment to use qdrant as vector store.
# (if uncommented, you need to comment out the weaviate service above,
# and set VECTOR_STORE to qdrant in the api & worker service.)
# qdrant:
# image: qdrant/qdrant:latest
# restart: always
# volumes:
# - ./volumes/qdrant:/qdrant/storage
# environment:
# QDRANT__API_KEY: 'difyai123456'
## uncomment to expose qdrant port to host
## ports:
## - "6333:6333"
# The nginx reverse proxy.
# used for reverse proxying the API service and Web service.

View File

@@ -22,6 +22,7 @@ import Tag from '@/app/components/base/tag'
import Switch from '@/app/components/base/switch'
import Divider from '@/app/components/base/divider'
import CopyFeedback from '@/app/components/base/copy-feedback'
import ShareQRCode from '@/app/components/base/qrcode'
import SecretKeyButton from '@/app/components/develop/secret-key/secret-key-button'
import type { AppDetailResponse } from '@/models/app'
import { AppType } from '@/types/app'
@@ -168,6 +169,7 @@ function AppCard({
</div>
</div>
<Divider type="vertical" className="!h-3.5 shrink-0 !mx-0.5" />
{isApp && <ShareQRCode content={isApp ? appUrl : apiUrl} selectorId={randomString(8)} className={'hover:bg-gray-200'} />}
<CopyFeedback
content={isApp ? appUrl : apiUrl}
selectorId={randomString(8)}

Binary file not shown.

After

Width:  |  Height:  |  Size: 4.5 KiB

View File

@@ -0,0 +1,19 @@
<svg width="130" height="24" viewBox="0 0 130 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<path fill-rule="evenodd" clip-rule="evenodd" d="M9.58154 1.7793H6.52779L4.34655 6.20409V17.7335L1.91602 22.2206H7.21333L9.58154 17.7335V1.7793ZM11.5761 1.7793H16.8111V22.2206H11.5761V1.7793ZM23.9166 1.7793H18.6816V6.01712H23.9166V1.7793ZM23.9166 7.38818H18.6816V22.2206H23.9166V7.38818Z" fill="url(#paint0_radial_11622_96091)"/>
<path d="M129.722 6.83203V18H127.482V6.83203H129.722Z" fill="#FF6A34"/>
<path d="M123.196 15.872H118.748L118.012 18H115.66L119.676 6.81604H122.284L126.3 18H123.932L123.196 15.872ZM122.588 14.08L120.972 9.40804L119.356 14.08H122.588Z" fill="#FF6A34"/>
<path d="M110.962 18H108.722L103.65 10.336V18H101.41V6.81598H103.65L108.722 14.496V6.81598H110.962V18Z" fill="#FF6A34"/>
<path d="M97.1258 15.872H92.6778L91.9418 18H89.5898L93.6058 6.81604H96.2138L100.23 18H97.8618L97.1258 15.872ZM96.5178 14.08L94.9018 9.40804L93.2858 14.08H96.5178Z" fill="#FF6A34"/>
<path d="M81.6482 6.83203V13.744C81.6482 14.5014 81.8455 15.0827 82.2402 15.488C82.6349 15.8827 83.1895 16.08 83.9042 16.08C84.6295 16.08 85.1895 15.8827 85.5842 15.488C85.9789 15.0827 86.1762 14.5014 86.1762 13.744V6.83203H88.4322V13.728C88.4322 14.6774 88.2242 15.4827 87.8082 16.144C87.4029 16.7947 86.8535 17.2854 86.1602 17.616C85.4775 17.9467 84.7149 18.112 83.8722 18.112C83.0402 18.112 82.2829 17.9467 81.6002 17.616C80.9282 17.2854 80.3949 16.7947 80.0002 16.144C79.6055 15.4827 79.4082 14.6774 79.4082 13.728V6.83203H81.6482Z" fill="#FF6A34"/>
<path d="M77.557 6.83203V18H75.317V13.248H70.533V18H68.293V6.83203H70.533V11.424H75.317V6.83203H77.557Z" fill="#FF6A34"/>
<path d="M55.7871 12.4C55.7871 11.3013 56.0324 10.32 56.5231 9.45599C57.0244 8.58132 57.7018 7.90399 58.5551 7.42399C59.4191 6.93332 60.3844 6.68799 61.4511 6.68799C62.6991 6.68799 63.7924 7.00799 64.7311 7.64799C65.6698 8.28799 66.3258 9.17332 66.6991 10.304H64.1231C63.8671 9.77065 63.5044 9.37065 63.0351 9.10399C62.5764 8.83732 62.0431 8.70399 61.4351 8.70399C60.7844 8.70399 60.2031 8.85865 59.6911 9.16799C59.1898 9.46665 58.7951 9.89332 58.5071 10.448C58.2298 11.0027 58.0911 11.6533 58.0911 12.4C58.0911 13.136 58.2298 13.7867 58.5071 14.352C58.7951 14.9067 59.1898 15.3387 59.6911 15.648C60.2031 15.9467 60.7844 16.096 61.4351 16.096C62.0431 16.096 62.5764 15.9627 63.0351 15.696C63.5044 15.4187 63.8671 15.0133 64.1231 14.48H66.6991C66.3258 15.6213 65.6698 16.512 64.7311 17.152C63.8031 17.7813 62.7098 18.096 61.4511 18.096C60.3844 18.096 59.4191 17.856 58.5551 17.376C57.7018 16.8853 57.0244 16.208 56.5231 15.344C56.0324 14.48 55.7871 13.4987 55.7871 12.4Z" fill="#FF6A34"/>
<path d="M54.4373 6.83203V18H52.1973V6.83203H54.4373Z" fill="#FF6A34"/>
<path d="M47.913 15.872H43.465L42.729 18H40.377L44.393 6.81598H47.001L51.017 18H48.649L47.913 15.872ZM47.305 14.08L45.689 9.40798L44.073 14.08H47.305Z" fill="#FF6A34"/>
<path d="M37.4395 12.272C38.0688 12.3893 38.5862 12.704 38.9915 13.216C39.3968 13.728 39.5995 14.3146 39.5995 14.976C39.5995 15.5733 39.4502 16.1013 39.1515 16.56C38.8635 17.008 38.4422 17.36 37.8875 17.616C37.3328 17.872 36.6768 18 35.9195 18H31.1035V6.83197H35.7115C36.4688 6.83197 37.1195 6.95464 37.6635 7.19997C38.2182 7.4453 38.6342 7.78664 38.9115 8.22397C39.1995 8.6613 39.3435 9.1573 39.3435 9.71197C39.3435 10.3626 39.1675 10.9066 38.8155 11.344C38.4742 11.7813 38.0155 12.0906 37.4395 12.272ZM33.3435 11.44H35.3915C35.9248 11.44 36.3355 11.3226 36.6235 11.088C36.9115 10.8426 37.0555 10.496 37.0555 10.048C37.0555 9.59997 36.9115 9.2533 36.6235 9.00797C36.3355 8.76264 35.9248 8.63997 35.3915 8.63997H33.3435V11.44ZM35.5995 16.176C36.1435 16.176 36.5648 16.048 36.8635 15.792C37.1728 15.536 37.3275 15.1733 37.3275 14.704C37.3275 14.224 37.1675 13.8506 36.8475 13.584C36.5275 13.3066 36.0955 13.168 35.5515 13.168H33.3435V16.176H35.5995Z" fill="#FF6A34"/>
<defs>
<radialGradient id="paint0_radial_11622_96091" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(6.5 5.5) rotate(45) scale(20.5061 22.0704)">
<stop stop-color="#FEBD3F"/>
<stop offset="0.77608" stop-color="#FF6933"/>
</radialGradient>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 4.0 KiB

View File

@@ -0,0 +1,11 @@
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
<g id="Baichuan">
<path id="Union" fill-rule="evenodd" clip-rule="evenodd" d="M8.58154 1.7793H5.52779L3.34655 6.20409V17.7335L0.916016 22.2206H6.21333L8.58154 17.7335V1.7793ZM10.5761 1.7793H15.8111V22.2206H10.5761V1.7793ZM22.9166 1.7793H17.6816V6.01712H22.9166V1.7793ZM22.9166 7.38818H17.6816V22.2206H22.9166V7.38818Z" fill="url(#paint0_radial_11622_96084)"/>
</g>
<defs>
<radialGradient id="paint0_radial_11622_96084" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(5.5 5.5) rotate(45) scale(20.5061 22.0704)">
<stop stop-color="#FEBD3F"/>
<stop offset="0.77608" stop-color="#FF6933"/>
</radialGradient>
</defs>
</svg>

After

Width:  |  Height:  |  Size: 748 B

View File

@@ -115,6 +115,8 @@ const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTM
ref,
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
Icon.displayName = '<%= fileName %>'
export default Icon
`.trim())

View File

@@ -0,0 +1,5 @@
.wrapper {
display: inline-flex;
background: url(~@/app/components/base/icons/assets/image/llm/baichuan-text-cn.png) center center no-repeat;
background-size: contain;
}

View File

@@ -0,0 +1,15 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import cn from 'classnames'
import s from './BaichuanTextCn.module.css'
const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement>>((
{ className, ...restProps },
ref,
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
Icon.displayName = 'BaichuanTextCn'
export default Icon

View File

@@ -10,4 +10,6 @@ const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTM
ref,
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
Icon.displayName = 'Minimax'
export default Icon

View File

@@ -10,4 +10,6 @@ const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTM
ref,
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
Icon.displayName = 'MinimaxText'
export default Icon

View File

@@ -10,4 +10,6 @@ const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTM
ref,
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
Icon.displayName = 'Tongyi'
export default Icon

View File

@@ -10,4 +10,6 @@ const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTM
ref,
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
Icon.displayName = 'TongyiText'
export default Icon

View File

@@ -10,4 +10,6 @@ const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTM
ref,
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
Icon.displayName = 'TongyiTextCn'
export default Icon

View File

@@ -10,4 +10,6 @@ const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTM
ref,
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
Icon.displayName = 'Wxyy'
export default Icon

View File

@@ -10,4 +10,6 @@ const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTM
ref,
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
Icon.displayName = 'WxyyText'
export default Icon

View File

@@ -10,4 +10,6 @@ const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTM
ref,
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
Icon.displayName = 'WxyyTextCn'
export default Icon

View File

@@ -1,3 +1,4 @@
export { default as BaichuanTextCn } from './BaichuanTextCn'
export { default as MinimaxText } from './MinimaxText'
export { default as Minimax } from './Minimax'
export { default as TongyiTextCn } from './TongyiTextCn'

View File

@@ -0,0 +1,76 @@
{
"icon": {
"type": "element",
"isRootNode": true,
"name": "svg",
"attributes": {
"width": "24",
"height": "24",
"viewBox": "0 0 24 24",
"fill": "none",
"xmlns": "http://www.w3.org/2000/svg"
},
"children": [
{
"type": "element",
"name": "g",
"attributes": {
"id": "Baichuan"
},
"children": [
{
"type": "element",
"name": "path",
"attributes": {
"id": "Union",
"fill-rule": "evenodd",
"clip-rule": "evenodd",
"d": "M8.58154 1.7793H5.52779L3.34655 6.20409V17.7335L0.916016 22.2206H6.21333L8.58154 17.7335V1.7793ZM10.5761 1.7793H15.8111V22.2206H10.5761V1.7793ZM22.9166 1.7793H17.6816V6.01712H22.9166V1.7793ZM22.9166 7.38818H17.6816V22.2206H22.9166V7.38818Z",
"fill": "url(#paint0_radial_11622_96084)"
},
"children": []
}
]
},
{
"type": "element",
"name": "defs",
"attributes": {},
"children": [
{
"type": "element",
"name": "radialGradient",
"attributes": {
"id": "paint0_radial_11622_96084",
"cx": "0",
"cy": "0",
"r": "1",
"gradientUnits": "userSpaceOnUse",
"gradientTransform": "translate(5.5 5.5) rotate(45) scale(20.5061 22.0704)"
},
"children": [
{
"type": "element",
"name": "stop",
"attributes": {
"stop-color": "#FEBD3F"
},
"children": []
},
{
"type": "element",
"name": "stop",
"attributes": {
"offset": "0.77608",
"stop-color": "#FF6933"
},
"children": []
}
]
}
]
}
]
},
"name": "Baichuan"
}

View File

@@ -0,0 +1,16 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import data from './Baichuan.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
props,
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
Icon.displayName = 'Baichuan'
export default Icon

View File

@@ -0,0 +1,156 @@
{
"icon": {
"type": "element",
"isRootNode": true,
"name": "svg",
"attributes": {
"width": "130",
"height": "24",
"viewBox": "0 0 130 24",
"fill": "none",
"xmlns": "http://www.w3.org/2000/svg"
},
"children": [
{
"type": "element",
"name": "path",
"attributes": {
"fill-rule": "evenodd",
"clip-rule": "evenodd",
"d": "M9.58154 1.7793H6.52779L4.34655 6.20409V17.7335L1.91602 22.2206H7.21333L9.58154 17.7335V1.7793ZM11.5761 1.7793H16.8111V22.2206H11.5761V1.7793ZM23.9166 1.7793H18.6816V6.01712H23.9166V1.7793ZM23.9166 7.38818H18.6816V22.2206H23.9166V7.38818Z",
"fill": "url(#paint0_radial_11622_96091)"
},
"children": []
},
{
"type": "element",
"name": "path",
"attributes": {
"d": "M129.722 6.83203V18H127.482V6.83203H129.722Z",
"fill": "#FF6A34"
},
"children": []
},
{
"type": "element",
"name": "path",
"attributes": {
"d": "M123.196 15.872H118.748L118.012 18H115.66L119.676 6.81604H122.284L126.3 18H123.932L123.196 15.872ZM122.588 14.08L120.972 9.40804L119.356 14.08H122.588Z",
"fill": "#FF6A34"
},
"children": []
},
{
"type": "element",
"name": "path",
"attributes": {
"d": "M110.962 18H108.722L103.65 10.336V18H101.41V6.81598H103.65L108.722 14.496V6.81598H110.962V18Z",
"fill": "#FF6A34"
},
"children": []
},
{
"type": "element",
"name": "path",
"attributes": {
"d": "M97.1258 15.872H92.6778L91.9418 18H89.5898L93.6058 6.81604H96.2138L100.23 18H97.8618L97.1258 15.872ZM96.5178 14.08L94.9018 9.40804L93.2858 14.08H96.5178Z",
"fill": "#FF6A34"
},
"children": []
},
{
"type": "element",
"name": "path",
"attributes": {
"d": "M81.6482 6.83203V13.744C81.6482 14.5014 81.8455 15.0827 82.2402 15.488C82.6349 15.8827 83.1895 16.08 83.9042 16.08C84.6295 16.08 85.1895 15.8827 85.5842 15.488C85.9789 15.0827 86.1762 14.5014 86.1762 13.744V6.83203H88.4322V13.728C88.4322 14.6774 88.2242 15.4827 87.8082 16.144C87.4029 16.7947 86.8535 17.2854 86.1602 17.616C85.4775 17.9467 84.7149 18.112 83.8722 18.112C83.0402 18.112 82.2829 17.9467 81.6002 17.616C80.9282 17.2854 80.3949 16.7947 80.0002 16.144C79.6055 15.4827 79.4082 14.6774 79.4082 13.728V6.83203H81.6482Z",
"fill": "#FF6A34"
},
"children": []
},
{
"type": "element",
"name": "path",
"attributes": {
"d": "M77.557 6.83203V18H75.317V13.248H70.533V18H68.293V6.83203H70.533V11.424H75.317V6.83203H77.557Z",
"fill": "#FF6A34"
},
"children": []
},
{
"type": "element",
"name": "path",
"attributes": {
"d": "M55.7871 12.4C55.7871 11.3013 56.0324 10.32 56.5231 9.45599C57.0244 8.58132 57.7018 7.90399 58.5551 7.42399C59.4191 6.93332 60.3844 6.68799 61.4511 6.68799C62.6991 6.68799 63.7924 7.00799 64.7311 7.64799C65.6698 8.28799 66.3258 9.17332 66.6991 10.304H64.1231C63.8671 9.77065 63.5044 9.37065 63.0351 9.10399C62.5764 8.83732 62.0431 8.70399 61.4351 8.70399C60.7844 8.70399 60.2031 8.85865 59.6911 9.16799C59.1898 9.46665 58.7951 9.89332 58.5071 10.448C58.2298 11.0027 58.0911 11.6533 58.0911 12.4C58.0911 13.136 58.2298 13.7867 58.5071 14.352C58.7951 14.9067 59.1898 15.3387 59.6911 15.648C60.2031 15.9467 60.7844 16.096 61.4351 16.096C62.0431 16.096 62.5764 15.9627 63.0351 15.696C63.5044 15.4187 63.8671 15.0133 64.1231 14.48H66.6991C66.3258 15.6213 65.6698 16.512 64.7311 17.152C63.8031 17.7813 62.7098 18.096 61.4511 18.096C60.3844 18.096 59.4191 17.856 58.5551 17.376C57.7018 16.8853 57.0244 16.208 56.5231 15.344C56.0324 14.48 55.7871 13.4987 55.7871 12.4Z",
"fill": "#FF6A34"
},
"children": []
},
{
"type": "element",
"name": "path",
"attributes": {
"d": "M54.4373 6.83203V18H52.1973V6.83203H54.4373Z",
"fill": "#FF6A34"
},
"children": []
},
{
"type": "element",
"name": "path",
"attributes": {
"d": "M47.913 15.872H43.465L42.729 18H40.377L44.393 6.81598H47.001L51.017 18H48.649L47.913 15.872ZM47.305 14.08L45.689 9.40798L44.073 14.08H47.305Z",
"fill": "#FF6A34"
},
"children": []
},
{
"type": "element",
"name": "path",
"attributes": {
"d": "M37.4395 12.272C38.0688 12.3893 38.5862 12.704 38.9915 13.216C39.3968 13.728 39.5995 14.3146 39.5995 14.976C39.5995 15.5733 39.4502 16.1013 39.1515 16.56C38.8635 17.008 38.4422 17.36 37.8875 17.616C37.3328 17.872 36.6768 18 35.9195 18H31.1035V6.83197H35.7115C36.4688 6.83197 37.1195 6.95464 37.6635 7.19997C38.2182 7.4453 38.6342 7.78664 38.9115 8.22397C39.1995 8.6613 39.3435 9.1573 39.3435 9.71197C39.3435 10.3626 39.1675 10.9066 38.8155 11.344C38.4742 11.7813 38.0155 12.0906 37.4395 12.272ZM33.3435 11.44H35.3915C35.9248 11.44 36.3355 11.3226 36.6235 11.088C36.9115 10.8426 37.0555 10.496 37.0555 10.048C37.0555 9.59997 36.9115 9.2533 36.6235 9.00797C36.3355 8.76264 35.9248 8.63997 35.3915 8.63997H33.3435V11.44ZM35.5995 16.176C36.1435 16.176 36.5648 16.048 36.8635 15.792C37.1728 15.536 37.3275 15.1733 37.3275 14.704C37.3275 14.224 37.1675 13.8506 36.8475 13.584C36.5275 13.3066 36.0955 13.168 35.5515 13.168H33.3435V16.176H35.5995Z",
"fill": "#FF6A34"
},
"children": []
},
{
"type": "element",
"name": "defs",
"attributes": {},
"children": [
{
"type": "element",
"name": "radialGradient",
"attributes": {
"id": "paint0_radial_11622_96091",
"cx": "0",
"cy": "0",
"r": "1",
"gradientUnits": "userSpaceOnUse",
"gradientTransform": "translate(6.5 5.5) rotate(45) scale(20.5061 22.0704)"
},
"children": [
{
"type": "element",
"name": "stop",
"attributes": {
"stop-color": "#FEBD3F"
},
"children": []
},
{
"type": "element",
"name": "stop",
"attributes": {
"offset": "0.77608",
"stop-color": "#FF6933"
},
"children": []
}
]
}
]
}
]
},
"name": "BaichuanText"
}

View File

@@ -0,0 +1,16 @@
// GENERATE BY script
// DON NOT EDIT IT MANUALLY
import * as React from 'react'
import data from './BaichuanText.json'
import IconBase from '@/app/components/base/icons/IconBase'
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
props,
ref,
) => <IconBase {...props} ref={ref} data={data as IconData} />)
Icon.displayName = 'BaichuanText'
export default Icon

View File

@@ -4,6 +4,8 @@ export { default as AzureOpenaiServiceText } from './AzureOpenaiServiceText'
export { default as AzureOpenaiService } from './AzureOpenaiService'
export { default as AzureaiText } from './AzureaiText'
export { default as Azureai } from './Azureai'
export { default as BaichuanText } from './BaichuanText'
export { default as Baichuan } from './Baichuan'
export { default as ChatglmText } from './ChatglmText'
export { default as Chatglm } from './Chatglm'
export { default as Gpt3 } from './Gpt3'

View File

@@ -0,0 +1,61 @@
'use client'
import React, { useState } from 'react'
import { useTranslation } from 'react-i18next'
import { debounce } from 'lodash-es'
import QRCode from 'qrcode.react'
import Tooltip from '../tooltip'
import QrcodeStyle from './style.module.css'
type Props = {
content: string
selectorId: string
className?: string
}
const prefixEmbedded = 'appOverview.overview.appInfo.qrcode.title'
const ShareQRCode = ({ content, selectorId, className }: Props) => {
const { t } = useTranslation()
const [isShow, setisShow] = useState<boolean>(false)
const onClickShow = debounce(() => {
setisShow(true)
}, 100)
const downloadQR = () => {
const canvas = document.getElementsByTagName('canvas')[0]
const link = document.createElement('a')
link.download = 'qrcode.png'
link.href = canvas.toDataURL()
link.click()
}
const onMouseLeave = debounce(() => {
setisShow(false)
}, 500)
return (
<Tooltip
selector={`common-qrcode-show-${selectorId}`}
content={t(`${prefixEmbedded}`) || ''}
>
<div
className={`w-8 h-8 cursor-pointer rounded-lg ${className ?? ''}`}
onMouseLeave={onMouseLeave}
onClick={onClickShow}
>
<div className={`w-full h-full ${QrcodeStyle.QrcodeIcon} ${isShow ? QrcodeStyle.show : ''}`} />
{isShow && <div className={QrcodeStyle.qrcodeform}>
<QRCode size={160} value={content} className={QrcodeStyle.qrcodeimage}/>
<div className={QrcodeStyle.text}>
<div className={`text-gray-500 ${QrcodeStyle.scan}`}>{t('appOverview.overview.appInfo.qrcode.scan')}</div>
<div className={`text-gray-500 ${QrcodeStyle.scan}`}>·</div>
<div className={QrcodeStyle.download} onClick={downloadQR}>{t('appOverview.overview.appInfo.qrcode.download')}</div>
</div>
</div>
}
</div>
</Tooltip>
)
}
export default ShareQRCode

View File

@@ -0,0 +1,61 @@
.QrcodeIcon {
background-image: url(~@/app/components/develop/secret-key/assets/qrcode.svg);
background-position: center;
background-repeat: no-repeat;
}
.QrcodeIcon:hover {
background-image: url(~@/app/components/develop/secret-key/assets/qrcode-hover.svg);
background-position: center;
background-repeat: no-repeat;
}
.QrcodeIcon.show {
background-image: url(~@/app/components/develop/secret-key/assets/qrcode-hover.svg);
background-position: center;
background-repeat: no-repeat;
}
.qrcodeimage {
position: relative;
object-fit: cover;
}
.scan {
margin: 0;
line-height: 1rem;
font-size: 0.75rem;
}
.download {
position: relative;
color: #155eef;
font-size: 0.75rem;
line-height: 1rem;
}
.text {
align-self: stretch;
display: flex;
flex-direction: row;
align-items: center;
justify-content: center;
gap: 4px;
}
.qrcodeform {
border: 0.5px solid #eaecf0;
display: flex;
flex-direction: column;
margin: 0 !important;
margin-top: 4px !important;
margin-left: -75px !important;
position: absolute;
border-radius: 8px;
background-color: #fff;
box-shadow: 0 12px 16px -4px rgba(16, 24, 40, 0.08),
0 4px 6px -2px rgba(16, 24, 40, 0.03);
overflow: hidden;
align-items: center;
justify-content: center;
padding: 12px;
gap: 8px;
z-index: 3;
font-family: "PingFang SC", serif;
}

View File

@@ -127,7 +127,7 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => {
{(step === 2 && (!datasetId || (datasetId && !!detail))) && <StepTwo
hasSetAPIKEY={!!embeddingsDefaultModel}
onSetting={showSetAPIKey}
indexingType={detail?.indexing_technique || ''}
indexingType={detail?.indexing_technique}
datasetId={datasetId}
dataSourceType={dataSourceType}
files={fileList.map(file => file.file)}

View File

@@ -1,4 +1,3 @@
/* eslint-disable no-mixed-operators */
'use client'
import React, { useEffect, useLayoutEffect, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
@@ -11,7 +10,7 @@ import { groupBy } from 'lodash-es'
import PreviewItem, { PreviewType } from './preview-item'
import LanguageSelect from './language-select'
import s from './index.module.css'
import type { CreateDocumentReq, CustomFile, FullDocumentDetail, FileIndexingEstimateResponse as IndexingEstimateResponse, NotionInfo, PreProcessingRule, Rules, createDocumentResponse } from '@/models/datasets'
import type { CreateDocumentReq, CustomFile, FileIndexingEstimateResponse, FullDocumentDetail, IndexingEstimateParams, IndexingEstimateResponse, NotionInfo, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets'
import {
createDocument,
createFirstDocument,
@@ -33,13 +32,14 @@ import { useDatasetDetailContext } from '@/context/dataset-detail'
import I18n from '@/context/i18n'
import { IS_CE_EDITION } from '@/config'
type ValueOf<T> = T[keyof T]
type StepTwoProps = {
isSetting?: boolean
documentDetail?: FullDocumentDetail
hasSetAPIKEY: boolean
onSetting: () => void
datasetId?: string
indexingType?: string
indexingType?: ValueOf<IndexingType>
dataSourceType: DataSourceType
files: CustomFile[]
notionPages?: NotionPage[]
@@ -89,21 +89,23 @@ const StepTwo = ({
const [rules, setRules] = useState<PreProcessingRule[]>([])
const [defaultConfig, setDefaultConfig] = useState<Rules>()
const hasSetIndexType = !!indexingType
const [indexType, setIndexType] = useState<IndexingType>(
indexingType
|| hasSetAPIKEY
const [indexType, setIndexType] = useState<ValueOf<IndexingType>>(
(indexingType
|| hasSetAPIKEY)
? IndexingType.QUALIFIED
: IndexingType.ECONOMICAL,
)
const [docForm, setDocForm] = useState<DocForm | string>(
datasetId && documentDetail ? documentDetail.doc_form : DocForm.TEXT,
(datasetId && documentDetail) ? documentDetail.doc_form : DocForm.TEXT,
)
const [docLanguage, setDocLanguage] = useState<string>(locale === 'en' ? 'English' : 'Chinese')
const [QATipHide, setQATipHide] = useState(false)
const [previewSwitched, setPreviewSwitched] = useState(false)
const [showPreview, { setTrue: setShowPreview, setFalse: hidePreview }] = useBoolean()
const [customFileIndexingEstimate, setCustomFileIndexingEstimate] = useState<IndexingEstimateResponse | null>(null)
const [automaticFileIndexingEstimate, setAutomaticFileIndexingEstimate] = useState<IndexingEstimateResponse | null>(null)
const [customFileIndexingEstimate, setCustomFileIndexingEstimate] = useState<FileIndexingEstimateResponse | null>(null)
const [automaticFileIndexingEstimate, setAutomaticFileIndexingEstimate] = useState<FileIndexingEstimateResponse | null>(null)
const [estimateTokes, setEstimateTokes] = useState<Pick<IndexingEstimateResponse, 'tokens' | 'total_price'> | null>(null)
const fileIndexingEstimate = (() => {
return segmentationType === SegmentType.AUTO ? automaticFileIndexingEstimate : customFileIndexingEstimate
})()
@@ -153,7 +155,7 @@ const StepTwo = ({
}
const resetRules = () => {
if (defaultConfig) {
setSegmentIdentifier(defaultConfig.segmentation.separator === '\n' ? '\\n' : defaultConfig.segmentation.separator || '\\n')
setSegmentIdentifier((defaultConfig.segmentation.separator === '\n' ? '\\n' : defaultConfig.segmentation.separator) || '\\n')
setMax(defaultConfig.segmentation.max_tokens)
setRules(defaultConfig.pre_processing_rules)
}
@@ -161,12 +163,14 @@ const StepTwo = ({
const fetchFileIndexingEstimate = async (docForm = DocForm.TEXT) => {
// eslint-disable-next-line @typescript-eslint/no-use-before-define
const res = await didFetchFileIndexingEstimate(getFileIndexingEstimateParams(docForm))
if (segmentationType === SegmentType.CUSTOM)
const res = await didFetchFileIndexingEstimate(getFileIndexingEstimateParams(docForm)!)
if (segmentationType === SegmentType.CUSTOM) {
setCustomFileIndexingEstimate(res)
else
}
else {
setAutomaticFileIndexingEstimate(res)
indexType === IndexingType.QUALIFIED && setEstimateTokes({ tokens: res.tokens, total_price: res.total_price })
}
}
const confirmChangeCustomConfig = () => {
@@ -179,8 +183,8 @@ const StepTwo = ({
const getIndexing_technique = () => indexingType || indexType
const getProcessRule = () => {
const processRule: any = {
rules: {}, // api will check this. It will be removed after api refactored.
const processRule: ProcessRule = {
rules: {} as any, // api will check this. It will be removed after api refactored.
mode: segmentationType,
}
if (segmentationType === SegmentType.CUSTOM) {
@@ -220,37 +224,35 @@ const StepTwo = ({
}) as NotionInfo[]
}
const getFileIndexingEstimateParams = (docForm: DocForm) => {
let params
const getFileIndexingEstimateParams = (docForm: DocForm): IndexingEstimateParams | undefined => {
if (dataSourceType === DataSourceType.FILE) {
params = {
return {
info_list: {
data_source_type: dataSourceType,
file_info_list: {
file_ids: files.map(file => file.id),
file_ids: files.map(file => file.id) as string[],
},
},
indexing_technique: getIndexing_technique(),
indexing_technique: getIndexing_technique() as string,
process_rule: getProcessRule(),
doc_form: docForm,
doc_language: docLanguage,
dataset_id: datasetId,
dataset_id: datasetId as string,
}
}
if (dataSourceType === DataSourceType.NOTION) {
params = {
return {
info_list: {
data_source_type: dataSourceType,
notion_info_list: getNotionInfo(),
},
indexing_technique: getIndexing_technique(),
indexing_technique: getIndexing_technique() as string,
process_rule: getProcessRule(),
doc_form: docForm,
doc_language: docLanguage,
dataset_id: datasetId,
dataset_id: datasetId as string,
}
}
return params
}
const getCreationParams = () => {
@@ -291,7 +293,7 @@ const StepTwo = ({
try {
const res = await fetchDefaultProcessRule({ url: '/datasets/process-rule' })
const separator = res.rules.segmentation.separator
setSegmentIdentifier(separator === '\n' ? '\\n' : separator || '\\n')
setSegmentIdentifier((separator === '\n' ? '\\n' : separator) || '\\n')
setMax(res.rules.segmentation.max_tokens)
setRules(res.rules.pre_processing_rules)
setDefaultConfig(res.rules)
@@ -306,7 +308,7 @@ const StepTwo = ({
const rules = documentDetail.dataset_process_rule.rules
const separator = rules.segmentation.separator
const max = rules.segmentation.max_tokens
setSegmentIdentifier(separator === '\n' ? '\\n' : separator || '\\n')
setSegmentIdentifier((separator === '\n' ? '\\n' : separator) || '\\n')
setMax(max)
setRules(rules.pre_processing_rules)
setDefaultConfig(rules)
@@ -330,7 +332,7 @@ const StepTwo = ({
res = await createFirstDocument({
body: params,
})
updateIndexingTypeCache && updateIndexingTypeCache(indexType)
updateIndexingTypeCache && updateIndexingTypeCache(indexType as string)
updateResultCache && updateResultCache(res)
}
else {
@@ -338,7 +340,7 @@ const StepTwo = ({
datasetId,
body: params,
})
updateIndexingTypeCache && updateIndexingTypeCache(indexType)
updateIndexingTypeCache && updateIndexingTypeCache(indexType as string)
updateResultCache && updateResultCache(res)
}
if (mutateDatasetRes)
@@ -549,9 +551,9 @@ const StepTwo = ({
<div className={s.tip}>{t('datasetCreation.stepTwo.qualifiedTip')}</div>
<div className='pb-0.5 text-xs font-medium text-gray-500'>{t('datasetCreation.stepTwo.emstimateCost')}</div>
{
fileIndexingEstimate
estimateTokes
? (
<div className='text-xs font-medium text-gray-800'>{formatNumber(fileIndexingEstimate.tokens)} tokens(<span className='text-yellow-500'>${formatNumber(fileIndexingEstimate.total_price)}</span>)</div>
<div className='text-xs font-medium text-gray-800'>{formatNumber(estimateTokes.tokens)} tokens(<span className='text-yellow-500'>${formatNumber(estimateTokes.total_price)}</span>)</div>
)
: (
<div className={s.calculating}>{t('datasetCreation.stepTwo.calculating')}</div>

View File

@@ -0,0 +1,4 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M4.33333 4.33333H4.34M11.6667 4.33333H11.6733M4.33333 11.6667H4.34M8.66667 8.66667H8.67333M11.6667 11.6667H11.6733M11.3333 14H14V11.3333M9.33333 11V14M14 9.33333H11M10.4 6.66667H12.9333C13.3067 6.66667 13.4934 6.66667 13.636 6.594C13.7614 6.53009 13.8634 6.4281 13.9273 6.30266C14 6.16005 14 5.97337 14 5.6V3.06667C14 2.6933 14 2.50661 13.9273 2.36401C13.8634 2.23856 13.7614 2.13658 13.636 2.07266C13.4934 2 13.3067 2 12.9333 2H10.4C10.0266 2 9.83995 2 9.69734 2.07266C9.5719 2.13658 9.46991 2.23856 9.406 2.36401C9.33333 2.50661 9.33333 2.6933 9.33333 3.06667V5.6C9.33333 5.97337 9.33333 6.16005 9.406 6.30266C9.46991 6.4281 9.5719 6.53009 9.69734 6.594C9.83995 6.66667 10.0266 6.66667 10.4 6.66667ZM3.06667 6.66667H5.6C5.97337 6.66667 6.16005 6.66667 6.30266 6.594C6.4281 6.53009 6.53009 6.4281 6.594 6.30266C6.66667 6.16005 6.66667 5.97337 6.66667 5.6V3.06667C6.66667 2.6933 6.66667 2.50661 6.594 2.36401C6.53009 2.23856 6.4281 2.13658 6.30266 2.07266C6.16005 2 5.97337 2 5.6 2H3.06667C2.6933 2 2.50661 2 2.36401 2.07266C2.23856 2.13658 2.13658 2.23856 2.07266 2.36401C2 2.50661 2 2.6933 2 3.06667V5.6C2 5.97337 2 6.16005 2.07266 6.30266C2.13658 6.4281 2.23856 6.53009 2.36401 6.594C2.50661 6.66667 2.6933 6.66667 3.06667 6.66667ZM3.06667 14H5.6C5.97337 14 6.16005 14 6.30266 13.9273C6.4281 13.8634 6.53009 13.7614 6.594 13.636C6.66667 13.4934 6.66667 13.3067 6.66667 12.9333V10.4C6.66667 10.0266 6.66667 9.83995 6.594 9.69734C6.53009 9.5719 6.4281 9.46991 6.30266 9.406C6.16005 9.33333 5.97337 9.33333 5.6 9.33333H3.06667C2.6933 9.33333 2.50661 9.33333 2.36401 9.406C2.23856 9.46991 2.13658 9.5719 2.07266 9.69734C2 9.83995 2 10.0266 2 10.4V12.9333C2 13.3067 2 13.4934 2.07266 13.636C2.13658 13.7614 2.23856 13.8634 2.36401 13.9273C2.50661 14 2.6933 14 3.06667 14Z" stroke="#1D2939" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

After

Width:  |  Height:  |  Size: 1.9 KiB

View File

@@ -0,0 +1,4 @@
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
<path d="M4.33333 4.33333H4.34M11.6667 4.33333H11.6733M4.33333 11.6667H4.34M8.66667 8.66667H8.67333M11.6667 11.6667H11.6733M11.3333 14H14V11.3333M9.33333 11V14M14 9.33333H11M10.4 6.66667H12.9333C13.3067 6.66667 13.4934 6.66667 13.636 6.594C13.7614 6.53009 13.8634 6.4281 13.9273 6.30266C14 6.16005 14 5.97337 14 5.6V3.06667C14 2.6933 14 2.50661 13.9273 2.36401C13.8634 2.23856 13.7614 2.13658 13.636 2.07266C13.4934 2 13.3067 2 12.9333 2H10.4C10.0266 2 9.83995 2 9.69734 2.07266C9.5719 2.13658 9.46991 2.23856 9.406 2.36401C9.33333 2.50661 9.33333 2.6933 9.33333 3.06667V5.6C9.33333 5.97337 9.33333 6.16005 9.406 6.30266C9.46991 6.4281 9.5719 6.53009 9.69734 6.594C9.83995 6.66667 10.0266 6.66667 10.4 6.66667ZM3.06667 6.66667H5.6C5.97337 6.66667 6.16005 6.66667 6.30266 6.594C6.4281 6.53009 6.53009 6.4281 6.594 6.30266C6.66667 6.16005 6.66667 5.97337 6.66667 5.6V3.06667C6.66667 2.6933 6.66667 2.50661 6.594 2.36401C6.53009 2.23856 6.4281 2.13658 6.30266 2.07266C6.16005 2 5.97337 2 5.6 2H3.06667C2.6933 2 2.50661 2 2.36401 2.07266C2.23856 2.13658 2.13658 2.23856 2.07266 2.36401C2 2.50661 2 2.6933 2 3.06667V5.6C2 5.97337 2 6.16005 2.07266 6.30266C2.13658 6.4281 2.23856 6.53009 2.36401 6.594C2.50661 6.66667 2.6933 6.66667 3.06667 6.66667ZM3.06667 14H5.6C5.97337 14 6.16005 14 6.30266 13.9273C6.4281 13.8634 6.53009 13.7614 6.594 13.636C6.66667 13.4934 6.66667 13.3067 6.66667 12.9333V10.4C6.66667 10.0266 6.66667 9.83995 6.594 9.69734C6.53009 9.5719 6.4281 9.46991 6.30266 9.406C6.16005 9.33333 5.97337 9.33333 5.6 9.33333H3.06667C2.6933 9.33333 2.50661 9.33333 2.36401 9.406C2.23856 9.46991 2.13658 9.5719 2.07266 9.69734C2 9.83995 2 10.0266 2 10.4V12.9333C2 13.3067 2 13.4934 2.07266 13.636C2.13658 13.7614 2.23856 13.8634 2.36401 13.9273C2.50661 14 2.6933 14 3.06667 14Z" stroke="#667085" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
</svg>

After

Width:  |  Height:  |  Size: 1.9 KiB

View File

@@ -0,0 +1,70 @@
import { ProviderEnum } from '../declarations'
import type { ProviderConfig } from '../declarations'
import { BaichuanTextCn } from '@/app/components/base/icons/src/image/llm'
import {
Baichuan,
BaichuanText,
} from '@/app/components/base/icons/src/public/llm'
const config: ProviderConfig = {
selector: {
name: {
'en': 'BAICHUAN AI',
'zh-Hans': '百川智能',
},
icon: <Baichuan className='w-full h-full' />,
},
item: {
key: ProviderEnum.baichuan,
titleIcon: {
'en': <BaichuanText className='w-[124px] h-6' />,
'zh-Hans': <BaichuanTextCn className='w-[100px] h-6' />,
},
},
modal: {
key: ProviderEnum.baichuan,
title: {
'en': 'BAICHUAN AI',
'zh-Hans': '百川智能',
},
icon: <Baichuan className='w-6 h-6' />,
link: {
href: 'https://platform.baichuan-ai.com/console/apikey',
label: {
'en': 'Get your API key from BAICHUAN AI',
'zh-Hans': '从百川智能获取 API Key',
},
},
validateKeys: ['api_key', 'secret_key'],
fields: [
{
type: 'text',
key: 'api_key',
required: true,
label: {
'en': 'API Key',
'zh-Hans': 'API Key',
},
placeholder: {
'en': 'Enter your API key here',
'zh-Hans': '在此输入您的 API Key',
},
},
{
type: 'text',
key: 'secret_key',
required: true,
label: {
'en': 'Secret Key',
'zh-Hans': 'Secret Key',
},
placeholder: {
'en': 'Enter your Secret key here',
'zh-Hans': '在此输入您的 Secret Key',
},
},
],
},
}
export default config

View File

@@ -12,6 +12,7 @@ import xinference from './xinference'
import openllm from './openllm'
import localai from './localai'
import zhipuai from './zhipuai'
import baichuan from './baichuan'
export default {
openai,
@@ -28,4 +29,5 @@ export default {
openllm,
localai,
zhipuai,
baichuan,
}

View File

@@ -56,6 +56,31 @@ const config: ProviderConfig = {
'server_url',
],
fields: [
{
type: 'radio',
key: 'model_type',
required: true,
label: {
'en': 'Model Type',
'zh-Hans': '模型类型',
},
options: [
{
key: 'text-generation',
label: {
'en': 'Text Generation',
'zh-Hans': '文本生成',
},
},
{
key: 'embeddings',
label: {
'en': 'Embeddings',
'zh-Hans': 'Embeddings',
},
},
],
},
{
type: 'text',
key: 'model_name',

View File

@@ -43,6 +43,7 @@ export enum ProviderEnum {
'openllm' = 'openllm',
'localai' = 'localai',
'zhipuai' = 'zhipuai',
'baichuan' = 'baichuan',
}
export type ProviderConfigItem = {

View File

@@ -79,6 +79,7 @@ const ModelPage = () => {
config.replicate,
config.huggingface_hub,
config.zhipuai,
config.baichuan,
config.spark,
config.minimax,
config.tongyi,
@@ -93,6 +94,7 @@ const ModelPage = () => {
modelList = [
config.huggingface_hub,
config.zhipuai,
config.baichuan,
config.spark,
config.minimax,
config.azure_openai,

View File

@@ -61,6 +61,11 @@ const translation = {
copied: 'Copied',
copy: 'Copy',
},
qrcode: {
title: 'QR code to share',
scan: 'Scan Share Application',
download: 'Download QR Code',
},
customize: {
way: 'way',
entry: 'Customize',

View File

@@ -61,6 +61,11 @@ const translation = {
copied: '已复制',
copy: '复制',
},
qrcode: {
title: '二维码分享',
scan: '扫码分享应用',
download: '下载二维码',
},
customize: {
way: '方法',
entry: '定制化',

View File

@@ -183,15 +183,22 @@ export type DocumentListResponse = {
limit: number
}
export type CreateDocumentReq = {
export type DocumentReq = {
original_document_id?: string
indexing_technique?: string
doc_form: 'text_model' | 'qa_model'
doc_language: string
data_source: DataSource
process_rule: ProcessRule
}
export type CreateDocumentReq = DocumentReq & {
data_source: DataSource
}
export type IndexingEstimateParams = DocumentReq & Partial<DataSource> & {
dataset_id: string
}
export type DataSource = {
type: DataSourceType
info_list: {

View File

@@ -1,6 +1,6 @@
{
"name": "dify-web",
"version": "0.3.24",
"version": "0.3.25",
"private": true,
"scripts": {
"dev": "next dev",
@@ -46,6 +46,7 @@
"mermaid": "10.4.0",
"negotiator": "^0.6.3",
"next": "13.3.1",
"qrcode.react": "^3.1.0",
"qs": "^6.11.1",
"react": "^18.2.0",
"react-18-input-autosize": "^3.0.0",

View File

@@ -10,6 +10,7 @@ import type {
FileIndexingEstimateResponse,
HitTestingRecordsResponse,
HitTestingResponse,
IndexingEstimateParams,
IndexingEstimateResponse,
IndexingStatusBatchResponse,
IndexingStatusResponse,
@@ -189,7 +190,7 @@ export const fetchTestingRecords: Fetcher<HitTestingRecordsResponse, { datasetId
return get<HitTestingRecordsResponse>(`/datasets/${datasetId}/queries`, { params })
}
export const fetchFileIndexingEstimate: Fetcher<FileIndexingEstimateResponse, any> = (body: any) => {
export const fetchFileIndexingEstimate: Fetcher<FileIndexingEstimateResponse, IndexingEstimateParams> = (body: IndexingEstimateParams) => {
return post<FileIndexingEstimateResponse>('/datasets/indexing-estimate', { body })
}