Compare commits

...

11 Commits

Author SHA1 Message Date
Xiyuan Chen
977ed79ea0 fix: enterprise API error handling and license enforcement (#33044)
Some checks are pending
autofix.ci / autofix (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Check Changed Files (push) Waiting to run
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Waiting to run
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-15 20:59:41 -07:00
Asuka Minato
dd39fcd9bc ci: Simplify nltk data download in Dockerfile (#33495) 2026-03-16 12:06:20 +09:00
dependabot[bot]
3c587097cd chore(deps): bump the python-packages group in /api with 13 updates (#33484)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-16 11:28:42 +09:00
dependabot[bot]
6a3fcc0a7b chore(deps): bump the llm group across 1 directory with 2 updates (#33491)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-16 11:23:51 +09:00
dependabot[bot]
8d3f2f56d9 chore(deps): bump the storage group in /api with 2 updates (#33481)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-16 11:10:07 +09:00
非法操作
09dad78a5d chore: add indexes for human_input_forms query patterns (#32849)
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Bond Zhu <37842169+MRZHUH@users.noreply.github.com>
2026-03-16 10:10:03 +08:00
dependabot[bot]
c71ecd2fe0 chore(deps-dev): update faker requirement from ~=40.8.0 to ~=40.11.0 in /api in the dev group (#33482)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-16 11:09:41 +09:00
dependabot[bot]
808d186156 chore(deps): bump litellm from 1.82.1 to 1.82.2 in /api in the llm group (#33480)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-16 11:08:28 +09:00
dependabot[bot]
ec0a01a568 chore(deps): bump the github-actions-dependencies group with 4 updates (#33485)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-16 11:07:42 +09:00
dependabot[bot]
ac23a0409e chore(deps): bump the storage group in /api with 2 updates (#33488)
Signed-off-by: dependabot[bot] <support@github.com>
Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com>
2026-03-16 11:07:11 +09:00
wangxiaolei
6ef69ff880 refactor: llm decouple code executor module (#33400)
Co-authored-by: Byron.wang <byron@dify.ai>
2026-03-16 10:06:14 +08:00
36 changed files with 1653 additions and 705 deletions

View File

@@ -27,7 +27,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}

View File

@@ -39,7 +39,7 @@ jobs:
with:
python-version: "3.11"
- uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
- uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
- name: Generate Docker Compose
if: steps.docker-compose-changes.outputs.any_changed == 'true'

View File

@@ -113,7 +113,7 @@ jobs:
context: "web"
steps:
- name: Download digests
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
with:
path: /tmp/digests
pattern: digests-${{ matrix.context }}-*

View File

@@ -19,7 +19,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
with:
enable-cache: true
python-version: "3.12"
@@ -69,7 +69,7 @@ jobs:
persist-credentials: false
- name: Setup UV and Python
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
with:
enable-cache: true
python-version: "3.12"

View File

@@ -28,7 +28,7 @@ jobs:
migration-changed: ${{ steps.changes.outputs.migration }}
steps:
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
- uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
id: changes
with:
filters: |

View File

@@ -22,7 +22,7 @@ jobs:
fetch-depth: 0
- name: Setup Python & UV
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
with:
enable-cache: true

View File

@@ -33,7 +33,7 @@ jobs:
- name: Setup UV and Python
if: steps.changed-files.outputs.any_changed == 'true'
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
with:
enable-cache: false
python-version: "3.12"

View File

@@ -120,7 +120,7 @@ jobs:
- name: Run Claude Code for Translation Sync
if: steps.detect_changes.outputs.CHANGED_FILES != ''
uses: anthropics/claude-code-action@26ec041249acb0a944c0a47b6c0c13f05dbc5b44 # v1.0.70
uses: anthropics/claude-code-action@cd77b50d2b0808657f8e6774085c8bf54484351c # v1.0.72
with:
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
github_token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -31,7 +31,7 @@ jobs:
remove_tool_cache: true
- name: Setup UV and Python
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
with:
enable-cache: true
python-version: ${{ matrix.python-version }}

View File

@@ -77,7 +77,7 @@ jobs:
uses: ./.github/actions/setup-web
- name: Download blob reports
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
with:
path: web/.vitest-reports
pattern: blob-report-*

View File

@@ -103,7 +103,6 @@ ignore_imports =
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
dify_graph.nodes.llm.node -> core.helper.code_executor
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
dify_graph.nodes.llm.node -> core.model_manager

View File

@@ -97,7 +97,7 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
# Download nltk data
RUN mkdir -p /usr/local/share/nltk_data \
&& NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; from unstructured.nlp.tokenize import download_nltk_packages; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords'); download_nltk_packages()" \
&& NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
&& chmod -R 755 /usr/local/share/nltk_data
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache

View File

@@ -1,16 +1,45 @@
import logging
import time
from flask import request
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from controllers.console.error import UnauthorizedAndForceLogout
from core.logging.context import init_request_context
from dify_app import DifyApp
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import LicenseStatus
logger = logging.getLogger(__name__)
# Console bootstrap APIs exempt from license check.
# Defined at module level to avoid per-request tuple construction.
# - system-features: license status for expiry UI (GlobalPublicStoreProvider)
# - setup: install/setup status check (AppInitializer)
# - init: init password validation for fresh install (InitPasswordPopup)
# - login: auto-login after setup completion (InstallForm)
# - features: billing/plan features (ProviderContextProvider)
# - account/profile: login check + user profile (AppContextProvider, useIsLogin)
# - workspaces/current: workspace + model providers (AppContextProvider)
# - version: version check (AppContextProvider)
# - activate/check: invitation link validation (signin page)
# Without these exemptions, the signin page triggers location.reload()
# on unauthorized_and_force_logout, causing an infinite loop.
_CONSOLE_EXEMPT_PREFIXES = (
"/console/api/system-features",
"/console/api/setup",
"/console/api/init",
"/console/api/login",
"/console/api/features",
"/console/api/account/profile",
"/console/api/workspaces/current",
"/console/api/version",
"/console/api/activate/check",
)
# ----------------------------
# Application Factory Function
@@ -31,6 +60,39 @@ def create_flask_app_with_configs() -> DifyApp:
init_request_context()
RecyclableContextVar.increment_thread_recycles()
# Enterprise license validation for API endpoints (both console and webapp)
# When license expires, block all API access except bootstrap endpoints needed
# for the frontend to load the license expiration page without infinite reloads.
if dify_config.ENTERPRISE_ENABLED:
is_console_api = request.path.startswith("/console/api/")
is_webapp_api = request.path.startswith("/api/")
if is_console_api or is_webapp_api:
if is_console_api:
is_exempt = any(request.path.startswith(p) for p in _CONSOLE_EXEMPT_PREFIXES)
else: # webapp API
is_exempt = request.path.startswith("/api/system-features")
if not is_exempt:
try:
# Check license status (cached — see EnterpriseService for TTL details)
license_status = EnterpriseService.get_cached_license_status()
if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST):
raise UnauthorizedAndForceLogout(
f"Enterprise license is {license_status}. Please contact your administrator."
)
if license_status is None:
raise UnauthorizedAndForceLogout(
"Unable to verify enterprise license. Please contact your administrator."
)
except UnauthorizedAndForceLogout:
raise
except Exception:
logger.exception("Failed to check enterprise license status")
raise UnauthorizedAndForceLogout(
"Unable to verify enterprise license. Please contact your administrator."
)
# add after request hook for injecting trace headers from OpenTelemetry span context
# Only adds headers when OTEL is enabled and has valid context
@dify_app.after_request

View File

@@ -45,6 +45,7 @@ from dify_graph.nodes.document_extractor import UnstructuredApiConfig
from dify_graph.nodes.http_request import build_http_request_config
from dify_graph.nodes.llm.entities import LLMNodeData
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
from dify_graph.nodes.llm.protocols import TemplateRenderer
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
from dify_graph.nodes.template_transform.template_renderer import (
@@ -228,6 +229,16 @@ class DefaultWorkflowCodeExecutor:
return isinstance(error, CodeExecutionError)
class DefaultLLMTemplateRenderer(TemplateRenderer):
def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str:
result = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2,
code=template,
inputs=inputs,
)
return str(result.get("result", ""))
@final
class DifyNodeFactory(NodeFactory):
"""
@@ -254,6 +265,7 @@ class DifyNodeFactory(NodeFactory):
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor)
self._llm_template_renderer: TemplateRenderer = DefaultLLMTemplateRenderer()
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
self._http_request_http_client = ssrf_proxy
self._http_request_tool_file_manager_factory = ToolFileManager
@@ -391,6 +403,8 @@ class DifyNodeFactory(NodeFactory):
model_instance=model_instance,
),
}
if validated_node_data.type in {BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
node_init_kwargs["template_renderer"] = self._llm_template_renderer
if include_http_client:
node_init_kwargs["http_client"] = self._http_request_http_client
return node_init_kwargs

View File

@@ -1,34 +1,53 @@
from __future__ import annotations
from collections.abc import Sequence
from typing import cast
from typing import Any, cast
from core.model_manager import ModelInstance
from dify_graph.file import FileType, file_manager
from dify_graph.file.models import File
from dify_graph.model_runtime.entities import PromptMessageRole
from dify_graph.model_runtime.entities.message_entities import (
from dify_graph.model_runtime.entities import (
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
PromptMessageRole,
TextPromptMessageContent,
)
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageContentUnionTypes,
SystemPromptMessage,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.runtime import VariablePool
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
from dify_graph.variables import ArrayFileSegment, FileSegment
from dify_graph.variables.segments import ArrayAnySegment, NoneSegment
from .exc import InvalidVariableTypeError
from .entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, MemoryConfig
from .exc import (
InvalidVariableTypeError,
MemoryRolePrefixRequiredError,
NoPromptFoundError,
TemplateTypeNotSupportError,
)
from .protocols import TemplateRenderer
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
model_instance.model_name,
model_instance.credentials,
dict(model_instance.credentials),
)
if not model_schema:
raise ValueError(f"Model schema not found for {model_instance.model_name}")
return model_schema
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence[File]:
variable = variable_pool.get(selector)
if variable is None:
return []
@@ -89,3 +108,366 @@ def fetch_memory_text(
human_prefix=human_prefix,
ai_prefix=ai_prefix,
)
def fetch_prompt_messages(
*,
sys_query: str | None = None,
sys_files: Sequence[File],
context: str | None = None,
memory: PromptMessageMemory | None = None,
model_instance: ModelInstance,
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
stop: Sequence[str] | None = None,
memory_config: MemoryConfig | None = None,
vision_enabled: bool = False,
vision_detail: ImagePromptMessageContent.DETAIL,
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
context_files: list[File] | None = None,
template_renderer: TemplateRenderer | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
model_schema = fetch_model_schema(model_instance=model_instance)
if isinstance(prompt_template, list):
prompt_messages.extend(
handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
vision_detail_config=vision_detail,
template_renderer=template_renderer,
)
)
prompt_messages.extend(
handle_memory_chat_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
)
)
if sys_query:
prompt_messages.extend(
handle_list_messages(
messages=[
LLMNodeChatModelMessage(
text=sys_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=vision_detail,
template_renderer=template_renderer,
)
)
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
prompt_messages.extend(
handle_completion_template(
template=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
template_renderer=template_renderer,
)
)
memory_text = handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
)
prompt_content = prompt_messages[0].content
if isinstance(prompt_content, str):
prompt_content = str(prompt_content)
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
elif isinstance(prompt_content, list):
for content_item in prompt_content:
if isinstance(content_item, TextPromptMessageContent):
if "#histories#" in content_item.data:
content_item.data = content_item.data.replace("#histories#", memory_text)
else:
content_item.data = memory_text + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
if sys_query:
if isinstance(prompt_content, str):
prompt_messages[0].content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
elif isinstance(prompt_content, list):
for content_item in prompt_content:
if isinstance(content_item, TextPromptMessageContent):
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
else:
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
_append_file_prompts(
prompt_messages=prompt_messages,
files=sys_files,
vision_enabled=vision_enabled,
vision_detail=vision_detail,
)
_append_file_prompts(
prompt_messages=prompt_messages,
files=context_files or [],
vision_enabled=vision_enabled,
vision_detail=vision_detail,
)
filtered_prompt_messages: list[PromptMessage] = []
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, list):
prompt_message_content: list[PromptMessageContentUnionTypes] = []
for content_item in prompt_message.content:
if not model_schema.features:
if content_item.type == PromptMessageContentType.TEXT:
prompt_message_content.append(content_item)
continue
if (
(
content_item.type == PromptMessageContentType.IMAGE
and ModelFeature.VISION not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.DOCUMENT
and ModelFeature.DOCUMENT not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.VIDEO
and ModelFeature.VIDEO not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.AUDIO
and ModelFeature.AUDIO not in model_schema.features
)
):
continue
prompt_message_content.append(content_item)
if prompt_message_content:
prompt_message.content = prompt_message_content
filtered_prompt_messages.append(prompt_message)
elif not prompt_message.is_empty():
filtered_prompt_messages.append(prompt_message)
if len(filtered_prompt_messages) == 0:
raise NoPromptFoundError(
"No prompt found in the LLM configuration. Please ensure a prompt is properly configured before proceeding."
)
return filtered_prompt_messages, stop
def handle_list_messages(
*,
messages: Sequence[LLMNodeChatModelMessage],
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
template_renderer: TemplateRenderer | None = None,
) -> Sequence[PromptMessage]:
prompt_messages: list[PromptMessage] = []
for message in messages:
if message.edition_type == "jinja2":
result_text = render_jinja2_message(
template=message.jinja2_text or "",
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
template_renderer=template_renderer,
)
prompt_messages.append(
combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)],
role=message.role,
)
)
continue
template = message.text.replace("{#context#}", context) if context else message.text
segment_group = variable_pool.convert_template(template)
file_contents: list[PromptMessageContentUnionTypes] = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config)
)
elif isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_contents.append(
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config)
)
if segment_group.text:
prompt_messages.append(
combine_message_content_with_role(
contents=[TextPromptMessageContent(data=segment_group.text)],
role=message.role,
)
)
if file_contents:
prompt_messages.append(combine_message_content_with_role(contents=file_contents, role=message.role))
return prompt_messages
def render_jinja2_message(
*,
template: str,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
template_renderer: TemplateRenderer | None = None,
) -> str:
if not template:
return ""
if template_renderer is None:
raise ValueError("template_renderer is required for jinja2 prompt rendering")
jinja2_inputs: dict[str, Any] = {}
for jinja2_variable in jinja2_variables:
variable = variable_pool.get(jinja2_variable.value_selector)
jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
return template_renderer.render_jinja2(template=template, inputs=jinja2_inputs)
def handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
template_renderer: TemplateRenderer | None = None,
) -> Sequence[PromptMessage]:
if template.edition_type == "jinja2":
result_text = render_jinja2_message(
template=template.jinja2_text or "",
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
template_renderer=template_renderer,
)
else:
template_text = template.text.replace("{#context#}", context) if context else template.text
result_text = variable_pool.convert_template(template_text).text
return [
combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)],
role=PromptMessageRole.USER,
)
]
def combine_message_content_with_role(
*,
contents: str | list[PromptMessageContentUnionTypes] | None = None,
role: PromptMessageRole,
) -> PromptMessage:
match role:
case PromptMessageRole.USER:
return UserPromptMessage(content=contents)
case PromptMessageRole.ASSISTANT:
return AssistantPromptMessage(content=contents)
case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=contents)
case _:
raise NotImplementedError(f"Role {role} is not supported")
def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: ModelInstance) -> int:
rest_tokens = 2000
runtime_model_schema = fetch_model_schema(model_instance=model_instance)
runtime_model_parameters = model_instance.parameters
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in runtime_model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
runtime_model_parameters.get(parameter_rule.name)
or runtime_model_parameters.get(str(parameter_rule.use_template))
or 0
)
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def handle_memory_chat_mode(
*,
memory: PromptMessageMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
) -> Sequence[PromptMessage]:
if not memory or not memory_config:
return []
rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance)
return memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
def handle_memory_completion_mode(
*,
memory: PromptMessageMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
) -> str:
if not memory or not memory_config:
return ""
rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
return fetch_memory_text(
memory=memory,
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
def _append_file_prompts(
*,
prompt_messages: list[PromptMessage],
files: Sequence[File],
vision_enabled: bool,
vision_detail: ImagePromptMessageContent.DETAIL,
) -> None:
if not vision_enabled or not files:
return
file_prompts = [file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in files]
if (
prompt_messages
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
existing_contents = prompt_messages[-1].content
assert isinstance(existing_contents, list)
prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))

View File

@@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Any, Literal
from sqlalchemy import select
from core.helper.code_executor import CodeExecutor, CodeLanguage
from core.llm_generator.output_parser.errors import OutputParserError
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_manager import ModelInstance
@@ -28,11 +27,10 @@ from dify_graph.enums import (
WorkflowNodeExecutionMetadataKey,
WorkflowNodeExecutionStatus,
)
from dify_graph.file import File, FileTransferMethod, FileType, file_manager
from dify_graph.file import File, FileTransferMethod, FileType
from dify_graph.model_runtime.entities import (
ImagePromptMessageContent,
PromptMessage,
PromptMessageContentType,
TextPromptMessageContent,
)
from dify_graph.model_runtime.entities.llm_entities import (
@@ -43,14 +41,7 @@ from dify_graph.model_runtime.entities.llm_entities import (
LLMStructuredOutput,
LLMUsage,
)
from dify_graph.model_runtime.entities.message_entities import (
AssistantPromptMessage,
PromptMessageContentUnionTypes,
PromptMessageRole,
SystemPromptMessage,
UserPromptMessage,
)
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
from dify_graph.model_runtime.memory import PromptMessageMemory
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
from dify_graph.node_events import (
@@ -64,13 +55,12 @@ from dify_graph.node_events import (
from dify_graph.nodes.base.entities import VariableSelector
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.runtime import VariablePool
from dify_graph.variables import (
ArrayFileSegment,
ArraySegment,
FileSegment,
NoneSegment,
ObjectSegment,
StringSegment,
@@ -89,9 +79,6 @@ from .exc import (
InvalidContextStructureError,
InvalidVariableTypeError,
LLMNodeError,
MemoryRolePrefixRequiredError,
NoPromptFoundError,
TemplateTypeNotSupportError,
VariableNotFoundError,
)
from .file_saver import FileSaverImpl, LLMFileSaver
@@ -118,6 +105,7 @@ class LLMNode(Node[LLMNodeData]):
_model_factory: ModelFactory
_model_instance: ModelInstance
_memory: PromptMessageMemory | None
_template_renderer: TemplateRenderer
def __init__(
self,
@@ -130,6 +118,7 @@ class LLMNode(Node[LLMNodeData]):
model_factory: ModelFactory,
model_instance: ModelInstance,
http_client: HttpClientProtocol,
template_renderer: TemplateRenderer,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
@@ -146,6 +135,7 @@ class LLMNode(Node[LLMNodeData]):
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
self._template_renderer = template_renderer
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
@@ -240,6 +230,7 @@ class LLMNode(Node[LLMNodeData]):
variable_pool=variable_pool,
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
context_files=context_files,
template_renderer=self._template_renderer,
)
# handle invoke result
@@ -773,182 +764,24 @@ class LLMNode(Node[LLMNodeData]):
variable_pool: VariablePool,
jinja2_variables: Sequence[VariableSelector],
context_files: list[File] | None = None,
template_renderer: TemplateRenderer | None = None,
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
prompt_messages: list[PromptMessage] = []
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
if isinstance(prompt_template, list):
# For chat model
prompt_messages.extend(
LLMNode.handle_list_messages(
messages=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
vision_detail_config=vision_detail,
)
)
# Get memory messages for chat mode
memory_messages = _handle_memory_chat_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
)
# Extend prompt_messages with memory messages
prompt_messages.extend(memory_messages)
# Add current query to the prompt messages
if sys_query:
message = LLMNodeChatModelMessage(
text=sys_query,
role=PromptMessageRole.USER,
edition_type="basic",
)
prompt_messages.extend(
LLMNode.handle_list_messages(
messages=[message],
context="",
jinja2_variables=[],
variable_pool=variable_pool,
vision_detail_config=vision_detail,
)
)
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
# For completion model
prompt_messages.extend(
_handle_completion_template(
template=prompt_template,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
)
# Get memory text for completion model
memory_text = _handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
)
# Insert histories into the prompt
prompt_content = prompt_messages[0].content
# For issue #11247 - Check if prompt content is a string or a list
if isinstance(prompt_content, str):
prompt_content = str(prompt_content)
if "#histories#" in prompt_content:
prompt_content = prompt_content.replace("#histories#", memory_text)
else:
prompt_content = memory_text + "\n" + prompt_content
prompt_messages[0].content = prompt_content
elif isinstance(prompt_content, list):
for content_item in prompt_content:
if isinstance(content_item, TextPromptMessageContent):
if "#histories#" in content_item.data:
content_item.data = content_item.data.replace("#histories#", memory_text)
else:
content_item.data = memory_text + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
# Add current query to the prompt message
if sys_query:
if isinstance(prompt_content, str):
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
prompt_messages[0].content = prompt_content
elif isinstance(prompt_content, list):
for content_item in prompt_content:
if isinstance(content_item, TextPromptMessageContent):
content_item.data = sys_query + "\n" + content_item.data
else:
raise ValueError("Invalid prompt content type")
else:
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
# The sys_files will be deprecated later
if vision_enabled and sys_files:
file_prompts = []
for file in sys_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# The context_files
if vision_enabled and context_files:
file_prompts = []
for file in context_files:
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
file_prompts.append(file_prompt)
# If last prompt is a user prompt, add files into its contents,
# otherwise append a new user prompt
if (
len(prompt_messages) > 0
and isinstance(prompt_messages[-1], UserPromptMessage)
and isinstance(prompt_messages[-1].content, list)
):
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
else:
prompt_messages.append(UserPromptMessage(content=file_prompts))
# Remove empty messages and filter unsupported content
filtered_prompt_messages = []
for prompt_message in prompt_messages:
if isinstance(prompt_message.content, list):
prompt_message_content: list[PromptMessageContentUnionTypes] = []
for content_item in prompt_message.content:
# Skip content if features are not defined
if not model_schema.features:
if content_item.type != PromptMessageContentType.TEXT:
continue
prompt_message_content.append(content_item)
continue
# Skip content if corresponding feature is not supported
if (
(
content_item.type == PromptMessageContentType.IMAGE
and ModelFeature.VISION not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.DOCUMENT
and ModelFeature.DOCUMENT not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.VIDEO
and ModelFeature.VIDEO not in model_schema.features
)
or (
content_item.type == PromptMessageContentType.AUDIO
and ModelFeature.AUDIO not in model_schema.features
)
):
continue
prompt_message_content.append(content_item)
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
prompt_message.content = prompt_message_content[0].data
else:
prompt_message.content = prompt_message_content
if prompt_message.is_empty():
continue
filtered_prompt_messages.append(prompt_message)
if len(filtered_prompt_messages) == 0:
raise NoPromptFoundError(
"No prompt found in the LLM configuration. "
"Please ensure a prompt is properly configured before proceeding."
)
return filtered_prompt_messages, stop
return llm_utils.fetch_prompt_messages(
sys_query=sys_query,
sys_files=sys_files,
context=context,
memory=memory,
model_instance=model_instance,
prompt_template=prompt_template,
stop=stop,
memory_config=memory_config,
vision_enabled=vision_enabled,
vision_detail=vision_detail,
variable_pool=variable_pool,
jinja2_variables=jinja2_variables,
context_files=context_files,
template_renderer=template_renderer,
)
@classmethod
def _extract_variable_selector_to_variable_mapping(
@@ -1048,59 +881,16 @@ class LLMNode(Node[LLMNodeData]):
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
vision_detail_config: ImagePromptMessageContent.DETAIL,
template_renderer: TemplateRenderer | None = None,
) -> Sequence[PromptMessage]:
prompt_messages: list[PromptMessage] = []
for message in messages:
if message.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=message.jinja2_text or "",
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=message.role
)
prompt_messages.append(prompt_message)
else:
# Get segment group from basic message
if context:
template = message.text.replace("{#context#}", context)
else:
template = message.text
segment_group = variable_pool.convert_template(template)
# Process segments for images
file_contents = []
for segment in segment_group.value:
if isinstance(segment, ArrayFileSegment):
for file in segment.value:
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
elif isinstance(segment, FileSegment):
file = segment.value
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
file_content = file_manager.to_prompt_message_content(
file, image_detail_config=vision_detail_config
)
file_contents.append(file_content)
# Create message with text from all segments
plain_text = segment_group.text
if plain_text:
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
)
prompt_messages.append(prompt_message)
if file_contents:
# Create message with image contents
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
prompt_messages.append(prompt_message)
return prompt_messages
return llm_utils.handle_list_messages(
messages=messages,
context=context,
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
vision_detail_config=vision_detail_config,
template_renderer=template_renderer,
)
@staticmethod
def handle_blocking_result(
@@ -1239,152 +1029,3 @@ class LLMNode(Node[LLMNodeData]):
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
def _combine_message_content_with_role(
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
):
match role:
case PromptMessageRole.USER:
return UserPromptMessage(content=contents)
case PromptMessageRole.ASSISTANT:
return AssistantPromptMessage(content=contents)
case PromptMessageRole.SYSTEM:
return SystemPromptMessage(content=contents)
case _:
raise NotImplementedError(f"Role {role} is not supported")
def _render_jinja2_message(
*,
template: str,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
):
if not template:
return ""
jinja2_inputs = {}
for jinja2_variable in jinja2_variables:
variable = variable_pool.get(jinja2_variable.value_selector)
jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
code_execute_resp = CodeExecutor.execute_workflow_code_template(
language=CodeLanguage.JINJA2,
code=template,
inputs=jinja2_inputs,
)
result_text = code_execute_resp["result"]
return result_text
def _calculate_rest_token(
*,
prompt_messages: list[PromptMessage],
model_instance: ModelInstance,
) -> int:
rest_tokens = 2000
runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
runtime_model_parameters = model_instance.parameters
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
if model_context_tokens:
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
max_tokens = 0
for parameter_rule in runtime_model_schema.parameter_rules:
if parameter_rule.name == "max_tokens" or (
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
):
max_tokens = (
runtime_model_parameters.get(parameter_rule.name)
or runtime_model_parameters.get(str(parameter_rule.use_template))
or 0
)
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
rest_tokens = max(rest_tokens, 0)
return rest_tokens
def _handle_memory_chat_mode(
*,
memory: PromptMessageMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
) -> Sequence[PromptMessage]:
memory_messages: Sequence[PromptMessage] = []
# Get messages from memory for chat model
if memory and memory_config:
rest_tokens = _calculate_rest_token(
prompt_messages=[],
model_instance=model_instance,
)
memory_messages = memory.get_history_prompt_messages(
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
)
return memory_messages
def _handle_memory_completion_mode(
*,
memory: PromptMessageMemory | None,
memory_config: MemoryConfig | None,
model_instance: ModelInstance,
) -> str:
memory_text = ""
# Get history text from memory for completion model
if memory and memory_config:
rest_tokens = _calculate_rest_token(
prompt_messages=[],
model_instance=model_instance,
)
if not memory_config.role_prefix:
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
memory_text = llm_utils.fetch_memory_text(
memory=memory,
max_token_limit=rest_tokens,
message_limit=memory_config.window.size if memory_config.window.enabled else None,
human_prefix=memory_config.role_prefix.user,
ai_prefix=memory_config.role_prefix.assistant,
)
return memory_text
def _handle_completion_template(
*,
template: LLMNodeCompletionModelPromptTemplate,
context: str | None,
jinja2_variables: Sequence[VariableSelector],
variable_pool: VariablePool,
) -> Sequence[PromptMessage]:
"""Handle completion template processing outside of LLMNode class.
Args:
template: The completion model prompt template
context: Optional context string
jinja2_variables: Variables for jinja2 template rendering
variable_pool: Variable pool for template conversion
Returns:
Sequence of prompt messages
"""
prompt_messages = []
if template.edition_type == "jinja2":
result_text = _render_jinja2_message(
template=template.jinja2_text or "",
jinja2_variables=jinja2_variables,
variable_pool=variable_pool,
)
else:
if context:
template_text = template.text.replace("{#context#}", context)
else:
template_text = template.text
result_text = variable_pool.convert_template(template_text).text
prompt_message = _combine_message_content_with_role(
contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER
)
prompt_messages.append(prompt_message)
return prompt_messages

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any, Protocol
from core.model_manager import ModelInstance
@@ -19,3 +20,11 @@ class ModelFactory(Protocol):
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
"""Create a model instance that is ready for schema lookup and invocation."""
...
class TemplateRenderer(Protocol):
"""Port for rendering prompt templates used by LLM-compatible nodes."""
def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str:
"""Render the given Jinja2 template into plain text."""
...

View File

@@ -28,7 +28,7 @@ from dify_graph.nodes.llm import (
llm_utils,
)
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.protocols import HttpClientProtocol
from libs.json_in_md_parser import parse_and_check_json_markdown
@@ -59,6 +59,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
_model_factory: "ModelFactory"
_model_instance: ModelInstance
_memory: PromptMessageMemory | None
_template_renderer: TemplateRenderer
def __init__(
self,
@@ -71,6 +72,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
model_factory: "ModelFactory",
model_instance: ModelInstance,
http_client: HttpClientProtocol,
template_renderer: TemplateRenderer,
memory: PromptMessageMemory | None = None,
llm_file_saver: LLMFileSaver | None = None,
):
@@ -87,6 +89,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
self._model_factory = model_factory
self._model_instance = model_instance
self._memory = memory
self._template_renderer = template_renderer
if llm_file_saver is None:
dify_ctx = self.require_dify_context()
@@ -142,7 +145,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
# two consecutive user prompts will be generated, causing model's error.
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
prompt_messages, stop = LLMNode.fetch_prompt_messages(
prompt_messages, stop = llm_utils.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query="",
memory=memory,
@@ -153,6 +156,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
vision_detail=node_data.vision.configs.detail,
variable_pool=variable_pool,
jinja2_variables=[],
template_renderer=self._template_renderer,
)
result_text = ""
@@ -287,7 +291,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
prompt_messages, _ = LLMNode.fetch_prompt_messages(
prompt_messages, _ = llm_utils.fetch_prompt_messages(
prompt_template=prompt_template,
sys_query="",
sys_files=[],
@@ -300,6 +304,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
vision_detail=node_data.vision.configs.detail,
variable_pool=self.graph_runtime_state.variable_pool,
jinja2_variables=[],
template_renderer=self._template_renderer,
)
rest_tokens = 2000

View File

@@ -0,0 +1,68 @@
"""add indexes for human_input_forms query patterns
Revision ID: 0ec65df55790
Revises: e288952f2994
Create Date: 2026-03-02 18:05:00.000000
"""
from alembic import op
# revision identifiers, used by Alembic.
revision = "0ec65df55790"
down_revision = "e288952f2994"
branch_labels = None
depends_on = None
def upgrade():
with op.batch_alter_table("human_input_forms", schema=None) as batch_op:
batch_op.create_index(
"human_input_forms_workflow_run_id_node_id_idx",
["workflow_run_id", "node_id"],
unique=False,
)
batch_op.create_index(
"human_input_forms_status_created_at_idx",
["status", "created_at"],
unique=False,
)
batch_op.create_index(
"human_input_forms_status_expiration_time_idx",
["status", "expiration_time"],
unique=False,
)
with op.batch_alter_table("human_input_form_deliveries", schema=None) as batch_op:
batch_op.create_index(
batch_op.f("human_input_form_deliveries_form_id_idx"),
["form_id"],
unique=False,
)
with op.batch_alter_table("human_input_form_recipients", schema=None) as batch_op:
batch_op.create_index(
batch_op.f("human_input_form_recipients_delivery_id_idx"),
["delivery_id"],
unique=False,
)
batch_op.create_index(
batch_op.f("human_input_form_recipients_form_id_idx"),
["form_id"],
unique=False,
)
def downgrade():
with op.batch_alter_table("human_input_forms", schema=None) as batch_op:
batch_op.drop_index("human_input_forms_workflow_run_id_node_id_idx")
batch_op.drop_index("human_input_forms_status_expiration_time_idx")
batch_op.drop_index("human_input_forms_status_created_at_idx")
with op.batch_alter_table("human_input_form_recipients", schema=None) as batch_op:
batch_op.drop_index(batch_op.f("human_input_form_recipients_form_id_idx"))
batch_op.drop_index(batch_op.f("human_input_form_recipients_delivery_id_idx"))
with op.batch_alter_table("human_input_form_deliveries", schema=None) as batch_op:
batch_op.drop_index(batch_op.f("human_input_form_deliveries_form_id_idx"))

View File

@@ -30,6 +30,15 @@ def _generate_token() -> str:
class HumanInputForm(DefaultFieldsMixin, Base):
__tablename__ = "human_input_forms"
__table_args__ = (
sa.Index(
"human_input_forms_workflow_run_id_node_id_idx",
"workflow_run_id",
"node_id",
),
sa.Index("human_input_forms_status_expiration_time_idx", "status", "expiration_time"),
sa.Index("human_input_forms_status_created_at_idx", "status", "created_at"),
)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
@@ -84,6 +93,12 @@ class HumanInputForm(DefaultFieldsMixin, Base):
class HumanInputDelivery(DefaultFieldsMixin, Base):
__tablename__ = "human_input_form_deliveries"
__table_args__ = (
sa.Index(
None,
"form_id",
),
)
form_id: Mapped[str] = mapped_column(
StringUUID,
@@ -181,6 +196,10 @@ RecipientPayload = Annotated[
class HumanInputFormRecipient(DefaultFieldsMixin, Base):
__tablename__ = "human_input_form_recipients"
__table_args__ = (
sa.Index(None, "form_id"),
sa.Index(None, "delivery_id"),
)
form_id: Mapped[str] = mapped_column(
StringUUID,

View File

@@ -6,12 +6,12 @@ requires-python = ">=3.11,<3.13"
dependencies = [
"aliyun-log-python-sdk~=0.9.37",
"arize-phoenix-otel~=0.15.0",
"azure-identity==1.25.2",
"beautifulsoup4==4.12.2",
"boto3==1.42.65",
"azure-identity==1.25.3",
"beautifulsoup4==4.14.3",
"boto3==1.42.68",
"bs4~=0.0.1",
"cachetools~=5.3.0",
"celery~=5.5.2",
"celery~=5.6.2",
"charset-normalizer>=3.4.4",
"flask~=3.1.2",
"flask-compress>=1.17,<1.24",
@@ -35,12 +35,12 @@ dependencies = [
"jsonschema>=4.25.1",
"langfuse~=2.51.3",
"langsmith~=0.7.16",
"markdown~=3.8.1",
"markdown~=3.10.2",
"mlflow-skinny>=3.0.0",
"numpy~=1.26.4",
"openpyxl~=3.1.5",
"opik~=1.10.37",
"litellm==1.82.1", # Pinned to avoid madoka dependency issue
"litellm==1.82.2", # Pinned to avoid madoka dependency issue
"opentelemetry-api==1.28.0",
"opentelemetry-distro==0.49b0",
"opentelemetry-exporter-otlp==1.28.0",
@@ -58,7 +58,7 @@ dependencies = [
"opentelemetry-sdk==1.28.0",
"opentelemetry-semantic-conventions==0.49b0",
"opentelemetry-util-http==0.49b0",
"pandas[excel,output-formatting,performance]~=2.2.2",
"pandas[excel,output-formatting,performance]~=3.0.1",
"psycogreen~=1.0.2",
"psycopg2-binary~=2.9.6",
"pycryptodome==3.23.0",
@@ -66,22 +66,22 @@ dependencies = [
"pydantic-extra-types~=2.11.0",
"pydantic-settings~=2.13.1",
"pyjwt~=2.12.0",
"pypdfium2==5.2.0",
"pypdfium2==5.6.0",
"python-docx~=1.2.0",
"python-dotenv==1.0.1",
"python-dotenv==1.2.2",
"pyyaml~=6.0.1",
"readabilipy~=0.3.0",
"redis[hiredis]~=7.3.0",
"resend~=2.9.0",
"sentry-sdk[flask]~=2.28.0",
"resend~=2.23.0",
"sentry-sdk[flask]~=2.54.0",
"sqlalchemy~=2.0.29",
"starlette==0.49.1",
"starlette==0.52.1",
"tiktoken~=0.12.0",
"transformers~=5.3.0",
"unstructured[docx,epub,md,ppt,pptx]~=0.18.18",
"yarl~=1.18.3",
"unstructured[docx,epub,md,ppt,pptx]~=0.21.5",
"yarl~=1.23.0",
"webvtt-py~=0.5.1",
"sseclient-py~=1.8.0",
"sseclient-py~=1.9.0",
"httpx-sse~=0.4.0",
"sendgrid~=6.12.3",
"flask-restx~=1.3.2",
@@ -111,7 +111,7 @@ package = false
dev = [
"coverage~=7.13.4",
"dotenv-linter~=0.7.0",
"faker~=40.8.0",
"faker~=40.11.0",
"lxml-stubs~=0.5.1",
"basedpyright~=1.38.2",
"ruff~=0.15.5",
@@ -120,7 +120,7 @@ dev = [
"pytest-cov~=7.0.0",
"pytest-env~=1.1.3",
"pytest-mock~=3.15.1",
"testcontainers~=4.13.2",
"testcontainers~=4.14.1",
"types-aiofiles~=25.1.0",
"types-beautifulsoup4~=4.12.0",
"types-cachetools~=6.2.0",

View File

@@ -6,6 +6,13 @@ from typing import Any
import httpx
from core.helper.trace_id_helper import generate_traceparent_header
from services.errors.enterprise import (
EnterpriseAPIBadRequestError,
EnterpriseAPIError,
EnterpriseAPIForbiddenError,
EnterpriseAPINotFoundError,
EnterpriseAPIUnauthorizedError,
)
logger = logging.getLogger(__name__)
@@ -64,10 +71,51 @@ class BaseRequest:
request_kwargs["timeout"] = timeout
response = client.request(method, url, **request_kwargs)
if raise_for_status:
response.raise_for_status()
# Validate HTTP status and raise domain-specific errors
if not response.is_success:
cls._handle_error_response(response)
return response.json()
@classmethod
def _handle_error_response(cls, response: httpx.Response) -> None:
"""
Handle non-2xx HTTP responses by raising appropriate domain errors.
Attempts to extract error message from JSON response body,
falls back to status text if parsing fails.
"""
error_message = f"Enterprise API request failed: {response.status_code} {response.reason_phrase}"
# Try to extract error message from JSON response
try:
error_data = response.json()
if isinstance(error_data, dict):
# Common error response formats:
# {"error": "...", "message": "..."}
# {"message": "..."}
# {"detail": "..."}
error_message = (
error_data.get("message") or error_data.get("error") or error_data.get("detail") or error_message
)
except Exception:
# If JSON parsing fails, use the default message
logger.debug(
"Failed to parse error response from enterprise API (status=%s)", response.status_code, exc_info=True
)
# Raise specific error based on status code
if response.status_code == 400:
raise EnterpriseAPIBadRequestError(error_message)
elif response.status_code == 401:
raise EnterpriseAPIUnauthorizedError(error_message)
elif response.status_code == 403:
raise EnterpriseAPIForbiddenError(error_message)
elif response.status_code == 404:
raise EnterpriseAPINotFoundError(error_message)
else:
raise EnterpriseAPIError(error_message, status_code=response.status_code)
class EnterpriseRequest(BaseRequest):
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")

View File

@@ -1,15 +1,26 @@
from __future__ import annotations
import logging
import uuid
from datetime import datetime
from typing import TYPE_CHECKING
from pydantic import BaseModel, ConfigDict, Field, model_validator
from configs import dify_config
from extensions.ext_redis import redis_client
from services.enterprise.base import EnterpriseRequest
if TYPE_CHECKING:
from services.feature_service import LicenseStatus
logger = logging.getLogger(__name__)
DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0
# License status cache configuration
LICENSE_STATUS_CACHE_KEY = "enterprise:license:status"
VALID_LICENSE_CACHE_TTL = 600 # 10 minutes — valid licenses are stable
INVALID_LICENSE_CACHE_TTL = 30 # 30 seconds — short so admin fixes are picked up quickly
class WebAppSettings(BaseModel):
@@ -52,7 +63,7 @@ class DefaultWorkspaceJoinResult(BaseModel):
model_config = ConfigDict(extra="forbid", populate_by_name=True)
@model_validator(mode="after")
def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult":
def _check_workspace_id_when_joined(self) -> DefaultWorkspaceJoinResult:
if self.joined and not self.workspace_id:
raise ValueError("workspace_id must be non-empty when joined is True")
return self
@@ -115,7 +126,6 @@ class EnterpriseService:
"/default-workspace/members",
json={"account_id": account_id},
timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS,
raise_for_status=True,
)
if not isinstance(data, dict):
raise ValueError("Invalid response format from enterprise default workspace API")
@@ -223,3 +233,64 @@ class EnterpriseService:
params = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
@classmethod
def get_cached_license_status(cls) -> LicenseStatus | None:
"""Get enterprise license status with Redis caching to reduce HTTP calls.
Caches valid statuses (active/expiring) for 10 minutes and invalid statuses
(inactive/expired/lost) for 30 seconds. The shorter TTL for invalid statuses
balances prompt license-fix detection against DoS mitigation — without
caching, every request on an expired license would hit the enterprise API.
Returns:
LicenseStatus enum value, or None if enterprise is disabled / unreachable.
"""
if not dify_config.ENTERPRISE_ENABLED:
return None
cached = cls._read_cached_license_status()
if cached is not None:
return cached
return cls._fetch_and_cache_license_status()
@classmethod
def _read_cached_license_status(cls) -> LicenseStatus | None:
"""Read license status from Redis cache, returning None on miss or failure."""
from services.feature_service import LicenseStatus
try:
raw = redis_client.get(LICENSE_STATUS_CACHE_KEY)
if raw:
value = raw.decode("utf-8") if isinstance(raw, bytes) else raw
return LicenseStatus(value)
except Exception:
logger.debug("Failed to read license status from cache", exc_info=True)
return None
@classmethod
def _fetch_and_cache_license_status(cls) -> LicenseStatus | None:
"""Fetch license status from enterprise API and cache the result."""
from services.feature_service import LicenseStatus
try:
info = cls.get_info()
license_info = info.get("License")
if not license_info:
return None
status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
ttl = (
VALID_LICENSE_CACHE_TTL
if status in (LicenseStatus.ACTIVE, LicenseStatus.EXPIRING)
else INVALID_LICENSE_CACHE_TTL
)
try:
redis_client.setex(LICENSE_STATUS_CACHE_KEY, ttl, status)
except Exception:
logger.debug("Failed to cache license status", exc_info=True)
return status
except Exception:
logger.debug("Failed to fetch enterprise license status", exc_info=True)
return None

View File

@@ -70,7 +70,6 @@ class PluginManagerService:
"POST",
"/pre-uninstall-plugin",
json=body.model_dump(),
raise_for_status=True,
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
)
except Exception:

View File

@@ -7,6 +7,7 @@ from . import (
conversation,
dataset,
document,
enterprise,
file,
index,
message,
@@ -21,6 +22,7 @@ __all__ = [
"conversation",
"dataset",
"document",
"enterprise",
"file",
"index",
"message",

View File

@@ -0,0 +1,45 @@
"""Enterprise service errors."""
from services.errors.base import BaseServiceError
class EnterpriseServiceError(BaseServiceError):
"""Base exception for enterprise service errors."""
def __init__(self, description: str | None = None, status_code: int | None = None):
super().__init__(description)
self.status_code = status_code
class EnterpriseAPIError(EnterpriseServiceError):
"""Generic enterprise API error (non-2xx response)."""
pass
class EnterpriseAPINotFoundError(EnterpriseServiceError):
"""Enterprise API returned 404 Not Found."""
def __init__(self, description: str | None = None):
super().__init__(description, status_code=404)
class EnterpriseAPIForbiddenError(EnterpriseServiceError):
"""Enterprise API returned 403 Forbidden."""
def __init__(self, description: str | None = None):
super().__init__(description, status_code=403)
class EnterpriseAPIUnauthorizedError(EnterpriseServiceError):
"""Enterprise API returned 401 Unauthorized."""
def __init__(self, description: str | None = None):
super().__init__(description, status_code=401)
class EnterpriseAPIBadRequestError(EnterpriseServiceError):
"""Enterprise API returned 400 Bad Request."""
def __init__(self, description: str | None = None):
super().__init__(description, status_code=400)

View File

@@ -379,14 +379,19 @@ class FeatureService:
)
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
if is_authenticated and (license_info := enterprise_info.get("License")):
# SECURITY NOTE: Only license *status* is exposed to unauthenticated callers
# so the login page can detect an expired/inactive license after force-logout.
# All other license details (expiry date, workspace usage) remain auth-gated.
# This behavior reflects prior internal review of information-leakage risks.
if license_info := enterprise_info.get("License"):
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
features.license.expired_at = license_info.get("expiredAt", "")
if workspaces_info := license_info.get("workspaces"):
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
features.license.workspaces.limit = workspaces_info.get("limit", 0)
features.license.workspaces.size = workspaces_info.get("used", 0)
if is_authenticated:
features.license.expired_at = license_info.get("expiredAt", "")
if workspaces_info := license_info.get("workspaces"):
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
features.license.workspaces.limit = workspaces_info.get("limit", 0)
features.license.workspaces.size = workspaces_info.get("used", 0)
if "PluginInstallationPermission" in enterprise_info:
plugin_installation_info = enterprise_info["PluginInstallationPermission"]

View File

@@ -10,7 +10,7 @@ from core.model_manager import ModelInstance
from dify_graph.enums import WorkflowNodeExecutionStatus
from dify_graph.node_events import StreamCompletedEvent
from dify_graph.nodes.llm.node import LLMNode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
@@ -75,6 +75,7 @@ def init_llm_node(config: dict) -> LLMNode:
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(spec=ModelInstance),
template_renderer=MagicMock(spec=TemplateRenderer),
http_client=MagicMock(spec=HttpClientProtocol),
)
@@ -158,7 +159,7 @@ def test_execute_llm():
return mock_model_instance
# Mock fetch_prompt_messages to avoid database calls
def mock_fetch_prompt_messages_1(**_kwargs):
def mock_fetch_prompt_messages_1(*_args, **_kwargs):
from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
return [

View File

@@ -358,10 +358,9 @@ class TestFeatureService:
assert result is not None
assert isinstance(result, SystemFeatureModel)
# --- 1. Verify Response Payload Optimization (Data Minimization) ---
# Ensure only essential UI flags are returned to unauthenticated clients
# to keep the payload lightweight and adhere to architectural boundaries.
assert result.license.status == LicenseStatus.NONE
# --- 1. Verify only license *status* is exposed to unauthenticated clients ---
# Detailed license info (expiry, workspaces) remains auth-gated.
assert result.license.status == LicenseStatus.ACTIVE
assert result.license.expired_at == ""
assert result.license.workspaces.enabled is False
assert result.license.workspaces.limit == 0

View File

@@ -20,7 +20,7 @@ from dify_graph.nodes.code import CodeNode
from dify_graph.nodes.document_extractor import DocumentExtractorNode
from dify_graph.nodes.http_request import HttpRequestNode
from dify_graph.nodes.llm import LLMNode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.parameter_extractor import ParameterExtractorNode
from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol
from dify_graph.nodes.question_classifier import QuestionClassifierNode
@@ -68,6 +68,8 @@ class MockNodeMixin:
kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance))
# LLM-like nodes now require an http_client; provide a mock by default for tests.
kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol))
if isinstance(self, (LLMNode, QuestionClassifierNode)):
kwargs.setdefault("template_renderer", MagicMock(spec=TemplateRenderer))
# Ensure TemplateTransformNode receives a renderer now required by constructor
if isinstance(self, TemplateTransformNode):

View File

@@ -34,8 +34,8 @@ from dify_graph.nodes.llm.entities import (
VisionConfigOptions,
)
from dify_graph.nodes.llm.file_saver import LLMFileSaver
from dify_graph.nodes.llm.node import LLMNode, _handle_memory_completion_mode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
from dify_graph.nodes.llm.node import LLMNode
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.runtime import GraphRuntimeState, VariablePool
from dify_graph.system_variable import SystemVariable
from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
@@ -107,6 +107,7 @@ def llm_node(
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
mock_template_renderer = mock.MagicMock(spec=TemplateRenderer)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
@@ -121,6 +122,7 @@ def llm_node(
model_factory=mock_model_factory,
model_instance=mock.MagicMock(spec=ModelInstance),
llm_file_saver=mock_file_saver,
template_renderer=mock_template_renderer,
http_client=http_client,
)
return node
@@ -590,6 +592,33 @@ def test_handle_list_messages_basic(llm_node):
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
def test_handle_list_messages_jinja2_uses_template_renderer(llm_node):
llm_node._template_renderer.render_jinja2.return_value = "Hello, world"
messages = [
LLMNodeChatModelMessage(
text="",
jinja2_text="Hello, {{ name }}",
role=PromptMessageRole.USER,
edition_type="jinja2",
)
]
result = llm_node.handle_list_messages(
messages=messages,
context=None,
jinja2_variables=[],
variable_pool=llm_node.graph_runtime_state.variable_pool,
vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
template_renderer=llm_node._template_renderer,
)
assert result == [UserPromptMessage(content=[TextPromptMessageContent(data="Hello, world")])]
llm_node._template_renderer.render_jinja2.assert_called_once_with(
template="Hello, {{ name }}",
inputs={},
)
def test_handle_memory_completion_mode_uses_prompt_message_interface():
memory = mock.MagicMock(spec=MockTokenBufferMemory)
memory.get_history_prompt_messages.return_value = [
@@ -613,8 +642,8 @@ def test_handle_memory_completion_mode_uses_prompt_message_interface():
window=MemoryConfig.WindowConfig(enabled=True, size=3),
)
with mock.patch("dify_graph.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token:
memory_text = _handle_memory_completion_mode(
with mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=2000) as mock_rest_token:
memory_text = llm_utils.handle_memory_completion_mode(
memory=memory,
memory_config=memory_config,
model_instance=model_instance,
@@ -630,6 +659,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
mock_model_factory = mock.MagicMock(spec=ModelFactory)
mock_template_renderer = mock.MagicMock(spec=TemplateRenderer)
node_config = {
"id": "1",
"data": llm_node_data.model_dump(),
@@ -644,6 +674,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
model_factory=mock_model_factory,
model_instance=mock.MagicMock(spec=ModelInstance),
llm_file_saver=mock_file_saver,
template_renderer=mock_template_renderer,
http_client=http_client,
)
return node, mock_file_saver

View File

@@ -1,5 +1,14 @@
from types import SimpleNamespace
from unittest.mock import MagicMock
from dify_graph.model_runtime.entities import ImagePromptMessageContent
from dify_graph.nodes.question_classifier import QuestionClassifierNodeData
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
from dify_graph.nodes.protocols import HttpClientProtocol
from dify_graph.nodes.question_classifier import (
QuestionClassifierNode,
QuestionClassifierNodeData,
)
from tests.workflow_test_utils import build_test_graph_init_params
def test_init_question_classifier_node_data():
@@ -65,3 +74,52 @@ def test_init_question_classifier_node_data_without_vision_config():
assert node_data.vision.enabled == False
assert node_data.vision.configs.variable_selector == ["sys", "files"]
assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH
def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(monkeypatch):
node_data = QuestionClassifierNodeData.model_validate(
{
"title": "test classifier node",
"query_variable_selector": ["id", "name"],
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}},
"classes": [{"id": "1", "name": "class 1"}],
"instruction": "This is a test instruction",
}
)
template_renderer = MagicMock(spec=TemplateRenderer)
node = QuestionClassifierNode(
id="node-id",
config={"id": "node-id", "data": node_data.model_dump(mode="json")},
graph_init_params=build_test_graph_init_params(
workflow_id="workflow-id",
graph_config={},
tenant_id="tenant-id",
app_id="app-id",
user_id="user-id",
),
graph_runtime_state=SimpleNamespace(variable_pool=MagicMock()),
credentials_provider=MagicMock(spec=CredentialsProvider),
model_factory=MagicMock(spec=ModelFactory),
model_instance=MagicMock(),
http_client=MagicMock(spec=HttpClientProtocol),
llm_file_saver=MagicMock(),
template_renderer=template_renderer,
)
fetch_prompt_messages = MagicMock(return_value=([], None))
monkeypatch.setattr(
"dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages",
fetch_prompt_messages,
)
monkeypatch.setattr(
"dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema",
MagicMock(return_value=SimpleNamespace(model_properties={}, parameter_rules=[])),
)
node._calculate_rest_token(
node_data=node_data,
query="hello",
model_instance=MagicMock(stop=(), parameters={}),
context="",
)
assert fetch_prompt_messages.call_args.kwargs["template_renderer"] is template_renderer

View File

@@ -140,6 +140,29 @@ class TestDefaultWorkflowCodeExecutor:
assert executor.is_execution_error(RuntimeError("boom")) is False
class TestDefaultLLMTemplateRenderer:
def test_render_jinja2_delegates_to_code_executor(self, monkeypatch):
renderer = node_factory.DefaultLLMTemplateRenderer()
execute_workflow_code_template = MagicMock(return_value={"result": "hello world"})
monkeypatch.setattr(
node_factory.CodeExecutor,
"execute_workflow_code_template",
execute_workflow_code_template,
)
result = renderer.render_jinja2(
template="Hello {{ name }}",
inputs={"name": "world"},
)
assert result == "hello world"
execute_workflow_code_template.assert_called_once_with(
language=CodeLanguage.JINJA2,
code="Hello {{ name }}",
inputs={"name": "world"},
)
class TestDifyNodeFactoryInit:
def test_init_builds_default_dependencies(self):
graph_init_params = SimpleNamespace(run_context={"context": "value"})
@@ -150,6 +173,7 @@ class TestDifyNodeFactoryInit:
http_request_config = sentinel.http_request_config
credentials_provider = sentinel.credentials_provider
model_factory = sentinel.model_factory
llm_template_renderer = sentinel.llm_template_renderer
with (
patch.object(
@@ -172,6 +196,11 @@ class TestDifyNodeFactoryInit:
"build_http_request_config",
return_value=http_request_config,
),
patch.object(
node_factory,
"DefaultLLMTemplateRenderer",
return_value=llm_template_renderer,
) as llm_renderer_factory,
patch.object(
node_factory,
"build_dify_model_access",
@@ -186,11 +215,14 @@ class TestDifyNodeFactoryInit:
resolve_dify_context.assert_called_once_with(graph_init_params.run_context)
build_dify_model_access.assert_called_once_with("tenant-id")
renderer_factory.assert_called_once()
llm_renderer_factory.assert_called_once()
assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor
assert factory.graph_init_params is graph_init_params
assert factory.graph_runtime_state is graph_runtime_state
assert factory._dify_context is dify_context
assert factory._template_renderer is template_renderer
assert factory._llm_template_renderer is llm_template_renderer
assert factory._document_extractor_unstructured_api_config is unstructured_api_config
assert factory._http_request_config is http_request_config
assert factory._llm_credentials_provider is credentials_provider
@@ -242,6 +274,7 @@ class TestDifyNodeFactoryCreateNode:
factory._code_executor = sentinel.code_executor
factory._code_limits = sentinel.code_limits
factory._template_renderer = sentinel.template_renderer
factory._llm_template_renderer = sentinel.llm_template_renderer
factory._template_transform_max_output_length = 2048
factory._http_request_http_client = sentinel.http_client
factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory
@@ -378,8 +411,22 @@ class TestDifyNodeFactoryCreateNode:
@pytest.mark.parametrize(
("node_type", "constructor_name", "expected_extra_kwargs"),
[
(BuiltinNodeTypes.LLM, "LLMNode", {"http_client": sentinel.http_client}),
(BuiltinNodeTypes.QUESTION_CLASSIFIER, "QuestionClassifierNode", {"http_client": sentinel.http_client}),
(
BuiltinNodeTypes.LLM,
"LLMNode",
{
"http_client": sentinel.http_client,
"template_renderer": sentinel.llm_template_renderer,
},
),
(
BuiltinNodeTypes.QUESTION_CLASSIFIER,
"QuestionClassifierNode",
{
"http_client": sentinel.http_client,
"template_renderer": sentinel.llm_template_renderer,
},
),
(BuiltinNodeTypes.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}),
],
)

View File

@@ -1,9 +1,8 @@
"""Unit tests for enterprise service integrations.
This module covers the enterprise-only default workspace auto-join behavior:
- Enterprise mode disabled: no external calls
- Successful join / skipped join: no errors
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
Covers:
- Default workspace auto-join behavior
- License status caching (get_cached_license_status)
"""
from unittest.mock import patch
@@ -11,6 +10,9 @@ from unittest.mock import patch
import pytest
from services.enterprise.enterprise_service import (
INVALID_LICENSE_CACHE_TTL,
LICENSE_STATUS_CACHE_KEY,
VALID_LICENSE_CACHE_TTL,
DefaultWorkspaceJoinResult,
EnterpriseService,
try_join_default_workspace,
@@ -37,7 +39,6 @@ class TestJoinDefaultWorkspace:
"/default-workspace/members",
json={"account_id": account_id},
timeout=1.0,
raise_for_status=True,
)
def test_join_default_workspace_invalid_response_format_raises(self):
@@ -139,3 +140,134 @@ class TestTryJoinDefaultWorkspace:
# Should not raise even though UUID parsing fails inside join_default_workspace
try_join_default_workspace("not-a-uuid")
# ---------------------------------------------------------------------------
# get_cached_license_status
# ---------------------------------------------------------------------------
_EE_SVC = "services.enterprise.enterprise_service"
class TestGetCachedLicenseStatus:
"""Tests for EnterpriseService.get_cached_license_status."""
def test_returns_none_when_enterprise_disabled(self):
with patch(f"{_EE_SVC}.dify_config") as mock_config:
mock_config.ENTERPRISE_ENABLED = False
assert EnterpriseService.get_cached_license_status() is None
def test_cache_hit_returns_license_status_enum(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = b"active"
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.ACTIVE
assert isinstance(result, LicenseStatus)
mock_get_info.assert_not_called()
def test_cache_miss_fetches_api_and_caches_valid_status(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.return_value = {"License": {"status": "active"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.ACTIVE
mock_redis.setex.assert_called_once_with(
LICENSE_STATUS_CACHE_KEY, VALID_LICENSE_CACHE_TTL, LicenseStatus.ACTIVE
)
def test_cache_miss_fetches_api_and_caches_invalid_status_with_short_ttl(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.return_value = {"License": {"status": "expired"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.EXPIRED
mock_redis.setex.assert_called_once_with(
LICENSE_STATUS_CACHE_KEY, INVALID_LICENSE_CACHE_TTL, LicenseStatus.EXPIRED
)
def test_redis_read_failure_falls_through_to_api(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.side_effect = ConnectionError("redis down")
mock_get_info.return_value = {"License": {"status": "active"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.ACTIVE
mock_get_info.assert_called_once()
def test_redis_write_failure_still_returns_status(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_redis.setex.side_effect = ConnectionError("redis down")
mock_get_info.return_value = {"License": {"status": "expiring"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.EXPIRING
def test_api_failure_returns_none(self):
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.side_effect = Exception("network failure")
assert EnterpriseService.get_cached_license_status() is None
def test_api_returns_no_license_info(self):
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.return_value = {} # no "License" key
assert EnterpriseService.get_cached_license_status() is None
mock_redis.setex.assert_not_called()

View File

@@ -34,7 +34,6 @@ class TestTryPreUninstallPlugin:
"POST",
"/pre-uninstall-plugin",
json={"tenant_id": "tenant-123", "plugin_unique_identifier": "com.example.my_plugin"},
raise_for_status=True,
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
)
@@ -62,7 +61,6 @@ class TestTryPreUninstallPlugin:
"POST",
"/pre-uninstall-plugin",
json={"tenant_id": "tenant-456", "plugin_unique_identifier": "com.example.other_plugin"},
raise_for_status=True,
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
)
mock_logger.exception.assert_called_once()
@@ -87,7 +85,6 @@ class TestTryPreUninstallPlugin:
"POST",
"/pre-uninstall-plugin",
json={"tenant_id": "tenant-789", "plugin_unique_identifier": "com.example.failing_plugin"},
raise_for_status=True,
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
)
mock_logger.exception.assert_called_once()

790
api/uv.lock generated

File diff suppressed because it is too large Load Diff