mirror of
https://github.com/langgenius/dify.git
synced 2026-01-09 07:44:12 +00:00
Compare commits
15 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8bf892b306 | ||
|
|
8480b0197b | ||
|
|
df07fb5951 | ||
|
|
4ab4bcc074 | ||
|
|
1d4f019de4 | ||
|
|
677aacc8e3 | ||
|
|
fda937175d | ||
|
|
024250803a | ||
|
|
b711ce33b7 | ||
|
|
52bec63275 | ||
|
|
657fa80f4d | ||
|
|
373e90ee6d | ||
|
|
41d4c5b424 | ||
|
|
86a9dea428 | ||
|
|
8606d80c66 |
@@ -1,11 +1,8 @@
|
||||
FROM mcr.microsoft.com/devcontainers/anaconda:0-3
|
||||
FROM mcr.microsoft.com/devcontainers/python:3.10
|
||||
|
||||
COPY . .
|
||||
|
||||
# Copy environment.yml (if found) to a temp location so we update the environment. Also
|
||||
# copy "noop.txt" so the COPY instruction does not fail if no environment.yml exists.
|
||||
COPY environment.yml* .devcontainer/noop.txt /tmp/conda-tmp/
|
||||
RUN if [ -f "/tmp/conda-tmp/environment.yml" ]; then umask 0002 && /opt/conda/bin/conda env update -n base -f /tmp/conda-tmp/environment.yml; fi \
|
||||
&& rm -rf /tmp/conda-tmp
|
||||
|
||||
# [Optional] Uncomment this section to install additional OS packages.
|
||||
# RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||
# && apt-get -y install --no-install-recommends <your-package-list-here>
|
||||
# && apt-get -y install --no-install-recommends <your-package-list-here>
|
||||
@@ -1,13 +1,12 @@
|
||||
// For format details, see https://aka.ms/devcontainer.json. For config options, see the
|
||||
// README at: https://github.com/devcontainers/templates/tree/main/src/anaconda
|
||||
{
|
||||
"name": "Anaconda (Python 3)",
|
||||
"name": "Python 3.10",
|
||||
"build": {
|
||||
"context": "..",
|
||||
"dockerfile": "Dockerfile"
|
||||
},
|
||||
"features": {
|
||||
"ghcr.io/dhoeric/features/act:1": {},
|
||||
"ghcr.io/devcontainers/features/node:1": {
|
||||
"nodeGypDependencies": true,
|
||||
"version": "lts"
|
||||
|
||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -144,6 +144,7 @@ docker/volumes/app/storage/*
|
||||
docker/volumes/db/data/*
|
||||
docker/volumes/redis/data/*
|
||||
docker/volumes/weaviate/*
|
||||
docker/volumes/qdrant/*
|
||||
|
||||
sdks/python-client/build
|
||||
sdks/python-client/dist
|
||||
|
||||
@@ -59,9 +59,9 @@ WEAVIATE_API_KEY=WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
WEAVIATE_GRPC_ENABLED=false
|
||||
WEAVIATE_BATCH_SIZE=100
|
||||
|
||||
# Qdrant configuration, use `path:` prefix for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
|
||||
QDRANT_URL=path:storage/qdrant
|
||||
QDRANT_API_KEY=your-qdrant-api-key
|
||||
# Qdrant configuration, use `http://localhost:6333` for local mode or `https://your-qdrant-cluster-url.qdrant.io` for remote mode
|
||||
QDRANT_URL=http://localhost:6333
|
||||
QDRANT_API_KEY=difyai123456
|
||||
|
||||
# Mail configuration, support: resend
|
||||
MAIL_TYPE=
|
||||
|
||||
@@ -92,7 +92,7 @@ class Config:
|
||||
self.CONSOLE_URL = get_env('CONSOLE_URL')
|
||||
self.API_URL = get_env('API_URL')
|
||||
self.APP_URL = get_env('APP_URL')
|
||||
self.CURRENT_VERSION = "0.3.24"
|
||||
self.CURRENT_VERSION = "0.3.25"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
import flask_restful
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
@@ -3,10 +3,9 @@ import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
import flask
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with, abort, inputs
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal_with, abort, inputs
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from constants.model_template import model_templates, demo_model_templates
|
||||
@@ -17,11 +16,9 @@ from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import ProviderTokenNotInitError, LLMBadRequestError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from events.app_event import app_was_created, app_was_deleted
|
||||
from fields.app_fields import app_pagination_fields, app_detail_fields, template_list_fields, \
|
||||
app_detail_fields_with_site
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.model import App, AppModelConfig, Site
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from core.login.login import login_required
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
from libs.login import login_required
|
||||
from werkzeug.exceptions import InternalServerError
|
||||
|
||||
import services
|
||||
from controllers.console import api
|
||||
|
||||
@@ -5,7 +5,7 @@ from typing import Generator, Union
|
||||
|
||||
import flask_login
|
||||
from flask import Response, stream_with_context
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
|
||||
@@ -2,8 +2,8 @@ from datetime import datetime
|
||||
|
||||
import pytz
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal_with
|
||||
from flask_restful.inputs import int_range
|
||||
from sqlalchemy import or_, func
|
||||
from sqlalchemy.orm import joinedload
|
||||
@@ -15,7 +15,7 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from fields.conversation_fields import conversation_pagination_fields, conversation_detail_fields, \
|
||||
conversation_message_detail_fields, conversation_with_summary_pagination_fields
|
||||
from libs.helper import TimestampField, datetime_string, uuid_value
|
||||
from libs.helper import datetime_string
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message, MessageAnnotation, Conversation
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
|
||||
@@ -16,9 +16,9 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMRateLimitError, LLMBadRequestError, LLMAuthorizationError, LLMAPIConnectionError, \
|
||||
ProviderTokenNotInitError, LLMAPIUnavailableError, QuotaExceededError, ModelCurrentlyNotSupportError
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from fields.conversation_fields import message_detail_fields
|
||||
from libs.helper import uuid_value, TimestampField
|
||||
from libs.helper import uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from extensions.ext_database import db
|
||||
from models.model import MessageAnnotation, Conversation, Message, MessageFeedback
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import json
|
||||
|
||||
from flask import request
|
||||
from flask_restful import Resource
|
||||
@@ -9,7 +8,7 @@ from controllers.console import api
|
||||
from controllers.console.app import _get_app
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from events.app_event import app_model_config_was_updated
|
||||
from extensions.ext_database import db
|
||||
from models.model import AppModelConfig
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal_with
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
|
||||
from controllers.console import api
|
||||
|
||||
@@ -5,7 +5,7 @@ from datetime import datetime
|
||||
import pytz
|
||||
from flask import jsonify
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
|
||||
@@ -1,16 +1,13 @@
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
import flask_login
|
||||
import requests
|
||||
from flask import request, redirect, current_app, session
|
||||
from flask import request, redirect, current_app
|
||||
from flask_login import current_user
|
||||
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from libs.oauth_data_source import NotionOAuth
|
||||
from controllers.console import api
|
||||
from ..setup import setup_required
|
||||
|
||||
@@ -2,10 +2,10 @@ import datetime
|
||||
import json
|
||||
|
||||
from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, marshal_with, fields, reqparse, marshal
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, marshal_with, reqparse
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.console import api
|
||||
@@ -15,7 +15,6 @@ from core.data_loader.loader.notion import NotionLoader
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from extensions.ext_database import db
|
||||
from fields.data_source_fields import integrate_notion_info_list_fields, integrate_list_fields
|
||||
from libs.helper import TimestampField
|
||||
from models.dataset import Document
|
||||
from models.source import DataSourceBinding
|
||||
from services.dataset_service import DatasetService, DocumentService
|
||||
|
||||
@@ -4,8 +4,8 @@ from flask import request, current_app
|
||||
from flask_login import current_user
|
||||
|
||||
from controllers.console.apikey import api_key_list, api_key_fields
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal, marshal_with
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal, marshal_with
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
import services
|
||||
from controllers.console import api
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
import random
|
||||
from datetime import datetime
|
||||
from typing import List
|
||||
|
||||
from flask import request, current_app
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, fields, marshal, marshal_with, reqparse
|
||||
from sqlalchemy import desc, asc
|
||||
from werkzeug.exceptions import NotFound, Forbidden
|
||||
@@ -25,7 +24,6 @@ from core.model_providers.model_factory import ModelFactory
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.document_fields import document_with_segments_fields, document_fields, \
|
||||
dataset_and_document_fields, document_status_fields
|
||||
from libs.helper import TimestampField
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import DatasetProcessRule, Dataset
|
||||
from models.dataset import Document, DocumentSegment
|
||||
|
||||
@@ -14,13 +14,12 @@ from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from core.model_providers.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_providers.model_factory import ModelFactory
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from fields.segment_fields import segment_fields
|
||||
from models.dataset import DocumentSegment
|
||||
|
||||
from libs.helper import TimestampField
|
||||
from services.dataset_service import DatasetService, DocumentService, SegmentService
|
||||
from tasks.enable_segment_to_index_task import enable_segment_to_index_task
|
||||
from tasks.disable_segment_from_index_task import disable_segment_from_index_task
|
||||
|
||||
@@ -2,8 +2,8 @@ from cachetools import TTLCache
|
||||
from flask import request, current_app
|
||||
|
||||
import services
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, marshal_with, fields
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, marshal_with
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.datasets.error import NoFileUploadedError, TooManyFilesError, FileTooLargeError, \
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal
|
||||
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
|
||||
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
from datetime import datetime
|
||||
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with, inputs
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal_with, inputs
|
||||
from sqlalchemy import and_
|
||||
from werkzeug.exceptions import NotFound, Forbidden, BadRequest
|
||||
|
||||
@@ -12,7 +12,6 @@ from controllers.console.explore.wraps import InstalledAppResource
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
from extensions.ext_database import db
|
||||
from fields.installed_app_fields import installed_app_list_fields
|
||||
from libs.helper import TimestampField
|
||||
from models.model import App, InstalledApp, RecommendedApp
|
||||
from services.account_service import TenantService
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, fields, marshal_with
|
||||
from sqlalchemy import and_
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource
|
||||
from functools import wraps
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
from functools import wraps
|
||||
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
|
||||
@@ -4,7 +4,7 @@ from datetime import datetime
|
||||
import pytz
|
||||
from flask import current_app, request
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, fields, marshal_with
|
||||
|
||||
from services.errors.account import CurrentPasswordIncorrectError as ServiceCurrentPasswordIncorrectError
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask import current_app
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal_with, abort, fields, marshal
|
||||
|
||||
import services
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
|
||||
from controllers.console import api
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
# -*- coding:utf-8 -*-
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, abort, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
|
||||
@@ -3,9 +3,8 @@ import logging
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from core.login.login import login_required
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, fields, marshal_with, reqparse, marshal, inputs
|
||||
from flask_restful.inputs import int_range
|
||||
|
||||
from controllers.console import api
|
||||
from controllers.console.admin import admin_required
|
||||
|
||||
@@ -4,12 +4,9 @@ import services.dataset_service
|
||||
from controllers.service_api import api
|
||||
from controllers.service_api.dataset.error import DatasetNameDuplicateError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.login.login import current_user
|
||||
from libs.login import current_user
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from extensions.ext_database import db
|
||||
from fields.dataset_fields import dataset_detail_fields
|
||||
from models.account import Account, TenantAccountJoin
|
||||
from models.dataset import Dataset
|
||||
from services.dataset_service import DatasetService
|
||||
from services.provider_service import ProviderService
|
||||
|
||||
|
||||
@@ -1,8 +1,6 @@
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
|
||||
from flask import current_app, request
|
||||
from flask import request
|
||||
from flask_restful import reqparse, marshal
|
||||
from sqlalchemy import desc
|
||||
from werkzeug.exceptions import NotFound
|
||||
@@ -13,13 +11,11 @@ from controllers.service_api.app.error import ProviderNotInitializeError
|
||||
from controllers.service_api.dataset.error import ArchivedDocumentImmutableError, DocumentIndexingError, \
|
||||
NoFileUploadedError, TooManyFilesError
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.login.login import current_user
|
||||
from libs.login import current_user
|
||||
from core.model_providers.error import ProviderTokenNotInitError
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from fields.document_fields import document_fields, document_status_fields
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from models.model import UploadFile
|
||||
from services.dataset_service import DocumentService
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@@ -7,10 +7,9 @@ from flask_login import user_logged_in
|
||||
from flask_restful import Resource
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from core.login.login import _get_user
|
||||
from libs.login import _get_user
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant, TenantAccountJoin, Account
|
||||
from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
|
||||
|
||||
|
||||
@@ -94,7 +94,7 @@ class ConversationMessageTask:
|
||||
if not self.conversation:
|
||||
self.is_new_conversation = True
|
||||
self.conversation = Conversation(
|
||||
app_id=self.app_model_config.app_id,
|
||||
app_id=self.app.id,
|
||||
app_model_config_id=self.app_model_config.id,
|
||||
model_provider=self.provider_name,
|
||||
model_id=self.model_name,
|
||||
@@ -115,7 +115,7 @@ class ConversationMessageTask:
|
||||
db.session.commit()
|
||||
|
||||
self.message = Message(
|
||||
app_id=self.app_model_config.app_id,
|
||||
app_id=self.app.id,
|
||||
model_provider=self.provider_name,
|
||||
model_id=self.model_name,
|
||||
override_model_configs=json.dumps(override_model_configs) if override_model_configs else None,
|
||||
|
||||
@@ -51,6 +51,9 @@ class ModelProviderFactory:
|
||||
elif provider_name == 'chatglm':
|
||||
from core.model_providers.providers.chatglm_provider import ChatGLMProvider
|
||||
return ChatGLMProvider
|
||||
elif provider_name == 'baichuan':
|
||||
from core.model_providers.providers.baichuan_provider import BaichuanProvider
|
||||
return BaichuanProvider
|
||||
elif provider_name == 'azure_openai':
|
||||
from core.model_providers.providers.azure_openai_provider import AzureOpenAIProvider
|
||||
return AzureOpenAIProvider
|
||||
|
||||
@@ -0,0 +1,22 @@
|
||||
from core.third_party.langchain.embeddings.openllm_embedding import OpenLLMEmbeddings
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.model_providers.models.embedding.base import BaseEmbedding
|
||||
|
||||
|
||||
class OpenLLMEmbedding(BaseEmbedding):
|
||||
def __init__(self, model_provider: BaseModelProvider, name: str):
|
||||
credentials = model_provider.get_model_credentials(
|
||||
model_name=name,
|
||||
model_type=self.type
|
||||
)
|
||||
|
||||
client = OpenLLMEmbeddings(
|
||||
server_url=credentials['server_url']
|
||||
)
|
||||
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"OpenLLM embedding: {str(ex)}")
|
||||
@@ -1,5 +1,4 @@
|
||||
from core.third_party.langchain.embeddings.xinference_embedding import XinferenceEmbedding as XinferenceEmbeddings
|
||||
from replicate.exceptions import ModelError, ReplicateError
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
@@ -21,7 +20,4 @@ class XinferenceEmbedding(BaseEmbedding):
|
||||
super().__init__(model_provider, client, name)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
if isinstance(ex, (ModelError, ReplicateError)):
|
||||
return LLMBadRequestError(f"Xinference embedding: {str(ex)}")
|
||||
else:
|
||||
return ex
|
||||
return LLMBadRequestError(f"Xinference embedding: {str(ex)}")
|
||||
|
||||
67
api/core/model_providers/models/llm/baichuan_model.py
Normal file
67
api/core/model_providers/models/llm/baichuan_model.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from typing import List, Optional, Any
|
||||
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
from core.model_providers.error import LLMBadRequestError
|
||||
from core.model_providers.models.llm.base import BaseLLM
|
||||
from core.model_providers.models.entity.message import PromptMessage
|
||||
from core.model_providers.models.entity.model_params import ModelMode, ModelKwargs
|
||||
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
|
||||
|
||||
|
||||
class BaichuanModel(BaseLLM):
|
||||
model_mode: ModelMode = ModelMode.CHAT
|
||||
|
||||
def _init_client(self) -> Any:
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, self.model_kwargs)
|
||||
return BaichuanChatLLM(
|
||||
streaming=self.streaming,
|
||||
callbacks=self.callbacks,
|
||||
**self.credentials,
|
||||
**provider_model_kwargs
|
||||
)
|
||||
|
||||
def _run(self, messages: List[PromptMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
callbacks: Callbacks = None,
|
||||
**kwargs) -> LLMResult:
|
||||
"""
|
||||
run predict by prompt messages and stop words.
|
||||
|
||||
:param messages:
|
||||
:param stop:
|
||||
:param callbacks:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return self._client.generate([prompts], stop, callbacks)
|
||||
|
||||
def prompt_file_name(self, mode: str) -> str:
|
||||
if mode == 'completion':
|
||||
return 'baichuan_completion'
|
||||
else:
|
||||
return 'baichuan_chat'
|
||||
|
||||
def get_num_tokens(self, messages: List[PromptMessage]) -> int:
|
||||
"""
|
||||
get num tokens of prompt messages.
|
||||
|
||||
:param messages:
|
||||
:return:
|
||||
"""
|
||||
prompts = self._get_prompt_from_messages(messages)
|
||||
return max(self._client.get_num_tokens_from_messages(prompts), 0)
|
||||
|
||||
def _set_model_kwargs(self, model_kwargs: ModelKwargs):
|
||||
provider_model_kwargs = self._to_model_kwargs_input(self.model_rules, model_kwargs)
|
||||
for k, v in provider_model_kwargs.items():
|
||||
if hasattr(self.client, k):
|
||||
setattr(self.client, k, v)
|
||||
|
||||
def handle_exceptions(self, ex: Exception) -> Exception:
|
||||
return LLMBadRequestError(f"Baichuan: {str(ex)}")
|
||||
|
||||
@property
|
||||
def support_streaming(self):
|
||||
return True
|
||||
@@ -132,8 +132,6 @@ class BaseLLM(BaseProviderModel):
|
||||
if self.deduct_quota:
|
||||
self.model_provider.check_quota_over_limit()
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if not callbacks:
|
||||
callbacks = self.callbacks
|
||||
else:
|
||||
|
||||
@@ -5,6 +5,7 @@ from typing import List, Optional, Any
|
||||
import openai
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
from langchain.schema import LLMResult
|
||||
from openai import api_requestor
|
||||
|
||||
from core.model_providers.providers.base import BaseModelProvider
|
||||
from core.third_party.langchain.llms.chat_open_ai import EnhanceChatOpenAI
|
||||
|
||||
167
api/core/model_providers/providers/baichuan_provider.py
Normal file
167
api/core/model_providers/providers/baichuan_provider.py
Normal file
@@ -0,0 +1,167 @@
|
||||
import json
|
||||
from json import JSONDecodeError
|
||||
from typing import Type
|
||||
|
||||
from langchain.schema import HumanMessage
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.model_providers.models.entity.model_params import ModelKwargsRules, KwargRule, ModelType
|
||||
from core.model_providers.models.llm.baichuan_model import BaichuanModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
|
||||
from models.provider import ProviderType
|
||||
|
||||
|
||||
class BaichuanProvider(BaseModelProvider):
|
||||
|
||||
@property
|
||||
def provider_name(self):
|
||||
"""
|
||||
Returns the name of a provider.
|
||||
"""
|
||||
return 'baichuan'
|
||||
|
||||
def _get_fixed_model_list(self, model_type: ModelType) -> list[dict]:
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
return [
|
||||
{
|
||||
'id': 'baichuan2-53b',
|
||||
'name': 'Baichuan2-53B',
|
||||
}
|
||||
]
|
||||
else:
|
||||
return []
|
||||
|
||||
def get_model_class(self, model_type: ModelType) -> Type[BaseProviderModel]:
|
||||
"""
|
||||
Returns the model class.
|
||||
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = BaichuanModel
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
return model_class
|
||||
|
||||
def get_model_parameter_rules(self, model_name: str, model_type: ModelType) -> ModelKwargsRules:
|
||||
"""
|
||||
get model parameter rules.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:return:
|
||||
"""
|
||||
return ModelKwargsRules(
|
||||
temperature=KwargRule[float](min=0, max=1, default=0.3, precision=2),
|
||||
top_p=KwargRule[float](min=0, max=0.99, default=0.85, precision=2),
|
||||
presence_penalty=KwargRule[float](enabled=False),
|
||||
frequency_penalty=KwargRule[float](enabled=False),
|
||||
max_tokens=KwargRule[int](enabled=False),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def is_provider_credentials_valid_or_raise(cls, credentials: dict):
|
||||
"""
|
||||
Validates the given credentials.
|
||||
"""
|
||||
if 'api_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Baichuan api_key must be provided.')
|
||||
|
||||
if 'secret_key' not in credentials:
|
||||
raise CredentialsValidateFailedError('Baichuan secret_key must be provided.')
|
||||
|
||||
try:
|
||||
credential_kwargs = {
|
||||
'api_key': credentials['api_key'],
|
||||
'secret_key': credentials['secret_key'],
|
||||
}
|
||||
|
||||
llm = BaichuanChatLLM(
|
||||
temperature=0,
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm([HumanMessage(content='ping')])
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@classmethod
|
||||
def encrypt_provider_credentials(cls, tenant_id: str, credentials: dict) -> dict:
|
||||
credentials['api_key'] = encrypter.encrypt_token(tenant_id, credentials['api_key'])
|
||||
credentials['secret_key'] = encrypter.encrypt_token(tenant_id, credentials['secret_key'])
|
||||
return credentials
|
||||
|
||||
def get_provider_credentials(self, obfuscated: bool = False) -> dict:
|
||||
if self.provider.provider_type == ProviderType.CUSTOM.value:
|
||||
try:
|
||||
credentials = json.loads(self.provider.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
credentials = {
|
||||
'api_key': None,
|
||||
'secret_key': None,
|
||||
}
|
||||
|
||||
if credentials['api_key']:
|
||||
credentials['api_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['api_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['api_key'] = encrypter.obfuscated_token(credentials['api_key'])
|
||||
|
||||
if credentials['secret_key']:
|
||||
credentials['secret_key'] = encrypter.decrypt_token(
|
||||
self.provider.tenant_id,
|
||||
credentials['secret_key']
|
||||
)
|
||||
|
||||
if obfuscated:
|
||||
credentials['secret_key'] = encrypter.obfuscated_token(credentials['secret_key'])
|
||||
|
||||
return credentials
|
||||
else:
|
||||
return {}
|
||||
|
||||
def should_deduct_quota(self):
|
||||
return True
|
||||
|
||||
@classmethod
|
||||
def is_model_credentials_valid_or_raise(cls, model_name: str, model_type: ModelType, credentials: dict):
|
||||
"""
|
||||
check model credentials valid.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
"""
|
||||
return
|
||||
|
||||
@classmethod
|
||||
def encrypt_model_credentials(cls, tenant_id: str, model_name: str, model_type: ModelType,
|
||||
credentials: dict) -> dict:
|
||||
"""
|
||||
encrypt model credentials for save.
|
||||
|
||||
:param tenant_id:
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param credentials:
|
||||
:return:
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_model_credentials(self, model_name: str, model_type: ModelType, obfuscated: bool = False) -> dict:
|
||||
"""
|
||||
get credentials for llm use.
|
||||
|
||||
:param model_name:
|
||||
:param model_type:
|
||||
:param obfuscated:
|
||||
:return:
|
||||
"""
|
||||
return self.get_provider_credentials(obfuscated)
|
||||
@@ -2,11 +2,13 @@ import json
|
||||
from typing import Type
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding
|
||||
from core.model_providers.models.entity.model_params import KwargRule, ModelKwargsRules, ModelType
|
||||
from core.model_providers.models.llm.openllm_model import OpenLLMModel
|
||||
from core.model_providers.providers.base import BaseModelProvider, CredentialsValidateFailedError
|
||||
|
||||
from core.model_providers.models.base import BaseProviderModel
|
||||
from core.third_party.langchain.embeddings.openllm_embedding import OpenLLMEmbeddings
|
||||
from core.third_party.langchain.llms.openllm import OpenLLM
|
||||
from models.provider import ProviderType
|
||||
|
||||
@@ -31,6 +33,8 @@ class OpenLLMProvider(BaseModelProvider):
|
||||
"""
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
model_class = OpenLLMModel
|
||||
elif model_type== ModelType.EMBEDDINGS:
|
||||
model_class = OpenLLMEmbedding
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -69,14 +73,21 @@ class OpenLLMProvider(BaseModelProvider):
|
||||
'server_url': credentials['server_url']
|
||||
}
|
||||
|
||||
llm = OpenLLM(
|
||||
llm_kwargs={
|
||||
'max_new_tokens': 10
|
||||
},
|
||||
**credential_kwargs
|
||||
)
|
||||
if model_type == ModelType.TEXT_GENERATION:
|
||||
llm = OpenLLM(
|
||||
llm_kwargs={
|
||||
'max_new_tokens': 10
|
||||
},
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
llm("ping")
|
||||
llm("ping")
|
||||
elif model_type == ModelType.EMBEDDINGS:
|
||||
embedding = OpenLLMEmbeddings(
|
||||
**credential_kwargs
|
||||
)
|
||||
|
||||
embedding.embed_query("ping")
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
|
||||
@@ -7,10 +7,11 @@
|
||||
"spark",
|
||||
"wenxin",
|
||||
"zhipuai",
|
||||
"baichuan",
|
||||
"chatglm",
|
||||
"replicate",
|
||||
"huggingface_hub",
|
||||
"xinference",
|
||||
"openllm",
|
||||
"localai"
|
||||
]
|
||||
]
|
||||
|
||||
15
api/core/model_providers/rules/baichuan.json
Normal file
15
api/core/model_providers/rules/baichuan.json
Normal file
@@ -0,0 +1,15 @@
|
||||
{
|
||||
"support_provider_types": [
|
||||
"custom"
|
||||
],
|
||||
"system_config": null,
|
||||
"model_flexibility": "fixed",
|
||||
"price_config": {
|
||||
"baichuan2-53b": {
|
||||
"prompt": "0.01",
|
||||
"completion": "0.01",
|
||||
"unit": "0.001",
|
||||
"currency": "RMB"
|
||||
}
|
||||
}
|
||||
}
|
||||
67
api/core/third_party/langchain/embeddings/openllm_embedding.py
vendored
Normal file
67
api/core/third_party/langchain/embeddings/openllm_embedding.py
vendored
Normal file
@@ -0,0 +1,67 @@
|
||||
"""Wrapper around OpenLLM embedding models."""
|
||||
from typing import Any, List, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, Extra
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
|
||||
|
||||
class OpenLLMEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around OpenLLM embedding models.
|
||||
"""
|
||||
|
||||
client: Any #: :meta private:
|
||||
|
||||
server_url: Optional[str] = None
|
||||
"""Optional server URL that currently runs a LLMServer with 'openllm start'."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Call out to OpenLLM's embedding endpoint.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
embeddings = []
|
||||
for text in texts:
|
||||
result = self.invoke_embedding(text=text)
|
||||
embeddings.append(result)
|
||||
|
||||
return [list(map(float, e)) for e in embeddings]
|
||||
|
||||
def invoke_embedding(self, text):
|
||||
params = [
|
||||
text
|
||||
]
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
response = requests.post(
|
||||
f'{self.server_url}/v1/embeddings',
|
||||
headers=headers,
|
||||
json=params
|
||||
)
|
||||
|
||||
if not response.ok:
|
||||
raise ValueError(f"OpenLLM HTTP {response.status_code} error: {response.text}")
|
||||
|
||||
json_response = response.json()
|
||||
return json_response[0]["embeddings"][0]
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Call out to OpenLLM's embedding endpoint.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
return self.embed_documents([text])[0]
|
||||
315
api/core/third_party/langchain/llms/baichuan_llm.py
vendored
Normal file
315
api/core/third_party/langchain/llms/baichuan_llm.py
vendored
Normal file
@@ -0,0 +1,315 @@
|
||||
"""Wrapper around Baichuan APIs."""
|
||||
from __future__ import annotations
|
||||
|
||||
import hashlib
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional, Iterator,
|
||||
)
|
||||
|
||||
import requests
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema import BaseMessage, ChatMessage, HumanMessage, AIMessage, SystemMessage
|
||||
from langchain.schema.messages import AIMessageChunk
|
||||
from langchain.schema.output import ChatResult, ChatGenerationChunk, ChatGeneration
|
||||
from pydantic import Extra, root_validator, BaseModel
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BaichuanModelAPI(BaseModel):
|
||||
api_key: str
|
||||
secret_key: str
|
||||
|
||||
base_url: str = "https://api.baichuan-ai.com/v1"
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
def do_request(self, model: str, messages: list[dict], parameters: dict, **kwargs: Any):
|
||||
stream = 'stream' in kwargs and kwargs['stream']
|
||||
|
||||
url = self.base_url + ("/stream/chat" if stream else "/chat")
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"messages": messages,
|
||||
"parameters": parameters
|
||||
}
|
||||
|
||||
json_data = json.dumps(data)
|
||||
time_stamp = int(time.time())
|
||||
signature = self._calculate_md5(self.secret_key + json_data + str(time_stamp))
|
||||
|
||||
headers = {
|
||||
"Content-Type": "application/json",
|
||||
"Authorization": "Bearer " + self.api_key,
|
||||
"X-BC-Request-Id": "your requestId",
|
||||
"X-BC-Timestamp": str(time_stamp),
|
||||
"X-BC-Signature": signature,
|
||||
"X-BC-Sign-Algo": "MD5",
|
||||
}
|
||||
|
||||
response = requests.post(url, data=json_data, headers=headers, stream=stream, timeout=(5, 60))
|
||||
|
||||
if not response.ok:
|
||||
raise ValueError(f"HTTP {response.status_code} error: {response.text}")
|
||||
|
||||
if not stream:
|
||||
json_response = response.json()
|
||||
if json_response['code'] != 0:
|
||||
raise ValueError(
|
||||
f"API {json_response['code']}"
|
||||
f" error: {json_response['msg']}"
|
||||
)
|
||||
return json_response
|
||||
else:
|
||||
return response
|
||||
|
||||
def _calculate_md5(self, input_string):
|
||||
md5 = hashlib.md5()
|
||||
md5.update(input_string.encode('utf-8'))
|
||||
encrypted = md5.hexdigest()
|
||||
return encrypted
|
||||
|
||||
|
||||
class BaichuanChatLLM(BaseChatModel):
|
||||
"""Wrapper around Baichuan large language models.
|
||||
To use, you should pass the api_key as a named parameter to the constructor.
|
||||
Example:
|
||||
.. code-block:: python
|
||||
from core.third_party.langchain.llms.baichuan_llm import BaichuanChatLLM
|
||||
model = BaichuanChatLLM(model="<model_name>", api_key="my-api-key", secret_key="my-secret-key")
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"api_key": "API_KEY", "secret_key": "SECRET_KEY"}
|
||||
|
||||
@property
|
||||
def lc_serializable(self) -> bool:
|
||||
return True
|
||||
|
||||
client: Any = None #: :meta private:
|
||||
model: str = "Baichuan2-53B"
|
||||
"""Model name to use."""
|
||||
temperature: float = 0.3
|
||||
"""A non-negative float that tunes the degree of randomness in generation."""
|
||||
top_p: float = 0.85
|
||||
"""Total probability mass of tokens to consider at each step."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the response or return it all at once."""
|
||||
api_key: Optional[str] = None
|
||||
secret_key: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
values["api_key"] = get_from_dict_or_env(
|
||||
values, "api_key", "BAICHUAN_API_KEY"
|
||||
)
|
||||
|
||||
values["secret_key"] = get_from_dict_or_env(
|
||||
values, "secret_key", "BAICHUAN_SECRET_KEY"
|
||||
)
|
||||
|
||||
values['client'] = BaichuanModelAPI(
|
||||
api_key=values['api_key'],
|
||||
secret_key=values['secret_key']
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
"""Get the default parameters for calling OpenAI API."""
|
||||
return {
|
||||
"model": self.model,
|
||||
"parameters": {
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p
|
||||
}
|
||||
}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return self._default_params
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "baichuan"
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> dict:
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_dict
|
||||
|
||||
def _convert_dict_to_message(self, _dict: Dict[str, Any]) -> BaseMessage:
|
||||
role = _dict["role"]
|
||||
if role == "user":
|
||||
return HumanMessage(content=_dict["content"])
|
||||
elif role == "assistant":
|
||||
return AIMessage(content=_dict["content"])
|
||||
elif role == "system":
|
||||
return SystemMessage(content=_dict["content"])
|
||||
else:
|
||||
return ChatMessage(content=_dict["content"], role=role)
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> List[Dict[str, Any]]:
|
||||
dict_messages = []
|
||||
for m in messages:
|
||||
message = self._convert_message_to_dict(m)
|
||||
if dict_messages:
|
||||
previous_message = dict_messages[-1]
|
||||
if previous_message['role'] == message['role']:
|
||||
dict_messages[-1]['content'] += f"\n{message['content']}"
|
||||
else:
|
||||
dict_messages.append(message)
|
||||
else:
|
||||
dict_messages.append(message)
|
||||
|
||||
return dict_messages
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
if self.streaming:
|
||||
generation: Optional[ChatGenerationChunk] = None
|
||||
llm_output: Optional[Dict] = None
|
||||
for chunk in self._stream(
|
||||
messages=messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
):
|
||||
if generation is None:
|
||||
generation = chunk
|
||||
else:
|
||||
generation += chunk
|
||||
|
||||
if chunk.generation_info is not None \
|
||||
and 'token_usage' in chunk.generation_info:
|
||||
llm_output = {"token_usage": chunk.generation_info['token_usage'], "model_name": self.model}
|
||||
|
||||
assert generation is not None
|
||||
return ChatResult(generations=[generation], llm_output=llm_output)
|
||||
else:
|
||||
message_dicts = self._create_message_dicts(messages)
|
||||
params = self._default_params
|
||||
params["messages"] = message_dicts
|
||||
params.update(kwargs)
|
||||
response = self.client.do_request(**params)
|
||||
return self._create_chat_result(response)
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
message_dicts = self._create_message_dicts(messages)
|
||||
params = self._default_params
|
||||
params["messages"] = message_dicts
|
||||
params.update(kwargs)
|
||||
|
||||
for event in self.client.do_request(stream=True, **params).iter_lines():
|
||||
if event:
|
||||
event = event.decode("utf-8")
|
||||
|
||||
meta = json.loads(event)
|
||||
|
||||
if meta['code'] != 0:
|
||||
raise ValueError(
|
||||
f"API {meta['code']}"
|
||||
f" error: {meta['msg']}"
|
||||
)
|
||||
|
||||
content = meta['data']['messages'][0]['content']
|
||||
|
||||
chunk_kwargs = {
|
||||
'message': AIMessageChunk(content=content),
|
||||
}
|
||||
|
||||
if 'usage' in meta:
|
||||
token_usage = meta['usage']
|
||||
overall_token_usage = {
|
||||
'prompt_tokens': token_usage.get('prompt_tokens', 0),
|
||||
'completion_tokens': token_usage.get('answer_tokens', 0),
|
||||
'total_tokens': token_usage.get('total_tokens', 0)
|
||||
}
|
||||
chunk_kwargs['generation_info'] = {'token_usage': overall_token_usage}
|
||||
|
||||
yield ChatGenerationChunk(**chunk_kwargs)
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(content)
|
||||
|
||||
def _create_chat_result(self, response: Dict[str, Any]) -> ChatResult:
|
||||
data = response["data"]
|
||||
generations = []
|
||||
for res in data["messages"]:
|
||||
message = self._convert_dict_to_message(res)
|
||||
gen = ChatGeneration(
|
||||
message=message
|
||||
)
|
||||
generations.append(gen)
|
||||
usage = response.get("usage")
|
||||
token_usage = {
|
||||
'prompt_tokens': usage.get('prompt_tokens', 0),
|
||||
'completion_tokens': usage.get('answer_tokens', 0),
|
||||
'total_tokens': usage.get('total_tokens', 0)
|
||||
}
|
||||
llm_output = {"token_usage": token_usage, "model_name": self.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
def get_num_tokens_from_messages(self, messages: List[BaseMessage]) -> int:
|
||||
"""Get the number of tokens in the messages.
|
||||
|
||||
Useful for checking if an input will fit in a model's context window.
|
||||
|
||||
Args:
|
||||
messages: The message inputs to tokenize.
|
||||
|
||||
Returns:
|
||||
The sum of the number of tokens across the messages.
|
||||
"""
|
||||
return sum([self.get_num_tokens(m.content) for m in messages])
|
||||
|
||||
def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
|
||||
token_usage: dict = {}
|
||||
for output in llm_outputs:
|
||||
if output is None:
|
||||
# Happens in streaming
|
||||
continue
|
||||
token_usage = output["token_usage"]
|
||||
|
||||
return {"token_usage": token_usage, "model_name": self.model}
|
||||
@@ -49,6 +49,7 @@ huggingface_hub~=0.16.4
|
||||
transformers~=4.31.0
|
||||
stripe~=5.5.0
|
||||
pandas==1.5.3
|
||||
xinference==0.4.2
|
||||
xinference==0.5.2
|
||||
safetensors==0.3.2
|
||||
zhipuai==1.0.7
|
||||
zhipuai==1.0.7
|
||||
werkzeug==2.3.7
|
||||
@@ -3,7 +3,7 @@ import logging
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Generator, Union, Any
|
||||
from typing import Generator, Union, Any, Optional
|
||||
|
||||
from flask import current_app, Flask
|
||||
from redis.client import PubSub
|
||||
@@ -141,12 +141,12 @@ class CompletionService:
|
||||
generate_worker_thread = threading.Thread(target=cls.generate_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'generate_task_id': generate_task_id,
|
||||
'app_model': app_model,
|
||||
'detached_app_model': app_model,
|
||||
'app_model_config': app_model_config,
|
||||
'query': query,
|
||||
'inputs': inputs,
|
||||
'user': user,
|
||||
'conversation': conversation,
|
||||
'detached_user': user,
|
||||
'detached_conversation': conversation,
|
||||
'streaming': streaming,
|
||||
'is_model_config_override': is_model_config_override,
|
||||
'retriever_from': args['retriever_from'] if 'retriever_from' in args else 'dev'
|
||||
@@ -155,7 +155,7 @@ class CompletionService:
|
||||
generate_worker_thread.start()
|
||||
|
||||
# wait for 10 minutes to close the thread
|
||||
cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
|
||||
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
|
||||
|
||||
return cls.compact_response(pubsub, streaming)
|
||||
|
||||
@@ -171,18 +171,22 @@ class CompletionService:
|
||||
return user
|
||||
|
||||
@classmethod
|
||||
def generate_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App, app_model_config: AppModelConfig,
|
||||
query: str, inputs: dict, user: Union[Account, EndUser],
|
||||
conversation: Conversation, streaming: bool, is_model_config_override: bool,
|
||||
def generate_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App, app_model_config: AppModelConfig,
|
||||
query: str, inputs: dict, detached_user: Union[Account, EndUser],
|
||||
detached_conversation: Optional[Conversation], streaming: bool, is_model_config_override: bool,
|
||||
retriever_from: str = 'dev'):
|
||||
with flask_app.app_context():
|
||||
# fixed the state of the model object when it detached from the original session
|
||||
user = db.session.merge(detached_user)
|
||||
app_model = db.session.merge(detached_app_model)
|
||||
|
||||
if detached_conversation:
|
||||
conversation = db.session.merge(detached_conversation)
|
||||
else:
|
||||
conversation = None
|
||||
|
||||
try:
|
||||
if conversation:
|
||||
# fixed the state of the conversation object when it detached from the original session
|
||||
conversation = db.session.query(Conversation).filter_by(id=conversation.id).first()
|
||||
|
||||
# run
|
||||
|
||||
Completion.generate(
|
||||
task_id=generate_task_id,
|
||||
app=app_model,
|
||||
@@ -200,36 +204,38 @@ class CompletionService:
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
|
||||
ModelCurrentlyNotSupportError) as e:
|
||||
db.session.rollback()
|
||||
PubHandler.pub_error(user, generate_task_id, e)
|
||||
except LLMAuthorizationError:
|
||||
db.session.rollback()
|
||||
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logging.exception("Unknown Error in completion")
|
||||
PubHandler.pub_error(user, generate_task_id, e)
|
||||
finally:
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def countdown_and_close(cls, worker_thread, pubsub, user, generate_task_id) -> threading.Thread:
|
||||
def countdown_and_close(cls, flask_app: Flask, worker_thread, pubsub, detached_user, generate_task_id) -> threading.Thread:
|
||||
# wait for 10 minutes to close the thread
|
||||
timeout = 600
|
||||
|
||||
def close_pubsub():
|
||||
sleep_iterations = 0
|
||||
while sleep_iterations < timeout and worker_thread.is_alive():
|
||||
if sleep_iterations > 0 and sleep_iterations % 10 == 0:
|
||||
PubHandler.ping(user, generate_task_id)
|
||||
with flask_app.app_context():
|
||||
user = db.session.merge(detached_user)
|
||||
|
||||
time.sleep(1)
|
||||
sleep_iterations += 1
|
||||
sleep_iterations = 0
|
||||
while sleep_iterations < timeout and worker_thread.is_alive():
|
||||
if sleep_iterations > 0 and sleep_iterations % 10 == 0:
|
||||
PubHandler.ping(user, generate_task_id)
|
||||
|
||||
if worker_thread.is_alive():
|
||||
PubHandler.stop(user, generate_task_id)
|
||||
try:
|
||||
pubsub.close()
|
||||
except:
|
||||
pass
|
||||
time.sleep(1)
|
||||
sleep_iterations += 1
|
||||
|
||||
if worker_thread.is_alive():
|
||||
PubHandler.stop(user, generate_task_id)
|
||||
try:
|
||||
pubsub.close()
|
||||
except:
|
||||
pass
|
||||
|
||||
countdown_thread = threading.Thread(target=close_pubsub)
|
||||
countdown_thread.start()
|
||||
@@ -279,25 +285,30 @@ class CompletionService:
|
||||
generate_worker_thread = threading.Thread(target=cls.generate_more_like_this_worker, kwargs={
|
||||
'flask_app': current_app._get_current_object(),
|
||||
'generate_task_id': generate_task_id,
|
||||
'app_model': app_model,
|
||||
'detached_app_model': app_model,
|
||||
'app_model_config': app_model_config,
|
||||
'message': message,
|
||||
'detached_message': message,
|
||||
'pre_prompt': pre_prompt,
|
||||
'user': user,
|
||||
'detached_user': user,
|
||||
'streaming': streaming
|
||||
})
|
||||
|
||||
generate_worker_thread.start()
|
||||
|
||||
cls.countdown_and_close(generate_worker_thread, pubsub, user, generate_task_id)
|
||||
cls.countdown_and_close(current_app._get_current_object(), generate_worker_thread, pubsub, user, generate_task_id)
|
||||
|
||||
return cls.compact_response(pubsub, streaming)
|
||||
|
||||
@classmethod
|
||||
def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, app_model: App,
|
||||
app_model_config: AppModelConfig, message: Message, pre_prompt: str,
|
||||
user: Union[Account, EndUser], streaming: bool):
|
||||
def generate_more_like_this_worker(cls, flask_app: Flask, generate_task_id: str, detached_app_model: App,
|
||||
app_model_config: AppModelConfig, detached_message: Message, pre_prompt: str,
|
||||
detached_user: Union[Account, EndUser], streaming: bool):
|
||||
with flask_app.app_context():
|
||||
# fixed the state of the model object when it detached from the original session
|
||||
user = db.session.merge(detached_user)
|
||||
app_model = db.session.merge(detached_app_model)
|
||||
message = db.session.merge(detached_message)
|
||||
|
||||
try:
|
||||
# run
|
||||
Completion.generate_more_like_this(
|
||||
@@ -314,15 +325,14 @@ class CompletionService:
|
||||
except (LLMBadRequestError, LLMAPIConnectionError, LLMAPIUnavailableError,
|
||||
LLMRateLimitError, ProviderTokenNotInitError, QuotaExceededError,
|
||||
ModelCurrentlyNotSupportError) as e:
|
||||
db.session.rollback()
|
||||
PubHandler.pub_error(user, generate_task_id, e)
|
||||
except LLMAuthorizationError:
|
||||
db.session.rollback()
|
||||
PubHandler.pub_error(user, generate_task_id, LLMAuthorizationError('Incorrect API key provided'))
|
||||
except Exception as e:
|
||||
db.session.rollback()
|
||||
logging.exception("Unknown Error in completion")
|
||||
PubHandler.pub_error(user, generate_task_id, e)
|
||||
finally:
|
||||
db.session.commit()
|
||||
|
||||
@classmethod
|
||||
def get_cleaned_inputs(cls, user_inputs: dict, app_model_config: AppModelConfig):
|
||||
@@ -386,6 +396,8 @@ class CompletionService:
|
||||
logging.exception(e)
|
||||
raise
|
||||
finally:
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
pubsub.unsubscribe(generate_channel)
|
||||
except ConnectionError:
|
||||
@@ -423,6 +435,8 @@ class CompletionService:
|
||||
logging.exception(e)
|
||||
raise
|
||||
finally:
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
pubsub.unsubscribe(generate_channel)
|
||||
except ConnectionError:
|
||||
|
||||
@@ -35,6 +35,10 @@ WENXIN_SECRET_KEY=
|
||||
# ZhipuAI Credentials
|
||||
ZHIPUAI_API_KEY=
|
||||
|
||||
# Baichuan Credentials
|
||||
BAICHUAN_API_KEY=
|
||||
BAICHUAN_SECRET_KEY=
|
||||
|
||||
# ChatGLM Credentials
|
||||
CHATGLM_API_BASE=
|
||||
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from core.model_providers.models.embedding.openllm_embedding import OpenLLMEmbedding
|
||||
from core.model_providers.models.entity.model_params import ModelType
|
||||
from core.model_providers.providers.openllm_provider import OpenLLMProvider
|
||||
from models.provider import Provider, ProviderType, ProviderModel
|
||||
|
||||
|
||||
def get_mock_provider():
|
||||
return Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='openllm',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config='',
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
|
||||
def get_mock_embedding_model(mocker):
|
||||
model_name = 'facebook/opt-125m'
|
||||
server_url = os.environ['OPENLLM_SERVER_URL']
|
||||
model_provider = OpenLLMProvider(provider=get_mock_provider())
|
||||
|
||||
mock_query = MagicMock()
|
||||
mock_query.filter.return_value.first.return_value = ProviderModel(
|
||||
provider_name='openllm',
|
||||
model_name=model_name,
|
||||
model_type=ModelType.EMBEDDINGS.value,
|
||||
encrypted_config=json.dumps({
|
||||
'server_url': server_url
|
||||
}),
|
||||
is_valid=True,
|
||||
)
|
||||
mocker.patch('extensions.ext_database.db.session.query', return_value=mock_query)
|
||||
|
||||
return OpenLLMEmbedding(
|
||||
model_provider=model_provider,
|
||||
name=model_name
|
||||
)
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
return encrypted_api_key
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_embed_documents(mock_decrypt, mocker):
|
||||
embedding_model = get_mock_embedding_model(mocker)
|
||||
rst = embedding_model.client.embed_documents(['test', 'test1'])
|
||||
assert isinstance(rst, list)
|
||||
assert len(rst) == 2
|
||||
assert len(rst[0]) > 0
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_embed_query(mock_decrypt, mocker):
|
||||
embedding_model = get_mock_embedding_model(mocker)
|
||||
rst = embedding_model.client.embed_query('test')
|
||||
assert isinstance(rst, list)
|
||||
assert len(rst) > 0
|
||||
@@ -0,0 +1,81 @@
|
||||
import json
|
||||
import os
|
||||
from unittest.mock import patch
|
||||
|
||||
|
||||
from core.model_providers.models.entity.message import PromptMessage, MessageType
|
||||
from core.model_providers.models.entity.model_params import ModelKwargs
|
||||
from core.model_providers.models.llm.baichuan_model import BaichuanModel
|
||||
from core.model_providers.providers.baichuan_provider import BaichuanProvider
|
||||
from models.provider import Provider, ProviderType
|
||||
|
||||
|
||||
def get_mock_provider(valid_api_key, valid_secret_key):
|
||||
return Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name='baichuan',
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps({
|
||||
'api_key': valid_api_key,
|
||||
'secret_key': valid_secret_key,
|
||||
}),
|
||||
is_valid=True,
|
||||
)
|
||||
|
||||
|
||||
def get_mock_model(model_name: str, streaming: bool = False):
|
||||
model_kwargs = ModelKwargs(
|
||||
temperature=0.01,
|
||||
)
|
||||
valid_api_key = os.environ['BAICHUAN_API_KEY']
|
||||
valid_secret_key = os.environ['BAICHUAN_SECRET_KEY']
|
||||
model_provider = BaichuanProvider(provider=get_mock_provider(valid_api_key, valid_secret_key))
|
||||
return BaichuanModel(
|
||||
model_provider=model_provider,
|
||||
name=model_name,
|
||||
model_kwargs=model_kwargs,
|
||||
streaming=streaming
|
||||
)
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_api_key):
|
||||
return encrypted_api_key
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_chat_get_num_tokens(mock_decrypt):
|
||||
model = get_mock_model('baichuan2-53b')
|
||||
rst = model.get_num_tokens([
|
||||
PromptMessage(type=MessageType.SYSTEM, content='you are a kindness Assistant.'),
|
||||
PromptMessage(type=MessageType.HUMAN, content='Who is your manufacturer?')
|
||||
])
|
||||
assert rst > 0
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_chat_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
||||
model = get_mock_model('baichuan2-53b')
|
||||
messages = [
|
||||
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||
]
|
||||
rst = model.run(
|
||||
messages,
|
||||
)
|
||||
assert len(rst.content) > 0
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_chat_stream_run(mock_decrypt, mocker):
|
||||
mocker.patch('core.model_providers.providers.base.BaseModelProvider.update_last_used', return_value=None)
|
||||
|
||||
model = get_mock_model('baichuan2-53b', streaming=True)
|
||||
messages = [
|
||||
PromptMessage(type=MessageType.HUMAN, content='Are you Human? you MUST only answer `y` or `n`?')
|
||||
]
|
||||
rst = model.run(
|
||||
messages
|
||||
)
|
||||
assert len(rst.content) > 0
|
||||
@@ -0,0 +1,97 @@
|
||||
import pytest
|
||||
from unittest.mock import patch
|
||||
import json
|
||||
|
||||
from langchain.schema import ChatResult, ChatGeneration, AIMessage
|
||||
|
||||
from core.model_providers.providers.baichuan_provider import BaichuanProvider
|
||||
from core.model_providers.providers.base import CredentialsValidateFailedError
|
||||
from models.provider import ProviderType, Provider
|
||||
|
||||
|
||||
PROVIDER_NAME = 'baichuan'
|
||||
MODEL_PROVIDER_CLASS = BaichuanProvider
|
||||
VALIDATE_CREDENTIAL = {
|
||||
'api_key': 'valid_key',
|
||||
'secret_key': 'valid_key',
|
||||
}
|
||||
|
||||
|
||||
def encrypt_side_effect(tenant_id, encrypt_key):
|
||||
return f'encrypted_{encrypt_key}'
|
||||
|
||||
|
||||
def decrypt_side_effect(tenant_id, encrypted_key):
|
||||
return encrypted_key.replace('encrypted_', '')
|
||||
|
||||
|
||||
def test_is_provider_credentials_valid_or_raise_valid(mocker):
|
||||
mocker.patch('core.third_party.langchain.llms.baichuan_llm.BaichuanChatLLM._generate',
|
||||
return_value=ChatResult(generations=[ChatGeneration(message=AIMessage(content='abc'))]))
|
||||
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(VALIDATE_CREDENTIAL)
|
||||
|
||||
|
||||
def test_is_provider_credentials_valid_or_raise_invalid():
|
||||
# raise CredentialsValidateFailedError if api_key is not in credentials
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise({})
|
||||
|
||||
credential = VALIDATE_CREDENTIAL.copy()
|
||||
credential['api_key'] = 'invalid_key'
|
||||
credential['secret_key'] = 'invalid_key'
|
||||
|
||||
# raise CredentialsValidateFailedError if api_key is invalid
|
||||
with pytest.raises(CredentialsValidateFailedError):
|
||||
MODEL_PROVIDER_CLASS.is_provider_credentials_valid_or_raise(credential)
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.encrypt_token', side_effect=encrypt_side_effect)
|
||||
def test_encrypt_credentials(mock_encrypt):
|
||||
result = MODEL_PROVIDER_CLASS.encrypt_provider_credentials('tenant_id', VALIDATE_CREDENTIAL.copy())
|
||||
assert result['api_key'] == f'encrypted_{VALIDATE_CREDENTIAL["api_key"]}'
|
||||
assert result['secret_key'] == f'encrypted_{VALIDATE_CREDENTIAL["secret_key"]}'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_credentials_custom(mock_decrypt):
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
|
||||
encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key']
|
||||
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(encrypted_credential),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_provider_credentials()
|
||||
assert result['api_key'] == 'valid_key'
|
||||
assert result['secret_key'] == 'valid_key'
|
||||
|
||||
|
||||
@patch('core.helper.encrypter.decrypt_token', side_effect=decrypt_side_effect)
|
||||
def test_get_credentials_obfuscated(mock_decrypt):
|
||||
encrypted_credential = VALIDATE_CREDENTIAL.copy()
|
||||
encrypted_credential['api_key'] = 'encrypted_' + encrypted_credential['api_key']
|
||||
encrypted_credential['secret_key'] = 'encrypted_' + encrypted_credential['secret_key']
|
||||
|
||||
provider = Provider(
|
||||
id='provider_id',
|
||||
tenant_id='tenant_id',
|
||||
provider_name=PROVIDER_NAME,
|
||||
provider_type=ProviderType.CUSTOM.value,
|
||||
encrypted_config=json.dumps(encrypted_credential),
|
||||
is_valid=True,
|
||||
)
|
||||
model_provider = MODEL_PROVIDER_CLASS(provider=provider)
|
||||
result = model_provider.get_provider_credentials(obfuscated=True)
|
||||
middle_token = result['api_key'][6:-2]
|
||||
secret_key_middle_token = result['secret_key'][6:-2]
|
||||
assert len(middle_token) == max(len(VALIDATE_CREDENTIAL['api_key']) - 8, 0)
|
||||
assert len(secret_key_middle_token) == max(len(VALIDATE_CREDENTIAL['secret_key']) - 8, 0)
|
||||
assert all(char == '*' for char in middle_token)
|
||||
assert all(char == '*' for char in secret_key_middle_token)
|
||||
@@ -49,4 +49,18 @@ services:
|
||||
AUTHORIZATION_ADMINLIST_ENABLED: 'true'
|
||||
AUTHORIZATION_ADMINLIST_USERS: 'hello@dify.ai'
|
||||
ports:
|
||||
- "8080:8080"
|
||||
- "8080:8080"
|
||||
|
||||
# Qdrant vector store.
|
||||
# uncomment to use qdrant as vector store.
|
||||
# (if uncommented, you need to comment out the weaviate service above,
|
||||
# and set VECTOR_STORE to qdrant in the api & worker service.)
|
||||
# qdrant:
|
||||
# image: qdrant/qdrant:latest
|
||||
# restart: always
|
||||
# volumes:
|
||||
# - ./volumes/qdrant:/qdrant/storage
|
||||
# environment:
|
||||
# QDRANT__API_KEY: 'difyai123456'
|
||||
# ports:
|
||||
# - "6333:6333"
|
||||
@@ -2,7 +2,7 @@ version: '3.1'
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:0.3.24
|
||||
image: langgenius/dify-api:0.3.25
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'api' starts the API server.
|
||||
@@ -85,9 +85,9 @@ services:
|
||||
# The Weaviate API key.
|
||||
WEAVIATE_API_KEY: WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
# The Qdrant endpoint URL. Only available when VECTOR_STORE is `qdrant`.
|
||||
QDRANT_URL: 'https://your-qdrant-cluster-url.qdrant.tech/'
|
||||
QDRANT_URL: http://qdrant:6333
|
||||
# The Qdrant API key.
|
||||
QDRANT_API_KEY: 'ak-difyai'
|
||||
QDRANT_API_KEY: difyai123456
|
||||
# Mail configuration, support: resend
|
||||
MAIL_TYPE: ''
|
||||
# default send from email address, if not specified
|
||||
@@ -103,15 +103,17 @@ services:
|
||||
depends_on:
|
||||
- db
|
||||
- redis
|
||||
- weaviate
|
||||
volumes:
|
||||
# Mount the storage directory to the container, for storing user files.
|
||||
- ./volumes/app/storage:/app/api/storage
|
||||
# uncomment to expose dify-api port to host
|
||||
# ports:
|
||||
# - "5001:5001"
|
||||
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:0.3.24
|
||||
image: langgenius/dify-api:0.3.25
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'worker' starts the Celery worker for processing the queue.
|
||||
@@ -143,10 +145,16 @@ services:
|
||||
# The type of storage to use for storing user files. Supported values are `local` and `s3`, Default: `local`
|
||||
STORAGE_TYPE: local
|
||||
STORAGE_LOCAL_PATH: storage
|
||||
# The Vector store configurations.
|
||||
# The type of vector store to use. Supported values are `weaviate`, `qdrant`.
|
||||
VECTOR_STORE: weaviate
|
||||
# The Weaviate endpoint URL. Only available when VECTOR_STORE is `weaviate`.
|
||||
WEAVIATE_ENDPOINT: http://weaviate:8080
|
||||
# The Weaviate API key.
|
||||
WEAVIATE_API_KEY: WVF5YThaHlkYwhGUSmCRgsX3tD5ngdN8pkih
|
||||
# The Qdrant endpoint URL. Only available when VECTOR_STORE is `qdrant`.
|
||||
QDRANT_URL: http://qdrant:6333
|
||||
# The Qdrant API key.
|
||||
QDRANT_API_KEY: difyai123456
|
||||
# Mail configuration, support: resend
|
||||
MAIL_TYPE: ''
|
||||
# default send from email address, if not specified
|
||||
@@ -156,14 +164,13 @@ services:
|
||||
depends_on:
|
||||
- db
|
||||
- redis
|
||||
- weaviate
|
||||
volumes:
|
||||
# Mount the storage directory to the container, for storing user files.
|
||||
- ./volumes/app/storage:/app/api/storage
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:0.3.24
|
||||
image: langgenius/dify-web:0.3.25
|
||||
restart: always
|
||||
environment:
|
||||
EDITION: SELF_HOSTED
|
||||
@@ -177,6 +184,9 @@ services:
|
||||
APP_API_URL: ''
|
||||
# The DSN for Sentry error reporting. If not set, Sentry error reporting will be disabled.
|
||||
SENTRY_DSN: ''
|
||||
# uncomment to expose dify-web port to host
|
||||
# ports:
|
||||
# - "3000:3000"
|
||||
|
||||
# The postgres database.
|
||||
db:
|
||||
@@ -211,6 +221,9 @@ services:
|
||||
command: redis-server --requirepass difyai123456
|
||||
healthcheck:
|
||||
test: ["CMD", "redis-cli","ping"]
|
||||
# uncomment to expose redis port to host
|
||||
# ports:
|
||||
# - "6379:6379"
|
||||
|
||||
# The Weaviate vector store.
|
||||
weaviate:
|
||||
@@ -232,6 +245,24 @@ services:
|
||||
AUTHENTICATION_APIKEY_USERS: 'hello@dify.ai'
|
||||
AUTHORIZATION_ADMINLIST_ENABLED: 'true'
|
||||
AUTHORIZATION_ADMINLIST_USERS: 'hello@dify.ai'
|
||||
# uncomment to expose weaviate port to host
|
||||
# ports:
|
||||
# - "8080:8080"
|
||||
|
||||
# Qdrant vector store.
|
||||
# uncomment to use qdrant as vector store.
|
||||
# (if uncommented, you need to comment out the weaviate service above,
|
||||
# and set VECTOR_STORE to qdrant in the api & worker service.)
|
||||
# qdrant:
|
||||
# image: qdrant/qdrant:latest
|
||||
# restart: always
|
||||
# volumes:
|
||||
# - ./volumes/qdrant:/qdrant/storage
|
||||
# environment:
|
||||
# QDRANT__API_KEY: 'difyai123456'
|
||||
## uncomment to expose qdrant port to host
|
||||
## ports:
|
||||
## - "6333:6333"
|
||||
|
||||
# The nginx reverse proxy.
|
||||
# used for reverse proxying the API service and Web service.
|
||||
|
||||
@@ -22,6 +22,7 @@ import Tag from '@/app/components/base/tag'
|
||||
import Switch from '@/app/components/base/switch'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import CopyFeedback from '@/app/components/base/copy-feedback'
|
||||
import ShareQRCode from '@/app/components/base/qrcode'
|
||||
import SecretKeyButton from '@/app/components/develop/secret-key/secret-key-button'
|
||||
import type { AppDetailResponse } from '@/models/app'
|
||||
import { AppType } from '@/types/app'
|
||||
@@ -168,6 +169,7 @@ function AppCard({
|
||||
</div>
|
||||
</div>
|
||||
<Divider type="vertical" className="!h-3.5 shrink-0 !mx-0.5" />
|
||||
{isApp && <ShareQRCode content={isApp ? appUrl : apiUrl} selectorId={randomString(8)} className={'hover:bg-gray-200'} />}
|
||||
<CopyFeedback
|
||||
content={isApp ? appUrl : apiUrl}
|
||||
selectorId={randomString(8)}
|
||||
|
||||
Binary file not shown.
|
After Width: | Height: | Size: 4.5 KiB |
@@ -0,0 +1,19 @@
|
||||
<svg width="130" height="24" viewBox="0 0 130 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path fill-rule="evenodd" clip-rule="evenodd" d="M9.58154 1.7793H6.52779L4.34655 6.20409V17.7335L1.91602 22.2206H7.21333L9.58154 17.7335V1.7793ZM11.5761 1.7793H16.8111V22.2206H11.5761V1.7793ZM23.9166 1.7793H18.6816V6.01712H23.9166V1.7793ZM23.9166 7.38818H18.6816V22.2206H23.9166V7.38818Z" fill="url(#paint0_radial_11622_96091)"/>
|
||||
<path d="M129.722 6.83203V18H127.482V6.83203H129.722Z" fill="#FF6A34"/>
|
||||
<path d="M123.196 15.872H118.748L118.012 18H115.66L119.676 6.81604H122.284L126.3 18H123.932L123.196 15.872ZM122.588 14.08L120.972 9.40804L119.356 14.08H122.588Z" fill="#FF6A34"/>
|
||||
<path d="M110.962 18H108.722L103.65 10.336V18H101.41V6.81598H103.65L108.722 14.496V6.81598H110.962V18Z" fill="#FF6A34"/>
|
||||
<path d="M97.1258 15.872H92.6778L91.9418 18H89.5898L93.6058 6.81604H96.2138L100.23 18H97.8618L97.1258 15.872ZM96.5178 14.08L94.9018 9.40804L93.2858 14.08H96.5178Z" fill="#FF6A34"/>
|
||||
<path d="M81.6482 6.83203V13.744C81.6482 14.5014 81.8455 15.0827 82.2402 15.488C82.6349 15.8827 83.1895 16.08 83.9042 16.08C84.6295 16.08 85.1895 15.8827 85.5842 15.488C85.9789 15.0827 86.1762 14.5014 86.1762 13.744V6.83203H88.4322V13.728C88.4322 14.6774 88.2242 15.4827 87.8082 16.144C87.4029 16.7947 86.8535 17.2854 86.1602 17.616C85.4775 17.9467 84.7149 18.112 83.8722 18.112C83.0402 18.112 82.2829 17.9467 81.6002 17.616C80.9282 17.2854 80.3949 16.7947 80.0002 16.144C79.6055 15.4827 79.4082 14.6774 79.4082 13.728V6.83203H81.6482Z" fill="#FF6A34"/>
|
||||
<path d="M77.557 6.83203V18H75.317V13.248H70.533V18H68.293V6.83203H70.533V11.424H75.317V6.83203H77.557Z" fill="#FF6A34"/>
|
||||
<path d="M55.7871 12.4C55.7871 11.3013 56.0324 10.32 56.5231 9.45599C57.0244 8.58132 57.7018 7.90399 58.5551 7.42399C59.4191 6.93332 60.3844 6.68799 61.4511 6.68799C62.6991 6.68799 63.7924 7.00799 64.7311 7.64799C65.6698 8.28799 66.3258 9.17332 66.6991 10.304H64.1231C63.8671 9.77065 63.5044 9.37065 63.0351 9.10399C62.5764 8.83732 62.0431 8.70399 61.4351 8.70399C60.7844 8.70399 60.2031 8.85865 59.6911 9.16799C59.1898 9.46665 58.7951 9.89332 58.5071 10.448C58.2298 11.0027 58.0911 11.6533 58.0911 12.4C58.0911 13.136 58.2298 13.7867 58.5071 14.352C58.7951 14.9067 59.1898 15.3387 59.6911 15.648C60.2031 15.9467 60.7844 16.096 61.4351 16.096C62.0431 16.096 62.5764 15.9627 63.0351 15.696C63.5044 15.4187 63.8671 15.0133 64.1231 14.48H66.6991C66.3258 15.6213 65.6698 16.512 64.7311 17.152C63.8031 17.7813 62.7098 18.096 61.4511 18.096C60.3844 18.096 59.4191 17.856 58.5551 17.376C57.7018 16.8853 57.0244 16.208 56.5231 15.344C56.0324 14.48 55.7871 13.4987 55.7871 12.4Z" fill="#FF6A34"/>
|
||||
<path d="M54.4373 6.83203V18H52.1973V6.83203H54.4373Z" fill="#FF6A34"/>
|
||||
<path d="M47.913 15.872H43.465L42.729 18H40.377L44.393 6.81598H47.001L51.017 18H48.649L47.913 15.872ZM47.305 14.08L45.689 9.40798L44.073 14.08H47.305Z" fill="#FF6A34"/>
|
||||
<path d="M37.4395 12.272C38.0688 12.3893 38.5862 12.704 38.9915 13.216C39.3968 13.728 39.5995 14.3146 39.5995 14.976C39.5995 15.5733 39.4502 16.1013 39.1515 16.56C38.8635 17.008 38.4422 17.36 37.8875 17.616C37.3328 17.872 36.6768 18 35.9195 18H31.1035V6.83197H35.7115C36.4688 6.83197 37.1195 6.95464 37.6635 7.19997C38.2182 7.4453 38.6342 7.78664 38.9115 8.22397C39.1995 8.6613 39.3435 9.1573 39.3435 9.71197C39.3435 10.3626 39.1675 10.9066 38.8155 11.344C38.4742 11.7813 38.0155 12.0906 37.4395 12.272ZM33.3435 11.44H35.3915C35.9248 11.44 36.3355 11.3226 36.6235 11.088C36.9115 10.8426 37.0555 10.496 37.0555 10.048C37.0555 9.59997 36.9115 9.2533 36.6235 9.00797C36.3355 8.76264 35.9248 8.63997 35.3915 8.63997H33.3435V11.44ZM35.5995 16.176C36.1435 16.176 36.5648 16.048 36.8635 15.792C37.1728 15.536 37.3275 15.1733 37.3275 14.704C37.3275 14.224 37.1675 13.8506 36.8475 13.584C36.5275 13.3066 36.0955 13.168 35.5515 13.168H33.3435V16.176H35.5995Z" fill="#FF6A34"/>
|
||||
<defs>
|
||||
<radialGradient id="paint0_radial_11622_96091" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(6.5 5.5) rotate(45) scale(20.5061 22.0704)">
|
||||
<stop stop-color="#FEBD3F"/>
|
||||
<stop offset="0.77608" stop-color="#FF6933"/>
|
||||
</radialGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 4.0 KiB |
11
web/app/components/base/icons/assets/public/llm/baichuan.svg
Normal file
11
web/app/components/base/icons/assets/public/llm/baichuan.svg
Normal file
@@ -0,0 +1,11 @@
|
||||
<svg width="24" height="24" viewBox="0 0 24 24" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<g id="Baichuan">
|
||||
<path id="Union" fill-rule="evenodd" clip-rule="evenodd" d="M8.58154 1.7793H5.52779L3.34655 6.20409V17.7335L0.916016 22.2206H6.21333L8.58154 17.7335V1.7793ZM10.5761 1.7793H15.8111V22.2206H10.5761V1.7793ZM22.9166 1.7793H17.6816V6.01712H22.9166V1.7793ZM22.9166 7.38818H17.6816V22.2206H22.9166V7.38818Z" fill="url(#paint0_radial_11622_96084)"/>
|
||||
</g>
|
||||
<defs>
|
||||
<radialGradient id="paint0_radial_11622_96084" cx="0" cy="0" r="1" gradientUnits="userSpaceOnUse" gradientTransform="translate(5.5 5.5) rotate(45) scale(20.5061 22.0704)">
|
||||
<stop stop-color="#FEBD3F"/>
|
||||
<stop offset="0.77608" stop-color="#FF6933"/>
|
||||
</radialGradient>
|
||||
</defs>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 748 B |
@@ -115,6 +115,8 @@ const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTM
|
||||
ref,
|
||||
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
|
||||
|
||||
Icon.displayName = '<%= fileName %>'
|
||||
|
||||
export default Icon
|
||||
`.trim())
|
||||
|
||||
|
||||
@@ -0,0 +1,5 @@
|
||||
.wrapper {
|
||||
display: inline-flex;
|
||||
background: url(~@/app/components/base/icons/assets/image/llm/baichuan-text-cn.png) center center no-repeat;
|
||||
background-size: contain;
|
||||
}
|
||||
@@ -0,0 +1,15 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import * as React from 'react'
|
||||
import cn from 'classnames'
|
||||
import s from './BaichuanTextCn.module.css'
|
||||
|
||||
const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTMLAttributes<HTMLSpanElement>, HTMLSpanElement>>((
|
||||
{ className, ...restProps },
|
||||
ref,
|
||||
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
|
||||
|
||||
Icon.displayName = 'BaichuanTextCn'
|
||||
|
||||
export default Icon
|
||||
@@ -10,4 +10,6 @@ const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTM
|
||||
ref,
|
||||
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
|
||||
|
||||
Icon.displayName = 'Minimax'
|
||||
|
||||
export default Icon
|
||||
|
||||
@@ -10,4 +10,6 @@ const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTM
|
||||
ref,
|
||||
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
|
||||
|
||||
Icon.displayName = 'MinimaxText'
|
||||
|
||||
export default Icon
|
||||
|
||||
@@ -10,4 +10,6 @@ const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTM
|
||||
ref,
|
||||
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
|
||||
|
||||
Icon.displayName = 'Tongyi'
|
||||
|
||||
export default Icon
|
||||
|
||||
@@ -10,4 +10,6 @@ const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTM
|
||||
ref,
|
||||
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
|
||||
|
||||
Icon.displayName = 'TongyiText'
|
||||
|
||||
export default Icon
|
||||
|
||||
@@ -10,4 +10,6 @@ const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTM
|
||||
ref,
|
||||
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
|
||||
|
||||
Icon.displayName = 'TongyiTextCn'
|
||||
|
||||
export default Icon
|
||||
|
||||
@@ -10,4 +10,6 @@ const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTM
|
||||
ref,
|
||||
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
|
||||
|
||||
Icon.displayName = 'Wxyy'
|
||||
|
||||
export default Icon
|
||||
|
||||
@@ -10,4 +10,6 @@ const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTM
|
||||
ref,
|
||||
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
|
||||
|
||||
Icon.displayName = 'WxyyText'
|
||||
|
||||
export default Icon
|
||||
|
||||
@@ -10,4 +10,6 @@ const Icon = React.forwardRef<HTMLSpanElement, React.DetailedHTMLProps<React.HTM
|
||||
ref,
|
||||
) => <span className={cn(s.wrapper, className)} {...restProps} ref={ref} />)
|
||||
|
||||
Icon.displayName = 'WxyyTextCn'
|
||||
|
||||
export default Icon
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
export { default as BaichuanTextCn } from './BaichuanTextCn'
|
||||
export { default as MinimaxText } from './MinimaxText'
|
||||
export { default as Minimax } from './Minimax'
|
||||
export { default as TongyiTextCn } from './TongyiTextCn'
|
||||
|
||||
76
web/app/components/base/icons/src/public/llm/Baichuan.json
Normal file
76
web/app/components/base/icons/src/public/llm/Baichuan.json
Normal file
@@ -0,0 +1,76 @@
|
||||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"width": "24",
|
||||
"height": "24",
|
||||
"viewBox": "0 0 24 24",
|
||||
"fill": "none",
|
||||
"xmlns": "http://www.w3.org/2000/svg"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "g",
|
||||
"attributes": {
|
||||
"id": "Baichuan"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"id": "Union",
|
||||
"fill-rule": "evenodd",
|
||||
"clip-rule": "evenodd",
|
||||
"d": "M8.58154 1.7793H5.52779L3.34655 6.20409V17.7335L0.916016 22.2206H6.21333L8.58154 17.7335V1.7793ZM10.5761 1.7793H15.8111V22.2206H10.5761V1.7793ZM22.9166 1.7793H17.6816V6.01712H22.9166V1.7793ZM22.9166 7.38818H17.6816V22.2206H22.9166V7.38818Z",
|
||||
"fill": "url(#paint0_radial_11622_96084)"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "defs",
|
||||
"attributes": {},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "radialGradient",
|
||||
"attributes": {
|
||||
"id": "paint0_radial_11622_96084",
|
||||
"cx": "0",
|
||||
"cy": "0",
|
||||
"r": "1",
|
||||
"gradientUnits": "userSpaceOnUse",
|
||||
"gradientTransform": "translate(5.5 5.5) rotate(45) scale(20.5061 22.0704)"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "stop",
|
||||
"attributes": {
|
||||
"stop-color": "#FEBD3F"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "stop",
|
||||
"attributes": {
|
||||
"offset": "0.77608",
|
||||
"stop-color": "#FF6933"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "Baichuan"
|
||||
}
|
||||
16
web/app/components/base/icons/src/public/llm/Baichuan.tsx
Normal file
16
web/app/components/base/icons/src/public/llm/Baichuan.tsx
Normal file
@@ -0,0 +1,16 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import * as React from 'react'
|
||||
import data from './Baichuan.json'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
|
||||
|
||||
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
|
||||
props,
|
||||
ref,
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />)
|
||||
|
||||
Icon.displayName = 'Baichuan'
|
||||
|
||||
export default Icon
|
||||
156
web/app/components/base/icons/src/public/llm/BaichuanText.json
Normal file
156
web/app/components/base/icons/src/public/llm/BaichuanText.json
Normal file
@@ -0,0 +1,156 @@
|
||||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"width": "130",
|
||||
"height": "24",
|
||||
"viewBox": "0 0 130 24",
|
||||
"fill": "none",
|
||||
"xmlns": "http://www.w3.org/2000/svg"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"fill-rule": "evenodd",
|
||||
"clip-rule": "evenodd",
|
||||
"d": "M9.58154 1.7793H6.52779L4.34655 6.20409V17.7335L1.91602 22.2206H7.21333L9.58154 17.7335V1.7793ZM11.5761 1.7793H16.8111V22.2206H11.5761V1.7793ZM23.9166 1.7793H18.6816V6.01712H23.9166V1.7793ZM23.9166 7.38818H18.6816V22.2206H23.9166V7.38818Z",
|
||||
"fill": "url(#paint0_radial_11622_96091)"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M129.722 6.83203V18H127.482V6.83203H129.722Z",
|
||||
"fill": "#FF6A34"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M123.196 15.872H118.748L118.012 18H115.66L119.676 6.81604H122.284L126.3 18H123.932L123.196 15.872ZM122.588 14.08L120.972 9.40804L119.356 14.08H122.588Z",
|
||||
"fill": "#FF6A34"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M110.962 18H108.722L103.65 10.336V18H101.41V6.81598H103.65L108.722 14.496V6.81598H110.962V18Z",
|
||||
"fill": "#FF6A34"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M97.1258 15.872H92.6778L91.9418 18H89.5898L93.6058 6.81604H96.2138L100.23 18H97.8618L97.1258 15.872ZM96.5178 14.08L94.9018 9.40804L93.2858 14.08H96.5178Z",
|
||||
"fill": "#FF6A34"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M81.6482 6.83203V13.744C81.6482 14.5014 81.8455 15.0827 82.2402 15.488C82.6349 15.8827 83.1895 16.08 83.9042 16.08C84.6295 16.08 85.1895 15.8827 85.5842 15.488C85.9789 15.0827 86.1762 14.5014 86.1762 13.744V6.83203H88.4322V13.728C88.4322 14.6774 88.2242 15.4827 87.8082 16.144C87.4029 16.7947 86.8535 17.2854 86.1602 17.616C85.4775 17.9467 84.7149 18.112 83.8722 18.112C83.0402 18.112 82.2829 17.9467 81.6002 17.616C80.9282 17.2854 80.3949 16.7947 80.0002 16.144C79.6055 15.4827 79.4082 14.6774 79.4082 13.728V6.83203H81.6482Z",
|
||||
"fill": "#FF6A34"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M77.557 6.83203V18H75.317V13.248H70.533V18H68.293V6.83203H70.533V11.424H75.317V6.83203H77.557Z",
|
||||
"fill": "#FF6A34"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M55.7871 12.4C55.7871 11.3013 56.0324 10.32 56.5231 9.45599C57.0244 8.58132 57.7018 7.90399 58.5551 7.42399C59.4191 6.93332 60.3844 6.68799 61.4511 6.68799C62.6991 6.68799 63.7924 7.00799 64.7311 7.64799C65.6698 8.28799 66.3258 9.17332 66.6991 10.304H64.1231C63.8671 9.77065 63.5044 9.37065 63.0351 9.10399C62.5764 8.83732 62.0431 8.70399 61.4351 8.70399C60.7844 8.70399 60.2031 8.85865 59.6911 9.16799C59.1898 9.46665 58.7951 9.89332 58.5071 10.448C58.2298 11.0027 58.0911 11.6533 58.0911 12.4C58.0911 13.136 58.2298 13.7867 58.5071 14.352C58.7951 14.9067 59.1898 15.3387 59.6911 15.648C60.2031 15.9467 60.7844 16.096 61.4351 16.096C62.0431 16.096 62.5764 15.9627 63.0351 15.696C63.5044 15.4187 63.8671 15.0133 64.1231 14.48H66.6991C66.3258 15.6213 65.6698 16.512 64.7311 17.152C63.8031 17.7813 62.7098 18.096 61.4511 18.096C60.3844 18.096 59.4191 17.856 58.5551 17.376C57.7018 16.8853 57.0244 16.208 56.5231 15.344C56.0324 14.48 55.7871 13.4987 55.7871 12.4Z",
|
||||
"fill": "#FF6A34"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M54.4373 6.83203V18H52.1973V6.83203H54.4373Z",
|
||||
"fill": "#FF6A34"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M47.913 15.872H43.465L42.729 18H40.377L44.393 6.81598H47.001L51.017 18H48.649L47.913 15.872ZM47.305 14.08L45.689 9.40798L44.073 14.08H47.305Z",
|
||||
"fill": "#FF6A34"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M37.4395 12.272C38.0688 12.3893 38.5862 12.704 38.9915 13.216C39.3968 13.728 39.5995 14.3146 39.5995 14.976C39.5995 15.5733 39.4502 16.1013 39.1515 16.56C38.8635 17.008 38.4422 17.36 37.8875 17.616C37.3328 17.872 36.6768 18 35.9195 18H31.1035V6.83197H35.7115C36.4688 6.83197 37.1195 6.95464 37.6635 7.19997C38.2182 7.4453 38.6342 7.78664 38.9115 8.22397C39.1995 8.6613 39.3435 9.1573 39.3435 9.71197C39.3435 10.3626 39.1675 10.9066 38.8155 11.344C38.4742 11.7813 38.0155 12.0906 37.4395 12.272ZM33.3435 11.44H35.3915C35.9248 11.44 36.3355 11.3226 36.6235 11.088C36.9115 10.8426 37.0555 10.496 37.0555 10.048C37.0555 9.59997 36.9115 9.2533 36.6235 9.00797C36.3355 8.76264 35.9248 8.63997 35.3915 8.63997H33.3435V11.44ZM35.5995 16.176C36.1435 16.176 36.5648 16.048 36.8635 15.792C37.1728 15.536 37.3275 15.1733 37.3275 14.704C37.3275 14.224 37.1675 13.8506 36.8475 13.584C36.5275 13.3066 36.0955 13.168 35.5515 13.168H33.3435V16.176H35.5995Z",
|
||||
"fill": "#FF6A34"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "defs",
|
||||
"attributes": {},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "radialGradient",
|
||||
"attributes": {
|
||||
"id": "paint0_radial_11622_96091",
|
||||
"cx": "0",
|
||||
"cy": "0",
|
||||
"r": "1",
|
||||
"gradientUnits": "userSpaceOnUse",
|
||||
"gradientTransform": "translate(6.5 5.5) rotate(45) scale(20.5061 22.0704)"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "stop",
|
||||
"attributes": {
|
||||
"stop-color": "#FEBD3F"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "stop",
|
||||
"attributes": {
|
||||
"offset": "0.77608",
|
||||
"stop-color": "#FF6933"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "BaichuanText"
|
||||
}
|
||||
@@ -0,0 +1,16 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import * as React from 'react'
|
||||
import data from './BaichuanText.json'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import type { IconBaseProps, IconData } from '@/app/components/base/icons/IconBase'
|
||||
|
||||
const Icon = React.forwardRef<React.MutableRefObject<SVGElement>, Omit<IconBaseProps, 'data'>>((
|
||||
props,
|
||||
ref,
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />)
|
||||
|
||||
Icon.displayName = 'BaichuanText'
|
||||
|
||||
export default Icon
|
||||
@@ -4,6 +4,8 @@ export { default as AzureOpenaiServiceText } from './AzureOpenaiServiceText'
|
||||
export { default as AzureOpenaiService } from './AzureOpenaiService'
|
||||
export { default as AzureaiText } from './AzureaiText'
|
||||
export { default as Azureai } from './Azureai'
|
||||
export { default as BaichuanText } from './BaichuanText'
|
||||
export { default as Baichuan } from './Baichuan'
|
||||
export { default as ChatglmText } from './ChatglmText'
|
||||
export { default as Chatglm } from './Chatglm'
|
||||
export { default as Gpt3 } from './Gpt3'
|
||||
|
||||
61
web/app/components/base/qrcode/index.tsx
Normal file
61
web/app/components/base/qrcode/index.tsx
Normal file
@@ -0,0 +1,61 @@
|
||||
'use client'
|
||||
import React, { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { debounce } from 'lodash-es'
|
||||
import QRCode from 'qrcode.react'
|
||||
import Tooltip from '../tooltip'
|
||||
import QrcodeStyle from './style.module.css'
|
||||
|
||||
type Props = {
|
||||
content: string
|
||||
selectorId: string
|
||||
className?: string
|
||||
}
|
||||
|
||||
const prefixEmbedded = 'appOverview.overview.appInfo.qrcode.title'
|
||||
|
||||
const ShareQRCode = ({ content, selectorId, className }: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const [isShow, setisShow] = useState<boolean>(false)
|
||||
const onClickShow = debounce(() => {
|
||||
setisShow(true)
|
||||
}, 100)
|
||||
|
||||
const downloadQR = () => {
|
||||
const canvas = document.getElementsByTagName('canvas')[0]
|
||||
const link = document.createElement('a')
|
||||
link.download = 'qrcode.png'
|
||||
link.href = canvas.toDataURL()
|
||||
link.click()
|
||||
}
|
||||
|
||||
const onMouseLeave = debounce(() => {
|
||||
setisShow(false)
|
||||
}, 500)
|
||||
|
||||
return (
|
||||
<Tooltip
|
||||
selector={`common-qrcode-show-${selectorId}`}
|
||||
content={t(`${prefixEmbedded}`) || ''}
|
||||
>
|
||||
<div
|
||||
className={`w-8 h-8 cursor-pointer rounded-lg ${className ?? ''}`}
|
||||
onMouseLeave={onMouseLeave}
|
||||
onClick={onClickShow}
|
||||
>
|
||||
<div className={`w-full h-full ${QrcodeStyle.QrcodeIcon} ${isShow ? QrcodeStyle.show : ''}`} />
|
||||
{isShow && <div className={QrcodeStyle.qrcodeform}>
|
||||
<QRCode size={160} value={content} className={QrcodeStyle.qrcodeimage}/>
|
||||
<div className={QrcodeStyle.text}>
|
||||
<div className={`text-gray-500 ${QrcodeStyle.scan}`}>{t('appOverview.overview.appInfo.qrcode.scan')}</div>
|
||||
<div className={`text-gray-500 ${QrcodeStyle.scan}`}>·</div>
|
||||
<div className={QrcodeStyle.download} onClick={downloadQR}>{t('appOverview.overview.appInfo.qrcode.download')}</div>
|
||||
</div>
|
||||
</div>
|
||||
}
|
||||
</div>
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
|
||||
export default ShareQRCode
|
||||
61
web/app/components/base/qrcode/style.module.css
Normal file
61
web/app/components/base/qrcode/style.module.css
Normal file
@@ -0,0 +1,61 @@
|
||||
.QrcodeIcon {
|
||||
background-image: url(~@/app/components/develop/secret-key/assets/qrcode.svg);
|
||||
background-position: center;
|
||||
background-repeat: no-repeat;
|
||||
}
|
||||
|
||||
.QrcodeIcon:hover {
|
||||
background-image: url(~@/app/components/develop/secret-key/assets/qrcode-hover.svg);
|
||||
background-position: center;
|
||||
background-repeat: no-repeat;
|
||||
}
|
||||
|
||||
.QrcodeIcon.show {
|
||||
background-image: url(~@/app/components/develop/secret-key/assets/qrcode-hover.svg);
|
||||
background-position: center;
|
||||
background-repeat: no-repeat;
|
||||
}
|
||||
|
||||
.qrcodeimage {
|
||||
position: relative;
|
||||
object-fit: cover;
|
||||
}
|
||||
.scan {
|
||||
margin: 0;
|
||||
line-height: 1rem;
|
||||
font-size: 0.75rem;
|
||||
}
|
||||
.download {
|
||||
position: relative;
|
||||
color: #155eef;
|
||||
font-size: 0.75rem;
|
||||
line-height: 1rem;
|
||||
}
|
||||
.text {
|
||||
align-self: stretch;
|
||||
display: flex;
|
||||
flex-direction: row;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
gap: 4px;
|
||||
}
|
||||
.qrcodeform {
|
||||
border: 0.5px solid #eaecf0;
|
||||
display: flex;
|
||||
flex-direction: column;
|
||||
margin: 0 !important;
|
||||
margin-top: 4px !important;
|
||||
margin-left: -75px !important;
|
||||
position: absolute;
|
||||
border-radius: 8px;
|
||||
background-color: #fff;
|
||||
box-shadow: 0 12px 16px -4px rgba(16, 24, 40, 0.08),
|
||||
0 4px 6px -2px rgba(16, 24, 40, 0.03);
|
||||
overflow: hidden;
|
||||
align-items: center;
|
||||
justify-content: center;
|
||||
padding: 12px;
|
||||
gap: 8px;
|
||||
z-index: 3;
|
||||
font-family: "PingFang SC", serif;
|
||||
}
|
||||
@@ -127,7 +127,7 @@ const DatasetUpdateForm = ({ datasetId }: DatasetUpdateFormProps) => {
|
||||
{(step === 2 && (!datasetId || (datasetId && !!detail))) && <StepTwo
|
||||
hasSetAPIKEY={!!embeddingsDefaultModel}
|
||||
onSetting={showSetAPIKey}
|
||||
indexingType={detail?.indexing_technique || ''}
|
||||
indexingType={detail?.indexing_technique}
|
||||
datasetId={datasetId}
|
||||
dataSourceType={dataSourceType}
|
||||
files={fileList.map(file => file.file)}
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
/* eslint-disable no-mixed-operators */
|
||||
'use client'
|
||||
import React, { useEffect, useLayoutEffect, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
@@ -11,7 +10,7 @@ import { groupBy } from 'lodash-es'
|
||||
import PreviewItem, { PreviewType } from './preview-item'
|
||||
import LanguageSelect from './language-select'
|
||||
import s from './index.module.css'
|
||||
import type { CreateDocumentReq, CustomFile, FullDocumentDetail, FileIndexingEstimateResponse as IndexingEstimateResponse, NotionInfo, PreProcessingRule, Rules, createDocumentResponse } from '@/models/datasets'
|
||||
import type { CreateDocumentReq, CustomFile, FileIndexingEstimateResponse, FullDocumentDetail, IndexingEstimateParams, IndexingEstimateResponse, NotionInfo, PreProcessingRule, ProcessRule, Rules, createDocumentResponse } from '@/models/datasets'
|
||||
import {
|
||||
createDocument,
|
||||
createFirstDocument,
|
||||
@@ -33,13 +32,14 @@ import { useDatasetDetailContext } from '@/context/dataset-detail'
|
||||
import I18n from '@/context/i18n'
|
||||
import { IS_CE_EDITION } from '@/config'
|
||||
|
||||
type ValueOf<T> = T[keyof T]
|
||||
type StepTwoProps = {
|
||||
isSetting?: boolean
|
||||
documentDetail?: FullDocumentDetail
|
||||
hasSetAPIKEY: boolean
|
||||
onSetting: () => void
|
||||
datasetId?: string
|
||||
indexingType?: string
|
||||
indexingType?: ValueOf<IndexingType>
|
||||
dataSourceType: DataSourceType
|
||||
files: CustomFile[]
|
||||
notionPages?: NotionPage[]
|
||||
@@ -89,21 +89,23 @@ const StepTwo = ({
|
||||
const [rules, setRules] = useState<PreProcessingRule[]>([])
|
||||
const [defaultConfig, setDefaultConfig] = useState<Rules>()
|
||||
const hasSetIndexType = !!indexingType
|
||||
const [indexType, setIndexType] = useState<IndexingType>(
|
||||
indexingType
|
||||
|| hasSetAPIKEY
|
||||
const [indexType, setIndexType] = useState<ValueOf<IndexingType>>(
|
||||
(indexingType
|
||||
|| hasSetAPIKEY)
|
||||
? IndexingType.QUALIFIED
|
||||
: IndexingType.ECONOMICAL,
|
||||
)
|
||||
const [docForm, setDocForm] = useState<DocForm | string>(
|
||||
datasetId && documentDetail ? documentDetail.doc_form : DocForm.TEXT,
|
||||
(datasetId && documentDetail) ? documentDetail.doc_form : DocForm.TEXT,
|
||||
)
|
||||
const [docLanguage, setDocLanguage] = useState<string>(locale === 'en' ? 'English' : 'Chinese')
|
||||
const [QATipHide, setQATipHide] = useState(false)
|
||||
const [previewSwitched, setPreviewSwitched] = useState(false)
|
||||
const [showPreview, { setTrue: setShowPreview, setFalse: hidePreview }] = useBoolean()
|
||||
const [customFileIndexingEstimate, setCustomFileIndexingEstimate] = useState<IndexingEstimateResponse | null>(null)
|
||||
const [automaticFileIndexingEstimate, setAutomaticFileIndexingEstimate] = useState<IndexingEstimateResponse | null>(null)
|
||||
const [customFileIndexingEstimate, setCustomFileIndexingEstimate] = useState<FileIndexingEstimateResponse | null>(null)
|
||||
const [automaticFileIndexingEstimate, setAutomaticFileIndexingEstimate] = useState<FileIndexingEstimateResponse | null>(null)
|
||||
const [estimateTokes, setEstimateTokes] = useState<Pick<IndexingEstimateResponse, 'tokens' | 'total_price'> | null>(null)
|
||||
|
||||
const fileIndexingEstimate = (() => {
|
||||
return segmentationType === SegmentType.AUTO ? automaticFileIndexingEstimate : customFileIndexingEstimate
|
||||
})()
|
||||
@@ -153,7 +155,7 @@ const StepTwo = ({
|
||||
}
|
||||
const resetRules = () => {
|
||||
if (defaultConfig) {
|
||||
setSegmentIdentifier(defaultConfig.segmentation.separator === '\n' ? '\\n' : defaultConfig.segmentation.separator || '\\n')
|
||||
setSegmentIdentifier((defaultConfig.segmentation.separator === '\n' ? '\\n' : defaultConfig.segmentation.separator) || '\\n')
|
||||
setMax(defaultConfig.segmentation.max_tokens)
|
||||
setRules(defaultConfig.pre_processing_rules)
|
||||
}
|
||||
@@ -161,12 +163,14 @@ const StepTwo = ({
|
||||
|
||||
const fetchFileIndexingEstimate = async (docForm = DocForm.TEXT) => {
|
||||
// eslint-disable-next-line @typescript-eslint/no-use-before-define
|
||||
const res = await didFetchFileIndexingEstimate(getFileIndexingEstimateParams(docForm))
|
||||
if (segmentationType === SegmentType.CUSTOM)
|
||||
const res = await didFetchFileIndexingEstimate(getFileIndexingEstimateParams(docForm)!)
|
||||
if (segmentationType === SegmentType.CUSTOM) {
|
||||
setCustomFileIndexingEstimate(res)
|
||||
|
||||
else
|
||||
}
|
||||
else {
|
||||
setAutomaticFileIndexingEstimate(res)
|
||||
indexType === IndexingType.QUALIFIED && setEstimateTokes({ tokens: res.tokens, total_price: res.total_price })
|
||||
}
|
||||
}
|
||||
|
||||
const confirmChangeCustomConfig = () => {
|
||||
@@ -179,8 +183,8 @@ const StepTwo = ({
|
||||
const getIndexing_technique = () => indexingType || indexType
|
||||
|
||||
const getProcessRule = () => {
|
||||
const processRule: any = {
|
||||
rules: {}, // api will check this. It will be removed after api refactored.
|
||||
const processRule: ProcessRule = {
|
||||
rules: {} as any, // api will check this. It will be removed after api refactored.
|
||||
mode: segmentationType,
|
||||
}
|
||||
if (segmentationType === SegmentType.CUSTOM) {
|
||||
@@ -220,37 +224,35 @@ const StepTwo = ({
|
||||
}) as NotionInfo[]
|
||||
}
|
||||
|
||||
const getFileIndexingEstimateParams = (docForm: DocForm) => {
|
||||
let params
|
||||
const getFileIndexingEstimateParams = (docForm: DocForm): IndexingEstimateParams | undefined => {
|
||||
if (dataSourceType === DataSourceType.FILE) {
|
||||
params = {
|
||||
return {
|
||||
info_list: {
|
||||
data_source_type: dataSourceType,
|
||||
file_info_list: {
|
||||
file_ids: files.map(file => file.id),
|
||||
file_ids: files.map(file => file.id) as string[],
|
||||
},
|
||||
},
|
||||
indexing_technique: getIndexing_technique(),
|
||||
indexing_technique: getIndexing_technique() as string,
|
||||
process_rule: getProcessRule(),
|
||||
doc_form: docForm,
|
||||
doc_language: docLanguage,
|
||||
dataset_id: datasetId,
|
||||
dataset_id: datasetId as string,
|
||||
}
|
||||
}
|
||||
if (dataSourceType === DataSourceType.NOTION) {
|
||||
params = {
|
||||
return {
|
||||
info_list: {
|
||||
data_source_type: dataSourceType,
|
||||
notion_info_list: getNotionInfo(),
|
||||
},
|
||||
indexing_technique: getIndexing_technique(),
|
||||
indexing_technique: getIndexing_technique() as string,
|
||||
process_rule: getProcessRule(),
|
||||
doc_form: docForm,
|
||||
doc_language: docLanguage,
|
||||
dataset_id: datasetId,
|
||||
dataset_id: datasetId as string,
|
||||
}
|
||||
}
|
||||
return params
|
||||
}
|
||||
|
||||
const getCreationParams = () => {
|
||||
@@ -291,7 +293,7 @@ const StepTwo = ({
|
||||
try {
|
||||
const res = await fetchDefaultProcessRule({ url: '/datasets/process-rule' })
|
||||
const separator = res.rules.segmentation.separator
|
||||
setSegmentIdentifier(separator === '\n' ? '\\n' : separator || '\\n')
|
||||
setSegmentIdentifier((separator === '\n' ? '\\n' : separator) || '\\n')
|
||||
setMax(res.rules.segmentation.max_tokens)
|
||||
setRules(res.rules.pre_processing_rules)
|
||||
setDefaultConfig(res.rules)
|
||||
@@ -306,7 +308,7 @@ const StepTwo = ({
|
||||
const rules = documentDetail.dataset_process_rule.rules
|
||||
const separator = rules.segmentation.separator
|
||||
const max = rules.segmentation.max_tokens
|
||||
setSegmentIdentifier(separator === '\n' ? '\\n' : separator || '\\n')
|
||||
setSegmentIdentifier((separator === '\n' ? '\\n' : separator) || '\\n')
|
||||
setMax(max)
|
||||
setRules(rules.pre_processing_rules)
|
||||
setDefaultConfig(rules)
|
||||
@@ -330,7 +332,7 @@ const StepTwo = ({
|
||||
res = await createFirstDocument({
|
||||
body: params,
|
||||
})
|
||||
updateIndexingTypeCache && updateIndexingTypeCache(indexType)
|
||||
updateIndexingTypeCache && updateIndexingTypeCache(indexType as string)
|
||||
updateResultCache && updateResultCache(res)
|
||||
}
|
||||
else {
|
||||
@@ -338,7 +340,7 @@ const StepTwo = ({
|
||||
datasetId,
|
||||
body: params,
|
||||
})
|
||||
updateIndexingTypeCache && updateIndexingTypeCache(indexType)
|
||||
updateIndexingTypeCache && updateIndexingTypeCache(indexType as string)
|
||||
updateResultCache && updateResultCache(res)
|
||||
}
|
||||
if (mutateDatasetRes)
|
||||
@@ -549,9 +551,9 @@ const StepTwo = ({
|
||||
<div className={s.tip}>{t('datasetCreation.stepTwo.qualifiedTip')}</div>
|
||||
<div className='pb-0.5 text-xs font-medium text-gray-500'>{t('datasetCreation.stepTwo.emstimateCost')}</div>
|
||||
{
|
||||
fileIndexingEstimate
|
||||
estimateTokes
|
||||
? (
|
||||
<div className='text-xs font-medium text-gray-800'>{formatNumber(fileIndexingEstimate.tokens)} tokens(<span className='text-yellow-500'>${formatNumber(fileIndexingEstimate.total_price)}</span>)</div>
|
||||
<div className='text-xs font-medium text-gray-800'>{formatNumber(estimateTokes.tokens)} tokens(<span className='text-yellow-500'>${formatNumber(estimateTokes.total_price)}</span>)</div>
|
||||
)
|
||||
: (
|
||||
<div className={s.calculating}>{t('datasetCreation.stepTwo.calculating')}</div>
|
||||
|
||||
@@ -0,0 +1,4 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M4.33333 4.33333H4.34M11.6667 4.33333H11.6733M4.33333 11.6667H4.34M8.66667 8.66667H8.67333M11.6667 11.6667H11.6733M11.3333 14H14V11.3333M9.33333 11V14M14 9.33333H11M10.4 6.66667H12.9333C13.3067 6.66667 13.4934 6.66667 13.636 6.594C13.7614 6.53009 13.8634 6.4281 13.9273 6.30266C14 6.16005 14 5.97337 14 5.6V3.06667C14 2.6933 14 2.50661 13.9273 2.36401C13.8634 2.23856 13.7614 2.13658 13.636 2.07266C13.4934 2 13.3067 2 12.9333 2H10.4C10.0266 2 9.83995 2 9.69734 2.07266C9.5719 2.13658 9.46991 2.23856 9.406 2.36401C9.33333 2.50661 9.33333 2.6933 9.33333 3.06667V5.6C9.33333 5.97337 9.33333 6.16005 9.406 6.30266C9.46991 6.4281 9.5719 6.53009 9.69734 6.594C9.83995 6.66667 10.0266 6.66667 10.4 6.66667ZM3.06667 6.66667H5.6C5.97337 6.66667 6.16005 6.66667 6.30266 6.594C6.4281 6.53009 6.53009 6.4281 6.594 6.30266C6.66667 6.16005 6.66667 5.97337 6.66667 5.6V3.06667C6.66667 2.6933 6.66667 2.50661 6.594 2.36401C6.53009 2.23856 6.4281 2.13658 6.30266 2.07266C6.16005 2 5.97337 2 5.6 2H3.06667C2.6933 2 2.50661 2 2.36401 2.07266C2.23856 2.13658 2.13658 2.23856 2.07266 2.36401C2 2.50661 2 2.6933 2 3.06667V5.6C2 5.97337 2 6.16005 2.07266 6.30266C2.13658 6.4281 2.23856 6.53009 2.36401 6.594C2.50661 6.66667 2.6933 6.66667 3.06667 6.66667ZM3.06667 14H5.6C5.97337 14 6.16005 14 6.30266 13.9273C6.4281 13.8634 6.53009 13.7614 6.594 13.636C6.66667 13.4934 6.66667 13.3067 6.66667 12.9333V10.4C6.66667 10.0266 6.66667 9.83995 6.594 9.69734C6.53009 9.5719 6.4281 9.46991 6.30266 9.406C6.16005 9.33333 5.97337 9.33333 5.6 9.33333H3.06667C2.6933 9.33333 2.50661 9.33333 2.36401 9.406C2.23856 9.46991 2.13658 9.5719 2.07266 9.69734C2 9.83995 2 10.0266 2 10.4V12.9333C2 13.3067 2 13.4934 2.07266 13.636C2.13658 13.7614 2.23856 13.8634 2.36401 13.9273C2.50661 14 2.6933 14 3.06667 14Z" stroke="#1D2939" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
||||
|
After Width: | Height: | Size: 1.9 KiB |
4
web/app/components/develop/secret-key/assets/qrcode.svg
Normal file
4
web/app/components/develop/secret-key/assets/qrcode.svg
Normal file
@@ -0,0 +1,4 @@
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M4.33333 4.33333H4.34M11.6667 4.33333H11.6733M4.33333 11.6667H4.34M8.66667 8.66667H8.67333M11.6667 11.6667H11.6733M11.3333 14H14V11.3333M9.33333 11V14M14 9.33333H11M10.4 6.66667H12.9333C13.3067 6.66667 13.4934 6.66667 13.636 6.594C13.7614 6.53009 13.8634 6.4281 13.9273 6.30266C14 6.16005 14 5.97337 14 5.6V3.06667C14 2.6933 14 2.50661 13.9273 2.36401C13.8634 2.23856 13.7614 2.13658 13.636 2.07266C13.4934 2 13.3067 2 12.9333 2H10.4C10.0266 2 9.83995 2 9.69734 2.07266C9.5719 2.13658 9.46991 2.23856 9.406 2.36401C9.33333 2.50661 9.33333 2.6933 9.33333 3.06667V5.6C9.33333 5.97337 9.33333 6.16005 9.406 6.30266C9.46991 6.4281 9.5719 6.53009 9.69734 6.594C9.83995 6.66667 10.0266 6.66667 10.4 6.66667ZM3.06667 6.66667H5.6C5.97337 6.66667 6.16005 6.66667 6.30266 6.594C6.4281 6.53009 6.53009 6.4281 6.594 6.30266C6.66667 6.16005 6.66667 5.97337 6.66667 5.6V3.06667C6.66667 2.6933 6.66667 2.50661 6.594 2.36401C6.53009 2.23856 6.4281 2.13658 6.30266 2.07266C6.16005 2 5.97337 2 5.6 2H3.06667C2.6933 2 2.50661 2 2.36401 2.07266C2.23856 2.13658 2.13658 2.23856 2.07266 2.36401C2 2.50661 2 2.6933 2 3.06667V5.6C2 5.97337 2 6.16005 2.07266 6.30266C2.13658 6.4281 2.23856 6.53009 2.36401 6.594C2.50661 6.66667 2.6933 6.66667 3.06667 6.66667ZM3.06667 14H5.6C5.97337 14 6.16005 14 6.30266 13.9273C6.4281 13.8634 6.53009 13.7614 6.594 13.636C6.66667 13.4934 6.66667 13.3067 6.66667 12.9333V10.4C6.66667 10.0266 6.66667 9.83995 6.594 9.69734C6.53009 9.5719 6.4281 9.46991 6.30266 9.406C6.16005 9.33333 5.97337 9.33333 5.6 9.33333H3.06667C2.6933 9.33333 2.50661 9.33333 2.36401 9.406C2.23856 9.46991 2.13658 9.5719 2.07266 9.69734C2 9.83995 2 10.0266 2 10.4V12.9333C2 13.3067 2 13.4934 2.07266 13.636C2.13658 13.7614 2.23856 13.8634 2.36401 13.9273C2.50661 14 2.6933 14 3.06667 14Z" stroke="#667085" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"/>
|
||||
</svg>
|
||||
|
||||
|
After Width: | Height: | Size: 1.9 KiB |
@@ -0,0 +1,70 @@
|
||||
import { ProviderEnum } from '../declarations'
|
||||
import type { ProviderConfig } from '../declarations'
|
||||
import { BaichuanTextCn } from '@/app/components/base/icons/src/image/llm'
|
||||
import {
|
||||
Baichuan,
|
||||
BaichuanText,
|
||||
} from '@/app/components/base/icons/src/public/llm'
|
||||
|
||||
const config: ProviderConfig = {
|
||||
selector: {
|
||||
name: {
|
||||
'en': 'BAICHUAN AI',
|
||||
'zh-Hans': '百川智能',
|
||||
},
|
||||
icon: <Baichuan className='w-full h-full' />,
|
||||
},
|
||||
item: {
|
||||
key: ProviderEnum.baichuan,
|
||||
titleIcon: {
|
||||
'en': <BaichuanText className='w-[124px] h-6' />,
|
||||
'zh-Hans': <BaichuanTextCn className='w-[100px] h-6' />,
|
||||
},
|
||||
},
|
||||
modal: {
|
||||
key: ProviderEnum.baichuan,
|
||||
title: {
|
||||
'en': 'BAICHUAN AI',
|
||||
'zh-Hans': '百川智能',
|
||||
},
|
||||
icon: <Baichuan className='w-6 h-6' />,
|
||||
link: {
|
||||
href: 'https://platform.baichuan-ai.com/console/apikey',
|
||||
label: {
|
||||
'en': 'Get your API key from BAICHUAN AI',
|
||||
'zh-Hans': '从百川智能获取 API Key',
|
||||
},
|
||||
},
|
||||
validateKeys: ['api_key', 'secret_key'],
|
||||
fields: [
|
||||
{
|
||||
type: 'text',
|
||||
key: 'api_key',
|
||||
required: true,
|
||||
label: {
|
||||
'en': 'API Key',
|
||||
'zh-Hans': 'API Key',
|
||||
},
|
||||
placeholder: {
|
||||
'en': 'Enter your API key here',
|
||||
'zh-Hans': '在此输入您的 API Key',
|
||||
},
|
||||
},
|
||||
{
|
||||
type: 'text',
|
||||
key: 'secret_key',
|
||||
required: true,
|
||||
label: {
|
||||
'en': 'Secret Key',
|
||||
'zh-Hans': 'Secret Key',
|
||||
},
|
||||
placeholder: {
|
||||
'en': 'Enter your Secret key here',
|
||||
'zh-Hans': '在此输入您的 Secret Key',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
}
|
||||
|
||||
export default config
|
||||
@@ -12,6 +12,7 @@ import xinference from './xinference'
|
||||
import openllm from './openllm'
|
||||
import localai from './localai'
|
||||
import zhipuai from './zhipuai'
|
||||
import baichuan from './baichuan'
|
||||
|
||||
export default {
|
||||
openai,
|
||||
@@ -28,4 +29,5 @@ export default {
|
||||
openllm,
|
||||
localai,
|
||||
zhipuai,
|
||||
baichuan,
|
||||
}
|
||||
|
||||
@@ -56,6 +56,31 @@ const config: ProviderConfig = {
|
||||
'server_url',
|
||||
],
|
||||
fields: [
|
||||
{
|
||||
type: 'radio',
|
||||
key: 'model_type',
|
||||
required: true,
|
||||
label: {
|
||||
'en': 'Model Type',
|
||||
'zh-Hans': '模型类型',
|
||||
},
|
||||
options: [
|
||||
{
|
||||
key: 'text-generation',
|
||||
label: {
|
||||
'en': 'Text Generation',
|
||||
'zh-Hans': '文本生成',
|
||||
},
|
||||
},
|
||||
{
|
||||
key: 'embeddings',
|
||||
label: {
|
||||
'en': 'Embeddings',
|
||||
'zh-Hans': 'Embeddings',
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
type: 'text',
|
||||
key: 'model_name',
|
||||
|
||||
@@ -43,6 +43,7 @@ export enum ProviderEnum {
|
||||
'openllm' = 'openllm',
|
||||
'localai' = 'localai',
|
||||
'zhipuai' = 'zhipuai',
|
||||
'baichuan' = 'baichuan',
|
||||
}
|
||||
|
||||
export type ProviderConfigItem = {
|
||||
|
||||
@@ -79,6 +79,7 @@ const ModelPage = () => {
|
||||
config.replicate,
|
||||
config.huggingface_hub,
|
||||
config.zhipuai,
|
||||
config.baichuan,
|
||||
config.spark,
|
||||
config.minimax,
|
||||
config.tongyi,
|
||||
@@ -93,6 +94,7 @@ const ModelPage = () => {
|
||||
modelList = [
|
||||
config.huggingface_hub,
|
||||
config.zhipuai,
|
||||
config.baichuan,
|
||||
config.spark,
|
||||
config.minimax,
|
||||
config.azure_openai,
|
||||
|
||||
@@ -61,6 +61,11 @@ const translation = {
|
||||
copied: 'Copied',
|
||||
copy: 'Copy',
|
||||
},
|
||||
qrcode: {
|
||||
title: 'QR code to share',
|
||||
scan: 'Scan Share Application',
|
||||
download: 'Download QR Code',
|
||||
},
|
||||
customize: {
|
||||
way: 'way',
|
||||
entry: 'Customize',
|
||||
|
||||
@@ -61,6 +61,11 @@ const translation = {
|
||||
copied: '已复制',
|
||||
copy: '复制',
|
||||
},
|
||||
qrcode: {
|
||||
title: '二维码分享',
|
||||
scan: '扫码分享应用',
|
||||
download: '下载二维码',
|
||||
},
|
||||
customize: {
|
||||
way: '方法',
|
||||
entry: '定制化',
|
||||
|
||||
@@ -183,15 +183,22 @@ export type DocumentListResponse = {
|
||||
limit: number
|
||||
}
|
||||
|
||||
export type CreateDocumentReq = {
|
||||
export type DocumentReq = {
|
||||
original_document_id?: string
|
||||
indexing_technique?: string
|
||||
doc_form: 'text_model' | 'qa_model'
|
||||
doc_language: string
|
||||
data_source: DataSource
|
||||
process_rule: ProcessRule
|
||||
}
|
||||
|
||||
export type CreateDocumentReq = DocumentReq & {
|
||||
data_source: DataSource
|
||||
}
|
||||
|
||||
export type IndexingEstimateParams = DocumentReq & Partial<DataSource> & {
|
||||
dataset_id: string
|
||||
}
|
||||
|
||||
export type DataSource = {
|
||||
type: DataSourceType
|
||||
info_list: {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "dify-web",
|
||||
"version": "0.3.24",
|
||||
"version": "0.3.25",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "next dev",
|
||||
@@ -46,6 +46,7 @@
|
||||
"mermaid": "10.4.0",
|
||||
"negotiator": "^0.6.3",
|
||||
"next": "13.3.1",
|
||||
"qrcode.react": "^3.1.0",
|
||||
"qs": "^6.11.1",
|
||||
"react": "^18.2.0",
|
||||
"react-18-input-autosize": "^3.0.0",
|
||||
|
||||
@@ -10,6 +10,7 @@ import type {
|
||||
FileIndexingEstimateResponse,
|
||||
HitTestingRecordsResponse,
|
||||
HitTestingResponse,
|
||||
IndexingEstimateParams,
|
||||
IndexingEstimateResponse,
|
||||
IndexingStatusBatchResponse,
|
||||
IndexingStatusResponse,
|
||||
@@ -189,7 +190,7 @@ export const fetchTestingRecords: Fetcher<HitTestingRecordsResponse, { datasetId
|
||||
return get<HitTestingRecordsResponse>(`/datasets/${datasetId}/queries`, { params })
|
||||
}
|
||||
|
||||
export const fetchFileIndexingEstimate: Fetcher<FileIndexingEstimateResponse, any> = (body: any) => {
|
||||
export const fetchFileIndexingEstimate: Fetcher<FileIndexingEstimateResponse, IndexingEstimateParams> = (body: IndexingEstimateParams) => {
|
||||
return post<FileIndexingEstimateResponse>('/datasets/indexing-estimate', { body })
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user