Compare commits

...

11 Commits

Author SHA1 Message Date
-LAN-
27ee32088d chore: code format 2026-02-25 17:00:03 +08:00
-LAN-
6dcca77820 refactor: use model_name directly in dataset service 2026-02-25 16:59:28 +08:00
-LAN-
8ff55da8b7 refactor(model-instance): use model_name consistently 2026-02-25 16:59:27 +08:00
-LAN-
3d74218034 fix(rag): support legacy model attr in mock model instances 2026-02-25 16:59:27 +08:00
-LAN-
b94a338636 chore: keep redis broadcast subscription unchanged 2026-02-25 16:59:26 +08:00
-LAN-
74a32a7715 fix(type-check): resolve style job typing failures 2026-02-25 16:59:26 +08:00
-LAN-
dd68b3608e refactor(node_factory): Make factory simpler
Signed-off-by: -LAN- <laipz8200@outlook.com>
2026-02-25 16:59:26 +08:00
-LAN-
e7c82f1158 refactor(model_schema): rename model to model_name
Signed-off-by: -LAN- <laipz8200@outlook.com>
2026-02-25 16:59:25 +08:00
-LAN-
d7cdbd6cca feat(llm_node): use model instance more
Signed-off-by: -LAN- <laipz8200@outlook.com>
2026-02-25 16:59:25 +08:00
-LAN-
e74d3791dc chore: rename and comments 2026-02-25 16:59:24 +08:00
-LAN-
90bab9c8a3 feat(llm): Extract CredentialsProvider and ModelFactory
Signed-off-by: -LAN- <laipz8200@outlook.com>
2026-02-25 16:59:23 +08:00
38 changed files with 675 additions and 178 deletions

View File

@@ -91,7 +91,6 @@ forbidden_modules =
core.logging
core.mcp
core.memory
core.model_manager
core.moderation
core.ops
core.plugin
@@ -121,6 +120,7 @@ ignore_imports =
core.workflow.nodes.llm.llm_utils -> configs
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
core.workflow.nodes.llm.llm_utils -> core.model_manager
core.workflow.nodes.llm.protocols -> core.model_manager
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.llm.llm_utils -> models.model
core.workflow.nodes.llm.llm_utils -> models.provider

View File

@@ -112,7 +112,7 @@ class BaseAgentRunner(AppRunner):
# check if model supports stream tool call
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
features = model_schema.features if model_schema and model_schema.features else []
self.stream_tool_call = ModelFeature.STREAM_TOOL_CALL in features
self.files = application_generate_entity.files if ModelFeature.VISION in features else []

View File

