mirror of
https://github.com/langgenius/dify.git
synced 2026-02-25 18:55:08 +00:00
Compare commits
11 Commits
dependabot
...
workflow-l
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
27ee32088d | ||
|
|
6dcca77820 | ||
|
|
8ff55da8b7 | ||
|
|
3d74218034 | ||
|
|
b94a338636 | ||
|
|
74a32a7715 | ||
|
|
dd68b3608e | ||
|
|
e7c82f1158 | ||
|
|
d7cdbd6cca | ||
|
|
e74d3791dc | ||
|
|
90bab9c8a3 |
@@ -91,7 +91,6 @@ forbidden_modules =
|
||||
core.logging
|
||||
core.mcp
|
||||
core.memory
|
||||
core.model_manager
|
||||
core.moderation
|
||||
core.ops
|
||||
core.plugin
|
||||
@@ -121,6 +120,7 @@ ignore_imports =
|
||||
core.workflow.nodes.llm.llm_utils -> configs
|
||||
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_manager
|
||||
core.workflow.nodes.llm.protocols -> core.model_manager
|
||||
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.llm.llm_utils -> models.model
|
||||
core.workflow.nodes.llm.llm_utils -> models.provider
|
||||
|
||||
@@ -112,7 +112,7 @@ class BaseAgentRunner(AppRunner):
|
||||
|
||||
# check if model supports stream tool call
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
|
||||
features = model_schema.features if model_schema and model_schema.features else []
|
||||
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
|
||||
self.files = application_generate_entity.files if ModelFeature.VISION in features else []
|
||||
|
||||
@@ -245,7 +245,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
iteration_step += 1
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
model=model_instance.model_name,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
|
||||
@@ -268,7 +268,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
model=model_instance.model_name,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
|
||||
@@ -178,7 +178,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
)
|
||||
|
||||
yield LLMResultChunk(
|
||||
model=model_instance.model,
|
||||
model=model_instance.model_name,
|
||||
prompt_messages=result.prompt_messages,
|
||||
system_fingerprint=result.system_fingerprint,
|
||||
delta=LLMResultChunkDelta(
|
||||
@@ -308,7 +308,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=model_instance.model,
|
||||
model=model_instance.model_name,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=llm_usage["usage"] or LLMUsage.empty_usage(),
|
||||
|
||||
@@ -178,7 +178,7 @@ class AgentChatAppRunner(AppRunner):
|
||||
|
||||
# change function call strategy based on LLM model
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
|
||||
if not model_schema:
|
||||
raise ValueError("Model schema not found")
|
||||
|
||||
|
||||
1
api/core/app/llm/__init__.py
Normal file
1
api/core/app/llm/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""LLM-related application services."""
|
||||
103
api/core/app/llm/model_access.py
Normal file
103
api/core/app/llm/model_access.py
Normal file
@@ -0,0 +1,103 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.errors.error import ProviderTokenNotInitError
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
|
||||
|
||||
class DifyCredentialsProvider:
|
||||
tenant_id: str
|
||||
provider_manager: ProviderManager
|
||||
|
||||
def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.provider_manager = provider_manager or ProviderManager()
|
||||
|
||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
||||
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
|
||||
provider_configuration = provider_configurations.get(provider_name)
|
||||
if not provider_configuration:
|
||||
raise ValueError(f"Provider {provider_name} does not exist.")
|
||||
|
||||
provider_model = provider_configuration.get_provider_model(model_type=ModelType.LLM, model=model_name)
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {model_name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model_name)
|
||||
if credentials is None:
|
||||
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
|
||||
|
||||
return credentials
|
||||
|
||||
|
||||
class DifyModelFactory:
|
||||
tenant_id: str
|
||||
model_manager: ModelManager
|
||||
|
||||
def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None:
|
||||
self.tenant_id = tenant_id
|
||||
self.model_manager = model_manager or ModelManager()
|
||||
|
||||
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
|
||||
return self.model_manager.get_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType.LLM,
|
||||
model=model_name,
|
||||
)
|
||||
|
||||
|
||||
def build_dify_model_access(tenant_id: str) -> tuple[CredentialsProvider, ModelFactory]:
|
||||
return (
|
||||
DifyCredentialsProvider(tenant_id=tenant_id),
|
||||
DifyModelFactory(tenant_id=tenant_id),
|
||||
)
|
||||
|
||||
|
||||
def fetch_model_config(
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name)
|
||||
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
stop: list[str] = []
|
||||
if "stop" in node_data_model.completion_params:
|
||||
stop = node_data_model.completion_params.pop("stop")
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
model_schema=model_schema,
|
||||
mode=node_data_model.mode,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=credentials,
|
||||
parameters=node_data_model.completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, final
|
||||
from typing_extensions import override
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.llm.model_access import build_dify_model_access
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
@@ -18,8 +19,13 @@ from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
|
||||
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.template_transform.template_renderer import CodeExecutorJinja2TemplateRenderer
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
CodeExecutorJinja2TemplateRenderer,
|
||||
)
|
||||
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -75,6 +81,8 @@ class DifyNodeFactory(NodeFactory):
|
||||
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
)
|
||||
|
||||
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(graph_init_params.tenant_id)
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
"""
|
||||
@@ -140,6 +148,16 @@ class DifyNodeFactory(NodeFactory):
|
||||
file_manager=self._http_request_file_manager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.LLM:
|
||||
return LLMNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
)
|
||||
|
||||
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
return KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
@@ -158,6 +176,26 @@ class DifyNodeFactory(NodeFactory):
|
||||
unstructured_api_config=self._document_extractor_unstructured_api_config,
|
||||
)
|
||||
|
||||
if node_type == NodeType.QUESTION_CLASSIFIER:
|
||||
return QuestionClassifierNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
)
|
||||
|
||||
if node_type == NodeType.PARAMETER_EXTRACTOR:
|
||||
return ParameterExtractorNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
|
||||
@@ -35,7 +35,7 @@ class ModelInstance:
|
||||
|
||||
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
|
||||
self.provider_model_bundle = provider_model_bundle
|
||||
self.model = model
|
||||
self.model_name = model
|
||||
self.provider = provider_model_bundle.configuration.provider.provider
|
||||
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
|
||||
self.model_type_instance = self.provider_model_bundle.model_type_instance
|
||||
@@ -163,7 +163,7 @@ class ModelInstance:
|
||||
Union[LLMResult, Generator],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=model_parameters,
|
||||
@@ -191,7 +191,7 @@ class ModelInstance:
|
||||
int,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
prompt_messages=prompt_messages,
|
||||
tools=tools,
|
||||
@@ -215,7 +215,7 @@ class ModelInstance:
|
||||
EmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
user=user,
|
||||
@@ -243,7 +243,7 @@ class ModelInstance:
|
||||
EmbeddingResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
multimodel_documents=multimodel_documents,
|
||||
user=user,
|
||||
@@ -264,7 +264,7 @@ class ModelInstance:
|
||||
list[int],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.get_num_tokens,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
texts=texts,
|
||||
),
|
||||
@@ -294,7 +294,7 @@ class ModelInstance:
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
@@ -328,7 +328,7 @@ class ModelInstance:
|
||||
RerankResult,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke_multimodal_rerank,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
query=query,
|
||||
docs=docs,
|
||||
@@ -352,7 +352,7 @@ class ModelInstance:
|
||||
bool,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
text=text,
|
||||
user=user,
|
||||
@@ -373,7 +373,7 @@ class ModelInstance:
|
||||
str,
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
file=file,
|
||||
user=user,
|
||||
@@ -396,7 +396,7 @@ class ModelInstance:
|
||||
Iterable[bytes],
|
||||
self._round_robin_invoke(
|
||||
function=self.model_type_instance.invoke,
|
||||
model=self.model,
|
||||
model=self.model_name,
|
||||
credentials=self.credentials,
|
||||
content_text=content_text,
|
||||
user=user,
|
||||
@@ -469,7 +469,7 @@ class ModelInstance:
|
||||
if not isinstance(self.model_type_instance, TTSModel):
|
||||
raise Exception("Model type instance is not TTSModel")
|
||||
return self.model_type_instance.get_tts_model_voices(
|
||||
model=self.model, credentials=self.credentials, language=language
|
||||
model=self.model_name, credentials=self.credentials, language=language
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -47,7 +47,9 @@ class AgentHistoryPromptTransform(PromptTransform):
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
self.memory.model_instance.model, self.memory.model_instance.credentials, self.history_messages
|
||||
self.model_config.model,
|
||||
self.model_config.credentials,
|
||||
self.history_messages,
|
||||
)
|
||||
if curr_message_tokens <= max_token_limit:
|
||||
return self.history_messages
|
||||
@@ -63,7 +65,9 @@ class AgentHistoryPromptTransform(PromptTransform):
|
||||
# a message is start with UserPromptMessage
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
curr_message_tokens = model_type_instance.get_num_tokens(
|
||||
self.memory.model_instance.model, self.memory.model_instance.credentials, prompt_messages
|
||||
self.model_config.model,
|
||||
self.model_config.credentials,
|
||||
prompt_messages,
|
||||
)
|
||||
# if current message token is overflow, drop all the prompts in current message and break
|
||||
if curr_message_tokens > max_token_limit:
|
||||
|
||||
@@ -35,7 +35,9 @@ class CacheEmbedding(Embeddings):
|
||||
embedding = (
|
||||
db.session.query(Embedding)
|
||||
.filter_by(
|
||||
model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider
|
||||
model_name=self._model_instance.model_name,
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@@ -52,7 +54,7 @@ class CacheEmbedding(Embeddings):
|
||||
try:
|
||||
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
self._model_instance.model, self._model_instance.credentials
|
||||
self._model_instance.model_name, self._model_instance.credentials
|
||||
)
|
||||
max_chunks = (
|
||||
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
@@ -87,7 +89,7 @@ class CacheEmbedding(Embeddings):
|
||||
hash = helper.generate_text_hash(texts[i])
|
||||
if hash not in cache_embeddings:
|
||||
embedding_cache = Embedding(
|
||||
model_name=self._model_instance.model,
|
||||
model_name=self._model_instance.model_name,
|
||||
hash=hash,
|
||||
provider_name=self._model_instance.provider,
|
||||
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
|
||||
@@ -114,7 +116,9 @@ class CacheEmbedding(Embeddings):
|
||||
embedding = (
|
||||
db.session.query(Embedding)
|
||||
.filter_by(
|
||||
model_name=self._model_instance.model, hash=file_id, provider_name=self._model_instance.provider
|
||||
model_name=self._model_instance.model_name,
|
||||
hash=file_id,
|
||||
provider_name=self._model_instance.provider,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
@@ -131,7 +135,7 @@ class CacheEmbedding(Embeddings):
|
||||
try:
|
||||
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
|
||||
model_schema = model_type_instance.get_model_schema(
|
||||
self._model_instance.model, self._model_instance.credentials
|
||||
self._model_instance.model_name, self._model_instance.credentials
|
||||
)
|
||||
max_chunks = (
|
||||
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
|
||||
@@ -168,7 +172,7 @@ class CacheEmbedding(Embeddings):
|
||||
file_id = multimodel_documents[i]["file_id"]
|
||||
if file_id not in cache_embeddings:
|
||||
embedding_cache = Embedding(
|
||||
model_name=self._model_instance.model,
|
||||
model_name=self._model_instance.model_name,
|
||||
hash=file_id,
|
||||
provider_name=self._model_instance.provider,
|
||||
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
|
||||
@@ -190,7 +194,7 @@ class CacheEmbedding(Embeddings):
|
||||
"""Embed query text."""
|
||||
# use doc embedding cache or store if not exists
|
||||
hash = helper.generate_text_hash(text)
|
||||
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
|
||||
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{hash}"
|
||||
embedding = redis_client.get(embedding_cache_key)
|
||||
if embedding:
|
||||
redis_client.expire(embedding_cache_key, 600)
|
||||
@@ -233,7 +237,7 @@ class CacheEmbedding(Embeddings):
|
||||
"""Embed multimodal documents."""
|
||||
# use doc embedding cache or store if not exists
|
||||
file_id = multimodel_document["file_id"]
|
||||
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{file_id}"
|
||||
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{file_id}"
|
||||
embedding = redis_client.get(embedding_cache_key)
|
||||
if embedding:
|
||||
redis_client.expire(embedding_cache_key, 600)
|
||||
|
||||
@@ -38,7 +38,7 @@ class RerankModelRunner(BaseRerankRunner):
|
||||
is_support_vision = model_manager.check_model_support_vision(
|
||||
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
|
||||
provider=self.rerank_model_instance.provider,
|
||||
model=self.rerank_model_instance.model,
|
||||
model=self.rerank_model_instance.model_name,
|
||||
model_type=ModelType.RERANK,
|
||||
)
|
||||
if not is_support_vision:
|
||||
|
||||
@@ -47,7 +47,7 @@ class ModelInvocationUtils:
|
||||
raise InvokeModelError("Model not found")
|
||||
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
|
||||
|
||||
if not schema:
|
||||
raise InvokeModelError("No model schema found")
|
||||
|
||||
@@ -8,7 +8,7 @@ from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
@@ -17,6 +17,8 @@ from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegme
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.file.models import File
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -24,49 +26,46 @@ from models.model import Conversation
|
||||
from models.provider import Provider, ProviderType
|
||||
from models.provider_ids import ModelProviderID
|
||||
|
||||
from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError
|
||||
from .exc import InvalidVariableTypeError
|
||||
|
||||
|
||||
def fetch_model_config(
|
||||
tenant_id: str, node_data_model: ModelConfig
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
if not node_data_model.mode:
|
||||
raise LLMModeRequiredError("LLM mode is required.")
|
||||
|
||||
model = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=node_data_model.provider,
|
||||
credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name)
|
||||
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
|
||||
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
|
||||
|
||||
# check model
|
||||
provider_model = model.provider_model_bundle.configuration.get_provider_model(
|
||||
model=node_data_model.name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
provider_model.raise_for_status()
|
||||
|
||||
# model config
|
||||
stop: list[str] = []
|
||||
if "stop" in node_data_model.completion_params:
|
||||
stop = node_data_model.completion_params.pop("stop")
|
||||
|
||||
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
return model, ModelConfigWithCredentialsEntity(
|
||||
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
return model_instance, ModelConfigWithCredentialsEntity(
|
||||
provider=node_data_model.provider,
|
||||
model=node_data_model.name,
|
||||
model_schema=model_schema,
|
||||
mode=node_data_model.mode,
|
||||
provider_model_bundle=model.provider_model_bundle,
|
||||
credentials=model.credentials,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=credentials,
|
||||
parameters=node_data_model.completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
@@ -131,7 +130,7 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
|
||||
if quota_unit == QuotaUnit.TOKENS:
|
||||
used_quota = usage.total_tokens
|
||||
elif quota_unit == QuotaUnit.CREDITS:
|
||||
used_quota = dify_config.get_model_credits(model_instance.model)
|
||||
used_quota = dify_config.get_model_credits(model_instance.model_name)
|
||||
else:
|
||||
used_quota = 1
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
@@ -38,11 +38,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import (
|
||||
ModelFeature,
|
||||
ModelPropertyKey,
|
||||
ModelType,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
@@ -76,6 +72,7 @@ from core.workflow.node_events import (
|
||||
from core.workflow.nodes.base.entities import VariableSelector
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.runtime import VariablePool
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import SegmentAttachmentBinding
|
||||
@@ -93,7 +90,6 @@ from .exc import (
|
||||
InvalidVariableTypeError,
|
||||
LLMNodeError,
|
||||
MemoryRolePrefixRequiredError,
|
||||
ModelNotExistError,
|
||||
NoPromptFoundError,
|
||||
TemplateTypeNotSupportError,
|
||||
VariableNotFoundError,
|
||||
@@ -118,6 +114,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
_file_outputs: list[File]
|
||||
|
||||
_llm_file_saver: LLMFileSaver
|
||||
_credentials_provider: CredentialsProvider
|
||||
_model_factory: ModelFactory
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -126,6 +124,8 @@ class LLMNode(Node[LLMNodeData]):
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
*,
|
||||
credentials_provider: CredentialsProvider,
|
||||
model_factory: ModelFactory,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -137,6 +137,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs = []
|
||||
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=graph_init_params.user_id,
|
||||
@@ -199,10 +202,21 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
|
||||
|
||||
# fetch model config
|
||||
model_instance, model_config = LLMNode._fetch_model_config(
|
||||
model_instance, model_config = self._fetch_model_config(
|
||||
node_data_model=self.node_data.model,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
model_name = getattr(model_instance, "model_name", None)
|
||||
if not isinstance(model_name, str):
|
||||
model_name = model_config.model
|
||||
model_provider = getattr(model_instance, "provider", None)
|
||||
if not isinstance(model_provider, str):
|
||||
model_provider = model_config.provider
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
model_name,
|
||||
model_instance.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {model_name}")
|
||||
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
@@ -225,14 +239,16 @@ class LLMNode(Node[LLMNodeData]):
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=self.node_data.model.completion_params,
|
||||
stop=model_config.stop,
|
||||
prompt_template=self.node_data.prompt_template,
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
@@ -286,14 +302,14 @@ class LLMNode(Node[LLMNodeData]):
|
||||
structured_output = event
|
||||
|
||||
process_data = {
|
||||
"model_mode": model_config.mode,
|
||||
"model_mode": self.node_data.model.mode,
|
||||
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
|
||||
model_mode=model_config.mode, prompt_messages=prompt_messages
|
||||
model_mode=self.node_data.model.mode, prompt_messages=prompt_messages
|
||||
),
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
"model_provider": model_provider,
|
||||
"model_name": model_name,
|
||||
}
|
||||
|
||||
outputs = {
|
||||
@@ -755,21 +771,18 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _fetch_model_config(
|
||||
self,
|
||||
*,
|
||||
node_data_model: ModelConfig,
|
||||
tenant_id: str,
|
||||
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
|
||||
model, model_config_with_cred = llm_utils.fetch_model_config(
|
||||
tenant_id=tenant_id, node_data_model=node_data_model
|
||||
node_data_model=node_data_model,
|
||||
credentials_provider=self._credentials_provider,
|
||||
model_factory=self._model_factory,
|
||||
)
|
||||
completion_params = model_config_with_cred.parameters
|
||||
|
||||
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
|
||||
|
||||
model_config_with_cred.parameters = completion_params
|
||||
# NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`.
|
||||
node_data_model.completion_params = completion_params
|
||||
@@ -782,14 +795,16 @@ class LLMNode(Node[LLMNodeData]):
|
||||
sys_files: Sequence[File],
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
stop: Sequence[str] | None = None,
|
||||
memory_config: MemoryConfig | None = None,
|
||||
vision_enabled: bool = False,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL,
|
||||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
tenant_id: str,
|
||||
context_files: list[File] | None = None,
|
||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
@@ -810,7 +825,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||
memory_messages = _handle_memory_chat_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
# Extend prompt_messages with memory messages
|
||||
prompt_messages.extend(memory_messages)
|
||||
@@ -847,7 +864,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||
memory_text = _handle_memory_completion_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
# Insert histories into the prompt
|
||||
prompt_content = prompt_messages[0].content
|
||||
@@ -924,7 +943,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
prompt_message_content: list[PromptMessageContentUnionTypes] = []
|
||||
for content_item in prompt_message.content:
|
||||
# Skip content if features are not defined
|
||||
if not model_config.model_schema.features:
|
||||
if not model_schema.features:
|
||||
if content_item.type != PromptMessageContentType.TEXT:
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
@@ -934,19 +953,19 @@ class LLMNode(Node[LLMNodeData]):
|
||||
if (
|
||||
(
|
||||
content_item.type == PromptMessageContentType.IMAGE
|
||||
and ModelFeature.VISION not in model_config.model_schema.features
|
||||
and ModelFeature.VISION not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.DOCUMENT
|
||||
and ModelFeature.DOCUMENT not in model_config.model_schema.features
|
||||
and ModelFeature.DOCUMENT not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.VIDEO
|
||||
and ModelFeature.VIDEO not in model_config.model_schema.features
|
||||
and ModelFeature.VIDEO not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.AUDIO
|
||||
and ModelFeature.AUDIO not in model_config.model_schema.features
|
||||
and ModelFeature.AUDIO not in model_schema.features
|
||||
)
|
||||
):
|
||||
continue
|
||||
@@ -965,19 +984,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
"Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
|
||||
model = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=model_config.provider,
|
||||
model=model_config.model,
|
||||
)
|
||||
model_schema = model.model_type_instance.get_model_schema(
|
||||
model=model_config.model,
|
||||
credentials=model.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ModelNotExistError(f"Model {model_config.model} not exist.")
|
||||
return filtered_prompt_messages, model_config.stop
|
||||
return filtered_prompt_messages, stop
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
@@ -1306,26 +1313,26 @@ def _render_jinja2_message(
|
||||
|
||||
|
||||
def _calculate_rest_token(
|
||||
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||
*,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
) -> int:
|
||||
rest_tokens = 2000
|
||||
|
||||
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
model_instance = ModelInstance(
|
||||
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
|
||||
)
|
||||
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in model_config.model_schema.parameter_rules:
|
||||
for parameter_rule in model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
model_config.parameters.get(parameter_rule.name)
|
||||
or model_config.parameters.get(str(parameter_rule.use_template))
|
||||
model_parameters.get(parameter_rule.name)
|
||||
or model_parameters.get(str(parameter_rule.use_template))
|
||||
or 0
|
||||
)
|
||||
|
||||
@@ -1339,12 +1346,19 @@ def _handle_memory_chat_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
) -> Sequence[PromptMessage]:
|
||||
memory_messages: Sequence[PromptMessage] = []
|
||||
# Get messages from memory for chat model
|
||||
if memory and memory_config:
|
||||
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
||||
rest_tokens = _calculate_rest_token(
|
||||
prompt_messages=[],
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
memory_messages = memory.get_history_prompt_messages(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
@@ -1356,12 +1370,19 @@ def _handle_memory_completion_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
model_instance: ModelInstance,
|
||||
model_schema: AIModelEntity,
|
||||
model_parameters: Mapping[str, Any],
|
||||
) -> str:
|
||||
memory_text = ""
|
||||
# Get history text from memory for completion model
|
||||
if memory and memory_config:
|
||||
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
|
||||
rest_tokens = _calculate_rest_token(
|
||||
prompt_messages=[],
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=model_parameters,
|
||||
)
|
||||
if not memory_config.role_prefix:
|
||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||
memory_text = memory.get_history_prompt_text(
|
||||
|
||||
21
api/core/workflow/nodes/llm/protocols.py
Normal file
21
api/core/workflow/nodes/llm/protocols.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
|
||||
|
||||
class CredentialsProvider(Protocol):
|
||||
"""Port for loading runtime credentials for a provider/model pair."""
|
||||
|
||||
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
|
||||
"""Return credentials for the target provider/model or raise a domain error."""
|
||||
...
|
||||
|
||||
|
||||
class ModelFactory(Protocol):
|
||||
"""Port for creating initialized LLM model instances for execution."""
|
||||
|
||||
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
|
||||
"""Create a model instance that is ready for schema lookup and invocation."""
|
||||
...
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
@@ -60,6 +60,11 @@ from .prompts import (
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
def extract_json(text):
|
||||
"""
|
||||
@@ -92,6 +97,27 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
_model_instance: ModelInstance | None = None
|
||||
_model_config: ModelConfigWithCredentialsEntity | None = None
|
||||
_credentials_provider: "CredentialsProvider"
|
||||
_model_factory: "ModelFactory"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
@@ -806,7 +832,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
"""
|
||||
if not self._model_instance or not self._model_config:
|
||||
self._model_instance, self._model_config = llm_utils.fetch_model_config(
|
||||
tenant_id=self.tenant_id, node_data_model=node_data_model
|
||||
node_data_model=node_data_model,
|
||||
credentials_provider=self._credentials_provider,
|
||||
model_factory=self._model_factory,
|
||||
)
|
||||
|
||||
return self._model_instance, self._model_config
|
||||
|
||||
@@ -24,6 +24,7 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
|
||||
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
|
||||
from .entities import QuestionClassifierNodeData
|
||||
@@ -49,6 +50,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
|
||||
_file_outputs: list["File"]
|
||||
_llm_file_saver: LLMFileSaver
|
||||
_credentials_provider: "CredentialsProvider"
|
||||
_model_factory: "ModelFactory"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -57,6 +60,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
credentials_provider: "CredentialsProvider",
|
||||
model_factory: "ModelFactory",
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
@@ -68,6 +73,9 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
# LLM file outputs, used for MultiModal outputs.
|
||||
self._file_outputs = []
|
||||
|
||||
self._credentials_provider = credentials_provider
|
||||
self._model_factory = model_factory
|
||||
|
||||
if llm_file_saver is None:
|
||||
llm_file_saver = FileSaverImpl(
|
||||
user_id=graph_init_params.user_id,
|
||||
@@ -89,9 +97,16 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
variables = {"query": query}
|
||||
# fetch model config
|
||||
model_instance, model_config = llm_utils.fetch_model_config(
|
||||
tenant_id=self.tenant_id,
|
||||
node_data_model=node_data.model,
|
||||
credentials_provider=self._credentials_provider,
|
||||
model_factory=self._model_factory,
|
||||
)
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
model_instance.model_name,
|
||||
model_instance.credentials,
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {model_instance.model_name}")
|
||||
# fetch memory
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
@@ -133,13 +148,15 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
prompt_template=prompt_template,
|
||||
sys_query="",
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
model_instance=model_instance,
|
||||
model_schema=model_schema,
|
||||
model_parameters=node_data.model.completion_params,
|
||||
stop=model_config.stop,
|
||||
sys_files=files,
|
||||
vision_enabled=node_data.vision.enabled,
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=[],
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
|
||||
@@ -1,8 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
@@ -11,6 +10,7 @@ from core.app.workflow.layers.observability import ObservabilityLayer
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.graph_config import NodeConfigData, NodeConfigDict
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
from core.workflow.file.models import File
|
||||
from core.workflow.graph import Graph
|
||||
@@ -168,7 +168,8 @@ class WorkflowEntry:
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node = node_factory.create_node(node_config)
|
||||
typed_node_config = cast(dict[str, object], node_config)
|
||||
node = cast(Any, node_factory).create_node(typed_node_config)
|
||||
node_cls = type(node)
|
||||
|
||||
try:
|
||||
@@ -256,7 +257,7 @@ class WorkflowEntry:
|
||||
|
||||
@classmethod
|
||||
def run_free_node(
|
||||
cls, node_data: dict, node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
|
||||
cls, node_data: dict[str, Any], node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
|
||||
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
|
||||
"""
|
||||
Run free node
|
||||
@@ -302,16 +303,15 @@ class WorkflowEntry:
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init workflow run state
|
||||
node_config = {
|
||||
node_config: NodeConfigDict = {
|
||||
"id": node_id,
|
||||
"data": node_data,
|
||||
"data": cast(NodeConfigData, node_data),
|
||||
}
|
||||
node: Node = node_cls(
|
||||
id=str(uuid.uuid4()),
|
||||
config=node_config,
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
node = node_factory.create_node(node_config)
|
||||
|
||||
try:
|
||||
# variable selector to variable mapping
|
||||
|
||||
@@ -107,19 +107,19 @@ class AppService:
|
||||
|
||||
if model_instance:
|
||||
if (
|
||||
model_instance.model == default_model_config["model"]["name"]
|
||||
model_instance.model_name == default_model_config["model"]["name"]
|
||||
and model_instance.provider == default_model_config["model"]["provider"]
|
||||
):
|
||||
default_model_dict = default_model_config["model"]
|
||||
else:
|
||||
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
|
||||
if model_schema is None:
|
||||
raise ValueError(f"model schema not found for model {model_instance.model}")
|
||||
raise ValueError(f"model schema not found for model {model_instance.model_name}")
|
||||
|
||||
default_model_dict = {
|
||||
"provider": model_instance.provider,
|
||||
"name": model_instance.model,
|
||||
"name": model_instance.model_name,
|
||||
"mode": model_schema.model_properties.get(ModelPropertyKey.MODE),
|
||||
"completion_params": {},
|
||||
}
|
||||
|
||||
@@ -252,7 +252,7 @@ class DatasetService:
|
||||
dataset.updated_by = account.id
|
||||
dataset.tenant_id = tenant_id
|
||||
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
|
||||
dataset.embedding_model = embedding_model.model if embedding_model else None
|
||||
dataset.embedding_model = embedding_model.model_name if embedding_model else None
|
||||
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
|
||||
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
|
||||
dataset.provider = provider
|
||||
@@ -384,7 +384,7 @@ class DatasetService:
|
||||
model=model,
|
||||
)
|
||||
text_embedding_model = cast(TextEmbeddingModel, model_instance.model_type_instance)
|
||||
model_schema = text_embedding_model.get_model_schema(model_instance.model, model_instance.credentials)
|
||||
model_schema = text_embedding_model.get_model_schema(model_instance.model_name, model_instance.credentials)
|
||||
if not model_schema:
|
||||
raise ValueError("Model schema not found")
|
||||
if model_schema.features and ModelFeature.VISION in model_schema.features:
|
||||
@@ -743,10 +743,12 @@ class DatasetService:
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=data["embedding_model"],
|
||||
)
|
||||
filtered_data["embedding_model"] = embedding_model.model
|
||||
embedding_model_name = embedding_model.model_name
|
||||
filtered_data["embedding_model"] = embedding_model_name
|
||||
filtered_data["embedding_model_provider"] = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
embedding_model.provider,
|
||||
embedding_model_name,
|
||||
)
|
||||
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
||||
except LLMBadRequestError:
|
||||
@@ -876,10 +878,12 @@ class DatasetService:
|
||||
return
|
||||
|
||||
# Apply new embedding model settings
|
||||
filtered_data["embedding_model"] = embedding_model.model
|
||||
embedding_model_name = embedding_model.model_name
|
||||
filtered_data["embedding_model"] = embedding_model_name
|
||||
filtered_data["embedding_model_provider"] = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
embedding_model.provider,
|
||||
embedding_model_name,
|
||||
)
|
||||
filtered_data["collection_binding_id"] = dataset_collection_binding.id
|
||||
|
||||
@@ -955,10 +959,12 @@ class DatasetService:
|
||||
knowledge_configuration.embedding_model,
|
||||
)
|
||||
dataset.is_multimodal = is_multimodal
|
||||
dataset.embedding_model = embedding_model.model
|
||||
embedding_model_name = embedding_model.model_name
|
||||
dataset.embedding_model = embedding_model_name
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
embedding_model.provider,
|
||||
embedding_model_name,
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
elif knowledge_configuration.indexing_technique == "economy":
|
||||
@@ -989,10 +995,12 @@ class DatasetService:
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model=knowledge_configuration.embedding_model,
|
||||
)
|
||||
dataset.embedding_model = embedding_model.model
|
||||
embedding_model_name = embedding_model.model_name
|
||||
dataset.embedding_model = embedding_model_name
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
embedding_model.provider,
|
||||
embedding_model_name,
|
||||
)
|
||||
is_multimodal = DatasetService.check_is_multimodal_model(
|
||||
current_user.current_tenant_id,
|
||||
@@ -1049,11 +1057,13 @@ class DatasetService:
|
||||
skip_embedding_update = True
|
||||
if not skip_embedding_update:
|
||||
if embedding_model:
|
||||
dataset.embedding_model = embedding_model.model
|
||||
embedding_model_name = embedding_model.model_name
|
||||
dataset.embedding_model = embedding_model_name
|
||||
dataset.embedding_model_provider = embedding_model.provider
|
||||
dataset_collection_binding = (
|
||||
DatasetCollectionBindingService.get_dataset_collection_binding(
|
||||
embedding_model.provider, embedding_model.model
|
||||
embedding_model.provider,
|
||||
embedding_model_name,
|
||||
)
|
||||
)
|
||||
dataset.collection_binding_id = dataset_collection_binding.id
|
||||
@@ -1884,7 +1894,7 @@ class DocumentService:
|
||||
embedding_model = model_manager.get_default_model_instance(
|
||||
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
dataset_embedding_model = embedding_model.model
|
||||
dataset_embedding_model = embedding_model.model_name
|
||||
dataset_embedding_model_provider = embedding_model.provider
|
||||
dataset.embedding_model = dataset_embedding_model
|
||||
dataset.embedding_model_provider = dataset_embedding_model_provider
|
||||
|
||||
@@ -80,6 +80,8 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=MagicMock(),
|
||||
model_factory=MagicMock(),
|
||||
)
|
||||
|
||||
return node
|
||||
@@ -115,7 +117,7 @@ def test_execute_llm():
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# Mock the _fetch_model_config to avoid database calls
|
||||
def mock_fetch_model_config(**_kwargs):
|
||||
def mock_fetch_model_config(*_args, **_kwargs):
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
@@ -227,7 +229,7 @@ def test_execute_llm_with_jinja2():
|
||||
db.session.close = MagicMock()
|
||||
|
||||
# Mock the _fetch_model_config method
|
||||
def mock_fetch_model_config(**_kwargs):
|
||||
def mock_fetch_model_config(*_args, **_kwargs):
|
||||
from decimal import Decimal
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
|
||||
@@ -9,6 +9,7 @@ from core.model_runtime.entities import AssistantPromptMessage
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
@@ -84,6 +85,8 @@ def init_parameter_extractor_node(config: dict):
|
||||
config=config,
|
||||
graph_init_params=init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
@@ -82,7 +82,7 @@ class TestCacheEmbeddingDocuments:
|
||||
Mock: Configured ModelInstance with text embedding capabilities
|
||||
"""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "text-embedding-ada-002"
|
||||
model_instance.model_name = "text-embedding-ada-002"
|
||||
model_instance.provider = "openai"
|
||||
model_instance.credentials = {"api_key": "test-key"}
|
||||
|
||||
@@ -597,7 +597,7 @@ class TestCacheEmbeddingQuery:
|
||||
def mock_model_instance(self):
|
||||
"""Create a mock ModelInstance for testing."""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "text-embedding-ada-002"
|
||||
model_instance.model_name = "text-embedding-ada-002"
|
||||
model_instance.provider = "openai"
|
||||
model_instance.credentials = {"api_key": "test-key"}
|
||||
return model_instance
|
||||
@@ -830,7 +830,7 @@ class TestEmbeddingModelSwitching:
|
||||
"""
|
||||
# Arrange
|
||||
model_instance_ada = Mock()
|
||||
model_instance_ada.model = "text-embedding-ada-002"
|
||||
model_instance_ada.model_name = "text-embedding-ada-002"
|
||||
model_instance_ada.provider = "openai"
|
||||
|
||||
# Mock model type instance for ada
|
||||
@@ -841,7 +841,7 @@ class TestEmbeddingModelSwitching:
|
||||
model_type_instance_ada.get_model_schema.return_value = model_schema_ada
|
||||
|
||||
model_instance_3_small = Mock()
|
||||
model_instance_3_small.model = "text-embedding-3-small"
|
||||
model_instance_3_small.model_name = "text-embedding-3-small"
|
||||
model_instance_3_small.provider = "openai"
|
||||
|
||||
# Mock model type instance for 3-small
|
||||
@@ -914,11 +914,11 @@ class TestEmbeddingModelSwitching:
|
||||
"""
|
||||
# Arrange
|
||||
model_instance_openai = Mock()
|
||||
model_instance_openai.model = "text-embedding-ada-002"
|
||||
model_instance_openai.model_name = "text-embedding-ada-002"
|
||||
model_instance_openai.provider = "openai"
|
||||
|
||||
model_instance_cohere = Mock()
|
||||
model_instance_cohere.model = "embed-english-v3.0"
|
||||
model_instance_cohere.model_name = "embed-english-v3.0"
|
||||
model_instance_cohere.provider = "cohere"
|
||||
|
||||
cache_openai = CacheEmbedding(model_instance_openai)
|
||||
@@ -1001,7 +1001,7 @@ class TestEmbeddingDimensionValidation:
|
||||
def mock_model_instance(self):
|
||||
"""Create a mock ModelInstance for testing."""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "text-embedding-ada-002"
|
||||
model_instance.model_name = "text-embedding-ada-002"
|
||||
model_instance.provider = "openai"
|
||||
model_instance.credentials = {"api_key": "test-key"}
|
||||
|
||||
@@ -1123,7 +1123,7 @@ class TestEmbeddingDimensionValidation:
|
||||
"""
|
||||
# Arrange - OpenAI ada-002 (1536 dimensions)
|
||||
model_instance_ada = Mock()
|
||||
model_instance_ada.model = "text-embedding-ada-002"
|
||||
model_instance_ada.model_name = "text-embedding-ada-002"
|
||||
model_instance_ada.provider = "openai"
|
||||
|
||||
# Mock model type instance for ada
|
||||
@@ -1156,7 +1156,7 @@ class TestEmbeddingDimensionValidation:
|
||||
|
||||
# Arrange - Cohere embed-english-v3.0 (1024 dimensions)
|
||||
model_instance_cohere = Mock()
|
||||
model_instance_cohere.model = "embed-english-v3.0"
|
||||
model_instance_cohere.model_name = "embed-english-v3.0"
|
||||
model_instance_cohere.provider = "cohere"
|
||||
|
||||
# Mock model type instance for cohere
|
||||
@@ -1225,7 +1225,7 @@ class TestEmbeddingEdgeCases:
|
||||
- MAX_CHUNKS: 10
|
||||
"""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "text-embedding-ada-002"
|
||||
model_instance.model_name = "text-embedding-ada-002"
|
||||
model_instance.provider = "openai"
|
||||
|
||||
model_type_instance = Mock()
|
||||
@@ -1702,7 +1702,7 @@ class TestEmbeddingCachePerformance:
|
||||
- MAX_CHUNKS: 10
|
||||
"""
|
||||
model_instance = Mock()
|
||||
model_instance.model = "text-embedding-ada-002"
|
||||
model_instance.model_name = "text-embedding-ada-002"
|
||||
model_instance.provider = "openai"
|
||||
|
||||
model_type_instance = Mock()
|
||||
|
||||
@@ -34,7 +34,7 @@ def create_mock_model_instance():
|
||||
mock_instance.provider_model_bundle.configuration = Mock()
|
||||
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
|
||||
mock_instance.provider = "test-provider"
|
||||
mock_instance.model = "test-model"
|
||||
mock_instance.model_name = "test-model"
|
||||
return mock_instance
|
||||
|
||||
|
||||
@@ -65,7 +65,7 @@ class TestRerankModelRunner:
|
||||
mock_instance.provider_model_bundle.configuration = Mock()
|
||||
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
|
||||
mock_instance.provider = "test-provider"
|
||||
mock_instance.model = "test-model"
|
||||
mock_instance.model_name = "test-model"
|
||||
return mock_instance
|
||||
|
||||
@pytest.fixture
|
||||
|
||||
@@ -199,11 +199,32 @@ def test_mock_config_builder():
|
||||
|
||||
def test_mock_factory_node_type_detection():
|
||||
"""Test that MockNodeFactory correctly identifies nodes to mock."""
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
|
||||
from .test_mock_factory import MockNodeFactory
|
||||
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test",
|
||||
app_id="test",
|
||||
workflow_id="test",
|
||||
graph_config={},
|
||||
user_id="test",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
|
||||
start_at=0,
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
)
|
||||
factory = MockNodeFactory(
|
||||
graph_init_params=None, # Will be set by test
|
||||
graph_runtime_state=None, # Will be set by test
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=None,
|
||||
)
|
||||
|
||||
@@ -288,7 +309,11 @@ def test_workflow_without_auto_mock():
|
||||
|
||||
def test_register_custom_mock_node():
|
||||
"""Test registering a custom mock implementation for a node type."""
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
|
||||
from .test_mock_factory import MockNodeFactory
|
||||
|
||||
@@ -298,9 +323,25 @@ def test_register_custom_mock_node():
|
||||
# Custom mock implementation
|
||||
pass
|
||||
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test",
|
||||
app_id="test",
|
||||
workflow_id="test",
|
||||
graph_config={},
|
||||
user_id="test",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
|
||||
start_at=0,
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
)
|
||||
factory = MockNodeFactory(
|
||||
graph_init_params=None,
|
||||
graph_runtime_state=None,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import datetime
|
||||
import time
|
||||
from collections.abc import Iterable
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
@@ -82,7 +82,7 @@ def _build_branching_graph(
|
||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||
llm_data = LLMNodeData(
|
||||
title=title,
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=prompt_text,
|
||||
@@ -101,6 +101,8 @@ def _build_branching_graph(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
credentials_provider=mock.Mock(),
|
||||
model_factory=mock.Mock(),
|
||||
)
|
||||
return llm_node
|
||||
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import datetime
|
||||
import time
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
@@ -78,7 +78,7 @@ def _build_llm_human_llm_graph(
|
||||
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
|
||||
llm_data = LLMNodeData(
|
||||
title=title,
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
|
||||
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||
prompt_template=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=prompt_text,
|
||||
@@ -97,6 +97,8 @@ def _build_llm_human_llm_graph(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
credentials_provider=mock.Mock(),
|
||||
model_factory=mock.Mock(),
|
||||
)
|
||||
return llm_node
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import time
|
||||
from unittest import mock
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMMode
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
@@ -85,6 +86,8 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=mock_config,
|
||||
credentials_provider=mock.Mock(),
|
||||
model_factory=mock.Mock(),
|
||||
)
|
||||
return llm_node
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ This module provides a MockNodeFactory that automatically detects and mocks node
|
||||
requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request).
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
@@ -74,7 +75,7 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
NodeType.CODE: MockCodeNode,
|
||||
}
|
||||
|
||||
def create_node(self, node_config: dict[str, Any]) -> Node:
|
||||
def create_node(self, node_config: Mapping[str, Any]) -> Node:
|
||||
"""
|
||||
Create a node instance, using mock implementations for third-party service nodes.
|
||||
|
||||
@@ -123,6 +124,16 @@ class MockNodeFactory(DifyNodeFactory):
|
||||
mock_config=self.mock_config,
|
||||
http_request_config=self._http_request_config,
|
||||
)
|
||||
elif node_type in {NodeType.LLM, NodeType.QUESTION_CLASSIFIER, NodeType.PARAMETER_EXTRACTOR}:
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
mock_config=self.mock_config,
|
||||
credentials_provider=self._llm_credentials_provider,
|
||||
model_factory=self._llm_model_factory,
|
||||
)
|
||||
else:
|
||||
mock_instance = mock_class(
|
||||
id=node_id,
|
||||
|
||||
@@ -16,9 +16,33 @@ from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNo
|
||||
|
||||
def test_mock_factory_registers_iteration_node():
|
||||
"""Test that MockNodeFactory has iteration node registered."""
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
|
||||
# Create a MockNodeFactory instance
|
||||
factory = MockNodeFactory(graph_init_params=None, graph_runtime_state=None, mock_config=None)
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test",
|
||||
app_id="test",
|
||||
workflow_id="test",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
user_id="test",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
|
||||
start_at=0,
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
)
|
||||
factory = MockNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=None,
|
||||
)
|
||||
|
||||
# Check that iteration node is registered
|
||||
assert NodeType.ITERATION in factory._mock_node_types
|
||||
|
||||
@@ -8,6 +8,7 @@ allowing tests to run without external dependencies.
|
||||
import time
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Optional
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
@@ -18,6 +19,7 @@ from core.workflow.nodes.document_extractor import DocumentExtractorNode
|
||||
from core.workflow.nodes.http_request import HttpRequestNode
|
||||
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.llm import LLMNode
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from core.workflow.nodes.question_classifier import QuestionClassifierNode
|
||||
from core.workflow.nodes.template_transform import TemplateTransformNode
|
||||
@@ -42,6 +44,10 @@ class MockNodeMixin:
|
||||
mock_config: Optional["MockConfig"] = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
if isinstance(self, (LLMNode, QuestionClassifierNode)):
|
||||
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
|
||||
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
|
||||
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
|
||||
@@ -101,11 +101,32 @@ def test_node_mock_config():
|
||||
|
||||
def test_mock_factory_detection():
|
||||
"""Test MockNodeFactory node type detection."""
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
|
||||
print("Testing MockNodeFactory detection...")
|
||||
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test",
|
||||
app_id="test",
|
||||
workflow_id="test",
|
||||
graph_config={},
|
||||
user_id="test",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
|
||||
start_at=0,
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
)
|
||||
factory = MockNodeFactory(
|
||||
graph_init_params=None,
|
||||
graph_runtime_state=None,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=None,
|
||||
)
|
||||
|
||||
@@ -133,11 +154,32 @@ def test_mock_factory_detection():
|
||||
|
||||
def test_mock_factory_registration():
|
||||
"""Test registering and unregistering mock node types."""
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
|
||||
print("Testing MockNodeFactory registration...")
|
||||
|
||||
graph_init_params = GraphInitParams(
|
||||
tenant_id="test",
|
||||
app_id="test",
|
||||
workflow_id="test",
|
||||
graph_config={},
|
||||
user_id="test",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
call_depth=0,
|
||||
)
|
||||
graph_runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
|
||||
start_at=0,
|
||||
total_tokens=0,
|
||||
node_run_steps=0,
|
||||
)
|
||||
factory = MockNodeFactory(
|
||||
graph_init_params=None,
|
||||
graph_runtime_state=None,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
mock_config=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ from unittest import mock
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||
from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config
|
||||
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
|
||||
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
|
||||
from core.model_runtime.entities.common_entities import I18nObject
|
||||
@@ -32,6 +33,7 @@ from core.workflow.nodes.llm.entities import (
|
||||
)
|
||||
from core.workflow.nodes.llm.file_saver import LLMFileSaver
|
||||
from core.workflow.nodes.llm.node import LLMNode
|
||||
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from models.enums import UserFrom
|
||||
@@ -100,6 +102,8 @@ def llm_node(
|
||||
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState
|
||||
) -> LLMNode:
|
||||
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
|
||||
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
||||
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
@@ -109,13 +113,29 @@ def llm_node(
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=mock_credentials_provider,
|
||||
model_factory=mock_model_factory,
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
return node
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def model_config():
|
||||
def model_config(monkeypatch):
|
||||
from tests.integration_tests.model_runtime.__mock.plugin_model import MockModelClass
|
||||
|
||||
def mock_plugin_model_providers(_self):
|
||||
providers = MockModelClass().fetch_model_providers("test")
|
||||
for provider in providers:
|
||||
provider.declaration.provider = f"{provider.plugin_id}/{provider.declaration.provider}"
|
||||
return providers
|
||||
|
||||
monkeypatch.setattr(
|
||||
ModelProviderFactory,
|
||||
"get_plugin_model_providers",
|
||||
mock_plugin_model_providers,
|
||||
)
|
||||
|
||||
# Create actual provider and model type instances
|
||||
model_provider_factory = ModelProviderFactory(tenant_id="test")
|
||||
provider_instance = model_provider_factory.get_plugin_model_provider("openai")
|
||||
@@ -125,7 +145,7 @@ def model_config():
|
||||
provider_model_bundle = ProviderModelBundle(
|
||||
configuration=ProviderConfiguration(
|
||||
tenant_id="1",
|
||||
provider=provider_instance,
|
||||
provider=provider_instance.declaration,
|
||||
preferred_provider_type=ProviderType.CUSTOM,
|
||||
using_provider_type=ProviderType.CUSTOM,
|
||||
system_configuration=SystemConfiguration(enabled=False),
|
||||
@@ -153,6 +173,89 @@ def model_config():
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsEntity):
|
||||
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
||||
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
||||
|
||||
provider_model_bundle = model_config.provider_model_bundle
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
provider_model = mock.MagicMock()
|
||||
|
||||
model_instance = mock.MagicMock(
|
||||
model_type_instance=model_type_instance,
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
)
|
||||
|
||||
mock_credentials_provider.fetch.return_value = {"api_key": "test"}
|
||||
mock_model_factory.init_model_instance.return_value = model_instance
|
||||
|
||||
with (
|
||||
mock.patch.object(
|
||||
provider_model_bundle.configuration.__class__,
|
||||
"get_provider_model",
|
||||
return_value=provider_model,
|
||||
),
|
||||
mock.patch.object(
|
||||
model_type_instance.__class__,
|
||||
"get_model_schema",
|
||||
return_value=model_config.model_schema,
|
||||
),
|
||||
):
|
||||
fetch_model_config(
|
||||
node_data_model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
|
||||
credentials_provider=mock_credentials_provider,
|
||||
model_factory=mock_model_factory,
|
||||
)
|
||||
|
||||
mock_credentials_provider.fetch.assert_called_once_with("openai", "gpt-3.5-turbo")
|
||||
mock_model_factory.init_model_instance.assert_called_once_with("openai", "gpt-3.5-turbo")
|
||||
provider_model.raise_for_status.assert_called_once()
|
||||
|
||||
|
||||
def test_dify_model_access_adapters_call_managers():
|
||||
mock_provider_manager = mock.MagicMock()
|
||||
mock_model_manager = mock.MagicMock()
|
||||
mock_configurations = mock.MagicMock()
|
||||
mock_provider_configuration = mock.MagicMock()
|
||||
mock_provider_model = mock.MagicMock()
|
||||
|
||||
mock_configurations.get.return_value = mock_provider_configuration
|
||||
mock_provider_configuration.get_provider_model.return_value = mock_provider_model
|
||||
mock_provider_configuration.get_current_credentials.return_value = {"api_key": "test"}
|
||||
|
||||
credentials_provider = DifyCredentialsProvider(
|
||||
tenant_id="tenant",
|
||||
provider_manager=mock_provider_manager,
|
||||
)
|
||||
model_factory = DifyModelFactory(
|
||||
tenant_id="tenant",
|
||||
model_manager=mock_model_manager,
|
||||
)
|
||||
|
||||
mock_provider_manager.get_configurations.return_value = mock_configurations
|
||||
|
||||
credentials_provider.fetch("openai", "gpt-3.5-turbo")
|
||||
model_factory.init_model_instance("openai", "gpt-3.5-turbo")
|
||||
|
||||
mock_provider_manager.get_configurations.assert_called_once_with("tenant")
|
||||
mock_configurations.get.assert_called_once_with("openai")
|
||||
mock_provider_configuration.get_provider_model.assert_called_once_with(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
mock_provider_configuration.get_current_credentials.assert_called_once_with(
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
mock_provider_model.raise_for_status.assert_called_once()
|
||||
mock_model_manager.get_model_instance.assert_called_once_with(
|
||||
tenant_id="tenant",
|
||||
provider="openai",
|
||||
model_type=ModelType.LLM,
|
||||
model="gpt-3.5-turbo",
|
||||
)
|
||||
|
||||
|
||||
def test_fetch_files_with_file_segment():
|
||||
file = File(
|
||||
id="1",
|
||||
@@ -485,6 +588,8 @@ def test_handle_list_messages_basic(llm_node):
|
||||
@pytest.fixture
|
||||
def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_state) -> tuple[LLMNode, LLMFileSaver]:
|
||||
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
|
||||
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
||||
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
@@ -494,6 +599,8 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||
config=node_config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
credentials_provider=mock_credentials_provider,
|
||||
model_factory=mock_model_factory,
|
||||
llm_file_saver=mock_file_saver,
|
||||
)
|
||||
return node, mock_file_saver
|
||||
|
||||
@@ -642,8 +642,16 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings:
|
||||
|
||||
# Mock embedding model
|
||||
mock_embedding_model = Mock()
|
||||
mock_embedding_model.model = "text-embedding-ada-002"
|
||||
mock_embedding_model.model_name = "text-embedding-ada-002"
|
||||
mock_embedding_model.provider = "openai"
|
||||
mock_embedding_model.credentials = {}
|
||||
|
||||
mock_model_schema = Mock()
|
||||
mock_model_schema.features = []
|
||||
|
||||
mock_text_embedding_model = Mock()
|
||||
mock_text_embedding_model.get_model_schema.return_value = mock_model_schema
|
||||
mock_embedding_model.model_type_instance = mock_text_embedding_model
|
||||
|
||||
mock_model_instance = Mock()
|
||||
mock_model_instance.get_model_instance.return_value = mock_embedding_model
|
||||
|
||||
@@ -174,7 +174,7 @@ class DatasetServiceTestDataFactory:
|
||||
Mock: Embedding model mock with model and provider attributes
|
||||
"""
|
||||
embedding_model = Mock()
|
||||
embedding_model.model = model
|
||||
embedding_model.model_name = model
|
||||
embedding_model.provider = provider
|
||||
return embedding_model
|
||||
|
||||
@@ -434,7 +434,7 @@ class TestDatasetServiceCreateDataset:
|
||||
# Assert
|
||||
assert result.indexing_technique == "high_quality"
|
||||
assert result.embedding_model_provider == embedding_model.provider
|
||||
assert result.embedding_model == embedding_model.model
|
||||
assert result.embedding_model == embedding_model.model_name
|
||||
mock_model_manager_instance.get_default_model_instance.assert_called_once_with(
|
||||
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
|
||||
@@ -46,7 +46,7 @@ class DatasetCreateTestDataFactory:
|
||||
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
|
||||
"""Create a mock embedding model."""
|
||||
embedding_model = Mock()
|
||||
embedding_model.model = model
|
||||
embedding_model.model_name = model
|
||||
embedding_model.provider = provider
|
||||
return embedding_model
|
||||
|
||||
@@ -244,7 +244,7 @@ class TestDatasetServiceCreateEmptyDataset:
|
||||
# Assert
|
||||
assert result.indexing_technique == "high_quality"
|
||||
assert result.embedding_model_provider == embedding_model.provider
|
||||
assert result.embedding_model == embedding_model.model
|
||||
assert result.embedding_model == embedding_model.model_name
|
||||
mock_model_manager_instance.get_default_model_instance.assert_called_once_with(
|
||||
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
|
||||
)
|
||||
|
||||
@@ -65,7 +65,7 @@ class DatasetUpdateTestDataFactory:
|
||||
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
|
||||
"""Create a mock embedding model."""
|
||||
embedding_model = Mock()
|
||||
embedding_model.model = model
|
||||
embedding_model.model_name = model
|
||||
embedding_model.provider = provider
|
||||
return embedding_model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user