Compare commits

...

26 Commits
0.4.0 ... 0.4.2

Author SHA1 Message Date
takatost
91ff07fcf7 bump version to 0.4.2 (#1898) 2024-01-04 01:35:07 +08:00
takatost
bb7af56e69 fix: zhipuai history format wrong (#1897) 2024-01-04 01:30:23 +08:00
Chenhe Gu
77f9e8ce0f add example api url endpoint in placeholder (#1887)
Co-authored-by: takatost <takatost@gmail.com>
2024-01-04 01:16:51 +08:00
Jyong
5ca4c4a44d add qdrant client timeout limit (#1894)
Co-authored-by: jyong <jyong@dify.ai>
2024-01-03 22:23:04 +08:00
Boris Polonsky
a44022c388 Grammar fix (#1892) 2024-01-03 22:13:12 +08:00
takatost
6333cf43a8 fix: anthropic messages empty raise errors (#1893) 2024-01-03 22:12:14 +08:00
Garfield Dai
91ee62d1ab fix: huggingface and replicate. (#1888) 2024-01-03 18:29:44 +08:00
takatost
ede69b4659 fix: gemini block error (#1877)
Co-authored-by: chenhe <guchenhe@gmail.com>
2024-01-03 17:45:15 +08:00
waltcow
61aaeff413 Fix variable name in AgentApplicationRunner (#1884) 2024-01-03 17:44:41 +08:00
zxhlyh
4e1cd75f6f fix: model parameter stop sequence (#1885) 2024-01-03 17:15:29 +08:00
zxhlyh
a8ff2e95da fix: model parameter modal initial value (#1883) 2024-01-03 17:10:37 +08:00
crazywoola
4d502ea44d fix: openai embedding list out of bound (#1879) 2024-01-03 15:30:22 +08:00
Bowen Liang
66b3588897 doc: Respect and prevent updating existed yarn lockfile when installing dependencies (#1871) 2024-01-03 15:27:19 +08:00
Yeuoly
9134849744 fix: remove tiktoken from text splitter (#1876) 2024-01-03 13:02:56 +08:00
Garfield Dai
fcf8512956 fix: more like this. (#1875) 2024-01-03 12:51:19 +08:00
takatost
ae975b10e9 fix: openai origin credential not start with { (#1874) 2024-01-03 12:10:43 +08:00
Yeuoly
b43f1441a9 Fix/model runtime (#1873) 2024-01-03 11:36:57 +08:00
takatost
5a2aa83030 fix: ciphertext error (#1872) 2024-01-03 11:20:46 +08:00
takatost
4de27d0404 bump version to 0.4.1 (#1870) 2024-01-03 10:01:37 +08:00
Yeuoly
c6d59681ff fix: xinference secret server_url (#1869) 2024-01-03 10:01:11 +08:00
takatost
3b668c0bb1 fix: IntegrityError import wrong (#1868) 2024-01-03 09:35:03 +08:00
takatost
4aed1fe8a8 fix: Azure text-davinci-003 missing (#1867) 2024-01-03 09:27:09 +08:00
takatost
2381264a3f fix: provider create cause IntegrityError (#1866) 2024-01-03 09:12:53 +08:00
takatost
4562e83b24 fix: hit testing throws errors cause internal server error (#1865) 2024-01-03 08:57:39 +08:00
takatost
7be77c19f5 fix: default model parameter precision (#1864) 2024-01-03 08:52:22 +08:00
takatost
82247c0f14 fix: agent strategy missing in app model config (#1863) 2024-01-03 08:43:51 +08:00
33 changed files with 379 additions and 163 deletions

View File

@@ -91,7 +91,7 @@ After running, you can access the Dify dashboard in your browser at [http://loca
### Helm Chart
A big thanks to @BorisPolonsky for providing us with a [Helm Chart](https://helm.sh/) version, which allows Dify to be deployed on Kubernetes.
Big thanks to @BorisPolonsky for providing us with a [Helm Chart](https://helm.sh/) version, which allows Dify to be deployed on Kubernetes.
You can go to https://github.com/BorisPolonsky/dify-helm for deployment information.
### Configuration

View File

@@ -87,7 +87,7 @@ class Config:
# ------------------------
# General Configurations.
# ------------------------
self.CURRENT_VERSION = "0.4.0"
self.CURRENT_VERSION = "0.4.2"
self.COMMIT_SHA = get_env('COMMIT_SHA')
self.EDITION = "SELF_HOSTED"
self.DEPLOY_ENV = get_env('DEPLOY_ENV')
@@ -197,6 +197,7 @@ class Config:
# qdrant settings
self.QDRANT_URL = get_env('QDRANT_URL')
self.QDRANT_API_KEY = get_env('QDRANT_API_KEY')
self.QDRANT_CLIENT_TIMEOUT = get_env('QDRANT_CLIENT_TIMEOUT')
# milvus / zilliz setting
self.MILVUS_HOST = get_env('MILVUS_HOST')

View File

@@ -1,6 +1,8 @@
import logging
from flask_login import current_user
from core.model_runtime.errors.invoke import InvokeError
from libs.login import login_required
from flask_restful import Resource, reqparse, marshal
from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
@@ -8,7 +10,7 @@ from werkzeug.exceptions import InternalServerError, NotFound, Forbidden
import services
from controllers.console import api
from controllers.console.app.error import ProviderNotInitializeError, ProviderQuotaExceededError, \
ProviderModelCurrentlyNotSupportError
ProviderModelCurrentlyNotSupportError, CompletionRequestError
from controllers.console.datasets.error import HighQualityDatasetOnlyError, DatasetNotInitializedError
from controllers.console.setup import setup_required
from controllers.console.wraps import account_initialization_required
@@ -69,6 +71,8 @@ class HitTestingApi(Resource):
raise ProviderNotInitializeError(
f"No Embedding Model or Reranking Model available. Please configure a valid provider "
f"in the Settings -> Model Provider.")
except InvokeError as e:
raise CompletionRequestError(e.description)
except ValueError as e:
raise ValueError(str(e))
except Exception as e:

View File

@@ -237,8 +237,8 @@ class AgentApplicationRunner(AppRunner):
all_message_tokens = 0
all_answer_tokens = 0
for agent_thought in agent_thoughts:
all_message_tokens += agent_thought.message_tokens
all_answer_tokens += agent_thought.answer_tokens
all_message_tokens += agent_thought.message_token
all_answer_tokens += agent_thought.answer_token
model_type_instance = model_config.provider_model_bundle.model_type_instance
model_type_instance = cast(LargeLanguageModel, model_type_instance)

View File

@@ -376,7 +376,8 @@ class ApplicationManager:
and 'enabled' in copy_app_model_config_dict['agent_mode'] and copy_app_model_config_dict['agent_mode'][
'enabled']:
agent_dict = copy_app_model_config_dict.get('agent_mode')
if agent_dict['strategy'] in ['router', 'react_router']:
agent_strategy = agent_dict.get('strategy', 'router')
if agent_strategy in ['router', 'react_router']:
dataset_ids = []
for tool in agent_dict.get('tools', []):
key = list(tool.keys())[0]
@@ -402,7 +403,7 @@ class ApplicationManager:
retrieve_strategy=DatasetRetrieveConfigEntity.RetrieveStrategy.value_of(
dataset_configs['retrieval_model']
),
single_strategy=agent_dict['strategy']
single_strategy=agent_strategy
)
)
else:
@@ -419,7 +420,7 @@ class ApplicationManager:
)
)
else:
if agent_dict['strategy'] == 'react':
if agent_strategy == 'react':
strategy = AgentEntity.Strategy.CHAIN_OF_THOUGHT
else:
strategy = AgentEntity.Strategy.FUNCTION_CALLING
@@ -472,7 +473,7 @@ class ApplicationManager:
more_like_this_dict = copy_app_model_config_dict.get('more_like_this')
if more_like_this_dict:
if 'enabled' in more_like_this_dict and more_like_this_dict['enabled']:
properties['more_like_this'] = copy_app_model_config_dict.get('opening_statement')
properties['more_like_this'] = True
# speech to text
speech_to_text_dict = copy_app_model_config_dict.get('speech_to_text')

View File

@@ -18,6 +18,7 @@ from models.dataset import Dataset, DatasetCollectionBinding
class QdrantConfig(BaseModel):
endpoint: str
api_key: Optional[str]
timeout: float = 20
root_path: Optional[str]
def to_qdrant_params(self):
@@ -33,6 +34,7 @@ class QdrantConfig(BaseModel):
return {
'url': self.endpoint,
'api_key': self.api_key,
'timeout': self.timeout
}

View File

@@ -49,7 +49,8 @@ class VectorIndex:
config=QdrantConfig(
endpoint=config.get('QDRANT_URL'),
api_key=config.get('QDRANT_API_KEY'),
root_path=current_app.root_path
root_path=current_app.root_path,
timeout=config.get('QDRANT_CLIENT_TIMEOUT')
),
embeddings=embeddings
)

View File

@@ -5,12 +5,12 @@ import re
import threading
import time
import uuid
from typing import Optional, List, cast
from typing import Optional, List, cast, Type, Union, Literal, AbstractSet, Collection, Any
from flask import current_app, Flask
from flask_login import current_user
from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter, TextSplitter
from langchain.text_splitter import TextSplitter, TS, TokenTextSplitter
from sqlalchemy.orm.exc import ObjectDeletedError
from core.data_loader.file_extractor import FileExtractor
@@ -23,7 +23,8 @@ from core.errors.error import ProviderTokenNotInitError
from core.model_runtime.entities.model_entities import ModelType, PriceType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.model_runtime.model_providers.__base.text_embedding_model import TextEmbeddingModel
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
from core.spiltter.fixed_text_splitter import FixedRecursiveCharacterTextSplitter, EnhanceRecursiveCharacterTextSplitter
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
@@ -502,7 +503,8 @@ class IndexingRunner:
if separator:
separator = separator.replace('\\n', '\n')
character_splitter = FixedRecursiveCharacterTextSplitter.from_tiktoken_encoder(
character_splitter = FixedRecursiveCharacterTextSplitter.from_gpt2_encoder(
chunk_size=segmentation["max_tokens"],
chunk_overlap=0,
fixed_separator=separator,
@@ -510,7 +512,7 @@ class IndexingRunner:
)
else:
# Automatic segmentation
character_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder(
character_splitter = EnhanceRecursiveCharacterTextSplitter.from_gpt2_encoder(
chunk_size=DatasetProcessRule.AUTOMATIC_RULES['segmentation']['max_tokens'],
chunk_overlap=0,
separators=["\n\n", "", ".", " ", ""]

View File

@@ -18,7 +18,7 @@ PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = {
'default': 0.0,
'min': 0.0,
'max': 1.0,
'precision': 1,
'precision': 2,
},
DefaultParameterName.TOP_P: {
'label': {
@@ -34,7 +34,7 @@ PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = {
'default': 1.0,
'min': 0.0,
'max': 1.0,
'precision': 1,
'precision': 2,
},
DefaultParameterName.PRESENCE_PENALTY: {
'label': {
@@ -50,7 +50,7 @@ PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = {
'default': 0.0,
'min': 0.0,
'max': 1.0,
'precision': 1,
'precision': 2,
},
DefaultParameterName.FREQUENCY_PENALTY: {
'label': {
@@ -66,7 +66,7 @@ PARAMETER_RULE_TEMPLATE: Dict[DefaultParameterName, dict] = {
'default': 0.0,
'min': 0.0,
'max': 1.0,
'precision': 1,
'precision': 2,
},
DefaultParameterName.MAX_TOKENS: {
'label': {

View File

@@ -132,8 +132,8 @@ class LargeLanguageModel(AIModel):
system_fingerprint = None
real_model = model
for chunk in result:
try:
try:
for chunk in result:
yield chunk
self._trigger_new_chunk_callbacks(
@@ -156,8 +156,8 @@ class LargeLanguageModel(AIModel):
if chunk.system_fingerprint:
system_fingerprint = chunk.system_fingerprint
except Exception as e:
raise self._transform_invoke_error(e)
except Exception as e:
raise self._transform_invoke_error(e)
self._trigger_after_invoke_callbacks(
model=model,

View File

@@ -252,6 +252,9 @@ class AnthropicLargeLanguageModel(LargeLanguageModel):
:param messages: List of PromptMessage to combine.
:return: Combined string with necessary human_prompt and ai_prompt tags.
"""
if not messages:
return ''
messages = messages.copy() # don't mutate the original list
if not isinstance(messages[-1], AssistantPromptMessage):
messages.append(AssistantPromptMessage(content=""))

View File

@@ -448,6 +448,46 @@ LLM_BASE_MODELS = [
currency='USD',
)
)
),
AzureBaseModel(
base_model_name='text-davinci-003',
entity=AIModelEntity(
model='fake-deployment-name',
label=I18nObject(
en_US='fake-deployment-name-label',
),
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
'mode': LLMMode.COMPLETION.value,
'context_size': 4096,
},
parameter_rules=[
ParameterRule(
name='temperature',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TEMPERATURE],
),
ParameterRule(
name='top_p',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.TOP_P],
),
ParameterRule(
name='presence_penalty',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.PRESENCE_PENALTY],
),
ParameterRule(
name='frequency_penalty',
**PARAMETER_RULE_TEMPLATE[DefaultParameterName.FREQUENCY_PENALTY],
),
_get_max_tokens(default=512, min_val=1, max_val=4096),
],
pricing=PriceConfig(
input=0.02,
output=0.02,
unit=0.001,
currency='USD',
)
)
)
]

View File

@@ -3,6 +3,7 @@ from typing import Optional, Generator, Union, List
import google.generativeai as genai
import google.api_core.exceptions as exceptions
import google.generativeai.client as client
from google.generativeai.types import HarmCategory, HarmBlockThreshold
from google.generativeai.types import GenerateContentResponse, ContentType
from google.generativeai.types.content_types import to_part
@@ -124,7 +125,7 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
last_msg = prompt_messages[-1]
content = self._format_message_to_glm_content(last_msg)
history.append(content)
else:
else:
for msg in prompt_messages: # makes message roles strictly alternating
content = self._format_message_to_glm_content(msg)
if history and history[-1]["role"] == content["role"]:
@@ -139,13 +140,21 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
new_custom_client = new_client_manager.make_client("generative")
google_model._client = new_custom_client
safety_settings={
HarmCategory.HARM_CATEGORY_HARASSMENT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_HATE_SPEECH: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_SEXUALLY_EXPLICIT: HarmBlockThreshold.BLOCK_NONE,
HarmCategory.HARM_CATEGORY_DANGEROUS_CONTENT: HarmBlockThreshold.BLOCK_NONE,
}
response = google_model.generate_content(
contents=history,
generation_config=genai.types.GenerationConfig(
**config_kwargs
),
stream=stream
stream=stream,
safety_settings=safety_settings
)
if stream:
@@ -169,7 +178,6 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
content=response.text
)
# calculate num tokens
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
@@ -202,11 +210,11 @@ class GoogleLargeLanguageModel(LargeLanguageModel):
for chunk in response:
content = chunk.text
index += 1
assistant_prompt_message = AssistantPromptMessage(
content=content if content else '',
)
if not response._done:
# transform assistant message to prompt message

View File

@@ -154,20 +154,31 @@ class HuggingfaceHubLargeLanguageModel(_CommonHuggingfaceHub, LargeLanguageModel
content=chunk.token.text
)
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
if chunk.details:
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
usage=usage,
),
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
usage=usage,
finish_reason=chunk.details.finish_reason,
),
)
else:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
),
)
def _handle_generate_response(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], response: any) -> LLMResult:
if isinstance(response, str):

View File

@@ -68,7 +68,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
for i in _iter:
# call embedding model
embeddings, embedding_used_tokens = self._embedding_invoke(
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
model=model,
client=client,
texts=tokens[i: i + max_chunks],
@@ -76,7 +76,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
)
used_tokens += embedding_used_tokens
batched_embeddings += [data for data in embeddings]
batched_embeddings += embeddings_batch
results: list[list[list[float]]] = [[] for _ in range(len(texts))]
num_tokens_in_batch: list[list[int]] = [[] for _ in range(len(texts))]
@@ -87,7 +87,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
for i in range(len(texts)):
_result = results[i]
if len(_result) == 0:
embeddings, embedding_used_tokens = self._embedding_invoke(
embeddings_batch, embedding_used_tokens = self._embedding_invoke(
model=model,
client=client,
texts=[""],
@@ -95,7 +95,7 @@ class OpenAITextEmbeddingModel(_CommonOpenAI, TextEmbeddingModel):
)
used_tokens += embedding_used_tokens
average = embeddings[0]
average = embeddings_batch[0]
else:
average = np.average(_result, axis=0, weights=num_tokens_in_batch[i])
embeddings[i] = (average / np.linalg.norm(average)).tolist()

View File

@@ -1,19 +1,21 @@
import logging
from decimal import Decimal
from urllib.parse import urljoin
import requests
import json
from typing import Optional, Generator, Union, List, cast
from sympy import comp
from core.model_runtime.entities.common_entities import I18nObject
from core.model_runtime.utils import helper
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, AssistantPromptMessage, PromptMessageContent, \
PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, ToolPromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, DefaultParameterName, \
from core.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessage, \
AssistantPromptMessage, PromptMessageContent, \
PromptMessageContentType, PromptMessageFunction, PromptMessageTool, UserPromptMessage, SystemPromptMessage, \
ToolPromptMessage
from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelType, PriceConfig, ParameterRule, \
DefaultParameterName, \
ParameterType, ModelPropertyKey, FetchFrom, AIModelEntity
from core.model_runtime.entities.llm_entities import LLMMode, LLMResult, LLMResultChunk, LLMResultChunkDelta
from core.model_runtime.errors.invoke import InvokeError
@@ -72,7 +74,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
:return:
"""
return self._num_tokens_from_messages(model, prompt_messages, tools)
def validate_credentials(self, model: str, credentials: dict) -> None:
"""
Validate model credentials using requests to ensure compatibility with all providers following OpenAI's API standard.
@@ -91,6 +93,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials['endpoint_url']
if not endpoint_url.endswith('/'):
endpoint_url += '/'
# prepare the payload for a simple ping to the model
data = {
@@ -107,11 +111,13 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
"content": "ping"
},
]
endpoint_url = urljoin(endpoint_url, 'chat/completions')
elif completion_type is LLMMode.COMPLETION:
data['prompt'] = 'ping'
endpoint_url = urljoin(endpoint_url, 'completions')
else:
raise ValueError("Unsupported completion type for model configuration.")
# send a post request to validate the credentials
response = requests.post(
endpoint_url,
@@ -121,8 +127,24 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
)
if response.status_code != 200:
raise CredentialsValidateFailedError(f'Credentials validation failed with status code {response.status_code}: {response.text}')
raise CredentialsValidateFailedError(
f'Credentials validation failed with status code {response.status_code}')
try:
json_result = response.json()
except json.JSONDecodeError as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error')
if (completion_type is LLMMode.CHAT
and ('object' not in json_result or json_result['object'] != 'chat.completion')):
raise CredentialsValidateFailedError(
f'Credentials validation failed: invalid response object, must be \'chat.completion\'')
elif (completion_type is LLMMode.COMPLETION
and ('object' not in json_result or json_result['object'] != 'text_completion')):
raise CredentialsValidateFailedError(
f'Credentials validation failed: invalid response object, must be \'text_completion\'')
except CredentialsValidateFailedError:
raise
except Exception as ex:
raise CredentialsValidateFailedError(f'An error occurred during credentials validation: {str(ex)}')
@@ -136,8 +158,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
model_type=ModelType.LLM,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'),
ModelPropertyKey.MODE: 'chat'
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
ModelPropertyKey.MODE: credentials.get('mode'),
},
parameter_rules=[
ParameterRule(
@@ -199,11 +221,11 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return entity
# validate_credentials method has been rewritten to use the requests library for compatibility with all providers following OpenAI's API standard.
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None, stream: bool = True, \
user: Optional[str] = None) -> Union[LLMResult, Generator]:
def _generate(self, model: str, credentials: dict, prompt_messages: list[PromptMessage], model_parameters: dict,
tools: Optional[list[PromptMessageTool]] = None, stop: Optional[List[str]] = None,
stream: bool = True, \
user: Optional[str] = None) -> Union[LLMResult, Generator]:
"""
Invoke llm completion model
@@ -225,7 +247,9 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials["endpoint_url"]
if not endpoint_url.endswith('/'):
endpoint_url += '/'
data = {
"model": model,
"stream": stream,
@@ -235,8 +259,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
completion_type = LLMMode.value_of(credentials['mode'])
if completion_type is LLMMode.CHAT:
endpoint_url = urljoin(endpoint_url, 'chat/completions')
data['messages'] = [self._convert_prompt_message_to_dict(m) for m in prompt_messages]
elif completion_type == LLMMode.COMPLETION:
endpoint_url = urljoin(endpoint_url, 'completions')
data['prompt'] = prompt_messages[0].content
else:
raise ValueError("Unsupported completion type for model configuration.")
@@ -247,8 +273,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
data["tool_choice"] = "auto"
for tool in tools:
formatted_tools.append( helper.dump_model(PromptMessageFunction(function=tool)))
formatted_tools.append(helper.dump_model(PromptMessageFunction(function=tool)))
data["tools"] = formatted_tools
if stop:
@@ -256,7 +282,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if user:
data["user"] = user
response = requests.post(
endpoint_url,
headers=headers,
@@ -277,8 +303,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
return self._handle_generate_response(model, credentials, response, prompt_messages)
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> Generator:
def _handle_generate_stream_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> Generator:
"""
Handle llm stream response
@@ -315,51 +341,64 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
if chunk:
decoded_chunk = chunk.decode('utf-8').strip().lstrip('data: ').lstrip()
chunk_json = None
try:
chunk_json = json.loads(decoded_chunk)
# stream ended
except json.JSONDecodeError as e:
yield create_final_llm_result_chunk(
index=chunk_index + 1,
index=chunk_index + 1,
message=AssistantPromptMessage(content=""),
finish_reason="Non-JSON encountered."
)
if len(chunk_json['choices']) == 0:
if not chunk_json or len(chunk_json['choices']) == 0:
continue
delta = chunk_json['choices'][0]['delta']
chunk_index = chunk_json['choices'][0]['index']
choice = chunk_json['choices'][0]
chunk_index = choice['index'] if 'index' in choice else chunk_index
if delta.get('finish_reason') is None and (delta.get('content') is None or delta.get('content') == ''):
if 'delta' in choice:
delta = choice['delta']
if delta.get('content') is None or delta.get('content') == '':
continue
assistant_message_tool_calls = delta.get('tool_calls', None)
# assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
# function_call = self._extract_response_function_call(assistant_message_function_call)
# tool_calls = [function_call] if function_call else []
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.get('content', ''),
tool_calls=tool_calls if assistant_message_tool_calls else []
)
full_assistant_content += delta.get('content', '')
elif 'text' in choice:
if choice.get('text') is None or choice.get('text') == '':
continue
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=choice.get('text', '')
)
full_assistant_content += choice.get('text', '')
else:
continue
assistant_message_tool_calls = delta.get('tool_calls', None)
# assistant_message_function_call = delta.delta.function_call
# extract tool calls from response
if assistant_message_tool_calls:
tool_calls = self._extract_response_tool_calls(assistant_message_tool_calls)
# function_call = self._extract_response_function_call(assistant_message_function_call)
# tool_calls = [function_call] if function_call else []
# transform assistant message to prompt message
assistant_prompt_message = AssistantPromptMessage(
content=delta.get('content', ''),
tool_calls=tool_calls if assistant_message_tool_calls else []
)
full_assistant_content += delta.get('content', '')
# check payload indicator for completion
if chunk_json['choices'][0].get('finish_reason') is not None:
yield create_final_llm_result_chunk(
index=chunk_index,
message=assistant_prompt_message,
finish_reason=chunk_json['choices'][0]['finish_reason']
)
else:
yield LLMResultChunk(
model=model,
@@ -375,10 +414,12 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message=AssistantPromptMessage(content=""),
finish_reason="End of stream."
)
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> LLMResult:
chunk_index += 1
def _handle_generate_response(self, model: str, credentials: dict, response: requests.Response,
prompt_messages: list[PromptMessage]) -> LLMResult:
response_json = response.json()
completion_type = LLMMode.value_of(credentials['mode'])
@@ -457,7 +498,8 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message = cast(AssistantPromptMessage, message)
message_dict = {"role": "assistant", "content": message.content}
if message.tool_calls:
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call in
message_dict["tool_calls"] = [helper.dump_model(PromptMessageFunction(function=tool_call)) for tool_call
in
message.tool_calls]
# function_call = message.tool_calls[0]
# message_dict["function_call"] = {
@@ -486,7 +528,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
message_dict["name"] = message.name
return message_dict
def _num_tokens_from_string(self, model: str, text: str,
tools: Optional[list[PromptMessageTool]] = None) -> int:
"""
@@ -509,10 +551,10 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
"""
Approximate num tokens with GPT2 tokenizer.
"""
tokens_per_message = 3
tokens_per_name = 1
num_tokens = 0
messages_dict = [self._convert_prompt_message_to_dict(m) for m in messages]
for message in messages_dict:
@@ -601,7 +643,7 @@ class OAIAPICompatLargeLanguageModel(_CommonOAI_API_Compat, LargeLanguageModel):
num_tokens += self._get_num_tokens_by_gpt2(required_field)
return num_tokens
def _extract_response_tool_calls(self,
response_tool_calls: list[dict]) \
-> list[AssistantPromptMessage.ToolCall]:

View File

@@ -33,8 +33,8 @@ model_credential_schema:
type: text-input
required: true
placeholder:
zh_Hans: 在此输入您的 API endpoint URL
en_US: Enter your API endpoint URL
zh_Hans: Base URL, eg. https://api.openai.com/v1
en_US: Base URL, eg. https://api.openai.com/v1
- variable: mode
show_on:
- variable: __model_type

View File

@@ -1,6 +1,7 @@
import time
from decimal import Decimal
from typing import Optional
from urllib.parse import urljoin
import requests
import json
@@ -42,8 +43,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials.get('endpoint_url')
if not endpoint_url.endswith('/'):
endpoint_url += '/'
endpoint_url = credentials['endpoint_url']
endpoint_url = urljoin(endpoint_url, 'embeddings')
extra_model_kwargs = {}
if user:
@@ -144,8 +148,11 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
if api_key:
headers["Authorization"] = f"Bearer {api_key}"
endpoint_url = credentials.get('endpoint_url')
if not endpoint_url.endswith('/'):
endpoint_url += '/'
endpoint_url = credentials['endpoint_url']
endpoint_url = urljoin(endpoint_url, 'embeddings')
payload = {
'input': 'ping',
@@ -160,8 +167,19 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
)
if response.status_code != 200:
raise CredentialsValidateFailedError(f"Invalid response status: {response.status_code}")
raise CredentialsValidateFailedError(
f'Credentials validation failed with status code {response.status_code}')
try:
json_result = response.json()
except json.JSONDecodeError as e:
raise CredentialsValidateFailedError(f'Credentials validation failed: JSON decode error')
if 'model' not in json_result:
raise CredentialsValidateFailedError(
f'Credentials validation failed: invalid response')
except CredentialsValidateFailedError:
raise
except Exception as ex:
raise CredentialsValidateFailedError(str(ex))
@@ -175,7 +193,7 @@ class OAICompatEmbeddingModel(_CommonOAI_API_Compat, TextEmbeddingModel):
model_type=ModelType.TEXT_EMBEDDING,
fetch_from=FetchFrom.CUSTOMIZABLE_MODEL,
model_properties={
ModelPropertyKey.CONTEXT_SIZE: credentials.get('context_size'),
ModelPropertyKey.CONTEXT_SIZE: int(credentials.get('context_size')),
ModelPropertyKey.MAX_CHUNKS: 1,
},
parameter_rules=[],

View File

@@ -116,7 +116,7 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
)
for key, value in input_properties:
if key not in ['system_prompt', 'prompt']:
if key not in ['system_prompt', 'prompt'] and 'stop' not in key:
value_type = value.get('type')
if not value_type:
@@ -151,9 +151,17 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
index = -1
current_completion: str = ""
stop_condition_reached = False
prediction_output_length = 10000
is_prediction_output_finished = False
for output in prediction.output_iterator():
current_completion += output
if not is_prediction_output_finished and prediction.status == 'succeeded':
prediction_output_length = len(prediction.output) - 1
is_prediction_output_finished = True
if stop:
for s in stop:
if s in current_completion:
@@ -172,20 +180,30 @@ class ReplicateLargeLanguageModel(_CommonReplicate, LargeLanguageModel):
content=output if output else ''
)
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
if index < prediction_output_length:
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message
)
)
else:
prompt_tokens = self.get_num_tokens(model, credentials, prompt_messages)
completion_tokens = self.get_num_tokens(model, credentials, [assistant_prompt_message])
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
usage = self._calc_response_usage(model, credentials, prompt_tokens, completion_tokens)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
usage=usage,
),
)
yield LLMResultChunk(
model=model,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=index,
message=assistant_prompt_message,
usage=usage
)
)
def _handle_generate_response(self, model: str, credentials: dict, prediction: Prediction, stop: list[str],
prompt_messages: list[PromptMessage]) -> LLMResult:

View File

@@ -33,10 +33,13 @@ class XinferenceHelper:
@staticmethod
def _clean_cache() -> None:
with cache_lock:
for model_uid, model in cache.items():
if model['expires'] < time():
try:
with cache_lock:
expired_keys = [model_uid for model_uid, model in cache.items() if model['expires'] < time()]
for model_uid in expired_keys:
del cache[model_uid]
except RuntimeError as e:
pass
@staticmethod
def _get_xinference_extra_parameter(server_url: str, model_uid: str) -> XinferenceModelExtraParameter:

View File

@@ -31,7 +31,7 @@ model_credential_schema:
label:
zh_Hans: 服务器URL
en_US: Server url
type: text-input
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入Xinference的服务器地址如 https://example.com/xxx

View File

@@ -117,7 +117,10 @@ class ZhipuAILargeLanguageModel(_CommonZhipuaiAI, LargeLanguageModel):
params = {
'model': model,
'prompt': [{ 'role': prompt_message.role.value, 'content': prompt_message.content } for prompt_message in prompt_messages],
'prompt': [{
'role': prompt_message.role.value if prompt_message.role.value != 'system' else 'user',
'content': prompt_message.content
} for prompt_message in prompt_messages],
**model_parameters
}

View File

@@ -24,7 +24,7 @@ provider_credential_schema:
- variable: api_key
label:
en_US: APIKey
type: text-input
type: secret-input
required: true
placeholder:
zh_Hans: 在此输入您的 APIKey

View File

@@ -3,6 +3,8 @@ from collections import defaultdict
from json import JSONDecodeError
from typing import Optional
from sqlalchemy.exc import IntegrityError
from core.entities.model_entities import DefaultModelEntity, DefaultModelProviderEntity
from core.entities.provider_configuration import ProviderConfigurations, ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, CustomProviderConfiguration, CustomModelConfiguration, \
@@ -380,17 +382,28 @@ class ProviderManager:
if quota.quota_type == ProviderQuotaType.TRIAL:
# Init trial provider records if not exists
if ProviderQuotaType.TRIAL not in provider_quota_to_provider_record_dict:
provider_record = Provider(
tenant_id=tenant_id,
provider_name=provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.TRIAL.value,
quota_limit=quota.quota_limit,
quota_used=0,
is_valid=True
)
db.session.add(provider_record)
db.session.commit()
try:
provider_record = Provider(
tenant_id=tenant_id,
provider_name=provider_name,
provider_type=ProviderType.SYSTEM.value,
quota_type=ProviderQuotaType.TRIAL.value,
quota_limit=quota.quota_limit,
quota_used=0,
is_valid=True
)
db.session.add(provider_record)
db.session.commit()
except IntegrityError:
db.session.rollback()
provider_record = db.session.query(Provider) \
.filter(
Provider.tenant_id == tenant_id,
Provider.provider_name == provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == ProviderQuotaType.TRIAL.value,
Provider.is_valid == True
).first()
provider_name_to_provider_records_dict[provider_name].append(provider_record)
@@ -433,17 +446,27 @@ class ProviderManager:
custom_provider_configuration = None
if custom_provider_record:
try:
provider_credentials = json.loads(custom_provider_record.encrypted_config)
# fix origin data
if (custom_provider_record.encrypted_config
and not custom_provider_record.encrypted_config.startswith("{")):
provider_credentials = {
"openai_api_key": custom_provider_record.encrypted_config
}
else:
provider_credentials = json.loads(custom_provider_record.encrypted_config)
except JSONDecodeError:
provider_credentials = {}
for variable in provider_credential_secret_variables:
if variable in provider_credentials:
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
)
try:
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
)
except ValueError:
pass
custom_provider_configuration = CustomProviderConfiguration(
credentials=provider_credentials
@@ -468,11 +491,14 @@ class ProviderManager:
for variable in model_credential_secret_variables:
if variable in provider_model_credentials:
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
)
try:
provider_model_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_model_credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
)
except ValueError:
pass
custom_model_configurations.append(
CustomModelConfiguration(
@@ -564,11 +590,14 @@ class ProviderManager:
for variable in provider_credential_secret_variables:
if variable in provider_credentials:
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
)
try:
provider_credentials[variable] = encrypter.decrypt_token_with_decoding(
provider_credentials.get(variable),
decoding_rsa_key,
decoding_cipher_rsa
)
except ValueError:
pass
current_using_credentials = provider_credentials
else:

View File

@@ -7,10 +7,38 @@ from typing import (
Optional,
)
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.text_splitter import RecursiveCharacterTextSplitter, TokenTextSplitter, TS, Type, Union, AbstractSet, Literal, Collection
from core.model_runtime.model_providers.__base.tokenizers.gpt2_tokenzier import GPT2Tokenizer
class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
"""
This class is used to implement from_gpt2_encoder, to prevent using of tiktoken
"""
@classmethod
def from_gpt2_encoder(
cls: Type[TS],
encoding_name: str = "gpt2",
model_name: Optional[str] = None,
allowed_special: Union[Literal["all"], AbstractSet[str]] = set(),
disallowed_special: Union[Literal["all"], Collection[str]] = "all",
**kwargs: Any,
):
def _token_encoder(text: str) -> int:
return GPT2Tokenizer.get_num_tokens(text)
if issubclass(cls, TokenTextSplitter):
extra_kwargs = {
"encoding_name": encoding_name,
"model_name": model_name,
"allowed_special": allowed_special,
"disallowed_special": disallowed_special,
}
kwargs = {**kwargs, **extra_kwargs}
return cls(length_function=_token_encoder, **kwargs)
class FixedRecursiveCharacterTextSplitter(EnhanceRecursiveCharacterTextSplitter):
def __init__(self, fixed_separator: str = "\n\n", separators: Optional[List[str]] = None, **kwargs: Any):
"""Create a new TextSplitter."""
super().__init__(**kwargs)
@@ -65,4 +93,4 @@ class FixedRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
if _good_splits:
merged_text = self._merge_splits(_good_splits, separator)
final_chunks.extend(merged_text)
return final_chunks
return final_chunks

View File

@@ -2,7 +2,7 @@ version: '3.1'
services:
# API service
api:
image: langgenius/dify-api:0.4.0
image: langgenius/dify-api:0.4.2
restart: always
environment:
# Startup mode, 'api' starts the API server.
@@ -128,7 +128,7 @@ services:
# worker service
# The Celery worker for processing the queue.
worker:
image: langgenius/dify-api:0.4.0
image: langgenius/dify-api:0.4.2
restart: always
environment:
# Startup mode, 'worker' starts the Celery worker for processing the queue.
@@ -196,7 +196,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:0.4.0
image: langgenius/dify-web:0.4.2
restart: always
environment:
EDITION: SELF_HOSTED

View File

@@ -10,7 +10,7 @@ First, install the dependencies:
```bash
npm install
# or
yarn
yarn install --frozen-lockfile
```
Then, configure the environment variables. Create a file named `.env.local` in the current directory and copy the contents from `.env.example`. Modify the values of these environment variables according to your requirements:

View File

@@ -287,7 +287,6 @@ const Metadata: FC<IMetadataProps> = ({ docDetail, loading, onUpdate }) => {
}
const onSave = async () => {
console.log('metadataParams:', metadataParams)
setSaveLoading(true)
const [e] = await asyncRunSafe<CommonResponse>(modifyDocMetadata({
datasetId,

View File

@@ -30,7 +30,6 @@ const WorkplaceSelector = () => {
const currentWorkspace = workspaces.find(v => v.current)
const handleSwitchWorkspace = async (tenant_id: string) => {
console.log(tenant_id, currentWorkspace?.id)
try {
if (currentWorkspace?.id === tenant_id)
return

View File

@@ -248,7 +248,7 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
...(isAdvancedMode ? [stopParameerRule] : []),
].map(parameter => (
<ParameterItem
key={parameter.name}
key={`${modelId}-${parameter.name}`}
className='mb-4'
parameterRule={parameter}
value={completionParams[parameter.name]}

View File

@@ -63,8 +63,13 @@ const ParameterItem: FC<ParameterItemProps> = ({
const handleChange = (v: ParameterValue) => {
setLocalValue(v)
if (!isNullOrUndefined(value) && onChange)
onChange(v)
if (onChange) {
if (parameterRule.name === 'stop')
onChange(v)
else if (!isNullOrUndefined(value))
onChange(v)
}
}
const handleNumberInputChange = (e: React.ChangeEvent<HTMLInputElement>) => {

View File

@@ -1,6 +1,6 @@
{
"name": "dify-web",
"version": "0.4.0",
"version": "0.4.2",
"private": true,
"scripts": {
"dev": "next dev",

View File

@@ -136,7 +136,6 @@ const handleStream = (response: Response, onData: IOnData, onCompleted?: IOnComp
onThought?.(bufferObj as ThoughtItem)
}
else if (bufferObj.event === 'message_end') {
console.log(bufferObj)
onMessageEnd?.(bufferObj as MessageEnd)
}
else if (bufferObj.event === 'message_replace') {