@@ -245,7 +245,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
iteration_step += 1
yield LLMResultChunk(
model=model_instance.model,
model=model_instance.model_name,
prompt_messages=prompt_messages,
delta=LLMResultChunkDelta(
index=0, message=AssistantPromptMessage(content=final_answer), usage=llm_usage["usage"]
@@ -268,7 +268,7 @@ class CotAgentRunner(BaseAgentRunner, ABC):
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
model=model_instance.model_name,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),

View File

@@ -178,7 +178,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
)
yield LLMResultChunk(
model=model_instance.model,
model=model_instance.model_name,
prompt_messages=result.prompt_messages,
system_fingerprint=result.system_fingerprint,
delta=LLMResultChunkDelta(
@@ -308,7 +308,7 @@ class FunctionCallAgentRunner(BaseAgentRunner):
self.queue_manager.publish(
QueueMessageEndEvent(
llm_result=LLMResult(
model=model_instance.model,
model=model_instance.model_name,
prompt_messages=prompt_messages,
message=AssistantPromptMessage(content=final_answer),
usage=llm_usage["usage"] or LLMUsage.empty_usage(),

View File

@@ -178,7 +178,7 @@ class AgentChatAppRunner(AppRunner):
# change function call strategy based on LLM model
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
if not model_schema:
raise ValueError("Model schema not found")

View File

@@ -0,0 +1 @@
"""LLM-related application services."""

View File

@@ -0,0 +1,103 @@
from __future__ import annotations
from typing import Any
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.errors.error import ProviderTokenNotInitError
from core.model_manager import ModelInstance, ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.provider_manager import ProviderManager
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
class DifyCredentialsProvider:
tenant_id: str
provider_manager: ProviderManager
def __init__(self, tenant_id: str, provider_manager: ProviderManager | None = None) -> None:
self.tenant_id = tenant_id
self.provider_manager = provider_manager or ProviderManager()
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
provider_configurations = self.provider_manager.get_configurations(self.tenant_id)
provider_configuration = provider_configurations.get(provider_name)
if not provider_configuration:
raise ValueError(f"Provider {provider_name} does not exist.")
provider_model = provider_configuration.get_provider_model(model_type=ModelType.LLM, model=model_name)
if provider_model is None:
raise ModelNotExistError(f"Model {model_name} not exist.")
provider_model.raise_for_status()
credentials = provider_configuration.get_current_credentials(model_type=ModelType.LLM, model=model_name)
if credentials is None:
raise ProviderTokenNotInitError(f"Model {model_name} credentials is not initialized.")
return credentials
class DifyModelFactory:
tenant_id: str
model_manager: ModelManager
def __init__(self, tenant_id: str, model_manager: ModelManager | None = None) -> None:
self.tenant_id = tenant_id
self.model_manager = model_manager or ModelManager()
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
return self.model_manager.get_model_instance(
tenant_id=self.tenant_id,
provider=provider_name,
model_type=ModelType.LLM,
model=model_name,
)
def build_dify_model_access(tenant_id: str) -> tuple[CredentialsProvider, ModelFactory]:
return (
DifyCredentialsProvider(tenant_id=tenant_id),
DifyModelFactory(tenant_id=tenant_id),
)
def fetch_model_config(
*,
node_data_model: ModelConfig,
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name)
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
provider_model_bundle = model_instance.provider_model_bundle
provider_model = provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name,
model_type=ModelType.LLM,
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
stop: list[str] = []
if "stop" in node_data_model.completion_params:
stop = node_data_model.completion_params.pop("stop")
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
return model_instance, ModelConfigWithCredentialsEntity(
provider=node_data_model.provider,
model=node_data_model.name,
model_schema=model_schema,
mode=node_data_model.mode,
provider_model_bundle=provider_model_bundle,
credentials=credentials,
parameters=node_data_model.completion_params,
stop=stop,
)

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, final
from typing_extensions import override
from configs import dify_config
from core.app.llm.model_access import build_dify_model_access
from core.helper.code_executor.code_executor import CodeExecutor
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.ssrf_proxy import ssrf_proxy
@@ -18,8 +19,13 @@ from core.workflow.nodes.code.limits import CodeNodeLimits
from core.workflow.nodes.document_extractor import DocumentExtractorNode, UnstructuredApiConfig
from core.workflow.nodes.http_request import HttpRequestNode, build_http_request_config
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from core.workflow.nodes.template_transform.template_renderer import CodeExecutorJinja2TemplateRenderer
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
from core.workflow.nodes.template_transform.template_renderer import (
CodeExecutorJinja2TemplateRenderer,
)
from core.workflow.nodes.template_transform.template_transform_node import TemplateTransformNode
if TYPE_CHECKING:
@@ -75,6 +81,8 @@ class DifyNodeFactory(NodeFactory):
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
)
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(graph_init_params.tenant_id)
@override
def create_node(self, node_config: NodeConfigDict) -> Node:
"""
@@ -140,6 +148,16 @@ class DifyNodeFactory(NodeFactory):
file_manager=self._http_request_file_manager,
)
if node_type == NodeType.LLM:
return LLMNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
)
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
return KnowledgeRetrievalNode(
id=node_id,
@@ -158,6 +176,26 @@ class DifyNodeFactory(NodeFactory):
unstructured_api_config=self._document_extractor_unstructured_api_config,
)
if node_type == NodeType.QUESTION_CLASSIFIER:
return QuestionClassifierNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
)
if node_type == NodeType.PARAMETER_EXTRACTOR:
return ParameterExtractorNode(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
)
return node_class(
id=node_id,
config=node_config,

View File

@@ -35,7 +35,7 @@ class ModelInstance:
def __init__(self, provider_model_bundle: ProviderModelBundle, model: str):
self.provider_model_bundle = provider_model_bundle
self.model = model
self.model_name = model
self.provider = provider_model_bundle.configuration.provider.provider
self.credentials = self._fetch_credentials_from_bundle(provider_model_bundle, model)
self.model_type_instance = self.provider_model_bundle.model_type_instance
@@ -163,7 +163,7 @@ class ModelInstance:
Union[LLMResult, Generator],
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
model=self.model_name,
credentials=self.credentials,
prompt_messages=prompt_messages,
model_parameters=model_parameters,
@@ -191,7 +191,7 @@ class ModelInstance:
int,
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,
model=self.model_name,
credentials=self.credentials,
prompt_messages=prompt_messages,
tools=tools,
@@ -215,7 +215,7 @@ class ModelInstance:
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
model=self.model_name,
credentials=self.credentials,
texts=texts,
user=user,
@@ -243,7 +243,7 @@ class ModelInstance:
EmbeddingResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
model=self.model_name,
credentials=self.credentials,
multimodel_documents=multimodel_documents,
user=user,
@@ -264,7 +264,7 @@ class ModelInstance:
list[int],
self._round_robin_invoke(
function=self.model_type_instance.get_num_tokens,
model=self.model,
model=self.model_name,
credentials=self.credentials,
texts=texts,
),
@@ -294,7 +294,7 @@ class ModelInstance:
RerankResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
@@ -328,7 +328,7 @@ class ModelInstance:
RerankResult,
self._round_robin_invoke(
function=self.model_type_instance.invoke_multimodal_rerank,
model=self.model,
model=self.model_name,
credentials=self.credentials,
query=query,
docs=docs,
@@ -352,7 +352,7 @@ class ModelInstance:
bool,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
model=self.model_name,
credentials=self.credentials,
text=text,
user=user,
@@ -373,7 +373,7 @@ class ModelInstance:
str,
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
model=self.model_name,
credentials=self.credentials,
file=file,
user=user,
@@ -396,7 +396,7 @@ class ModelInstance:
Iterable[bytes],
self._round_robin_invoke(
function=self.model_type_instance.invoke,
model=self.model,
model=self.model_name,
credentials=self.credentials,
content_text=content_text,
user=user,
@@ -469,7 +469,7 @@ class ModelInstance:
if not isinstance(self.model_type_instance, TTSModel):
raise Exception("Model type instance is not TTSModel")
return self.model_type_instance.get_tts_model_voices(
model=self.model, credentials=self.credentials, language=language
model=self.model_name, credentials=self.credentials, language=language
)

View File

@@ -47,7 +47,9 @@ class AgentHistoryPromptTransform(PromptTransform):
model_type_instance = cast(LargeLanguageModel, model_type_instance)
curr_message_tokens = model_type_instance.get_num_tokens(
self.memory.model_instance.model, self.memory.model_instance.credentials, self.history_messages
self.model_config.model,
self.model_config.credentials,
self.history_messages,
)
if curr_message_tokens <= max_token_limit:
return self.history_messages
@@ -63,7 +65,9 @@ class AgentHistoryPromptTransform(PromptTransform):
# a message is start with UserPromptMessage
if isinstance(prompt_message, UserPromptMessage):
curr_message_tokens = model_type_instance.get_num_tokens(
self.memory.model_instance.model, self.memory.model_instance.credentials, prompt_messages
self.model_config.model,
self.model_config.credentials,
prompt_messages,
)
# if current message token is overflow, drop all the prompts in current message and break
if curr_message_tokens > max_token_limit:

View File

@@ -35,7 +35,9 @@ class CacheEmbedding(Embeddings):
embedding = (
db.session.query(Embedding)
.filter_by(
model_name=self._model_instance.model, hash=hash, provider_name=self._model_instance.provider
model_name=self._model_instance.model_name,
hash=hash,
provider_name=self._model_instance.provider,
)
.first()
)
@@ -52,7 +54,7 @@ class CacheEmbedding(Embeddings):
try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(
self._model_instance.model, self._model_instance.credentials
self._model_instance.model_name, self._model_instance.credentials
)
max_chunks = (
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
@@ -87,7 +89,7 @@ class CacheEmbedding(Embeddings):
hash = helper.generate_text_hash(texts[i])
if hash not in cache_embeddings:
embedding_cache = Embedding(
model_name=self._model_instance.model,
model_name=self._model_instance.model_name,
hash=hash,
provider_name=self._model_instance.provider,
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
@@ -114,7 +116,9 @@ class CacheEmbedding(Embeddings):
embedding = (
db.session.query(Embedding)
.filter_by(
model_name=self._model_instance.model, hash=file_id, provider_name=self._model_instance.provider
model_name=self._model_instance.model_name,
hash=file_id,
provider_name=self._model_instance.provider,
)
.first()
)
@@ -131,7 +135,7 @@ class CacheEmbedding(Embeddings):
try:
model_type_instance = cast(TextEmbeddingModel, self._model_instance.model_type_instance)
model_schema = model_type_instance.get_model_schema(
self._model_instance.model, self._model_instance.credentials
self._model_instance.model_name, self._model_instance.credentials
)
max_chunks = (
model_schema.model_properties[ModelPropertyKey.MAX_CHUNKS]
@@ -168,7 +172,7 @@ class CacheEmbedding(Embeddings):
file_id = multimodel_documents[i]["file_id"]
if file_id not in cache_embeddings:
embedding_cache = Embedding(
model_name=self._model_instance.model,
model_name=self._model_instance.model_name,
hash=file_id,
provider_name=self._model_instance.provider,
embedding=pickle.dumps(n_embedding, protocol=pickle.HIGHEST_PROTOCOL),
@@ -190,7 +194,7 @@ class CacheEmbedding(Embeddings):
"""Embed query text."""
# use doc embedding cache or store if not exists
hash = helper.generate_text_hash(text)
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{hash}"
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{hash}"
embedding = redis_client.get(embedding_cache_key)
if embedding:
redis_client.expire(embedding_cache_key, 600)
@@ -233,7 +237,7 @@ class CacheEmbedding(Embeddings):
"""Embed multimodal documents."""
# use doc embedding cache or store if not exists
file_id = multimodel_document["file_id"]
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model}_{file_id}"
embedding_cache_key = f"{self._model_instance.provider}_{self._model_instance.model_name}_{file_id}"
embedding = redis_client.get(embedding_cache_key)
if embedding:
redis_client.expire(embedding_cache_key, 600)

View File

@@ -38,7 +38,7 @@ class RerankModelRunner(BaseRerankRunner):
is_support_vision = model_manager.check_model_support_vision(
tenant_id=self.rerank_model_instance.provider_model_bundle.configuration.tenant_id,
provider=self.rerank_model_instance.provider,
model=self.rerank_model_instance.model,
model=self.rerank_model_instance.model_name,
model_type=ModelType.RERANK,
)
if not is_support_vision:

View File

@@ -47,7 +47,7 @@ class ModelInvocationUtils:
raise InvokeModelError("Model not found")
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
if not schema:
raise InvokeModelError("No model schema found")

View File

@@ -8,7 +8,7 @@ from configs import dify_config
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
@@ -17,6 +17,8 @@ from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegme
from core.workflow.enums import SystemVariableKey
from core.workflow.file.models import File
from core.workflow.nodes.llm.entities import ModelConfig
from core.workflow.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
@@ -24,49 +26,46 @@ from models.model import Conversation
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
from .exc import InvalidVariableTypeError, LLMModeRequiredError, ModelNotExistError
from .exc import InvalidVariableTypeError
def fetch_model_config(
tenant_id: str, node_data_model: ModelConfig
*,
node_data_model: ModelConfig,
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
if not node_data_model.mode:
raise LLMModeRequiredError("LLM mode is required.")
model = ModelManager().get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=node_data_model.provider,
credentials = credentials_provider.fetch(node_data_model.provider, node_data_model.name)
model_instance = model_factory.init_model_instance(node_data_model.provider, node_data_model.name)
provider_model_bundle = model_instance.provider_model_bundle
provider_model = provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name,
model_type=ModelType.LLM,
)
model.model_type_instance = cast(LargeLanguageModel, model.model_type_instance)
# check model
provider_model = model.provider_model_bundle.configuration.get_provider_model(
model=node_data_model.name, model_type=ModelType.LLM
)
if provider_model is None:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
provider_model.raise_for_status()
# model config
stop: list[str] = []
if "stop" in node_data_model.completion_params:
stop = node_data_model.completion_params.pop("stop")
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
model_schema = model_instance.model_type_instance.get_model_schema(node_data_model.name, credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
return model, ModelConfigWithCredentialsEntity(
model_instance.model_type_instance = cast(LargeLanguageModel, model_instance.model_type_instance)
return model_instance, ModelConfigWithCredentialsEntity(
provider=node_data_model.provider,
model=node_data_model.name,
model_schema=model_schema,
mode=node_data_model.mode,
provider_model_bundle=model.provider_model_bundle,
credentials=model.credentials,
provider_model_bundle=provider_model_bundle,
credentials=credentials,
parameters=node_data_model.completion_params,
stop=stop,
)
@@ -131,7 +130,7 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model)
used_quota = dify_config.get_model_credits(model_instance.model_name)
else:
used_quota = 1

View File

@@ -16,7 +16,7 @@ from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.model_manager import ModelInstance
from core.model_runtime.entities import (
ImagePromptMessageContent,
PromptMessage,
@@ -38,11 +38,7 @@ from core.model_runtime.entities.message_entities import (
SystemPromptMessage,
UserPromptMessage,
)
from core.model_runtime.entities.model_entities import (
ModelFeature,
ModelPropertyKey,
ModelType,
)
from core.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey
from core.model_runtime.utils.encoders import jsonable_encoder
from core.prompt.entities.advanced_prompt_entities import CompletionModelPromptTemplate, MemoryConfig
from core.prompt.utils.prompt_message_util import PromptMessageUtil
@@ -76,6 +72,7 @@ from core.workflow.node_events import (
from core.workflow.nodes.base.entities import VariableSelector
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.runtime import VariablePool
from extensions.ext_database import db
from models.dataset import SegmentAttachmentBinding
@@ -93,7 +90,6 @@ from .exc import (
InvalidVariableTypeError,
LLMNodeError,
MemoryRolePrefixRequiredError,
ModelNotExistError,
NoPromptFoundError,
TemplateTypeNotSupportError,
VariableNotFoundError,
@@ -118,6 +114,8 @@ class LLMNode(Node[LLMNodeData]):
_file_outputs: list[File]
_llm_file_saver: LLMFileSaver
_credentials_provider: CredentialsProvider
_model_factory: ModelFactory
def __init__(
self,
@@ -126,6 +124,8 @@ class LLMNode(Node[LLMNodeData]):
graph_init_params: GraphInitParams,
graph_runtime_state: GraphRuntimeState,
*,
credentials_provider: CredentialsProvider,
model_factory: ModelFactory,
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
@@ -137,6 +137,9 @@ class LLMNode(Node[LLMNodeData]):
# LLM file outputs, used for MultiModal outputs.
self._file_outputs = []
self._credentials_provider = credentials_provider
self._model_factory = model_factory
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
@@ -199,10 +202,21 @@ class LLMNode(Node[LLMNodeData]):
node_inputs["#context_files#"] = [file.model_dump() for file in context_files]
# fetch model config
model_instance, model_config = LLMNode._fetch_model_config(
model_instance, model_config = self._fetch_model_config(
node_data_model=self.node_data.model,
tenant_id=self.tenant_id,
)
model_name = getattr(model_instance, "model_name", None)
if not isinstance(model_name, str):
model_name = model_config.model
model_provider = getattr(model_instance, "provider", None)
if not isinstance(model_provider, str):
model_provider = model_config.provider
model_schema = model_instance.model_type_instance.get_model_schema(
model_name,
model_instance.credentials,
)
if not model_schema:
raise ValueError(f"Model schema not found for {model_name}")
# fetch memory
memory = llm_utils.fetch_memory(
@@ -225,14 +239,16 @@ class LLMNode(Node[LLMNodeData]):
sys_files=files,
context=context,
memory=memory,
model_config=model_config,
model_instance=model_instance,
model_schema=model_schema,
model_parameters=self.node_data.model.completion_params,
stop=model_config.stop,
prompt_template=self.node_data.prompt_template,
memory_config=self.node_data.memory,
vision_enabled=self.node_data.vision.enabled,
vision_detail=self.node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
tenant_id=self.tenant_id,
context_files=context_files,
)
@@ -286,14 +302,14 @@ class LLMNode(Node[LLMNodeData]):
structured_output = event
process_data = {
"model_mode": model_config.mode,
"model_mode": self.node_data.model.mode,
"prompts": PromptMessageUtil.prompt_messages_to_prompt_for_saving(
model_mode=model_config.mode, prompt_messages=prompt_messages
model_mode=self.node_data.model.mode, prompt_messages=prompt_messages
),
"usage": jsonable_encoder(usage),
"finish_reason": finish_reason,
"model_provider": model_config.provider,
"model_name": model_config.model,
"model_provider": model_provider,
"model_name": model_name,
}
outputs = {
@@ -755,21 +771,18 @@ class LLMNode(Node[LLMNodeData]):
return None
@staticmethod
def _fetch_model_config(
self,
*,
node_data_model: ModelConfig,
tenant_id: str,
) -> tuple[ModelInstance, ModelConfigWithCredentialsEntity]:
model, model_config_with_cred = llm_utils.fetch_model_config(
tenant_id=tenant_id, node_data_model=node_data_model
node_data_model=node_data_model,
credentials_provider=self._credentials_provider,
model_factory=self._model_factory,
)
completion_params = model_config_with_cred.parameters
model_schema = model.model_type_instance.get_model_schema(node_data_model.name, model.credentials)
if not model_schema:
raise ModelNotExistError(f"Model {node_data_model.name} not exist.")
model_config_with_cred.parameters = completion_params
# NOTE(-LAN-): This line modify the `self.node_data.model`, which is used in `_invoke_llm()`.
node_data_model.completion_params = completion_params
@@ -782,14 +795,16 @@ class LLMNode(Node[LLMNodeData]):
sys_files: Sequence[File],
context: str | None = None,
memory: TokenBufferMemory | None = None,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
stop: Sequence[str] | None = None,
memory_config: MemoryConfig | None = None,
vision_enabled: bool = False,
vision_detail: ImagePromptMessageContent.DETAIL,
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
tenant_id: str,
context_files: list[File] | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
@@ -810,7 +825,9 @@ class LLMNode(Node[LLMNodeData]):
memory_messages = _handle_memory_chat_mode(
memory=memory,
memory_config=memory_config,
model_config=model_config,
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
# Extend prompt_messages with memory messages
prompt_messages.extend(memory_messages)
@@ -847,7 +864,9 @@ class LLMNode(Node[LLMNodeData]):
memory_text = _handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_config=model_config,
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
@@ -924,7 +943,7 @@ class LLMNode(Node[LLMNodeData]):
prompt_message_content: list[PromptMessageContentUnionTypes] = []
for content_item in prompt_message.content:
# Skip content if features are not defined
if not model_config.model_schema.features:
if not model_schema.features:
if content_item.type != PromptMessageContentType.TEXT:
continue
prompt_message_content.append(content_item)
@@ -934,19 +953,19 @@ class LLMNode(Node[LLMNodeData]):
if (
(
content_item.type == PromptMessageContentType.IMAGE
and ModelFeature.VISION not in model_config.model_schema.features
and ModelFeature.VISION not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.DOCUMENT
and ModelFeature.DOCUMENT not in model_config.model_schema.features
and ModelFeature.DOCUMENT not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.VIDEO
and ModelFeature.VIDEO not in model_config.model_schema.features
and ModelFeature.VIDEO not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.AUDIO
and ModelFeature.AUDIO not in model_config.model_schema.features
and ModelFeature.AUDIO not in model_schema.features
)
):
continue
@@ -965,19 +984,7 @@ class LLMNode(Node[LLMNodeData]):
"Please ensure a prompt is properly configured before proceeding."
)
model = ModelManager().get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.provider,
model=model_config.model,
)
model_schema = model.model_type_instance.get_model_schema(
model=model_config.model,
credentials=model.credentials,
)
if not model_schema:
raise ModelNotExistError(f"Model {model_config.model} not exist.")
return filtered_prompt_messages, model_config.stop
return filtered_prompt_messages, stop
@classmethod
def _extract_variable_selector_to_variable_mapping(
@@ -1306,26 +1313,26 @@ def _render_jinja2_message(
def _calculate_rest_token(
*, prompt_messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
*,
prompt_messages: list[PromptMessage],
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
) -> int:
rest_tokens = 2000
model_context_tokens = model_config.model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
model_context_tokens = model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
model_instance = ModelInstance(
provider_model_bundle=model_config.provider_model_bundle, model=model_config.model
)
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in model_config.model_schema.parameter_rules:
for parameter_rule in model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
model_config.parameters.get(parameter_rule.name)
or model_config.parameters.get(str(parameter_rule.use_template))
model_parameters.get(parameter_rule.name)
or model_parameters.get(str(parameter_rule.use_template))
or 0
)
@@ -1339,12 +1346,19 @@ def _handle_memory_chat_mode(
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
) -> Sequence[PromptMessage]:
memory_messages: Sequence[PromptMessage] = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
rest_tokens = _calculate_rest_token(
prompt_messages=[],
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
@@ -1356,12 +1370,19 @@ def _handle_memory_completion_mode(
*,
memory: TokenBufferMemory | None,
memory_config: MemoryConfig | None,
model_config: ModelConfigWithCredentialsEntity,
model_instance: ModelInstance,
model_schema: AIModelEntity,
model_parameters: Mapping[str, Any],
) -> str:
memory_text = ""
# Get history text from memory for completion model
if memory and memory_config:
rest_tokens = _calculate_rest_token(prompt_messages=[], model_config=model_config)
rest_tokens = _calculate_rest_token(
prompt_messages=[],
model_instance=model_instance,
model_schema=model_schema,
model_parameters=model_parameters,
)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_text = memory.get_history_prompt_text(

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from typing import Any, Protocol
from core.model_manager import ModelInstance
class CredentialsProvider(Protocol):
"""Port for loading runtime credentials for a provider/model pair."""
def fetch(self, provider_name: str, model_name: str) -> dict[str, Any]:
"""Return credentials for the target provider/model or raise a domain error."""
...
class ModelFactory(Protocol):
"""Port for creating initialized LLM model instances for execution."""
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
"""Create a model instance that is ready for schema lookup and invocation."""
...

View File

@@ -3,7 +3,7 @@ import json
import logging
import uuid
from collections.abc import Mapping, Sequence
from typing import Any, cast
from typing import TYPE_CHECKING, Any, cast
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.memory.token_buffer_memory import TokenBufferMemory
@@ -60,6 +60,11 @@ from .prompts import (
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.runtime import GraphRuntimeState
def extract_json(text):
"""
@@ -92,6 +97,27 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
_model_instance: ModelInstance | None = None
_model_config: ModelConfigWithCredentialsEntity | None = None
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
self._credentials_provider = credentials_provider
self._model_factory = model_factory
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
@@ -806,7 +832,9 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
"""
if not self._model_instance or not self._model_config:
self._model_instance, self._model_config = llm_utils.fetch_model_config(
tenant_id=self.tenant_id, node_data_model=node_data_model
node_data_model=node_data_model,
credentials_provider=self._credentials_provider,
model_factory=self._model_factory,
)
return self._model_instance, self._model_config

View File

@@ -24,6 +24,7 @@ from core.workflow.nodes.base.node import Node
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.nodes.llm import LLMNode, LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, llm_utils
from core.workflow.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from libs.json_in_md_parser import parse_and_check_json_markdown
from .entities import QuestionClassifierNodeData
@@ -49,6 +50,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
_file_outputs: list["File"]
_llm_file_saver: LLMFileSaver
_credentials_provider: "CredentialsProvider"
_model_factory: "ModelFactory"
def __init__(
self,
@@ -57,6 +60,8 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
*,
credentials_provider: "CredentialsProvider",
model_factory: "ModelFactory",
llm_file_saver: LLMFileSaver | None = None,
):
super().__init__(
@@ -68,6 +73,9 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
# LLM file outputs, used for MultiModal outputs.
self._file_outputs = []
self._credentials_provider = credentials_provider
self._model_factory = model_factory
if llm_file_saver is None:
llm_file_saver = FileSaverImpl(
user_id=graph_init_params.user_id,
@@ -89,9 +97,16 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
variables = {"query": query}
# fetch model config
model_instance, model_config = llm_utils.fetch_model_config(
tenant_id=self.tenant_id,
node_data_model=node_data.model,
credentials_provider=self._credentials_provider,
model_factory=self._model_factory,
)
model_schema = model_instance.model_type_instance.get_model_schema(
model_instance.model_name,
model_instance.credentials,
)
if not model_schema:
raise ValueError(f"Model schema not found for {model_instance.model_name}")
# fetch memory
memory = llm_utils.fetch_memory(
variable_pool=variable_pool,
@@ -133,13 +148,15 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
prompt_template=prompt_template,
sys_query="",
memory=memory,
model_config=model_config,
model_instance=model_instance,
model_schema=model_schema,
model_parameters=node_data.model.completion_params,
stop=model_config.stop,
sys_files=files,
vision_enabled=node_data.vision.enabled,
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
tenant_id=self.tenant_id,
)
result_text = ""

View File

@@ -1,8 +1,7 @@
import logging
import time
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import Any
from typing import Any, cast
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
@@ -11,6 +10,7 @@ from core.app.workflow.layers.observability import ObservabilityLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.entities.graph_config import NodeConfigData, NodeConfigDict
from core.workflow.errors import WorkflowNodeRunFailedError
from core.workflow.file.models import File
from core.workflow.graph import Graph
@@ -168,7 +168,8 @@ class WorkflowEntry:
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
node = node_factory.create_node(node_config)
typed_node_config = cast(dict[str, object], node_config)
node = cast(Any, node_factory).create_node(typed_node_config)
node_cls = type(node)
try:
@@ -256,7 +257,7 @@ class WorkflowEntry:
@classmethod
def run_free_node(
cls, node_data: dict, node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
cls, node_data: dict[str, Any], node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
"""
Run free node
@@ -302,16 +303,15 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
node_config = {
node_config: NodeConfigDict = {
"id": node_id,
"data": node_data,
"data": cast(NodeConfigData, node_data),
}
node: Node = node_cls(
id=str(uuid.uuid4()),
config=node_config,
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
node = node_factory.create_node(node_config)
try:
# variable selector to variable mapping

View File

@@ -107,19 +107,19 @@ class AppService:
if model_instance:
if (
model_instance.model == default_model_config["model"]["name"]
model_instance.model_name == default_model_config["model"]["name"]
and model_instance.provider == default_model_config["model"]["provider"]
):
default_model_dict = default_model_config["model"]
else:
llm_model = cast(LargeLanguageModel, model_instance.model_type_instance)
model_schema = llm_model.get_model_schema(model_instance.model, model_instance.credentials)
model_schema = llm_model.get_model_schema(model_instance.model_name, model_instance.credentials)
if model_schema is None:
raise ValueError(f"model schema not found for model {model_instance.model}")
raise ValueError(f"model schema not found for model {model_instance.model_name}")
default_model_dict = {
"provider": model_instance.provider,
"name": model_instance.model,
"name": model_instance.model_name,
"mode": model_schema.model_properties.get(ModelPropertyKey.MODE),
"completion_params": {},
}

View File

@@ -252,7 +252,7 @@ class DatasetService:
dataset.updated_by = account.id
dataset.tenant_id = tenant_id
dataset.embedding_model_provider = embedding_model.provider if embedding_model else None
dataset.embedding_model = embedding_model.model if embedding_model else None
dataset.embedding_model = embedding_model.model_name if embedding_model else None
dataset.retrieval_model = retrieval_model.model_dump() if retrieval_model else None
dataset.permission = permission or DatasetPermissionEnum.ONLY_ME
dataset.provider = provider
@@ -384,7 +384,7 @@ class DatasetService:
model=model,
)
text_embedding_model = cast(TextEmbeddingModel, model_instance.model_type_instance)
model_schema = text_embedding_model.get_model_schema(model_instance.model, model_instance.credentials)
model_schema = text_embedding_model.get_model_schema(model_instance.model_name, model_instance.credentials)
if not model_schema:
raise ValueError("Model schema not found")
if model_schema.features and ModelFeature.VISION in model_schema.features:
@@ -743,10 +743,12 @@ class DatasetService:
model_type=ModelType.TEXT_EMBEDDING,
model=data["embedding_model"],
)
filtered_data["embedding_model"] = embedding_model.model
embedding_model_name = embedding_model.model_name
filtered_data["embedding_model"] = embedding_model_name
filtered_data["embedding_model_provider"] = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
embedding_model.provider,
embedding_model_name,
)
filtered_data["collection_binding_id"] = dataset_collection_binding.id
except LLMBadRequestError:
@@ -876,10 +878,12 @@ class DatasetService:
return
# Apply new embedding model settings
filtered_data["embedding_model"] = embedding_model.model
embedding_model_name = embedding_model.model_name
filtered_data["embedding_model"] = embedding_model_name
filtered_data["embedding_model_provider"] = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
embedding_model.provider,
embedding_model_name,
)
filtered_data["collection_binding_id"] = dataset_collection_binding.id
@@ -955,10 +959,12 @@ class DatasetService:
knowledge_configuration.embedding_model,
)
dataset.is_multimodal = is_multimodal
dataset.embedding_model = embedding_model.model
embedding_model_name = embedding_model.model_name
dataset.embedding_model = embedding_model_name
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
embedding_model.provider,
embedding_model_name,
)
dataset.collection_binding_id = dataset_collection_binding.id
elif knowledge_configuration.indexing_technique == "economy":
@@ -989,10 +995,12 @@ class DatasetService:
model_type=ModelType.TEXT_EMBEDDING,
model=knowledge_configuration.embedding_model,
)
dataset.embedding_model = embedding_model.model
embedding_model_name = embedding_model.model_name
dataset.embedding_model = embedding_model_name
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
embedding_model.provider,
embedding_model_name,
)
is_multimodal = DatasetService.check_is_multimodal_model(
current_user.current_tenant_id,
@@ -1049,11 +1057,13 @@ class DatasetService:
skip_embedding_update = True
if not skip_embedding_update:
if embedding_model:
dataset.embedding_model = embedding_model.model
embedding_model_name = embedding_model.model_name
dataset.embedding_model = embedding_model_name
dataset.embedding_model_provider = embedding_model.provider
dataset_collection_binding = (
DatasetCollectionBindingService.get_dataset_collection_binding(
embedding_model.provider, embedding_model.model
embedding_model.provider,
embedding_model_name,
)
)
dataset.collection_binding_id = dataset_collection_binding.id
@@ -1884,7 +1894,7 @@ class DocumentService:
embedding_model = model_manager.get_default_model_instance(
tenant_id=current_user.current_tenant_id, model_type=ModelType.TEXT_EMBEDDING
)
dataset_embedding_model = embedding_model.model
dataset_embedding_model = embedding_model.model_name
dataset_embedding_model_provider = embedding_model.provider
dataset.embedding_model = dataset_embedding_model
dataset.embedding_model_provider = dataset_embedding_model_provider

View File

@@ -80,6 +80,8 @@ def init_llm_node(config: dict) -> LLMNode:
config=config,
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(),
model_factory=MagicMock(),
)
return node
@@ -115,7 +117,7 @@ def test_execute_llm():
db.session.close = MagicMock()
# Mock the _fetch_model_config to avoid database calls
def mock_fetch_model_config(**_kwargs):
def mock_fetch_model_config(*_args, **_kwargs):
from decimal import Decimal
from unittest.mock import MagicMock
@@ -227,7 +229,7 @@ def test_execute_llm_with_jinja2():
db.session.close = MagicMock()
# Mock the _fetch_model_config method
def mock_fetch_model_config(**_kwargs):
def mock_fetch_model_config(*_args, **_kwargs):
from decimal import Decimal
from unittest.mock import MagicMock

