mirror of
https://github.com/langgenius/dify.git
synced 2026-01-13 18:12:37 +00:00
Compare commits
26 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
91ff07fcf7 | ||
|
|
bb7af56e69 | ||
|
|
77f9e8ce0f | ||
|
|
5ca4c4a44d | ||
|
|
a44022c388 | ||
|
|
6333cf43a8 | ||
|
|
91ee62d1ab | ||
|
|
ede69b4659 | ||
|
|
61aaeff413 | ||
|
|
4e1cd75f6f | ||
|
|
a8ff2e95da | ||
|
|
4d502ea44d | ||
|
|
66b3588897 | ||
|
|
9134849744 | ||
|
|
fcf8512956 | ||
|
|
ae975b10e9 | ||
|
|
b43f1441a9 | ||
|
|
5a2aa83030 | ||
|
|
4de27d0404 | ||
|
|
c6d59681ff | ||
|
|
3b668c0bb1 | ||
|
|
4aed1fe8a8 | ||
|
|
2381264a3f | ||
|
|
4562e83b24 | ||
|
|
7be77c19f5 | ||
|
|
82247c0f14 |
@@ -91,7 +91,7 @@ After running, you can access the Dify dashboard in your browser at [http://loca
|
||||
|
||||
### Helm Chart
|
||||
|
||||
A big thanks to @BorisPolonsky for providing us with a [Helm Chart](https://helm.sh/) version, which allows Dify to be deployed on Kubernetes.
|
||||
Big thanks to @BorisPolonsky for providing us with a [Helm Chart](https://helm.sh/) version, which allows Dify to be deployed on Kubernetes.
|
||||
You can go to https://github.com/BorisPolonsky/dify-helm for deployment information.
|
||||
|
||||
### Configuration
|
||||
|
||||
@@ -87,7 +87,7 @@ class Config:
|
||||
# ------------------------
|
||||
# General Configurations.
|
||||
# ------------------------
|
||||
self.CURRENT_VERSION = "0.4.0"
|
||||
self.CURRENT_VERSION = "0.4.2"
|
||||
self.COMMIT_SHA = get_env('COMMIT_SHA')
|
||||
self.EDITION = "SELF_HOSTED"
|
||||
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
|
||||
@@ -197,6 +197,7 @@ class Config:
|
||||
# qdrant settings
|
||||
self.QDRANT_URL = get_env('QDRANT_URL')
|
||||
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
|
||||
self.QDRANT_CLIENT_TIMEOUT = get_env('QDRANT_CLIENT_TIMEOUT')
|
||||
|
||||
# milvus / zilliz setting
|
||||
self.MILVUS_HOST = get_env('MILVUS_HOST')
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
import logging
|
||||
|
||||
from flask_login import current_user
|
||||
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
from libs.login import login_required
|
||||
from flask_restful import Resource, reqparse, marshal
|
||||
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
|
||||
@@ -8,7 +10,7 @@ from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
|
||||
import services
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \
|
||||
ProviderModelCurrentlyNotSupportError
|
||||
ProviderModelCurrentlyNotSupportError, CompletionRequestError
|
||||
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
|
||||
from controllers.console.setup import setup_required
|
||||
from controllers.console.wraps import account_initialization_required
|
||||
@@ -69,6 +71,8 @@ class HitTestingApi(Resource):
|
||||
raise ProviderNotInitializeError(
|
||||
f"No Embedding Model or Reranking Model available. Please configure a valid provider "
|
||||
f"in the Settings -> Model Provider.")
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
except ValueError as e:
|
||||
raise ValueError(str(e))
|
||||
except Exception as e:
|
||||
|
||||
@@ -237,8 +237,8 @@ class AgentApplicationRunner(AppRunner):
|
||||
all_message_tokens = 0
|
||||
all_answer_tokens = 0
|
||||
for agent_thought in agent_thoughts:
|
||||
all_message_tokens += agent_thought.message_tokens
|
||||
all_answer_tokens += agent_thought.answer_tokens
|
||||
all_message_tokens += agent_thought.message_token
|
||||
all_answer_tokens += agent_thought.answer_token
|
||||
|
||||
model_type_instance = model_config.provider_model_bundle.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
@@ -376,7 +376,8 @@ class ApplicationManager:
|
||||
and 'enabled' in copy_app_model_config_dict['agent_mode'] and copy_app_model_config_dict['agent_mode'][
|
||||
'enabled']:
|
||||
agent_dict = copy_app_model_config_dict.get('agent_mode')
|
||||
if agent_dict['strategy'] in ['router', 'react_router']:
|
||||
agent_strategy = agent_dict.get('strategy', 'router')
|
||||
if agent_strategy in ['router', 'react_router']:
|
||||
dataset_ids = []
|
||||
for tool in agent_dict.get('tools', []):
|
||||
key = list(tool.keys())[0]
|
||||
@@ -402,7 +403,7 @@ class ApplicationManager:
|
||||
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
|
||||
dataset_configs['retrieval_model']
|
||||
),
|
||||
single_strategy=agent_dict['strategy']
|
||||
single_strategy=agent_strategy
|
||||
)
|
||||
)
|
||||
else:
|
||||
@@ -419,7 +420,7 @@ class ApplicationManager:
|
||||
)
|
||||
)
|
||||
else:
|
||||
if agent_dict['strategy'] == 'react':
|
||||
if agent_strategy == 'react':
|
||||
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
|
||||
else:
|
||||
strategy = AgentEntity.Strategy.FUNCTION_CALLING
|
||||
@@ -472,7 +473,7 @@ class ApplicationManager:
|
||||
more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
|
||||
if more_like_this_dict:
|
||||
if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
|
||||
properties['more_like_this'] = copy_app_model_config_dict.get('opening_statement')
|
||||
properties['more_like_this'] = True
|
||||
|
||||
# speech to text
|
||||
speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')
|
||||
|
||||
@@ -18,6 +18,7 @@ from models.dataset import Dataset, DatasetCollectionBinding
|
||||
class QdrantConfig(BaseModel):
|
||||
endpoint: str
|
||||
api_key: Optional[str]
|
||||
timeout: float = 20
|
||||
root_path: Optional[str]
|
||||
|
||||
def to_qdrant_params(self):
|
||||
@@ -33,6 +34,7 @@ class QdrantConfig(BaseModel):
|
||||
return {
|
||||
'url': self.endpoint,
|
||||
'api_key': self.api_key,
|
||||
'timeout': self.timeout
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -49,7 +49,8 @@ class VectorIndex:
|
||||
config=QdrantConfig(
|
||||
endpoint=config.get('QDRANT_URL'),
|
||||
api_key=config.get('QDRANT_API_KEY'),
|
||||
root_path=current_app.root_path
|
||||
root_path=current_app.root_path,
|
||||
timeout=config.get('QDRANT_CLIENT_TIMEOUT')
|
||||
),
|
||||
embeddings=embeddings
|
||||
)
|
||||
|
||||
@@ -5,12 +5,12 @@ import re
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
from typing import Optional, List, cast
|
||||
from typing import Optional, List, cast, Type, Union, Literal, AbstractSet, Collection, Any
|
||||
|
||||
from flask import current_app, Flask
|
||||
from flask_login import current_user
|
||||
from langchain.schema import Document
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
|
||||
from langchain.text_splitter import TextSplitter, TS, TokenTextSplitter
|
||||
from sqlalchemy.orm.exc import ObjectDeletedError
|
||||
|
||||
from core.data_loader.file_extractor import FileExtractor
|
||||
@@ -23,7 +23,8 @@ from core.errors.error import ProviderTokenNotInitError
|
||||
from core.model_runtime.entities.model_entities import ModelType, PriceType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
|
||||
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter, EnhanceRecursiveCharacterTextSplitter
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
@@ -502,7 +503,8 @@ class IndexingRunner:
|
||||
if separator:
|
||||
separator = separator.replace('\\n', '\n')
|
||||
|
||||
character_splitter = FixedRecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||
|
||||
character_splitter = FixedRecursiveCharacterTextSplitter.from_gpt2_encoder(
|
||||
chunk_size=segmentation["max_tokens"],
|
||||
chunk_overlap=0,
|
||||
fixed_separator=separator,
|
||||
@@ -510,7 +512,7 @@ class IndexingRunner:
|
||||
)
|
||||
else:
|
||||
# Automatic segmentation
|
||||
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
|
||||
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_gpt2_encoder(
|
||||
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
|
||||
chunk_overlap=0,
|
||||
separators=["\n\n", "。", ".", " ", ""]
|
||||
|
||||
@@ -18,7 +18,7 @@ PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = {
|
||||
'default': 0.0,
|
||||
'min': 0.0,
|
||||
'max': 1.0,
|
||||
'precision': 1,
|
||||
'precision': 2,
|
||||
},
|
||||
DefaultParameterName.TOP_P: {
|
||||
'label': {
|
||||
@@ -34,7 +34,7 @@ PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = {
|
||||
'default': 1.0,
|
||||
'min': 0.0,
|
||||
'max': 1.0,
|
||||
'precision': 1,
|
||||
'precision': 2,
|
||||
},
|
||||
DefaultParameterName.PRESENCE_PENALTY: {
|
||||
'label': {
|
||||
@@ -50,7 +50,7 @@ PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = {
|
||||
'default': 0.0,
|
||||
'min': 0.0,
|
||||
'max': 1.0,
|
||||
'precision': 1,
|
||||
'precision': 2,
|
||||
},
|
||||
DefaultParameterName.FREQUENCY_PENALTY: {
|
||||
'label': {
|
||||
@@ -66,7 +66,7 @@ PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = {
|
||||
'default': 0.0,
|
||||
'min': 0.0,
|
||||
'max': 1.0,
|
||||
'precision': 1,
|
||||
'precision': 2,
|
||||
},
|
||||
DefaultParameterName.MAX_TOKENS: {
|
||||
'label': {
|
||||
|
||||
@@ -132,8 +132,8 @@ class LargeLanguageModel(AIModel):
|
||||
system_fingerprint = None
|
||||
real_model = model
|
||||
|
||||
for chunk in result:
|
||||
try:
|
||||
try:
|
||||
for chunk in result:
|
||||
yield chunk
|
||||
|
||||
self._trigger_new_chunk_callbacks(
|
||||
@@ -156,8 +156,8 @@ class LargeLanguageModel(AIModel):
|
||||
|
||||
if chunk.system_fingerprint:
|
||||
system_fingerprint = chunk.system_fingerprint
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
except Exception as e:
|
||||
raise self._transform_invoke_error(e)
|
||||
|
||||
self._trigger_after_invoke_callbacks(
|
||||
model=model,
|
||||
|
||||
@@ -252,6 +252,9 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
|
||||
:param messages: List of PromptMessage to combine.
|
||||
:return: Combined string with necessary human_prompt and ai_prompt tags.
|
||||
"""
|
||||
if not messages:
|
||||
return ''
|
||||
|
||||
messages = messages.copy() # don't mutate the original list
|
||||
if not isinstance(messages[-1], AssistantPromptMessage):
|
||||
messages.append(AssistantPromptMessage(content=""))
|
||||
|
||||
@@ -448,6 +448,46 @@ LLM_BASE_MODELS = [
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
),
|
||||
AzureBaseModel(
|
||||
base_model_name='text-davinci-003',
|
||||
entity=AIModelEntity(
|
||||
model='fake-deployment-name',
|
||||
label=I18nObject(
|
||||
en_US='fake-deployment-name-label',
|
||||
),
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
'mode': LLMMode.COMPLETION.value,
|
||||
'context_size': 4096,
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
name='temperature',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
|
||||
),
|
||||
ParameterRule(
|
||||
name='top_p',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
|
||||
),
|
||||
ParameterRule(
|
||||
name='presence_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
|
||||
),
|
||||
ParameterRule(
|
||||
name='frequency_penalty',
|
||||
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
|
||||
),
|
||||
_get_max_tokens(default=512, min_val=1, max_val=4096),
|
||||
],
|
||||
pricing=PriceConfig(
|
||||
input=0.02,
|
||||
output=0.02,
|
||||
unit=0.001,
|
||||
currency='USD',
|
||||
)
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import Optional, Generator, Union, List
|
||||
import google.generativeai as genai
|
||||
import google.api_core.exceptions as exceptions
|
||||
import google.generativeai.client as client
|
||||
from google.generativeai.types import HarmCategory, HarmBlockThreshold
|
||||
|
||||
from google.generativeai.types import GenerateContentResponse, ContentType
|
||||
from google.generativeai.types.content_types import to_part
|
||||
@@ -124,7 +125,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
last_msg = prompt_messages[-1]
|
||||
content = self._format_message_to_glm_content(last_msg)
|
||||
history.append(content)
|
||||
else:
|
||||
else:
|
||||
for msg in prompt_messages: # makes message roles strictly alternating
|
||||
content = self._format_message_to_glm_content(msg)
|
||||
if history and history[-1]["role"] == content["role"]:
|
||||
@@ -139,13 +140,21 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
new_custom_client = new_client_manager.make_client("generative")
|
||||
|
||||
google_model._client = new_custom_client
|
||||
|
||||
safety_settings={
|
||||
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
|
||||
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
|
||||
}
|
||||
|
||||
response = google_model.generate_content(
|
||||
contents=history,
|
||||
generation_config=genai.types.GenerationConfig(
|
||||
**config_kwargs
|
||||
),
|
||||
stream=stream
|
||||
stream=stream,
|
||||
safety_settings=safety_settings
|
||||
)
|
||||
|
||||
if stream:
|
||||
@@ -169,7 +178,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
content=response.text
|
||||
)
|
||||
|
||||
|
||||
# calculate num tokens
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
@@ -202,11 +210,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
|
||||
for chunk in response:
|
||||
content = chunk.text
|
||||
index += 1
|
||||
|
||||
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=content if content else '',
|
||||
)
|
||||
|
||||
|
||||
if not response._done:
|
||||
|
||||
# transform assistant message to prompt message
|
||||
|
||||
@@ -154,20 +154,31 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
|
||||
content=chunk.token.text
|
||||
)
|
||||
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
if chunk.details:
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
finish_reason=chunk.details.finish_reason,
|
||||
),
|
||||
)
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
),
|
||||
)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any) -> LLMResult:
|
||||
if isinstance(response, str):
|
||||
|
||||
@@ -68,7 +68,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
||||
|
||||
for i in _iter:
|
||||
# call embedding model
|
||||
embeddings, embedding_used_tokens = self._embedding_invoke(
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
client=client,
|
||||
texts=tokens[i: i + max_chunks],
|
||||
@@ -76,7 +76,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
batched_embeddings += [data for data in embeddings]
|
||||
batched_embeddings += embeddings_batch
|
||||
|
||||
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
|
||||
num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
|
||||
@@ -87,7 +87,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
||||
for i in range(len(texts)):
|
||||
_result = results[i]
|
||||
if len(_result) == 0:
|
||||
embeddings, embedding_used_tokens = self._embedding_invoke(
|
||||
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
|
||||
model=model,
|
||||
client=client,
|
||||
texts=[""],
|
||||
@@ -95,7 +95,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
|
||||
)
|
||||
|
||||
used_tokens += embedding_used_tokens
|
||||
average = embeddings[0]
|
||||
average = embeddings_batch[0]
|
||||
else:
|
||||
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
|
||||
embeddings[i] = (average / np.linalg.norm(average)).tolist()
|
||||
|
||||
@@ -1,19 +1,21 @@
|
||||
import logging
|
||||
from decimal import Decimal
|
||||
from urllib.parse import urljoin
|
||||
|
||||
import requests
|
||||
import json
|
||||
|
||||
from typing import Optional, Generator, Union, List, cast
|
||||
|
||||
from sympy import comp
|
||||
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
from core.model_runtime.utils import helper
|
||||
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, AssistantPromptMessage, PromptMessageContent, \
|
||||
PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, ToolPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, DefaultParameterName, \
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, \
|
||||
AssistantPromptMessage, PromptMessageContent, \
|
||||
PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, \
|
||||
ToolPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, \
|
||||
DefaultParameterName, \
|
||||
ParameterType, ModelPropertyKey, FetchFrom, AIModelEntity
|
||||
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
|
||||
from core.model_runtime.errors.invoke import InvokeError
|
||||
@@ -72,7 +74,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
:return:
|
||||
"""
|
||||
return self._num_tokens_from_messages(model, prompt_messages, tools)
|
||||
|
||||
|
||||
def validate_credentials(self, model: str, credentials: dict) -> None:
|
||||
"""
|
||||
Validate model credentials using requests to ensure compatibility with all providers following OpenAI's API standard.
|
||||
@@ -91,6 +93,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
endpoint_url = credentials['endpoint_url']
|
||||
if not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
|
||||
# prepare the payload for a simple ping to the model
|
||||
data = {
|
||||
@@ -107,11 +111,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
"content": "ping"
|
||||
},
|
||||
]
|
||||
endpoint_url = urljoin(endpoint_url, 'chat/completions')
|
||||
elif completion_type is LLMMode.COMPLETION:
|
||||
data['prompt'] = 'ping'
|
||||
endpoint_url = urljoin(endpoint_url, 'completions')
|
||||
else:
|
||||
raise ValueError("Unsupported completion type for model configuration.")
|
||||
|
||||
|
||||
# send a post request to validate the credentials
|
||||
response = requests.post(
|
||||
endpoint_url,
|
||||
@@ -121,8 +127,24 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise CredentialsValidateFailedError(f'Credentials validation failed with status code {response.status_code}: {response.text}')
|
||||
raise CredentialsValidateFailedError(
|
||||
f'Credentials validation failed with status code {response.status_code}')
|
||||
|
||||
try:
|
||||
json_result = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error')
|
||||
|
||||
if (completion_type is LLMMode.CHAT
|
||||
and ('object' not in json_result or json_result['object'] != 'chat.completion')):
|
||||
raise CredentialsValidateFailedError(
|
||||
f'Credentials validation failed: invalid response object, must be \'chat.completion\'')
|
||||
elif (completion_type is LLMMode.COMPLETION
|
||||
and ('object' not in json_result or json_result['object'] != 'text_completion')):
|
||||
raise CredentialsValidateFailedError(
|
||||
f'Credentials validation failed: invalid response object, must be \'text_completion\'')
|
||||
except CredentialsValidateFailedError:
|
||||
raise
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
|
||||
|
||||
@@ -136,8 +158,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
model_type=ModelType.LLM,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'),
|
||||
ModelPropertyKey.MODE: 'chat'
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
|
||||
ModelPropertyKey.MODE: credentials.get('mode'),
|
||||
},
|
||||
parameter_rules=[
|
||||
ParameterRule(
|
||||
@@ -199,11 +221,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
|
||||
return entity
|
||||
|
||||
|
||||
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, \
|
||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
|
||||
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
|
||||
stream: bool = True, \
|
||||
user: Optional[str] = None) -> Union[LLMResult, Generator]:
|
||||
"""
|
||||
Invoke llm completion model
|
||||
|
||||
@@ -225,7 +247,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
endpoint_url = credentials["endpoint_url"]
|
||||
|
||||
if not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
|
||||
data = {
|
||||
"model": model,
|
||||
"stream": stream,
|
||||
@@ -235,8 +259,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
completion_type = LLMMode.value_of(credentials['mode'])
|
||||
|
||||
if completion_type is LLMMode.CHAT:
|
||||
endpoint_url = urljoin(endpoint_url, 'chat/completions')
|
||||
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
|
||||
elif completion_type == LLMMode.COMPLETION:
|
||||
endpoint_url = urljoin(endpoint_url, 'completions')
|
||||
data['prompt'] = prompt_messages[0].content
|
||||
else:
|
||||
raise ValueError("Unsupported completion type for model configuration.")
|
||||
@@ -247,8 +273,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
data["tool_choice"] = "auto"
|
||||
|
||||
for tool in tools:
|
||||
formatted_tools.append( helper.dump_model(PromptMessageFunction(function=tool)))
|
||||
|
||||
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
|
||||
|
||||
data["tools"] = formatted_tools
|
||||
|
||||
if stop:
|
||||
@@ -256,7 +282,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
|
||||
if user:
|
||||
data["user"] = user
|
||||
|
||||
|
||||
response = requests.post(
|
||||
endpoint_url,
|
||||
headers=headers,
|
||||
@@ -277,8 +303,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
|
||||
return self._handle_generate_response(model, credentials, response, prompt_messages)
|
||||
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
|
||||
prompt_messages: list[PromptMessage]) -> Generator:
|
||||
"""
|
||||
Handle llm stream response
|
||||
|
||||
@@ -315,51 +341,64 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
if chunk:
|
||||
decoded_chunk = chunk.decode('utf-8').strip().lstrip('data: ').lstrip()
|
||||
|
||||
chunk_json = None
|
||||
try:
|
||||
chunk_json = json.loads(decoded_chunk)
|
||||
# stream ended
|
||||
except json.JSONDecodeError as e:
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index + 1,
|
||||
index=chunk_index + 1,
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="Non-JSON encountered."
|
||||
)
|
||||
|
||||
if len(chunk_json['choices']) == 0:
|
||||
if not chunk_json or len(chunk_json['choices']) == 0:
|
||||
continue
|
||||
|
||||
delta = chunk_json['choices'][0]['delta']
|
||||
chunk_index = chunk_json['choices'][0]['index']
|
||||
choice = chunk_json['choices'][0]
|
||||
chunk_index = choice['index'] if 'index' in choice else chunk_index
|
||||
|
||||
if delta.get('finish_reason') is None and (delta.get('content') is None or delta.get('content') == ''):
|
||||
if 'delta' in choice:
|
||||
delta = choice['delta']
|
||||
if delta.get('content') is None or delta.get('content') == '':
|
||||
continue
|
||||
|
||||
assistant_message_tool_calls = delta.get('tool_calls', None)
|
||||
# assistant_message_function_call = delta.delta.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
if assistant_message_tool_calls:
|
||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
# function_call = self._extract_response_function_call(assistant_message_function_call)
|
||||
# tool_calls = [function_call] if function_call else []
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.get('content', ''),
|
||||
tool_calls=tool_calls if assistant_message_tool_calls else []
|
||||
)
|
||||
|
||||
full_assistant_content += delta.get('content', '')
|
||||
elif 'text' in choice:
|
||||
if choice.get('text') is None or choice.get('text') == '':
|
||||
continue
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=choice.get('text', '')
|
||||
)
|
||||
|
||||
full_assistant_content += choice.get('text', '')
|
||||
else:
|
||||
continue
|
||||
|
||||
assistant_message_tool_calls = delta.get('tool_calls', None)
|
||||
# assistant_message_function_call = delta.delta.function_call
|
||||
|
||||
# extract tool calls from response
|
||||
if assistant_message_tool_calls:
|
||||
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
|
||||
# function_call = self._extract_response_function_call(assistant_message_function_call)
|
||||
# tool_calls = [function_call] if function_call else []
|
||||
|
||||
# transform assistant message to prompt message
|
||||
assistant_prompt_message = AssistantPromptMessage(
|
||||
content=delta.get('content', ''),
|
||||
tool_calls=tool_calls if assistant_message_tool_calls else []
|
||||
)
|
||||
|
||||
full_assistant_content += delta.get('content', '')
|
||||
|
||||
# check payload indicator for completion
|
||||
if chunk_json['choices'][0].get('finish_reason') is not None:
|
||||
|
||||
yield create_final_llm_result_chunk(
|
||||
index=chunk_index,
|
||||
message=assistant_prompt_message,
|
||||
finish_reason=chunk_json['choices'][0]['finish_reason']
|
||||
)
|
||||
|
||||
else:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
@@ -375,10 +414,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
message=AssistantPromptMessage(content=""),
|
||||
finish_reason="End of stream."
|
||||
)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
|
||||
|
||||
chunk_index += 1
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response,
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
|
||||
response_json = response.json()
|
||||
|
||||
completion_type = LLMMode.value_of(credentials['mode'])
|
||||
@@ -457,7 +498,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
message = cast(AssistantPromptMessage, message)
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
if message.tool_calls:
|
||||
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call in
|
||||
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
|
||||
in
|
||||
message.tool_calls]
|
||||
# function_call = message.tool_calls[0]
|
||||
# message_dict["function_call"] = {
|
||||
@@ -486,7 +528,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
message_dict["name"] = message.name
|
||||
|
||||
return message_dict
|
||||
|
||||
|
||||
def _num_tokens_from_string(self, model: str, text: str,
|
||||
tools: Optional[list[PromptMessageTool]] = None) -> int:
|
||||
"""
|
||||
@@ -509,10 +551,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
"""
|
||||
Approximate num tokens with GPT2 tokenizer.
|
||||
"""
|
||||
|
||||
|
||||
tokens_per_message = 3
|
||||
tokens_per_name = 1
|
||||
|
||||
|
||||
num_tokens = 0
|
||||
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
|
||||
for message in messages_dict:
|
||||
@@ -601,7 +643,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
|
||||
num_tokens += self._get_num_tokens_by_gpt2(required_field)
|
||||
|
||||
return num_tokens
|
||||
|
||||
|
||||
def _extract_response_tool_calls(self,
|
||||
response_tool_calls: list[dict]) \
|
||||
-> list[AssistantPromptMessage.ToolCall]:
|
||||
|
||||
@@ -33,8 +33,8 @@ model_credential_schema:
|
||||
type: text-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 API endpoint URL
|
||||
en_US: Enter your API endpoint URL
|
||||
zh_Hans: Base URL, eg. https://api.openai.com/v1
|
||||
en_US: Base URL, eg. https://api.openai.com/v1
|
||||
- variable: mode
|
||||
show_on:
|
||||
- variable: __model_type
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import time
|
||||
from decimal import Decimal
|
||||
from typing import Optional
|
||||
from urllib.parse import urljoin
|
||||
import requests
|
||||
import json
|
||||
|
||||
@@ -42,8 +43,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
endpoint_url = credentials.get('endpoint_url')
|
||||
if not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
|
||||
endpoint_url = credentials['endpoint_url']
|
||||
endpoint_url = urljoin(endpoint_url, 'embeddings')
|
||||
|
||||
extra_model_kwargs = {}
|
||||
if user:
|
||||
@@ -144,8 +148,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
||||
if api_key:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
endpoint_url = credentials.get('endpoint_url')
|
||||
if not endpoint_url.endswith('/'):
|
||||
endpoint_url += '/'
|
||||
|
||||
endpoint_url = credentials['endpoint_url']
|
||||
endpoint_url = urljoin(endpoint_url, 'embeddings')
|
||||
|
||||
payload = {
|
||||
'input': 'ping',
|
||||
@@ -160,8 +167,19 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
||||
)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise CredentialsValidateFailedError(f"Invalid response status: {response.status_code}")
|
||||
raise CredentialsValidateFailedError(
|
||||
f'Credentials validation failed with status code {response.status_code}')
|
||||
|
||||
try:
|
||||
json_result = response.json()
|
||||
except json.JSONDecodeError as e:
|
||||
raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error')
|
||||
|
||||
if 'model' not in json_result:
|
||||
raise CredentialsValidateFailedError(
|
||||
f'Credentials validation failed: invalid response')
|
||||
except CredentialsValidateFailedError:
|
||||
raise
|
||||
except Exception as ex:
|
||||
raise CredentialsValidateFailedError(str(ex))
|
||||
|
||||
@@ -175,7 +193,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
|
||||
model_properties={
|
||||
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'),
|
||||
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
|
||||
ModelPropertyKey.MAX_CHUNKS: 1,
|
||||
},
|
||||
parameter_rules=[],
|
||||
|
||||
@@ -116,7 +116,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||
)
|
||||
|
||||
for key, value in input_properties:
|
||||
if key not in ['system_prompt', 'prompt']:
|
||||
if key not in ['system_prompt', 'prompt'] and 'stop' not in key:
|
||||
value_type = value.get('type')
|
||||
|
||||
if not value_type:
|
||||
@@ -151,9 +151,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||
index = -1
|
||||
current_completion: str = ""
|
||||
stop_condition_reached = False
|
||||
|
||||
prediction_output_length = 10000
|
||||
is_prediction_output_finished = False
|
||||
|
||||
for output in prediction.output_iterator():
|
||||
current_completion += output
|
||||
|
||||
if not is_prediction_output_finished and prediction.status == 'succeeded':
|
||||
prediction_output_length = len(prediction.output) - 1
|
||||
is_prediction_output_finished = True
|
||||
|
||||
if stop:
|
||||
for s in stop:
|
||||
if s in current_completion:
|
||||
@@ -172,20 +180,30 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
|
||||
content=output if output else ''
|
||||
)
|
||||
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
if index < prediction_output_length:
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message
|
||||
)
|
||||
)
|
||||
else:
|
||||
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
|
||||
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
|
||||
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage,
|
||||
),
|
||||
)
|
||||
yield LLMResultChunk(
|
||||
model=model,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=index,
|
||||
message=assistant_prompt_message,
|
||||
usage=usage
|
||||
)
|
||||
)
|
||||
|
||||
def _handle_generate_response(self, model: str, credentials: dict, prediction: Prediction, stop: list[str],
|
||||
prompt_messages: list[PromptMessage]) -> LLMResult:
|
||||
|
||||
@@ -33,10 +33,13 @@ class XinferenceHelper:
|
||||
|
||||
@staticmethod
|
||||
def _clean_cache() -> None:
|
||||
with cache_lock:
|
||||
for model_uid, model in cache.items():
|
||||
if model['expires'] < time():
|
||||
try:
|
||||
with cache_lock:
|
||||
expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()]
|
||||
for model_uid in expired_keys:
|
||||
del cache[model_uid]
|
||||
except RuntimeError as e:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:
|
||||
|
||||
@@ -31,7 +31,7 @@ model_credential_schema:
|
||||
label:
|
||||
zh_Hans: 服务器URL
|
||||
en_US: Server url
|
||||
type: text-input
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入Xinference的服务器地址,如 https://example.com/xxx
|
||||
|
||||
@@ -117,7 +117,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
|
||||
|
||||
params = {
|
||||
'model': model,
|
||||
'prompt': [{ 'role': prompt_message.role.value, 'content': prompt_message.content } for prompt_message in prompt_messages],
|
||||
'prompt': [{
|
||||
'role': prompt_message.role.value if prompt_message.role.value != 'system' else 'user',
|
||||
'content': prompt_message.content
|
||||
} for prompt_message in prompt_messages],
|
||||
**model_parameters
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ provider_credential_schema:
|
||||
- variable: api_key
|
||||
label:
|
||||
en_US: APIKey
|
||||
type: text-input
|
||||
type: secret-input
|
||||
required: true
|
||||
placeholder:
|
||||
zh_Hans: 在此输入您的 APIKey
|
||||
|
||||
@@ -3,6 +3,8 @@ from collections import defaultdict
|
||||
from json import JSONDecodeError
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
|
||||
from core.entities.provider_configuration import ProviderConfigurations, ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, CustomModelConfiguration, \
|
||||
@@ -380,17 +382,28 @@ class ProviderManager:
|
||||
if quota.quota_type == ProviderQuotaType.TRIAL:
|
||||
# Init trial provider records if not exists
|
||||
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
|
||||
provider_record = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=ProviderQuotaType.TRIAL.value,
|
||||
quota_limit=quota.quota_limit,
|
||||
quota_used=0,
|
||||
is_valid=True
|
||||
)
|
||||
db.session.add(provider_record)
|
||||
db.session.commit()
|
||||
try:
|
||||
provider_record = Provider(
|
||||
tenant_id=tenant_id,
|
||||
provider_name=provider_name,
|
||||
provider_type=ProviderType.SYSTEM.value,
|
||||
quota_type=ProviderQuotaType.TRIAL.value,
|
||||
quota_limit=quota.quota_limit,
|
||||
quota_used=0,
|
||||
is_valid=True
|
||||
)
|
||||
db.session.add(provider_record)
|
||||
db.session.commit()
|
||||
except IntegrityError:
|
||||
db.session.rollback()
|
||||
provider_record = db.session.query(Provider) \
|
||||
.filter(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == provider_name,
|
||||
Provider.provider_type == ProviderType.SYSTEM.value,
|
||||
Provider.quota_type == ProviderQuotaType.TRIAL.value,
|
||||
Provider.is_valid == True
|
||||
).first()
|
||||
|
||||
provider_name_to_provider_records_dict[provider_name].append(provider_record)
|
||||
|
||||
@@ -433,17 +446,27 @@ class ProviderManager:
|
||||
custom_provider_configuration = None
|
||||
if custom_provider_record:
|
||||
try:
|
||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||
# fix origin data
|
||||
if (custom_provider_record.encrypted_config
|
||||
and not custom_provider_record.encrypted_config.startswith("{")):
|
||||
provider_credentials = {
|
||||
"openai_api_key": custom_provider_record.encrypted_config
|
||||
}
|
||||
else:
|
||||
provider_credentials = json.loads(custom_provider_record.encrypted_config)
|
||||
except JSONDecodeError:
|
||||
provider_credentials = {}
|
||||
|
||||
for variable in provider_credential_secret_variables:
|
||||
if variable in provider_credentials:
|
||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_credentials.get(variable),
|
||||
decoding_rsa_key,
|
||||
decoding_cipher_rsa
|
||||
)
|
||||
try:
|
||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_credentials.get(variable),
|
||||
decoding_rsa_key,
|
||||
decoding_cipher_rsa
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
custom_provider_configuration = CustomProviderConfiguration(
|
||||
credentials=provider_credentials
|
||||
@@ -468,11 +491,14 @@ class ProviderManager:
|
||||
|
||||
for variable in model_credential_secret_variables:
|
||||
if variable in provider_model_credentials:
|
||||
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_model_credentials.get(variable),
|
||||
decoding_rsa_key,
|
||||
decoding_cipher_rsa
|
||||
)
|
||||
try:
|
||||
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_model_credentials.get(variable),
|
||||
decoding_rsa_key,
|
||||
decoding_cipher_rsa
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
custom_model_configurations.append(
|
||||
CustomModelConfiguration(
|
||||
@@ -564,11 +590,14 @@ class ProviderManager:
|
||||
|
||||
for variable in provider_credential_secret_variables:
|
||||
if variable in provider_credentials:
|
||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_credentials.get(variable),
|
||||
decoding_rsa_key,
|
||||
decoding_cipher_rsa
|
||||
)
|
||||
try:
|
||||
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
|
||||
provider_credentials.get(variable),
|
||||
decoding_rsa_key,
|
||||
decoding_cipher_rsa
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
current_using_credentials = provider_credentials
|
||||
else:
|
||||
|
||||
@@ -7,10 +7,38 @@ from typing import (
|
||||
Optional,
|
||||
)
|
||||
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter
|
||||
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter, TS, Type, Union, AbstractSet, Literal, Collection
|
||||
|
||||
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
|
||||
|
||||
class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
"""
|
||||
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
|
||||
"""
|
||||
@classmethod
|
||||
def from_gpt2_encoder(
|
||||
cls: Type[TS],
|
||||
encoding_name: str = "gpt2",
|
||||
model_name: Optional[str] = None,
|
||||
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
|
||||
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
|
||||
**kwargs: Any,
|
||||
):
|
||||
def _token_encoder(text: str) -> int:
|
||||
return GPT2Tokenizer.get_num_tokens(text)
|
||||
|
||||
if issubclass(cls, TokenTextSplitter):
|
||||
extra_kwargs = {
|
||||
"encoding_name": encoding_name,
|
||||
"model_name": model_name,
|
||||
"allowed_special": allowed_special,
|
||||
"disallowed_special": disallowed_special,
|
||||
}
|
||||
kwargs = {**kwargs, **extra_kwargs}
|
||||
|
||||
return cls(length_function=_token_encoder, **kwargs)
|
||||
|
||||
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
|
||||
def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
|
||||
"""Create a new TextSplitter."""
|
||||
super().__init__(**kwargs)
|
||||
@@ -65,4 +93,4 @@ class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
|
||||
if _good_splits:
|
||||
merged_text = self._merge_splits(_good_splits, separator)
|
||||
final_chunks.extend(merged_text)
|
||||
return final_chunks
|
||||
return final_chunks
|
||||
@@ -2,7 +2,7 @@ version: '3.1'
|
||||
services:
|
||||
# API service
|
||||
api:
|
||||
image: langgenius/dify-api:0.4.0
|
||||
image: langgenius/dify-api:0.4.2
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'api' starts the API server.
|
||||
@@ -128,7 +128,7 @@ services:
|
||||
# worker service
|
||||
# The Celery worker for processing the queue.
|
||||
worker:
|
||||
image: langgenius/dify-api:0.4.0
|
||||
image: langgenius/dify-api:0.4.2
|
||||
restart: always
|
||||
environment:
|
||||
# Startup mode, 'worker' starts the Celery worker for processing the queue.
|
||||
@@ -196,7 +196,7 @@ services:
|
||||
|
||||
# Frontend web application.
|
||||
web:
|
||||
image: langgenius/dify-web:0.4.0
|
||||
image: langgenius/dify-web:0.4.2
|
||||
restart: always
|
||||
environment:
|
||||
EDITION: SELF_HOSTED
|
||||
|
||||
@@ -10,7 +10,7 @@ First, install the dependencies:
|
||||
```bash
|
||||
npm install
|
||||
# or
|
||||
yarn
|
||||
yarn install --frozen-lockfile
|
||||
```
|
||||
|
||||
Then, configure the environment variables. Create a file named `.env.local` in the current directory and copy the contents from `.env.example`. Modify the values of these environment variables according to your requirements:
|
||||
|
||||
@@ -287,7 +287,6 @@ const Metadata: FC<IMetadataProps> = ({ docDetail, loading, onUpdate }) => {
|
||||
}
|
||||
|
||||
const onSave = async () => {
|
||||
console.log('metadataParams:', metadataParams)
|
||||
setSaveLoading(true)
|
||||
const [e] = await asyncRunSafe<CommonResponse>(modifyDocMetadata({
|
||||
datasetId,
|
||||
|
||||
@@ -30,7 +30,6 @@ const WorkplaceSelector = () => {
|
||||
const currentWorkspace = workspaces.find(v => v.current)
|
||||
|
||||
const handleSwitchWorkspace = async (tenant_id: string) => {
|
||||
console.log(tenant_id, currentWorkspace?.id)
|
||||
try {
|
||||
if (currentWorkspace?.id === tenant_id)
|
||||
return
|
||||
|
||||
@@ -248,7 +248,7 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
||||
...(isAdvancedMode ? [stopParameerRule] : []),
|
||||
].map(parameter => (
|
||||
<ParameterItem
|
||||
key={parameter.name}
|
||||
key={`${modelId}-${parameter.name}`}
|
||||
className='mb-4'
|
||||
parameterRule={parameter}
|
||||
value={completionParams[parameter.name]}
|
||||
|
||||
@@ -63,8 +63,13 @@ const ParameterItem: FC<ParameterItemProps> = ({
|
||||
|
||||
const handleChange = (v: ParameterValue) => {
|
||||
setLocalValue(v)
|
||||
if (!isNullOrUndefined(value) && onChange)
|
||||
onChange(v)
|
||||
|
||||
if (onChange) {
|
||||
if (parameterRule.name === 'stop')
|
||||
onChange(v)
|
||||
else if (!isNullOrUndefined(value))
|
||||
onChange(v)
|
||||
}
|
||||
}
|
||||
|
||||
const handleNumberInputChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
{
|
||||
"name": "dify-web",
|
||||
"version": "0.4.0",
|
||||
"version": "0.4.2",
|
||||
"private": true,
|
||||
"scripts": {
|
||||
"dev": "next dev",
|
||||
|
||||
@@ -136,7 +136,6 @@ const handleStream = (response: Response, onData: IOnData, onCompleted?: IOnComp
|
||||
onThought?.(bufferObj as ThoughtItem)
|
||||
}
|
||||
else if (bufferObj.event === 'message_end') {
|
||||
console.log(bufferObj)
|
||||
onMessageEnd?.(bufferObj as MessageEnd)
|
||||
}
|
||||
else if (bufferObj.event === 'message_replace') {
|
||||
|
||||
Reference in New Issue
Block a user