View File

@@ -9,6 +9,7 @@ from core.model_runtime.entities import AssistantPromptMessage
from core.workflow.entities import GraphInitParams
from core.workflow.enums import WorkflowNodeExecutionStatus
from core.workflow.graph import Graph
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
@@ -84,6 +85,8 @@ def init_parameter_extractor_node(config: dict):
config=config,
graph_init_params=init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
)
return node

View File

@@ -82,7 +82,7 @@ class TestCacheEmbeddingDocuments:
Mock: Configured ModelInstance with text embedding capabilities
"""
model_instance = Mock()
model_instance.model = "text-embedding-ada-002"
model_instance.model_name = "text-embedding-ada-002"
model_instance.provider = "openai"
model_instance.credentials = {"api_key": "test-key"}
@@ -597,7 +597,7 @@ class TestCacheEmbeddingQuery:
def mock_model_instance(self):
"""Create a mock ModelInstance for testing."""
model_instance = Mock()
model_instance.model = "text-embedding-ada-002"
model_instance.model_name = "text-embedding-ada-002"
model_instance.provider = "openai"
model_instance.credentials = {"api_key": "test-key"}
return model_instance
@@ -830,7 +830,7 @@ class TestEmbeddingModelSwitching:
"""
# Arrange
model_instance_ada = Mock()
model_instance_ada.model = "text-embedding-ada-002"
model_instance_ada.model_name = "text-embedding-ada-002"
model_instance_ada.provider = "openai"
# Mock model type instance for ada
@@ -841,7 +841,7 @@ class TestEmbeddingModelSwitching:
model_type_instance_ada.get_model_schema.return_value = model_schema_ada
model_instance_3_small = Mock()
model_instance_3_small.model = "text-embedding-3-small"
model_instance_3_small.model_name = "text-embedding-3-small"
model_instance_3_small.provider = "openai"
# Mock model type instance for 3-small
@@ -914,11 +914,11 @@ class TestEmbeddingModelSwitching:
"""
# Arrange
model_instance_openai = Mock()
model_instance_openai.model = "text-embedding-ada-002"
model_instance_openai.model_name = "text-embedding-ada-002"
model_instance_openai.provider = "openai"
model_instance_cohere = Mock()
model_instance_cohere.model = "embed-english-v3.0"
model_instance_cohere.model_name = "embed-english-v3.0"
model_instance_cohere.provider = "cohere"
cache_openai = CacheEmbedding(model_instance_openai)
@@ -1001,7 +1001,7 @@ class TestEmbeddingDimensionValidation:
def mock_model_instance(self):
"""Create a mock ModelInstance for testing."""
model_instance = Mock()
model_instance.model = "text-embedding-ada-002"
model_instance.model_name = "text-embedding-ada-002"
model_instance.provider = "openai"
model_instance.credentials = {"api_key": "test-key"}
@@ -1123,7 +1123,7 @@ class TestEmbeddingDimensionValidation:
"""
# Arrange - OpenAI ada-002 (1536 dimensions)
model_instance_ada = Mock()
model_instance_ada.model = "text-embedding-ada-002"
model_instance_ada.model_name = "text-embedding-ada-002"
model_instance_ada.provider = "openai"
# Mock model type instance for ada
@@ -1156,7 +1156,7 @@ class TestEmbeddingDimensionValidation:
# Arrange - Cohere embed-english-v3.0 (1024 dimensions)
model_instance_cohere = Mock()
model_instance_cohere.model = "embed-english-v3.0"
model_instance_cohere.model_name = "embed-english-v3.0"
model_instance_cohere.provider = "cohere"
# Mock model type instance for cohere
@@ -1225,7 +1225,7 @@ class TestEmbeddingEdgeCases:
- MAX_CHUNKS: 10
"""
model_instance = Mock()
model_instance.model = "text-embedding-ada-002"
model_instance.model_name = "text-embedding-ada-002"
model_instance.provider = "openai"
model_type_instance = Mock()
@@ -1702,7 +1702,7 @@ class TestEmbeddingCachePerformance:
- MAX_CHUNKS: 10
"""
model_instance = Mock()
model_instance.model = "text-embedding-ada-002"
model_instance.model_name = "text-embedding-ada-002"
model_instance.provider = "openai"
model_type_instance = Mock()

View File

@@ -34,7 +34,7 @@ def create_mock_model_instance():
mock_instance.provider_model_bundle.configuration = Mock()
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
mock_instance.provider = "test-provider"
mock_instance.model = "test-model"
mock_instance.model_name = "test-model"
return mock_instance
@@ -65,7 +65,7 @@ class TestRerankModelRunner:
mock_instance.provider_model_bundle.configuration = Mock()
mock_instance.provider_model_bundle.configuration.tenant_id = "test-tenant-id"
mock_instance.provider = "test-provider"
mock_instance.model = "test-model"
mock_instance.model_name = "test-model"
return mock_instance
@pytest.fixture

View File

@@ -199,11 +199,32 @@ def test_mock_config_builder():
def test_mock_factory_node_type_detection():
"""Test that MockNodeFactory correctly identifies nodes to mock."""
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
from .test_mock_factory import MockNodeFactory
graph_init_params = GraphInitParams(
tenant_id="test",
app_id="test",
workflow_id="test",
graph_config={},
user_id="test",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
start_at=0,
total_tokens=0,
node_run_steps=0,
)
factory = MockNodeFactory(
graph_init_params=None, # Will be set by test
graph_runtime_state=None, # Will be set by test
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=None,
)
@@ -288,7 +309,11 @@ def test_workflow_without_auto_mock():
def test_register_custom_mock_node():
"""Test registering a custom mock implementation for a node type."""
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams
from core.workflow.nodes.template_transform import TemplateTransformNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
from .test_mock_factory import MockNodeFactory
@@ -298,9 +323,25 @@ def test_register_custom_mock_node():
# Custom mock implementation
pass
graph_init_params = GraphInitParams(
tenant_id="test",
app_id="test",
workflow_id="test",
graph_config={},
user_id="test",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
start_at=0,
total_tokens=0,
node_run_steps=0,
)
factory = MockNodeFactory(
graph_init_params=None,
graph_runtime_state=None,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=None,
)

View File

@@ -1,9 +1,9 @@
import datetime
import time
from collections.abc import Iterable
from unittest import mock
from unittest.mock import MagicMock
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
@@ -82,7 +82,7 @@ def _build_branching_graph(
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
llm_data = LLMNodeData(
title=title,
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text=prompt_text,
@@ -101,6 +101,8 @@ def _build_branching_graph(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
credentials_provider=mock.Mock(),
model_factory=mock.Mock(),
)
return llm_node

View File

@@ -1,8 +1,8 @@
import datetime
import time
from unittest import mock
from unittest.mock import MagicMock
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
@@ -78,7 +78,7 @@ def _build_llm_human_llm_graph(
def _create_llm_node(node_id: str, title: str, prompt_text: str) -> MockLLMNode:
llm_data = LLMNodeData(
title=title,
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode=LLMMode.CHAT, completion_params={}),
model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
prompt_template=[
LLMNodeChatModelMessage(
text=prompt_text,
@@ -97,6 +97,8 @@ def _build_llm_human_llm_graph(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
credentials_provider=mock.Mock(),
model_factory=mock.Mock(),
)
return llm_node

View File

@@ -1,4 +1,5 @@
import time
from unittest import mock
from core.model_runtime.entities.llm_entities import LLMMode
from core.model_runtime.entities.message_entities import PromptMessageRole
@@ -85,6 +86,8 @@ def _build_if_else_graph(branch_value: str, mock_config: MockConfig) -> tuple[Gr
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=mock_config,
credentials_provider=mock.Mock(),
model_factory=mock.Mock(),
)
return llm_node

View File

@@ -5,6 +5,7 @@ This module provides a MockNodeFactory that automatically detects and mocks node
requiring external services (LLM, Agent, Tool, Knowledge Retrieval, HTTP Request).
"""
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any
from core.app.workflow.node_factory import DifyNodeFactory
@@ -74,7 +75,7 @@ class MockNodeFactory(DifyNodeFactory):
NodeType.CODE: MockCodeNode,
}
def create_node(self, node_config: dict[str, Any]) -> Node:
def create_node(self, node_config: Mapping[str, Any]) -> Node:
"""
Create a node instance, using mock implementations for third-party service nodes.
@@ -123,6 +124,16 @@ class MockNodeFactory(DifyNodeFactory):
mock_config=self.mock_config,
http_request_config=self._http_request_config,
)
elif node_type in {NodeType.LLM, NodeType.QUESTION_CLASSIFIER, NodeType.PARAMETER_EXTRACTOR}:
mock_instance = mock_class(
id=node_id,
config=node_config,
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
mock_config=self.mock_config,
credentials_provider=self._llm_credentials_provider,
model_factory=self._llm_model_factory,
)
else:
mock_instance = mock_class(
id=node_id,

View File

@@ -16,9 +16,33 @@ from tests.unit_tests.core.workflow.graph_engine.test_mock_factory import MockNo
def test_mock_factory_registers_iteration_node():
"""Test that MockNodeFactory has iteration node registered."""
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
# Create a MockNodeFactory instance
factory = MockNodeFactory(graph_init_params=None, graph_runtime_state=None, mock_config=None)
graph_init_params = GraphInitParams(
tenant_id="test",
app_id="test",
workflow_id="test",
graph_config={"nodes": [], "edges": []},
user_id="test",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
start_at=0,
total_tokens=0,
node_run_steps=0,
)
factory = MockNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=None,
)
# Check that iteration node is registered
assert NodeType.ITERATION in factory._mock_node_types

View File

@@ -8,6 +8,7 @@ allowing tests to run without external dependencies.
import time
from collections.abc import Generator, Mapping
from typing import TYPE_CHECKING, Any, Optional
from unittest.mock import MagicMock
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
@@ -18,6 +19,7 @@ from core.workflow.nodes.document_extractor import DocumentExtractorNode
from core.workflow.nodes.http_request import HttpRequestNode
from core.workflow.nodes.knowledge_retrieval import KnowledgeRetrievalNode
from core.workflow.nodes.llm import LLMNode
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.nodes.parameter_extractor import ParameterExtractorNode
from core.workflow.nodes.question_classifier import QuestionClassifierNode
from core.workflow.nodes.template_transform import TemplateTransformNode
@@ -42,6 +44,10 @@ class MockNodeMixin:
mock_config: Optional["MockConfig"] = None,
**kwargs: Any,
):
if isinstance(self, (LLMNode, QuestionClassifierNode)):
kwargs.setdefault("credentials_provider", MagicMock(spec=CredentialsProvider))
kwargs.setdefault("model_factory", MagicMock(spec=ModelFactory))
super().__init__(
id=id,
config=config,

View File

@@ -101,11 +101,32 @@ def test_node_mock_config():
def test_mock_factory_detection():
"""Test MockNodeFactory node type detection."""
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
print("Testing MockNodeFactory detection...")
graph_init_params = GraphInitParams(
tenant_id="test",
app_id="test",
workflow_id="test",
graph_config={},
user_id="test",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
start_at=0,
total_tokens=0,
node_run_steps=0,
)
factory = MockNodeFactory(
graph_init_params=None,
graph_runtime_state=None,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=None,
)
@@ -133,11 +154,32 @@ def test_mock_factory_detection():
def test_mock_factory_registration():
"""Test registering and unregistering mock node types."""
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
print("Testing MockNodeFactory registration...")
graph_init_params = GraphInitParams(
tenant_id="test",
app_id="test",
workflow_id="test",
graph_config={},
user_id="test",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.SERVICE_API,
call_depth=0,
)
graph_runtime_state = GraphRuntimeState(
variable_pool=VariablePool(environment_variables=[], conversation_variables=[], user_inputs={}),
start_at=0,
total_tokens=0,
node_run_steps=0,
)
factory = MockNodeFactory(
graph_init_params=None,
graph_runtime_state=None,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
mock_config=None,
)

View File

@@ -6,6 +6,7 @@ from unittest import mock
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
from core.app.llm.model_access import DifyCredentialsProvider, DifyModelFactory, fetch_model_config
from core.entities.provider_configuration import ProviderConfiguration, ProviderModelBundle
from core.entities.provider_entities import CustomConfiguration, SystemConfiguration
from core.model_runtime.entities.common_entities import I18nObject
@@ -32,6 +33,7 @@ from core.workflow.nodes.llm.entities import (
)
from core.workflow.nodes.llm.file_saver import LLMFileSaver
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.llm.protocols import CredentialsProvider, ModelFactory
from core.workflow.runtime import GraphRuntimeState, VariablePool
from core.workflow.system_variable import SystemVariable
from models.enums import UserFrom
@@ -100,6 +102,8 @@ def llm_node(
llm_node_data: LLMNodeData, graph_init_params: GraphInitParams, graph_runtime_state: GraphRuntimeState
) -> LLMNode:
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
@@ -109,13 +113,29 @@ def llm_node(
config=node_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,
model_factory=mock_model_factory,
llm_file_saver=mock_file_saver,
)
return node
@pytest.fixture
def model_config():
def model_config(monkeypatch):
from tests.integration_tests.model_runtime.__mock.plugin_model import MockModelClass
def mock_plugin_model_providers(_self):
providers = MockModelClass().fetch_model_providers("test")
for provider in providers:
provider.declaration.provider = f"{provider.plugin_id}/{provider.declaration.provider}"
return providers
monkeypatch.setattr(
ModelProviderFactory,
"get_plugin_model_providers",
mock_plugin_model_providers,
)
# Create actual provider and model type instances
model_provider_factory = ModelProviderFactory(tenant_id="test")
provider_instance = model_provider_factory.get_plugin_model_provider("openai")
@@ -125,7 +145,7 @@ def model_config():
provider_model_bundle = ProviderModelBundle(
configuration=ProviderConfiguration(
tenant_id="1",
provider=provider_instance,
provider=provider_instance.declaration,
preferred_provider_type=ProviderType.CUSTOM,
using_provider_type=ProviderType.CUSTOM,
system_configuration=SystemConfiguration(enabled=False),
@@ -153,6 +173,89 @@ def model_config():
)
def test_fetch_model_config_uses_ports(model_config: ModelConfigWithCredentialsEntity):
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
provider_model_bundle = model_config.provider_model_bundle
model_type_instance = provider_model_bundle.model_type_instance
provider_model = mock.MagicMock()
model_instance = mock.MagicMock(
model_type_instance=model_type_instance,
provider_model_bundle=provider_model_bundle,
)
mock_credentials_provider.fetch.return_value = {"api_key": "test"}
mock_model_factory.init_model_instance.return_value = model_instance
with (
mock.patch.object(
provider_model_bundle.configuration.__class__,
"get_provider_model",
return_value=provider_model,
),
mock.patch.object(
model_type_instance.__class__,
"get_model_schema",
return_value=model_config.model_schema,
),
):
fetch_model_config(
node_data_model=ModelConfig(provider="openai", name="gpt-3.5-turbo", mode="chat", completion_params={}),
credentials_provider=mock_credentials_provider,
model_factory=mock_model_factory,
)
mock_credentials_provider.fetch.assert_called_once_with("openai", "gpt-3.5-turbo")
mock_model_factory.init_model_instance.assert_called_once_with("openai", "gpt-3.5-turbo")
provider_model.raise_for_status.assert_called_once()
def test_dify_model_access_adapters_call_managers():
mock_provider_manager = mock.MagicMock()
mock_model_manager = mock.MagicMock()
mock_configurations = mock.MagicMock()
mock_provider_configuration = mock.MagicMock()
mock_provider_model = mock.MagicMock()
mock_configurations.get.return_value = mock_provider_configuration
mock_provider_configuration.get_provider_model.return_value = mock_provider_model
mock_provider_configuration.get_current_credentials.return_value = {"api_key": "test"}
credentials_provider = DifyCredentialsProvider(
tenant_id="tenant",
provider_manager=mock_provider_manager,
)
model_factory = DifyModelFactory(
tenant_id="tenant",
model_manager=mock_model_manager,
)
mock_provider_manager.get_configurations.return_value = mock_configurations
credentials_provider.fetch("openai", "gpt-3.5-turbo")
model_factory.init_model_instance("openai", "gpt-3.5-turbo")
mock_provider_manager.get_configurations.assert_called_once_with("tenant")
mock_configurations.get.assert_called_once_with("openai")
mock_provider_configuration.get_provider_model.assert_called_once_with(
model_type=ModelType.LLM,
model="gpt-3.5-turbo",
)
mock_provider_configuration.get_current_credentials.assert_called_once_with(
model_type=ModelType.LLM,
model="gpt-3.5-turbo",
)
mock_provider_model.raise_for_status.assert_called_once()
mock_model_manager.get_model_instance.assert_called_once_with(
tenant_id="tenant",
provider="openai",
model_type=ModelType.LLM,
model="gpt-3.5-turbo",
)
def test_fetch_files_with_file_segment():
file = File(
id="1",
@@ -485,6 +588,8 @@ def test_handle_list_messages_basic(llm_node):
@pytest.fixture
def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_state) -> tuple[LLMNode, LLMFileSaver]:
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
@@ -494,6 +599,8 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
config=node_config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
credentials_provider=mock_credentials_provider,
model_factory=mock_model_factory,
llm_file_saver=mock_file_saver,
)
return node, mock_file_saver

View File

@@ -642,8 +642,16 @@ class TestDatasetServiceUpdateRagPipelineDatasetSettings:
# Mock embedding model
mock_embedding_model = Mock()
mock_embedding_model.model = "text-embedding-ada-002"
mock_embedding_model.model_name = "text-embedding-ada-002"
mock_embedding_model.provider = "openai"
mock_embedding_model.credentials = {}
mock_model_schema = Mock()
mock_model_schema.features = []
mock_text_embedding_model = Mock()
mock_text_embedding_model.get_model_schema.return_value = mock_model_schema
mock_embedding_model.model_type_instance = mock_text_embedding_model
mock_model_instance = Mock()
mock_model_instance.get_model_instance.return_value = mock_embedding_model

View File

@@ -174,7 +174,7 @@ class DatasetServiceTestDataFactory:
Mock: Embedding model mock with model and provider attributes
"""
embedding_model = Mock()
embedding_model.model = model
embedding_model.model_name = model
embedding_model.provider = provider
return embedding_model
@@ -434,7 +434,7 @@ class TestDatasetServiceCreateDataset:
# Assert
assert result.indexing_technique == "high_quality"
assert result.embedding_model_provider == embedding_model.provider
assert result.embedding_model == embedding_model.model
assert result.embedding_model == embedding_model.model_name
mock_model_manager_instance.get_default_model_instance.assert_called_once_with(
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
)

View File

@@ -46,7 +46,7 @@ class DatasetCreateTestDataFactory:
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
"""Create a mock embedding model."""
embedding_model = Mock()
embedding_model.model = model
embedding_model.model_name = model
embedding_model.provider = provider
return embedding_model
@@ -244,7 +244,7 @@ class TestDatasetServiceCreateEmptyDataset:
# Assert
assert result.indexing_technique == "high_quality"
assert result.embedding_model_provider == embedding_model.provider
assert result.embedding_model == embedding_model.model
assert result.embedding_model == embedding_model.model_name
mock_model_manager_instance.get_default_model_instance.assert_called_once_with(
tenant_id=tenant_id, model_type=ModelType.TEXT_EMBEDDING
)

View File

@@ -65,7 +65,7 @@ class DatasetUpdateTestDataFactory:
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
"""Create a mock embedding model."""
embedding_model = Mock()
embedding_model.model = model
embedding_model.model_name = model
embedding_model.provider = provider
return embedding_model