mirror of
https://github.com/langgenius/dify.git
synced 2026-03-16 04:37:04 +00:00
Compare commits
11 Commits
test/share
...
main
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
977ed79ea0 | ||
|
|
dd39fcd9bc | ||
|
|
3c587097cd | ||
|
|
6a3fcc0a7b | ||
|
|
8d3f2f56d9 | ||
|
|
09dad78a5d | ||
|
|
c71ecd2fe0 | ||
|
|
808d186156 | ||
|
|
ec0a01a568 | ||
|
|
ac23a0409e | ||
|
|
6ef69ff880 |
2
.github/workflows/api-tests.yml
vendored
2
.github/workflows/api-tests.yml
vendored
@@ -27,7 +27,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
2
.github/workflows/autofix.yml
vendored
2
.github/workflows/autofix.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
with:
|
||||
python-version: "3.11"
|
||||
|
||||
- uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
- uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
|
||||
- name: Generate Docker Compose
|
||||
if: steps.docker-compose-changes.outputs.any_changed == 'true'
|
||||
|
||||
2
.github/workflows/build-push.yml
vendored
2
.github/workflows/build-push.yml
vendored
@@ -113,7 +113,7 @@ jobs:
|
||||
context: "web"
|
||||
steps:
|
||||
- name: Download digests
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
with:
|
||||
path: /tmp/digests
|
||||
pattern: digests-${{ matrix.context }}-*
|
||||
|
||||
4
.github/workflows/db-migration-test.yml
vendored
4
.github/workflows/db-migration-test.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
@@ -69,7 +69,7 @@ jobs:
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: "3.12"
|
||||
|
||||
2
.github/workflows/main-ci.yml
vendored
2
.github/workflows/main-ci.yml
vendored
@@ -28,7 +28,7 @@ jobs:
|
||||
migration-changed: ${{ steps.changes.outputs.migration }}
|
||||
steps:
|
||||
- uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
- uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 # v3.0.2
|
||||
- uses: dorny/paths-filter@fbd0ab8f3e69293af611ebaee6363fc25e6d187d # v4.0.1
|
||||
id: changes
|
||||
with:
|
||||
filters: |
|
||||
|
||||
2
.github/workflows/pyrefly-diff.yml
vendored
2
.github/workflows/pyrefly-diff.yml
vendored
@@ -22,7 +22,7 @@ jobs:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Setup Python & UV
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
with:
|
||||
enable-cache: true
|
||||
|
||||
|
||||
2
.github/workflows/style.yml
vendored
2
.github/workflows/style.yml
vendored
@@ -33,7 +33,7 @@ jobs:
|
||||
|
||||
- name: Setup UV and Python
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
with:
|
||||
enable-cache: false
|
||||
python-version: "3.12"
|
||||
|
||||
2
.github/workflows/translate-i18n-claude.yml
vendored
2
.github/workflows/translate-i18n-claude.yml
vendored
@@ -120,7 +120,7 @@ jobs:
|
||||
|
||||
- name: Run Claude Code for Translation Sync
|
||||
if: steps.detect_changes.outputs.CHANGED_FILES != ''
|
||||
uses: anthropics/claude-code-action@26ec041249acb0a944c0a47b6c0c13f05dbc5b44 # v1.0.70
|
||||
uses: anthropics/claude-code-action@cd77b50d2b0808657f8e6774085c8bf54484351c # v1.0.72
|
||||
with:
|
||||
anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }}
|
||||
github_token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/vdb-tests.yml
vendored
2
.github/workflows/vdb-tests.yml
vendored
@@ -31,7 +31,7 @@ jobs:
|
||||
remove_tool_cache: true
|
||||
|
||||
- name: Setup UV and Python
|
||||
uses: astral-sh/setup-uv@6ee6290f1cbc4156c0bdd66691b2c144ef8df19a # v7.4.0
|
||||
uses: astral-sh/setup-uv@e06108dd0aef18192324c70427afc47652e63a82 # v7.5.0
|
||||
with:
|
||||
enable-cache: true
|
||||
python-version: ${{ matrix.python-version }}
|
||||
|
||||
2
.github/workflows/web-tests.yml
vendored
2
.github/workflows/web-tests.yml
vendored
@@ -77,7 +77,7 @@ jobs:
|
||||
uses: ./.github/actions/setup-web
|
||||
|
||||
- name: Download blob reports
|
||||
uses: actions/download-artifact@70fc10c6e5e1ce46ad2ea6f2b72d43f7d47b13c3 # v8.0.0
|
||||
uses: actions/download-artifact@3e5f45b2cfb9172054b4087a40e8e0b5a5461e7c # v8.0.1
|
||||
with:
|
||||
path: web/.vitest-reports
|
||||
pattern: blob-report-*
|
||||
|
||||
@@ -103,7 +103,6 @@ ignore_imports =
|
||||
dify_graph.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
dify_graph.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
dify_graph.nodes.tool.tool_node -> core.tools.utils.message_transformer
|
||||
dify_graph.nodes.llm.node -> core.helper.code_executor
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.errors
|
||||
dify_graph.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
dify_graph.nodes.llm.node -> core.model_manager
|
||||
|
||||
@@ -97,7 +97,7 @@ ENV PATH="${VIRTUAL_ENV}/bin:${PATH}"
|
||||
|
||||
# Download nltk data
|
||||
RUN mkdir -p /usr/local/share/nltk_data \
|
||||
&& NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; from unstructured.nlp.tokenize import download_nltk_packages; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords'); download_nltk_packages()" \
|
||||
&& NLTK_DATA=/usr/local/share/nltk_data python -c "import nltk; nltk.download('punkt'); nltk.download('averaged_perceptron_tagger'); nltk.download('stopwords')" \
|
||||
&& chmod -R 755 /usr/local/share/nltk_data
|
||||
|
||||
ENV TIKTOKEN_CACHE_DIR=/app/api/.tiktoken_cache
|
||||
|
||||
@@ -1,16 +1,45 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from flask import request
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from controllers.console.error import UnauthorizedAndForceLogout
|
||||
from core.logging.context import init_request_context
|
||||
from dify_app import DifyApp
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Console bootstrap APIs exempt from license check.
|
||||
# Defined at module level to avoid per-request tuple construction.
|
||||
# - system-features: license status for expiry UI (GlobalPublicStoreProvider)
|
||||
# - setup: install/setup status check (AppInitializer)
|
||||
# - init: init password validation for fresh install (InitPasswordPopup)
|
||||
# - login: auto-login after setup completion (InstallForm)
|
||||
# - features: billing/plan features (ProviderContextProvider)
|
||||
# - account/profile: login check + user profile (AppContextProvider, useIsLogin)
|
||||
# - workspaces/current: workspace + model providers (AppContextProvider)
|
||||
# - version: version check (AppContextProvider)
|
||||
# - activate/check: invitation link validation (signin page)
|
||||
# Without these exemptions, the signin page triggers location.reload()
|
||||
# on unauthorized_and_force_logout, causing an infinite loop.
|
||||
_CONSOLE_EXEMPT_PREFIXES = (
|
||||
"/console/api/system-features",
|
||||
"/console/api/setup",
|
||||
"/console/api/init",
|
||||
"/console/api/login",
|
||||
"/console/api/features",
|
||||
"/console/api/account/profile",
|
||||
"/console/api/workspaces/current",
|
||||
"/console/api/version",
|
||||
"/console/api/activate/check",
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Application Factory Function
|
||||
@@ -31,6 +60,39 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
init_request_context()
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
# Enterprise license validation for API endpoints (both console and webapp)
|
||||
# When license expires, block all API access except bootstrap endpoints needed
|
||||
# for the frontend to load the license expiration page without infinite reloads.
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
is_console_api = request.path.startswith("/console/api/")
|
||||
is_webapp_api = request.path.startswith("/api/")
|
||||
|
||||
if is_console_api or is_webapp_api:
|
||||
if is_console_api:
|
||||
is_exempt = any(request.path.startswith(p) for p in _CONSOLE_EXEMPT_PREFIXES)
|
||||
else: # webapp API
|
||||
is_exempt = request.path.startswith("/api/system-features")
|
||||
|
||||
if not is_exempt:
|
||||
try:
|
||||
# Check license status (cached — see EnterpriseService for TTL details)
|
||||
license_status = EnterpriseService.get_cached_license_status()
|
||||
if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST):
|
||||
raise UnauthorizedAndForceLogout(
|
||||
f"Enterprise license is {license_status}. Please contact your administrator."
|
||||
)
|
||||
if license_status is None:
|
||||
raise UnauthorizedAndForceLogout(
|
||||
"Unable to verify enterprise license. Please contact your administrator."
|
||||
)
|
||||
except UnauthorizedAndForceLogout:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to check enterprise license status")
|
||||
raise UnauthorizedAndForceLogout(
|
||||
"Unable to verify enterprise license. Please contact your administrator."
|
||||
)
|
||||
|
||||
# add after request hook for injecting trace headers from OpenTelemetry span context
|
||||
# Only adds headers when OTEL is enabled and has valid context
|
||||
@dify_app.after_request
|
||||
|
||||
@@ -45,6 +45,7 @@ from dify_graph.nodes.document_extractor import UnstructuredApiConfig
|
||||
from dify_graph.nodes.http_request import build_http_request_config
|
||||
from dify_graph.nodes.llm.entities import LLMNodeData
|
||||
from dify_graph.nodes.llm.exc import LLMModeRequiredError, ModelNotExistError
|
||||
from dify_graph.nodes.llm.protocols import TemplateRenderer
|
||||
from dify_graph.nodes.parameter_extractor.entities import ParameterExtractorNodeData
|
||||
from dify_graph.nodes.question_classifier.entities import QuestionClassifierNodeData
|
||||
from dify_graph.nodes.template_transform.template_renderer import (
|
||||
@@ -228,6 +229,16 @@ class DefaultWorkflowCodeExecutor:
|
||||
return isinstance(error, CodeExecutionError)
|
||||
|
||||
|
||||
class DefaultLLMTemplateRenderer(TemplateRenderer):
|
||||
def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str:
|
||||
result = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2,
|
||||
code=template,
|
||||
inputs=inputs,
|
||||
)
|
||||
return str(result.get("result", ""))
|
||||
|
||||
|
||||
@final
|
||||
class DifyNodeFactory(NodeFactory):
|
||||
"""
|
||||
@@ -254,6 +265,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
self._template_renderer = CodeExecutorJinja2TemplateRenderer(code_executor=self._code_executor)
|
||||
self._llm_template_renderer: TemplateRenderer = DefaultLLMTemplateRenderer()
|
||||
self._template_transform_max_output_length = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
self._http_request_http_client = ssrf_proxy
|
||||
self._http_request_tool_file_manager_factory = ToolFileManager
|
||||
@@ -391,6 +403,8 @@ class DifyNodeFactory(NodeFactory):
|
||||
model_instance=model_instance,
|
||||
),
|
||||
}
|
||||
if validated_node_data.type in {BuiltinNodeTypes.LLM, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
|
||||
node_init_kwargs["template_renderer"] = self._llm_template_renderer
|
||||
if include_http_client:
|
||||
node_init_kwargs["http_client"] = self._http_request_http_client
|
||||
return node_init_kwargs
|
||||
|
||||
@@ -1,34 +1,53 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
from typing import Any, cast
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.file import FileType, file_manager
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.model_runtime.entities import PromptMessageRole
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
from dify_graph.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import AIModelEntity, ModelFeature, ModelPropertyKey
|
||||
from dify_graph.model_runtime.memory import PromptMessageMemory
|
||||
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment
|
||||
from dify_graph.variables import ArrayFileSegment, FileSegment
|
||||
from dify_graph.variables.segments import ArrayAnySegment, NoneSegment
|
||||
|
||||
from .exc import InvalidVariableTypeError
|
||||
from .entities import LLMNodeChatModelMessage, LLMNodeCompletionModelPromptTemplate, MemoryConfig
|
||||
from .exc import (
|
||||
InvalidVariableTypeError,
|
||||
MemoryRolePrefixRequiredError,
|
||||
NoPromptFoundError,
|
||||
TemplateTypeNotSupportError,
|
||||
)
|
||||
from .protocols import TemplateRenderer
|
||||
|
||||
|
||||
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
|
||||
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
|
||||
model_instance.model_name,
|
||||
model_instance.credentials,
|
||||
dict(model_instance.credentials),
|
||||
)
|
||||
if not model_schema:
|
||||
raise ValueError(f"Model schema not found for {model_instance.model_name}")
|
||||
return model_schema
|
||||
|
||||
|
||||
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence["File"]:
|
||||
def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequence[File]:
|
||||
variable = variable_pool.get(selector)
|
||||
if variable is None:
|
||||
return []
|
||||
@@ -89,3 +108,366 @@ def fetch_memory_text(
|
||||
human_prefix=human_prefix,
|
||||
ai_prefix=ai_prefix,
|
||||
)
|
||||
|
||||
|
||||
def fetch_prompt_messages(
|
||||
*,
|
||||
sys_query: str | None = None,
|
||||
sys_files: Sequence[File],
|
||||
context: str | None = None,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
model_instance: ModelInstance,
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
stop: Sequence[str] | None = None,
|
||||
memory_config: MemoryConfig | None = None,
|
||||
vision_enabled: bool = False,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL,
|
||||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
context_files: list[File] | None = None,
|
||||
template_renderer: TemplateRenderer | None = None,
|
||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
model_schema = fetch_model_schema(model_instance=model_instance)
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
prompt_messages.extend(
|
||||
handle_list_messages(
|
||||
messages=prompt_template,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=vision_detail,
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
)
|
||||
|
||||
prompt_messages.extend(
|
||||
handle_memory_chat_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
)
|
||||
|
||||
if sys_query:
|
||||
prompt_messages.extend(
|
||||
handle_list_messages(
|
||||
messages=[
|
||||
LLMNodeChatModelMessage(
|
||||
text=sys_query,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
],
|
||||
context="",
|
||||
jinja2_variables=[],
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=vision_detail,
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
)
|
||||
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
prompt_messages.extend(
|
||||
handle_completion_template(
|
||||
template=prompt_template,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
)
|
||||
|
||||
memory_text = handle_memory_completion_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
prompt_content = prompt_messages[0].content
|
||||
if isinstance(prompt_content, str):
|
||||
prompt_content = str(prompt_content)
|
||||
if "#histories#" in prompt_content:
|
||||
prompt_content = prompt_content.replace("#histories#", memory_text)
|
||||
else:
|
||||
prompt_content = memory_text + "\n" + prompt_content
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif isinstance(prompt_content, list):
|
||||
for content_item in prompt_content:
|
||||
if isinstance(content_item, TextPromptMessageContent):
|
||||
if "#histories#" in content_item.data:
|
||||
content_item.data = content_item.data.replace("#histories#", memory_text)
|
||||
else:
|
||||
content_item.data = memory_text + "\n" + content_item.data
|
||||
else:
|
||||
raise ValueError("Invalid prompt content type")
|
||||
|
||||
if sys_query:
|
||||
if isinstance(prompt_content, str):
|
||||
prompt_messages[0].content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
|
||||
elif isinstance(prompt_content, list):
|
||||
for content_item in prompt_content:
|
||||
if isinstance(content_item, TextPromptMessageContent):
|
||||
content_item.data = sys_query + "\n" + content_item.data
|
||||
else:
|
||||
raise ValueError("Invalid prompt content type")
|
||||
else:
|
||||
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
|
||||
|
||||
_append_file_prompts(
|
||||
prompt_messages=prompt_messages,
|
||||
files=sys_files,
|
||||
vision_enabled=vision_enabled,
|
||||
vision_detail=vision_detail,
|
||||
)
|
||||
_append_file_prompts(
|
||||
prompt_messages=prompt_messages,
|
||||
files=context_files or [],
|
||||
vision_enabled=vision_enabled,
|
||||
vision_detail=vision_detail,
|
||||
)
|
||||
|
||||
filtered_prompt_messages: list[PromptMessage] = []
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message_content: list[PromptMessageContentUnionTypes] = []
|
||||
for content_item in prompt_message.content:
|
||||
if not model_schema.features:
|
||||
if content_item.type == PromptMessageContentType.TEXT:
|
||||
prompt_message_content.append(content_item)
|
||||
continue
|
||||
|
||||
if (
|
||||
(
|
||||
content_item.type == PromptMessageContentType.IMAGE
|
||||
and ModelFeature.VISION not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.DOCUMENT
|
||||
and ModelFeature.DOCUMENT not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.VIDEO
|
||||
and ModelFeature.VIDEO not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.AUDIO
|
||||
and ModelFeature.AUDIO not in model_schema.features
|
||||
)
|
||||
):
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
if prompt_message_content:
|
||||
prompt_message.content = prompt_message_content
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
elif not prompt_message.is_empty():
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
|
||||
if len(filtered_prompt_messages) == 0:
|
||||
raise NoPromptFoundError(
|
||||
"No prompt found in the LLM configuration. Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
|
||||
return filtered_prompt_messages, stop
|
||||
|
||||
|
||||
def handle_list_messages(
|
||||
*,
|
||||
messages: Sequence[LLMNodeChatModelMessage],
|
||||
context: str | None,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||
template_renderer: TemplateRenderer | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for message in messages:
|
||||
if message.edition_type == "jinja2":
|
||||
result_text = render_jinja2_message(
|
||||
template=message.jinja2_text or "",
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
prompt_messages.append(
|
||||
combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=result_text)],
|
||||
role=message.role,
|
||||
)
|
||||
)
|
||||
continue
|
||||
|
||||
template = message.text.replace("{#context#}", context) if context else message.text
|
||||
segment_group = variable_pool.convert_template(template)
|
||||
file_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for segment in segment_group.value:
|
||||
if isinstance(segment, ArrayFileSegment):
|
||||
for file in segment.value:
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config)
|
||||
)
|
||||
elif isinstance(segment, FileSegment):
|
||||
file = segment.value
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_contents.append(
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail_config)
|
||||
)
|
||||
|
||||
if segment_group.text:
|
||||
prompt_messages.append(
|
||||
combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=segment_group.text)],
|
||||
role=message.role,
|
||||
)
|
||||
)
|
||||
if file_contents:
|
||||
prompt_messages.append(combine_message_content_with_role(contents=file_contents, role=message.role))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
|
||||
def render_jinja2_message(
|
||||
*,
|
||||
template: str,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
template_renderer: TemplateRenderer | None = None,
|
||||
) -> str:
|
||||
if not template:
|
||||
return ""
|
||||
if template_renderer is None:
|
||||
raise ValueError("template_renderer is required for jinja2 prompt rendering")
|
||||
|
||||
jinja2_inputs: dict[str, Any] = {}
|
||||
for jinja2_variable in jinja2_variables:
|
||||
variable = variable_pool.get(jinja2_variable.value_selector)
|
||||
jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
|
||||
return template_renderer.render_jinja2(template=template, inputs=jinja2_inputs)
|
||||
|
||||
|
||||
def handle_completion_template(
|
||||
*,
|
||||
template: LLMNodeCompletionModelPromptTemplate,
|
||||
context: str | None,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
template_renderer: TemplateRenderer | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
if template.edition_type == "jinja2":
|
||||
result_text = render_jinja2_message(
|
||||
template=template.jinja2_text or "",
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
else:
|
||||
template_text = template.text.replace("{#context#}", context) if context else template.text
|
||||
result_text = variable_pool.convert_template(template_text).text
|
||||
return [
|
||||
combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=result_text)],
|
||||
role=PromptMessageRole.USER,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def combine_message_content_with_role(
|
||||
*,
|
||||
contents: str | list[PromptMessageContentUnionTypes] | None = None,
|
||||
role: PromptMessageRole,
|
||||
) -> PromptMessage:
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
return UserPromptMessage(content=contents)
|
||||
case PromptMessageRole.ASSISTANT:
|
||||
return AssistantPromptMessage(content=contents)
|
||||
case PromptMessageRole.SYSTEM:
|
||||
return SystemPromptMessage(content=contents)
|
||||
case _:
|
||||
raise NotImplementedError(f"Role {role} is not supported")
|
||||
|
||||
|
||||
def calculate_rest_token(*, prompt_messages: list[PromptMessage], model_instance: ModelInstance) -> int:
|
||||
rest_tokens = 2000
|
||||
runtime_model_schema = fetch_model_schema(model_instance=model_instance)
|
||||
runtime_model_parameters = model_instance.parameters
|
||||
|
||||
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in runtime_model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
runtime_model_parameters.get(parameter_rule.name)
|
||||
or runtime_model_parameters.get(str(parameter_rule.use_template))
|
||||
or 0
|
||||
)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
|
||||
def handle_memory_chat_mode(
|
||||
*,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
) -> Sequence[PromptMessage]:
|
||||
if not memory or not memory_config:
|
||||
return []
|
||||
rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance)
|
||||
return memory.get_history_prompt_messages(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
)
|
||||
|
||||
|
||||
def handle_memory_completion_mode(
|
||||
*,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
) -> str:
|
||||
if not memory or not memory_config:
|
||||
return ""
|
||||
|
||||
rest_tokens = calculate_rest_token(prompt_messages=[], model_instance=model_instance)
|
||||
if not memory_config.role_prefix:
|
||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||
|
||||
return fetch_memory_text(
|
||||
memory=memory,
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
human_prefix=memory_config.role_prefix.user,
|
||||
ai_prefix=memory_config.role_prefix.assistant,
|
||||
)
|
||||
|
||||
|
||||
def _append_file_prompts(
|
||||
*,
|
||||
prompt_messages: list[PromptMessage],
|
||||
files: Sequence[File],
|
||||
vision_enabled: bool,
|
||||
vision_detail: ImagePromptMessageContent.DETAIL,
|
||||
) -> None:
|
||||
if not vision_enabled or not files:
|
||||
return
|
||||
|
||||
file_prompts = [file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in files]
|
||||
if (
|
||||
prompt_messages
|
||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||
and isinstance(prompt_messages[-1].content, list)
|
||||
):
|
||||
existing_contents = prompt_messages[-1].content
|
||||
assert isinstance(existing_contents, list)
|
||||
prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
@@ -11,7 +11,6 @@ from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.model_manager import ModelInstance
|
||||
@@ -28,11 +27,10 @@ from dify_graph.enums import (
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from dify_graph.file import File, FileTransferMethod, FileType, file_manager
|
||||
from dify_graph.file import File, FileTransferMethod, FileType
|
||||
from dify_graph.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.llm_entities import (
|
||||
@@ -43,14 +41,7 @@ from dify_graph.model_runtime.entities.llm_entities import (
|
||||
LLMStructuredOutput,
|
||||
LLMUsage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from dify_graph.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from dify_graph.model_runtime.memory import PromptMessageMemory
|
||||
from dify_graph.model_runtime.utils.encoders import jsonable_encoder
|
||||
from dify_graph.node_events import (
|
||||
@@ -64,13 +55,12 @@ from dify_graph.node_events import (
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables import (
|
||||
ArrayFileSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
NoneSegment,
|
||||
ObjectSegment,
|
||||
StringSegment,
|
||||
@@ -89,9 +79,6 @@ from .exc import (
|
||||
InvalidContextStructureError,
|
||||
InvalidVariableTypeError,
|
||||
LLMNodeError,
|
||||
MemoryRolePrefixRequiredError,
|
||||
NoPromptFoundError,
|
||||
TemplateTypeNotSupportError,
|
||||
VariableNotFoundError,
|
||||
)
|
||||
from .file_saver import FileSaverImpl, LLMFileSaver
|
||||
@@ -118,6 +105,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
_model_factory: ModelFactory
|
||||
_model_instance: ModelInstance
|
||||
_memory: PromptMessageMemory | None
|
||||
_template_renderer: TemplateRenderer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -130,6 +118,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
model_factory: ModelFactory,
|
||||
model_instance: ModelInstance,
|
||||
http_client: HttpClientProtocol,
|
||||
template_renderer: TemplateRenderer,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
@@ -146,6 +135,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
self._model_factory = model_factory
|
||||
self._model_instance = model_instance
|
||||
self._memory = memory
|
||||
self._template_renderer = template_renderer
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
@@ -240,6 +230,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
context_files=context_files,
|
||||
template_renderer=self._template_renderer,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
@@ -773,182 +764,24 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool: VariablePool,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
context_files: list[File] | None = None,
|
||||
template_renderer: TemplateRenderer | None = None,
|
||||
) -> tuple[Sequence[PromptMessage], Sequence[str] | None]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
# For chat model
|
||||
prompt_messages.extend(
|
||||
LLMNode.handle_list_messages(
|
||||
messages=prompt_template,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=vision_detail,
|
||||
)
|
||||
)
|
||||
|
||||
# Get memory messages for chat mode
|
||||
memory_messages = _handle_memory_chat_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
# Extend prompt_messages with memory messages
|
||||
prompt_messages.extend(memory_messages)
|
||||
|
||||
# Add current query to the prompt messages
|
||||
if sys_query:
|
||||
message = LLMNodeChatModelMessage(
|
||||
text=sys_query,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
prompt_messages.extend(
|
||||
LLMNode.handle_list_messages(
|
||||
messages=[message],
|
||||
context="",
|
||||
jinja2_variables=[],
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=vision_detail,
|
||||
)
|
||||
)
|
||||
|
||||
elif isinstance(prompt_template, LLMNodeCompletionModelPromptTemplate):
|
||||
# For completion model
|
||||
prompt_messages.extend(
|
||||
_handle_completion_template(
|
||||
template=prompt_template,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
)
|
||||
|
||||
# Get memory text for completion model
|
||||
memory_text = _handle_memory_completion_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
# Insert histories into the prompt
|
||||
prompt_content = prompt_messages[0].content
|
||||
# For issue #11247 - Check if prompt content is a string or a list
|
||||
if isinstance(prompt_content, str):
|
||||
prompt_content = str(prompt_content)
|
||||
if "#histories#" in prompt_content:
|
||||
prompt_content = prompt_content.replace("#histories#", memory_text)
|
||||
else:
|
||||
prompt_content = memory_text + "\n" + prompt_content
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif isinstance(prompt_content, list):
|
||||
for content_item in prompt_content:
|
||||
if isinstance(content_item, TextPromptMessageContent):
|
||||
if "#histories#" in content_item.data:
|
||||
content_item.data = content_item.data.replace("#histories#", memory_text)
|
||||
else:
|
||||
content_item.data = memory_text + "\n" + content_item.data
|
||||
else:
|
||||
raise ValueError("Invalid prompt content type")
|
||||
|
||||
# Add current query to the prompt message
|
||||
if sys_query:
|
||||
if isinstance(prompt_content, str):
|
||||
prompt_content = str(prompt_messages[0].content).replace("#sys.query#", sys_query)
|
||||
prompt_messages[0].content = prompt_content
|
||||
elif isinstance(prompt_content, list):
|
||||
for content_item in prompt_content:
|
||||
if isinstance(content_item, TextPromptMessageContent):
|
||||
content_item.data = sys_query + "\n" + content_item.data
|
||||
else:
|
||||
raise ValueError("Invalid prompt content type")
|
||||
else:
|
||||
raise TemplateTypeNotSupportError(type_name=str(type(prompt_template)))
|
||||
|
||||
# The sys_files will be deprecated later
|
||||
if vision_enabled and sys_files:
|
||||
file_prompts = []
|
||||
for file in sys_files:
|
||||
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||
file_prompts.append(file_prompt)
|
||||
# If last prompt is a user prompt, add files into its contents,
|
||||
# otherwise append a new user prompt
|
||||
if (
|
||||
len(prompt_messages) > 0
|
||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||
and isinstance(prompt_messages[-1].content, list)
|
||||
):
|
||||
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
# The context_files
|
||||
if vision_enabled and context_files:
|
||||
file_prompts = []
|
||||
for file in context_files:
|
||||
file_prompt = file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||
file_prompts.append(file_prompt)
|
||||
# If last prompt is a user prompt, add files into its contents,
|
||||
# otherwise append a new user prompt
|
||||
if (
|
||||
len(prompt_messages) > 0
|
||||
and isinstance(prompt_messages[-1], UserPromptMessage)
|
||||
and isinstance(prompt_messages[-1].content, list)
|
||||
):
|
||||
prompt_messages[-1] = UserPromptMessage(content=file_prompts + prompt_messages[-1].content)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
# Remove empty messages and filter unsupported content
|
||||
filtered_prompt_messages = []
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message_content: list[PromptMessageContentUnionTypes] = []
|
||||
for content_item in prompt_message.content:
|
||||
# Skip content if features are not defined
|
||||
if not model_schema.features:
|
||||
if content_item.type != PromptMessageContentType.TEXT:
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
continue
|
||||
|
||||
# Skip content if corresponding feature is not supported
|
||||
if (
|
||||
(
|
||||
content_item.type == PromptMessageContentType.IMAGE
|
||||
and ModelFeature.VISION not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.DOCUMENT
|
||||
and ModelFeature.DOCUMENT not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.VIDEO
|
||||
and ModelFeature.VIDEO not in model_schema.features
|
||||
)
|
||||
or (
|
||||
content_item.type == PromptMessageContentType.AUDIO
|
||||
and ModelFeature.AUDIO not in model_schema.features
|
||||
)
|
||||
):
|
||||
continue
|
||||
prompt_message_content.append(content_item)
|
||||
if len(prompt_message_content) == 1 and prompt_message_content[0].type == PromptMessageContentType.TEXT:
|
||||
prompt_message.content = prompt_message_content[0].data
|
||||
else:
|
||||
prompt_message.content = prompt_message_content
|
||||
if prompt_message.is_empty():
|
||||
continue
|
||||
filtered_prompt_messages.append(prompt_message)
|
||||
|
||||
if len(filtered_prompt_messages) == 0:
|
||||
raise NoPromptFoundError(
|
||||
"No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
|
||||
return filtered_prompt_messages, stop
|
||||
return llm_utils.fetch_prompt_messages(
|
||||
sys_query=sys_query,
|
||||
sys_files=sys_files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_instance=model_instance,
|
||||
prompt_template=prompt_template,
|
||||
stop=stop,
|
||||
memory_config=memory_config,
|
||||
vision_enabled=vision_enabled,
|
||||
vision_detail=vision_detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=jinja2_variables,
|
||||
context_files=context_files,
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
@@ -1048,59 +881,16 @@ class LLMNode(Node[LLMNodeData]):
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
vision_detail_config: ImagePromptMessageContent.DETAIL,
|
||||
template_renderer: TemplateRenderer | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for message in messages:
|
||||
if message.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=message.jinja2_text or "",
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=result_text)], role=message.role
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
else:
|
||||
# Get segment group from basic message
|
||||
if context:
|
||||
template = message.text.replace("{#context#}", context)
|
||||
else:
|
||||
template = message.text
|
||||
segment_group = variable_pool.convert_template(template)
|
||||
|
||||
# Process segments for images
|
||||
file_contents = []
|
||||
for segment in segment_group.value:
|
||||
if isinstance(segment, ArrayFileSegment):
|
||||
for file in segment.value:
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
elif isinstance(segment, FileSegment):
|
||||
file = segment.value
|
||||
if file.type in {FileType.IMAGE, FileType.VIDEO, FileType.AUDIO, FileType.DOCUMENT}:
|
||||
file_content = file_manager.to_prompt_message_content(
|
||||
file, image_detail_config=vision_detail_config
|
||||
)
|
||||
file_contents.append(file_content)
|
||||
|
||||
# Create message with text from all segments
|
||||
plain_text = segment_group.text
|
||||
if plain_text:
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=plain_text)], role=message.role
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
if file_contents:
|
||||
# Create message with image contents
|
||||
prompt_message = _combine_message_content_with_role(contents=file_contents, role=message.role)
|
||||
prompt_messages.append(prompt_message)
|
||||
|
||||
return prompt_messages
|
||||
return llm_utils.handle_list_messages(
|
||||
messages=messages,
|
||||
context=context,
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=vision_detail_config,
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def handle_blocking_result(
|
||||
@@ -1239,152 +1029,3 @@ class LLMNode(Node[LLMNodeData]):
|
||||
@property
|
||||
def model_instance(self) -> ModelInstance:
|
||||
return self._model_instance
|
||||
|
||||
|
||||
def _combine_message_content_with_role(
|
||||
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole
|
||||
):
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
return UserPromptMessage(content=contents)
|
||||
case PromptMessageRole.ASSISTANT:
|
||||
return AssistantPromptMessage(content=contents)
|
||||
case PromptMessageRole.SYSTEM:
|
||||
return SystemPromptMessage(content=contents)
|
||||
case _:
|
||||
raise NotImplementedError(f"Role {role} is not supported")
|
||||
|
||||
|
||||
def _render_jinja2_message(
|
||||
*,
|
||||
template: str,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
):
|
||||
if not template:
|
||||
return ""
|
||||
|
||||
jinja2_inputs = {}
|
||||
for jinja2_variable in jinja2_variables:
|
||||
variable = variable_pool.get(jinja2_variable.value_selector)
|
||||
jinja2_inputs[jinja2_variable.variable] = variable.to_object() if variable else ""
|
||||
code_execute_resp = CodeExecutor.execute_workflow_code_template(
|
||||
language=CodeLanguage.JINJA2,
|
||||
code=template,
|
||||
inputs=jinja2_inputs,
|
||||
)
|
||||
result_text = code_execute_resp["result"]
|
||||
return result_text
|
||||
|
||||
|
||||
def _calculate_rest_token(
|
||||
*,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_instance: ModelInstance,
|
||||
) -> int:
|
||||
rest_tokens = 2000
|
||||
runtime_model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
runtime_model_parameters = model_instance.parameters
|
||||
|
||||
model_context_tokens = runtime_model_schema.model_properties.get(ModelPropertyKey.CONTEXT_SIZE)
|
||||
if model_context_tokens:
|
||||
curr_message_tokens = model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
max_tokens = 0
|
||||
for parameter_rule in runtime_model_schema.parameter_rules:
|
||||
if parameter_rule.name == "max_tokens" or (
|
||||
parameter_rule.use_template and parameter_rule.use_template == "max_tokens"
|
||||
):
|
||||
max_tokens = (
|
||||
runtime_model_parameters.get(parameter_rule.name)
|
||||
or runtime_model_parameters.get(str(parameter_rule.use_template))
|
||||
or 0
|
||||
)
|
||||
|
||||
rest_tokens = model_context_tokens - max_tokens - curr_message_tokens
|
||||
rest_tokens = max(rest_tokens, 0)
|
||||
|
||||
return rest_tokens
|
||||
|
||||
|
||||
def _handle_memory_chat_mode(
|
||||
*,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
) -> Sequence[PromptMessage]:
|
||||
memory_messages: Sequence[PromptMessage] = []
|
||||
# Get messages from memory for chat model
|
||||
if memory and memory_config:
|
||||
rest_tokens = _calculate_rest_token(
|
||||
prompt_messages=[],
|
||||
model_instance=model_instance,
|
||||
)
|
||||
memory_messages = memory.get_history_prompt_messages(
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
)
|
||||
return memory_messages
|
||||
|
||||
|
||||
def _handle_memory_completion_mode(
|
||||
*,
|
||||
memory: PromptMessageMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
) -> str:
|
||||
memory_text = ""
|
||||
# Get history text from memory for completion model
|
||||
if memory and memory_config:
|
||||
rest_tokens = _calculate_rest_token(
|
||||
prompt_messages=[],
|
||||
model_instance=model_instance,
|
||||
)
|
||||
if not memory_config.role_prefix:
|
||||
raise MemoryRolePrefixRequiredError("Memory role prefix is required for completion model.")
|
||||
memory_text = llm_utils.fetch_memory_text(
|
||||
memory=memory,
|
||||
max_token_limit=rest_tokens,
|
||||
message_limit=memory_config.window.size if memory_config.window.enabled else None,
|
||||
human_prefix=memory_config.role_prefix.user,
|
||||
ai_prefix=memory_config.role_prefix.assistant,
|
||||
)
|
||||
return memory_text
|
||||
|
||||
|
||||
def _handle_completion_template(
|
||||
*,
|
||||
template: LLMNodeCompletionModelPromptTemplate,
|
||||
context: str | None,
|
||||
jinja2_variables: Sequence[VariableSelector],
|
||||
variable_pool: VariablePool,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""Handle completion template processing outside of LLMNode class.
|
||||
|
||||
Args:
|
||||
template: The completion model prompt template
|
||||
context: Optional context string
|
||||
jinja2_variables: Variables for jinja2 template rendering
|
||||
variable_pool: Variable pool for template conversion
|
||||
|
||||
Returns:
|
||||
Sequence of prompt messages
|
||||
"""
|
||||
prompt_messages = []
|
||||
if template.edition_type == "jinja2":
|
||||
result_text = _render_jinja2_message(
|
||||
template=template.jinja2_text or "",
|
||||
jinja2_variables=jinja2_variables,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
else:
|
||||
if context:
|
||||
template_text = template.text.replace("{#context#}", context)
|
||||
else:
|
||||
template_text = template.text
|
||||
result_text = variable_pool.convert_template(template_text).text
|
||||
prompt_message = _combine_message_content_with_role(
|
||||
contents=[TextPromptMessageContent(data=result_text)], role=PromptMessageRole.USER
|
||||
)
|
||||
prompt_messages.append(prompt_message)
|
||||
return prompt_messages
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
@@ -19,3 +20,11 @@ class ModelFactory(Protocol):
|
||||
def init_model_instance(self, provider_name: str, model_name: str) -> ModelInstance:
|
||||
"""Create a model instance that is ready for schema lookup and invocation."""
|
||||
...
|
||||
|
||||
|
||||
class TemplateRenderer(Protocol):
|
||||
"""Port for rendering prompt templates used by LLM-compatible nodes."""
|
||||
|
||||
def render_jinja2(self, *, template: str, inputs: Mapping[str, Any]) -> str:
|
||||
"""Render the given Jinja2 template into plain text."""
|
||||
...
|
||||
|
||||
@@ -28,7 +28,7 @@ from dify_graph.nodes.llm import (
|
||||
llm_utils,
|
||||
)
|
||||
from dify_graph.nodes.llm.file_saver import FileSaverImpl, LLMFileSaver
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
|
||||
@@ -59,6 +59,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
_model_factory: "ModelFactory"
|
||||
_model_instance: ModelInstance
|
||||
_memory: PromptMessageMemory | None
|
||||
_template_renderer: TemplateRenderer
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -71,6 +72,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
model_factory: "ModelFactory",
|
||||
model_instance: ModelInstance,
|
||||
http_client: HttpClientProtocol,
|
||||
template_renderer: TemplateRenderer,
|
||||
memory: PromptMessageMemory | None = None,
|
||||
llm_file_saver: LLMFileSaver | None = None,
|
||||
):
|
||||
@@ -87,6 +89,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
self._model_factory = model_factory
|
||||
self._model_instance = model_instance
|
||||
self._memory = memory
|
||||
self._template_renderer = template_renderer
|
||||
|
||||
if llm_file_saver is None:
|
||||
dify_ctx = self.require_dify_context()
|
||||
@@ -142,7 +145,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
# If both self._get_prompt_template and self._fetch_prompt_messages append a user prompt,
|
||||
# two consecutive user prompts will be generated, causing model's error.
|
||||
# To avoid this, set sys_query to an empty string so that only one user prompt is appended at the end.
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
prompt_messages, stop = llm_utils.fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
sys_query="",
|
||||
memory=memory,
|
||||
@@ -153,6 +156,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=[],
|
||||
template_renderer=self._template_renderer,
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
@@ -287,7 +291,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
|
||||
prompt_template = self._get_prompt_template(node_data, query, None, 2000)
|
||||
prompt_messages, _ = LLMNode.fetch_prompt_messages(
|
||||
prompt_messages, _ = llm_utils.fetch_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
sys_query="",
|
||||
sys_files=[],
|
||||
@@ -300,6 +304,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
vision_detail=node_data.vision.configs.detail,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
jinja2_variables=[],
|
||||
template_renderer=self._template_renderer,
|
||||
)
|
||||
rest_tokens = 2000
|
||||
|
||||
|
||||
@@ -0,0 +1,68 @@
|
||||
"""add indexes for human_input_forms query patterns
|
||||
|
||||
Revision ID: 0ec65df55790
|
||||
Revises: e288952f2994
|
||||
Create Date: 2026-03-02 18:05:00.000000
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "0ec65df55790"
|
||||
down_revision = "e288952f2994"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
with op.batch_alter_table("human_input_forms", schema=None) as batch_op:
|
||||
batch_op.create_index(
|
||||
"human_input_forms_workflow_run_id_node_id_idx",
|
||||
["workflow_run_id", "node_id"],
|
||||
unique=False,
|
||||
)
|
||||
batch_op.create_index(
|
||||
"human_input_forms_status_created_at_idx",
|
||||
["status", "created_at"],
|
||||
unique=False,
|
||||
)
|
||||
batch_op.create_index(
|
||||
"human_input_forms_status_expiration_time_idx",
|
||||
["status", "expiration_time"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
with op.batch_alter_table("human_input_form_deliveries", schema=None) as batch_op:
|
||||
batch_op.create_index(
|
||||
batch_op.f("human_input_form_deliveries_form_id_idx"),
|
||||
["form_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
with op.batch_alter_table("human_input_form_recipients", schema=None) as batch_op:
|
||||
batch_op.create_index(
|
||||
batch_op.f("human_input_form_recipients_delivery_id_idx"),
|
||||
["delivery_id"],
|
||||
unique=False,
|
||||
)
|
||||
batch_op.create_index(
|
||||
batch_op.f("human_input_form_recipients_form_id_idx"),
|
||||
["form_id"],
|
||||
unique=False,
|
||||
)
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table("human_input_forms", schema=None) as batch_op:
|
||||
batch_op.drop_index("human_input_forms_workflow_run_id_node_id_idx")
|
||||
batch_op.drop_index("human_input_forms_status_expiration_time_idx")
|
||||
batch_op.drop_index("human_input_forms_status_created_at_idx")
|
||||
|
||||
with op.batch_alter_table("human_input_form_recipients", schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f("human_input_form_recipients_form_id_idx"))
|
||||
batch_op.drop_index(batch_op.f("human_input_form_recipients_delivery_id_idx"))
|
||||
|
||||
with op.batch_alter_table("human_input_form_deliveries", schema=None) as batch_op:
|
||||
batch_op.drop_index(batch_op.f("human_input_form_deliveries_form_id_idx"))
|
||||
@@ -30,6 +30,15 @@ def _generate_token() -> str:
|
||||
|
||||
class HumanInputForm(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "human_input_forms"
|
||||
__table_args__ = (
|
||||
sa.Index(
|
||||
"human_input_forms_workflow_run_id_node_id_idx",
|
||||
"workflow_run_id",
|
||||
"node_id",
|
||||
),
|
||||
sa.Index("human_input_forms_status_expiration_time_idx", "status", "expiration_time"),
|
||||
sa.Index("human_input_forms_status_created_at_idx", "status", "created_at"),
|
||||
)
|
||||
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
@@ -84,6 +93,12 @@ class HumanInputForm(DefaultFieldsMixin, Base):
|
||||
|
||||
class HumanInputDelivery(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "human_input_form_deliveries"
|
||||
__table_args__ = (
|
||||
sa.Index(
|
||||
None,
|
||||
"form_id",
|
||||
),
|
||||
)
|
||||
|
||||
form_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
@@ -181,6 +196,10 @@ RecipientPayload = Annotated[
|
||||
|
||||
class HumanInputFormRecipient(DefaultFieldsMixin, Base):
|
||||
__tablename__ = "human_input_form_recipients"
|
||||
__table_args__ = (
|
||||
sa.Index(None, "form_id"),
|
||||
sa.Index(None, "delivery_id"),
|
||||
)
|
||||
|
||||
form_id: Mapped[str] = mapped_column(
|
||||
StringUUID,
|
||||
|
||||
@@ -6,12 +6,12 @@ requires-python = ">=3.11,<3.13"
|
||||
dependencies = [
|
||||
"aliyun-log-python-sdk~=0.9.37",
|
||||
"arize-phoenix-otel~=0.15.0",
|
||||
"azure-identity==1.25.2",
|
||||
"beautifulsoup4==4.12.2",
|
||||
"boto3==1.42.65",
|
||||
"azure-identity==1.25.3",
|
||||
"beautifulsoup4==4.14.3",
|
||||
"boto3==1.42.68",
|
||||
"bs4~=0.0.1",
|
||||
"cachetools~=5.3.0",
|
||||
"celery~=5.5.2",
|
||||
"celery~=5.6.2",
|
||||
"charset-normalizer>=3.4.4",
|
||||
"flask~=3.1.2",
|
||||
"flask-compress>=1.17,<1.24",
|
||||
@@ -35,12 +35,12 @@ dependencies = [
|
||||
"jsonschema>=4.25.1",
|
||||
"langfuse~=2.51.3",
|
||||
"langsmith~=0.7.16",
|
||||
"markdown~=3.8.1",
|
||||
"markdown~=3.10.2",
|
||||
"mlflow-skinny>=3.0.0",
|
||||
"numpy~=1.26.4",
|
||||
"openpyxl~=3.1.5",
|
||||
"opik~=1.10.37",
|
||||
"litellm==1.82.1", # Pinned to avoid madoka dependency issue
|
||||
"litellm==1.82.2", # Pinned to avoid madoka dependency issue
|
||||
"opentelemetry-api==1.28.0",
|
||||
"opentelemetry-distro==0.49b0",
|
||||
"opentelemetry-exporter-otlp==1.28.0",
|
||||
@@ -58,7 +58,7 @@ dependencies = [
|
||||
"opentelemetry-sdk==1.28.0",
|
||||
"opentelemetry-semantic-conventions==0.49b0",
|
||||
"opentelemetry-util-http==0.49b0",
|
||||
"pandas[excel,output-formatting,performance]~=2.2.2",
|
||||
"pandas[excel,output-formatting,performance]~=3.0.1",
|
||||
"psycogreen~=1.0.2",
|
||||
"psycopg2-binary~=2.9.6",
|
||||
"pycryptodome==3.23.0",
|
||||
@@ -66,22 +66,22 @@ dependencies = [
|
||||
"pydantic-extra-types~=2.11.0",
|
||||
"pydantic-settings~=2.13.1",
|
||||
"pyjwt~=2.12.0",
|
||||
"pypdfium2==5.2.0",
|
||||
"pypdfium2==5.6.0",
|
||||
"python-docx~=1.2.0",
|
||||
"python-dotenv==1.0.1",
|
||||
"python-dotenv==1.2.2",
|
||||
"pyyaml~=6.0.1",
|
||||
"readabilipy~=0.3.0",
|
||||
"redis[hiredis]~=7.3.0",
|
||||
"resend~=2.9.0",
|
||||
"sentry-sdk[flask]~=2.28.0",
|
||||
"resend~=2.23.0",
|
||||
"sentry-sdk[flask]~=2.54.0",
|
||||
"sqlalchemy~=2.0.29",
|
||||
"starlette==0.49.1",
|
||||
"starlette==0.52.1",
|
||||
"tiktoken~=0.12.0",
|
||||
"transformers~=5.3.0",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.18.18",
|
||||
"yarl~=1.18.3",
|
||||
"unstructured[docx,epub,md,ppt,pptx]~=0.21.5",
|
||||
"yarl~=1.23.0",
|
||||
"webvtt-py~=0.5.1",
|
||||
"sseclient-py~=1.8.0",
|
||||
"sseclient-py~=1.9.0",
|
||||
"httpx-sse~=0.4.0",
|
||||
"sendgrid~=6.12.3",
|
||||
"flask-restx~=1.3.2",
|
||||
@@ -111,7 +111,7 @@ package = false
|
||||
dev = [
|
||||
"coverage~=7.13.4",
|
||||
"dotenv-linter~=0.7.0",
|
||||
"faker~=40.8.0",
|
||||
"faker~=40.11.0",
|
||||
"lxml-stubs~=0.5.1",
|
||||
"basedpyright~=1.38.2",
|
||||
"ruff~=0.15.5",
|
||||
@@ -120,7 +120,7 @@ dev = [
|
||||
"pytest-cov~=7.0.0",
|
||||
"pytest-env~=1.1.3",
|
||||
"pytest-mock~=3.15.1",
|
||||
"testcontainers~=4.13.2",
|
||||
"testcontainers~=4.14.1",
|
||||
"types-aiofiles~=25.1.0",
|
||||
"types-beautifulsoup4~=4.12.0",
|
||||
"types-cachetools~=6.2.0",
|
||||
|
||||
@@ -6,6 +6,13 @@ from typing import Any
|
||||
import httpx
|
||||
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
from services.errors.enterprise import (
|
||||
EnterpriseAPIBadRequestError,
|
||||
EnterpriseAPIError,
|
||||
EnterpriseAPIForbiddenError,
|
||||
EnterpriseAPINotFoundError,
|
||||
EnterpriseAPIUnauthorizedError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -64,10 +71,51 @@ class BaseRequest:
|
||||
request_kwargs["timeout"] = timeout
|
||||
|
||||
response = client.request(method, url, **request_kwargs)
|
||||
if raise_for_status:
|
||||
response.raise_for_status()
|
||||
|
||||
# Validate HTTP status and raise domain-specific errors
|
||||
if not response.is_success:
|
||||
cls._handle_error_response(response)
|
||||
return response.json()
|
||||
|
||||
@classmethod
|
||||
def _handle_error_response(cls, response: httpx.Response) -> None:
|
||||
"""
|
||||
Handle non-2xx HTTP responses by raising appropriate domain errors.
|
||||
|
||||
Attempts to extract error message from JSON response body,
|
||||
falls back to status text if parsing fails.
|
||||
"""
|
||||
error_message = f"Enterprise API request failed: {response.status_code} {response.reason_phrase}"
|
||||
|
||||
# Try to extract error message from JSON response
|
||||
try:
|
||||
error_data = response.json()
|
||||
if isinstance(error_data, dict):
|
||||
# Common error response formats:
|
||||
# {"error": "...", "message": "..."}
|
||||
# {"message": "..."}
|
||||
# {"detail": "..."}
|
||||
error_message = (
|
||||
error_data.get("message") or error_data.get("error") or error_data.get("detail") or error_message
|
||||
)
|
||||
except Exception:
|
||||
# If JSON parsing fails, use the default message
|
||||
logger.debug(
|
||||
"Failed to parse error response from enterprise API (status=%s)", response.status_code, exc_info=True
|
||||
)
|
||||
|
||||
# Raise specific error based on status code
|
||||
if response.status_code == 400:
|
||||
raise EnterpriseAPIBadRequestError(error_message)
|
||||
elif response.status_code == 401:
|
||||
raise EnterpriseAPIUnauthorizedError(error_message)
|
||||
elif response.status_code == 403:
|
||||
raise EnterpriseAPIForbiddenError(error_message)
|
||||
elif response.status_code == 404:
|
||||
raise EnterpriseAPINotFoundError(error_message)
|
||||
else:
|
||||
raise EnterpriseAPIError(error_message, status_code=response.status_code)
|
||||
|
||||
|
||||
class EnterpriseRequest(BaseRequest):
|
||||
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
|
||||
|
||||
@@ -1,15 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.enterprise.base import EnterpriseRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0
|
||||
# License status cache configuration
|
||||
LICENSE_STATUS_CACHE_KEY = "enterprise:license:status"
|
||||
VALID_LICENSE_CACHE_TTL = 600 # 10 minutes — valid licenses are stable
|
||||
INVALID_LICENSE_CACHE_TTL = 30 # 30 seconds — short so admin fixes are picked up quickly
|
||||
|
||||
|
||||
class WebAppSettings(BaseModel):
|
||||
@@ -52,7 +63,7 @@ class DefaultWorkspaceJoinResult(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult":
|
||||
def _check_workspace_id_when_joined(self) -> DefaultWorkspaceJoinResult:
|
||||
if self.joined and not self.workspace_id:
|
||||
raise ValueError("workspace_id must be non-empty when joined is True")
|
||||
return self
|
||||
@@ -115,7 +126,6 @@ class EnterpriseService:
|
||||
"/default-workspace/members",
|
||||
json={"account_id": account_id},
|
||||
timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS,
|
||||
raise_for_status=True,
|
||||
)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Invalid response format from enterprise default workspace API")
|
||||
@@ -223,3 +233,64 @@ class EnterpriseService:
|
||||
|
||||
params = {"appId": app_id}
|
||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
|
||||
|
||||
@classmethod
|
||||
def get_cached_license_status(cls) -> LicenseStatus | None:
|
||||
"""Get enterprise license status with Redis caching to reduce HTTP calls.
|
||||
|
||||
Caches valid statuses (active/expiring) for 10 minutes and invalid statuses
|
||||
(inactive/expired/lost) for 30 seconds. The shorter TTL for invalid statuses
|
||||
balances prompt license-fix detection against DoS mitigation — without
|
||||
caching, every request on an expired license would hit the enterprise API.
|
||||
|
||||
Returns:
|
||||
LicenseStatus enum value, or None if enterprise is disabled / unreachable.
|
||||
"""
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
return None
|
||||
|
||||
cached = cls._read_cached_license_status()
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
return cls._fetch_and_cache_license_status()
|
||||
|
||||
@classmethod
|
||||
def _read_cached_license_status(cls) -> LicenseStatus | None:
|
||||
"""Read license status from Redis cache, returning None on miss or failure."""
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
try:
|
||||
raw = redis_client.get(LICENSE_STATUS_CACHE_KEY)
|
||||
if raw:
|
||||
value = raw.decode("utf-8") if isinstance(raw, bytes) else raw
|
||||
return LicenseStatus(value)
|
||||
except Exception:
|
||||
logger.debug("Failed to read license status from cache", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _fetch_and_cache_license_status(cls) -> LicenseStatus | None:
|
||||
"""Fetch license status from enterprise API and cache the result."""
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
try:
|
||||
info = cls.get_info()
|
||||
license_info = info.get("License")
|
||||
if not license_info:
|
||||
return None
|
||||
|
||||
status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
|
||||
ttl = (
|
||||
VALID_LICENSE_CACHE_TTL
|
||||
if status in (LicenseStatus.ACTIVE, LicenseStatus.EXPIRING)
|
||||
else INVALID_LICENSE_CACHE_TTL
|
||||
)
|
||||
try:
|
||||
redis_client.setex(LICENSE_STATUS_CACHE_KEY, ttl, status)
|
||||
except Exception:
|
||||
logger.debug("Failed to cache license status", exc_info=True)
|
||||
return status
|
||||
except Exception:
|
||||
logger.debug("Failed to fetch enterprise license status", exc_info=True)
|
||||
return None
|
||||
|
||||
@@ -70,7 +70,6 @@ class PluginManagerService:
|
||||
"POST",
|
||||
"/pre-uninstall-plugin",
|
||||
json=body.model_dump(),
|
||||
raise_for_status=True,
|
||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||
)
|
||||
except Exception:
|
||||
|
||||
@@ -7,6 +7,7 @@ from . import (
|
||||
conversation,
|
||||
dataset,
|
||||
document,
|
||||
enterprise,
|
||||
file,
|
||||
index,
|
||||
message,
|
||||
@@ -21,6 +22,7 @@ __all__ = [
|
||||
"conversation",
|
||||
"dataset",
|
||||
"document",
|
||||
"enterprise",
|
||||
"file",
|
||||
"index",
|
||||
"message",
|
||||
|
||||
45
api/services/errors/enterprise.py
Normal file
45
api/services/errors/enterprise.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Enterprise service errors."""
|
||||
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
|
||||
class EnterpriseServiceError(BaseServiceError):
|
||||
"""Base exception for enterprise service errors."""
|
||||
|
||||
def __init__(self, description: str | None = None, status_code: int | None = None):
|
||||
super().__init__(description)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class EnterpriseAPIError(EnterpriseServiceError):
|
||||
"""Generic enterprise API error (non-2xx response)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class EnterpriseAPINotFoundError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 404 Not Found."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=404)
|
||||
|
||||
|
||||
class EnterpriseAPIForbiddenError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 403 Forbidden."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=403)
|
||||
|
||||
|
||||
class EnterpriseAPIUnauthorizedError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 401 Unauthorized."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=401)
|
||||
|
||||
|
||||
class EnterpriseAPIBadRequestError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 400 Bad Request."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=400)
|
||||
@@ -379,14 +379,19 @@ class FeatureService:
|
||||
)
|
||||
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
|
||||
|
||||
if is_authenticated and (license_info := enterprise_info.get("License")):
|
||||
# SECURITY NOTE: Only license *status* is exposed to unauthenticated callers
|
||||
# so the login page can detect an expired/inactive license after force-logout.
|
||||
# All other license details (expiry date, workspace usage) remain auth-gated.
|
||||
# This behavior reflects prior internal review of information-leakage risks.
|
||||
if license_info := enterprise_info.get("License"):
|
||||
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
|
||||
features.license.expired_at = license_info.get("expiredAt", "")
|
||||
|
||||
if workspaces_info := license_info.get("workspaces"):
|
||||
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
|
||||
features.license.workspaces.limit = workspaces_info.get("limit", 0)
|
||||
features.license.workspaces.size = workspaces_info.get("used", 0)
|
||||
if is_authenticated:
|
||||
features.license.expired_at = license_info.get("expiredAt", "")
|
||||
if workspaces_info := license_info.get("workspaces"):
|
||||
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
|
||||
features.license.workspaces.limit = workspaces_info.get("limit", 0)
|
||||
features.license.workspaces.size = workspaces_info.get("used", 0)
|
||||
|
||||
if "PluginInstallationPermission" in enterprise_info:
|
||||
plugin_installation_info = enterprise_info["PluginInstallationPermission"]
|
||||
|
||||
@@ -10,7 +10,7 @@ from core.model_manager import ModelInstance
|
||||
from dify_graph.enums import WorkflowNodeExecutionStatus
|
||||
from dify_graph.node_events import StreamCompletedEvent
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
@@ -75,6 +75,7 @@ def init_llm_node(config: dict) -> LLMNode:
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
model_instance=MagicMock(spec=ModelInstance),
|
||||
template_renderer=MagicMock(spec=TemplateRenderer),
|
||||
http_client=MagicMock(spec=HttpClientProtocol),
|
||||
)
|
||||
|
||||
@@ -158,7 +159,7 @@ def test_execute_llm():
|
||||
return mock_model_instance
|
||||
|
||||
# Mock fetch_prompt_messages to avoid database calls
|
||||
def mock_fetch_prompt_messages_1(**_kwargs):
|
||||
def mock_fetch_prompt_messages_1(*_args, **_kwargs):
|
||||
from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
|
||||
|
||||
return [
|
||||
|
||||
@@ -358,10 +358,9 @@ class TestFeatureService:
|
||||
assert result is not None
|
||||
assert isinstance(result, SystemFeatureModel)
|
||||
|
||||
# --- 1. Verify Response Payload Optimization (Data Minimization) ---
|
||||
# Ensure only essential UI flags are returned to unauthenticated clients
|
||||
# to keep the payload lightweight and adhere to architectural boundaries.
|
||||
assert result.license.status == LicenseStatus.NONE
|
||||
# --- 1. Verify only license *status* is exposed to unauthenticated clients ---
|
||||
# Detailed license info (expiry, workspaces) remains auth-gated.
|
||||
assert result.license.status == LicenseStatus.ACTIVE
|
||||
assert result.license.expired_at == ""
|
||||
assert result.license.workspaces.enabled is False
|
||||
assert result.license.workspaces.limit == 0
|
||||
|
||||
@@ -20,7 +20,7 @@ from dify_graph.nodes.code import CodeNode
|
||||
from dify_graph.nodes.document_extractor import DocumentExtractorNode
|
||||
from dify_graph.nodes.http_request import HttpRequestNode
|
||||
from dify_graph.nodes.llm import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.nodes.parameter_extractor import ParameterExtractorNode
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol, ToolFileManagerProtocol
|
||||
from dify_graph.nodes.question_classifier import QuestionClassifierNode
|
||||
@@ -68,6 +68,8 @@ class MockNodeMixin:
|
||||
kwargs.setdefault("model_instance", MagicMock(spec=ModelInstance))
|
||||
# LLM-like nodes now require an http_client; provide a mock by default for tests.
|
||||
kwargs.setdefault("http_client", MagicMock(spec=HttpClientProtocol))
|
||||
if isinstance(self, (LLMNode, QuestionClassifierNode)):
|
||||
kwargs.setdefault("template_renderer", MagicMock(spec=TemplateRenderer))
|
||||
|
||||
# Ensure TemplateTransformNode receives a renderer now required by constructor
|
||||
if isinstance(self, TemplateTransformNode):
|
||||
|
||||
@@ -34,8 +34,8 @@ from dify_graph.nodes.llm.entities import (
|
||||
VisionConfigOptions,
|
||||
)
|
||||
from dify_graph.nodes.llm.file_saver import LLMFileSaver
|
||||
from dify_graph.nodes.llm.node import LLMNode, _handle_memory_completion_mode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory
|
||||
from dify_graph.nodes.llm.node import LLMNode
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.runtime import GraphRuntimeState, VariablePool
|
||||
from dify_graph.system_variable import SystemVariable
|
||||
from dify_graph.variables import ArrayAnySegment, ArrayFileSegment, NoneSegment
|
||||
@@ -107,6 +107,7 @@ def llm_node(
|
||||
mock_file_saver = mock.MagicMock(spec=LLMFileSaver)
|
||||
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
||||
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
||||
mock_template_renderer = mock.MagicMock(spec=TemplateRenderer)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
@@ -121,6 +122,7 @@ def llm_node(
|
||||
model_factory=mock_model_factory,
|
||||
model_instance=mock.MagicMock(spec=ModelInstance),
|
||||
llm_file_saver=mock_file_saver,
|
||||
template_renderer=mock_template_renderer,
|
||||
http_client=http_client,
|
||||
)
|
||||
return node
|
||||
@@ -590,6 +592,33 @@ def test_handle_list_messages_basic(llm_node):
|
||||
assert result[0].content == [TextPromptMessageContent(data="Hello, world")]
|
||||
|
||||
|
||||
def test_handle_list_messages_jinja2_uses_template_renderer(llm_node):
|
||||
llm_node._template_renderer.render_jinja2.return_value = "Hello, world"
|
||||
messages = [
|
||||
LLMNodeChatModelMessage(
|
||||
text="",
|
||||
jinja2_text="Hello, {{ name }}",
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="jinja2",
|
||||
)
|
||||
]
|
||||
|
||||
result = llm_node.handle_list_messages(
|
||||
messages=messages,
|
||||
context=None,
|
||||
jinja2_variables=[],
|
||||
variable_pool=llm_node.graph_runtime_state.variable_pool,
|
||||
vision_detail_config=ImagePromptMessageContent.DETAIL.HIGH,
|
||||
template_renderer=llm_node._template_renderer,
|
||||
)
|
||||
|
||||
assert result == [UserPromptMessage(content=[TextPromptMessageContent(data="Hello, world")])]
|
||||
llm_node._template_renderer.render_jinja2.assert_called_once_with(
|
||||
template="Hello, {{ name }}",
|
||||
inputs={},
|
||||
)
|
||||
|
||||
|
||||
def test_handle_memory_completion_mode_uses_prompt_message_interface():
|
||||
memory = mock.MagicMock(spec=MockTokenBufferMemory)
|
||||
memory.get_history_prompt_messages.return_value = [
|
||||
@@ -613,8 +642,8 @@ def test_handle_memory_completion_mode_uses_prompt_message_interface():
|
||||
window=MemoryConfig.WindowConfig(enabled=True, size=3),
|
||||
)
|
||||
|
||||
with mock.patch("dify_graph.nodes.llm.node._calculate_rest_token", return_value=2000) as mock_rest_token:
|
||||
memory_text = _handle_memory_completion_mode(
|
||||
with mock.patch("dify_graph.nodes.llm.llm_utils.calculate_rest_token", return_value=2000) as mock_rest_token:
|
||||
memory_text = llm_utils.handle_memory_completion_mode(
|
||||
memory=memory,
|
||||
memory_config=memory_config,
|
||||
model_instance=model_instance,
|
||||
@@ -630,6 +659,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||
mock_file_saver: LLMFileSaver = mock.MagicMock(spec=LLMFileSaver)
|
||||
mock_credentials_provider = mock.MagicMock(spec=CredentialsProvider)
|
||||
mock_model_factory = mock.MagicMock(spec=ModelFactory)
|
||||
mock_template_renderer = mock.MagicMock(spec=TemplateRenderer)
|
||||
node_config = {
|
||||
"id": "1",
|
||||
"data": llm_node_data.model_dump(),
|
||||
@@ -644,6 +674,7 @@ def llm_node_for_multimodal(llm_node_data, graph_init_params, graph_runtime_stat
|
||||
model_factory=mock_model_factory,
|
||||
model_instance=mock.MagicMock(spec=ModelInstance),
|
||||
llm_file_saver=mock_file_saver,
|
||||
template_renderer=mock_template_renderer,
|
||||
http_client=http_client,
|
||||
)
|
||||
return node, mock_file_saver
|
||||
|
||||
@@ -1,5 +1,14 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent
|
||||
from dify_graph.nodes.question_classifier import QuestionClassifierNodeData
|
||||
from dify_graph.nodes.llm.protocols import CredentialsProvider, ModelFactory, TemplateRenderer
|
||||
from dify_graph.nodes.protocols import HttpClientProtocol
|
||||
from dify_graph.nodes.question_classifier import (
|
||||
QuestionClassifierNode,
|
||||
QuestionClassifierNodeData,
|
||||
)
|
||||
from tests.workflow_test_utils import build_test_graph_init_params
|
||||
|
||||
|
||||
def test_init_question_classifier_node_data():
|
||||
@@ -65,3 +74,52 @@ def test_init_question_classifier_node_data_without_vision_config():
|
||||
assert node_data.vision.enabled == False
|
||||
assert node_data.vision.configs.variable_selector == ["sys", "files"]
|
||||
assert node_data.vision.configs.detail == ImagePromptMessageContent.DETAIL.HIGH
|
||||
|
||||
|
||||
def test_question_classifier_calculate_rest_token_uses_shared_prompt_builder(monkeypatch):
|
||||
node_data = QuestionClassifierNodeData.model_validate(
|
||||
{
|
||||
"title": "test classifier node",
|
||||
"query_variable_selector": ["id", "name"],
|
||||
"model": {"provider": "openai", "name": "gpt-3.5-turbo", "mode": "completion", "completion_params": {}},
|
||||
"classes": [{"id": "1", "name": "class 1"}],
|
||||
"instruction": "This is a test instruction",
|
||||
}
|
||||
)
|
||||
template_renderer = MagicMock(spec=TemplateRenderer)
|
||||
node = QuestionClassifierNode(
|
||||
id="node-id",
|
||||
config={"id": "node-id", "data": node_data.model_dump(mode="json")},
|
||||
graph_init_params=build_test_graph_init_params(
|
||||
workflow_id="workflow-id",
|
||||
graph_config={},
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
user_id="user-id",
|
||||
),
|
||||
graph_runtime_state=SimpleNamespace(variable_pool=MagicMock()),
|
||||
credentials_provider=MagicMock(spec=CredentialsProvider),
|
||||
model_factory=MagicMock(spec=ModelFactory),
|
||||
model_instance=MagicMock(),
|
||||
http_client=MagicMock(spec=HttpClientProtocol),
|
||||
llm_file_saver=MagicMock(),
|
||||
template_renderer=template_renderer,
|
||||
)
|
||||
fetch_prompt_messages = MagicMock(return_value=([], None))
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_prompt_messages",
|
||||
fetch_prompt_messages,
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"dify_graph.nodes.question_classifier.question_classifier_node.llm_utils.fetch_model_schema",
|
||||
MagicMock(return_value=SimpleNamespace(model_properties={}, parameter_rules=[])),
|
||||
)
|
||||
|
||||
node._calculate_rest_token(
|
||||
node_data=node_data,
|
||||
query="hello",
|
||||
model_instance=MagicMock(stop=(), parameters={}),
|
||||
context="",
|
||||
)
|
||||
|
||||
assert fetch_prompt_messages.call_args.kwargs["template_renderer"] is template_renderer
|
||||
|
||||
@@ -140,6 +140,29 @@ class TestDefaultWorkflowCodeExecutor:
|
||||
assert executor.is_execution_error(RuntimeError("boom")) is False
|
||||
|
||||
|
||||
class TestDefaultLLMTemplateRenderer:
|
||||
def test_render_jinja2_delegates_to_code_executor(self, monkeypatch):
|
||||
renderer = node_factory.DefaultLLMTemplateRenderer()
|
||||
execute_workflow_code_template = MagicMock(return_value={"result": "hello world"})
|
||||
monkeypatch.setattr(
|
||||
node_factory.CodeExecutor,
|
||||
"execute_workflow_code_template",
|
||||
execute_workflow_code_template,
|
||||
)
|
||||
|
||||
result = renderer.render_jinja2(
|
||||
template="Hello {{ name }}",
|
||||
inputs={"name": "world"},
|
||||
)
|
||||
|
||||
assert result == "hello world"
|
||||
execute_workflow_code_template.assert_called_once_with(
|
||||
language=CodeLanguage.JINJA2,
|
||||
code="Hello {{ name }}",
|
||||
inputs={"name": "world"},
|
||||
)
|
||||
|
||||
|
||||
class TestDifyNodeFactoryInit:
|
||||
def test_init_builds_default_dependencies(self):
|
||||
graph_init_params = SimpleNamespace(run_context={"context": "value"})
|
||||
@@ -150,6 +173,7 @@ class TestDifyNodeFactoryInit:
|
||||
http_request_config = sentinel.http_request_config
|
||||
credentials_provider = sentinel.credentials_provider
|
||||
model_factory = sentinel.model_factory
|
||||
llm_template_renderer = sentinel.llm_template_renderer
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
@@ -172,6 +196,11 @@ class TestDifyNodeFactoryInit:
|
||||
"build_http_request_config",
|
||||
return_value=http_request_config,
|
||||
),
|
||||
patch.object(
|
||||
node_factory,
|
||||
"DefaultLLMTemplateRenderer",
|
||||
return_value=llm_template_renderer,
|
||||
) as llm_renderer_factory,
|
||||
patch.object(
|
||||
node_factory,
|
||||
"build_dify_model_access",
|
||||
@@ -186,11 +215,14 @@ class TestDifyNodeFactoryInit:
|
||||
resolve_dify_context.assert_called_once_with(graph_init_params.run_context)
|
||||
build_dify_model_access.assert_called_once_with("tenant-id")
|
||||
renderer_factory.assert_called_once()
|
||||
llm_renderer_factory.assert_called_once()
|
||||
assert renderer_factory.call_args.kwargs["code_executor"] is factory._code_executor
|
||||
assert factory.graph_init_params is graph_init_params
|
||||
assert factory.graph_runtime_state is graph_runtime_state
|
||||
assert factory._dify_context is dify_context
|
||||
assert factory._template_renderer is template_renderer
|
||||
|
||||
assert factory._llm_template_renderer is llm_template_renderer
|
||||
assert factory._document_extractor_unstructured_api_config is unstructured_api_config
|
||||
assert factory._http_request_config is http_request_config
|
||||
assert factory._llm_credentials_provider is credentials_provider
|
||||
@@ -242,6 +274,7 @@ class TestDifyNodeFactoryCreateNode:
|
||||
factory._code_executor = sentinel.code_executor
|
||||
factory._code_limits = sentinel.code_limits
|
||||
factory._template_renderer = sentinel.template_renderer
|
||||
factory._llm_template_renderer = sentinel.llm_template_renderer
|
||||
factory._template_transform_max_output_length = 2048
|
||||
factory._http_request_http_client = sentinel.http_client
|
||||
factory._http_request_tool_file_manager_factory = sentinel.tool_file_manager_factory
|
||||
@@ -378,8 +411,22 @@ class TestDifyNodeFactoryCreateNode:
|
||||
@pytest.mark.parametrize(
|
||||
("node_type", "constructor_name", "expected_extra_kwargs"),
|
||||
[
|
||||
(BuiltinNodeTypes.LLM, "LLMNode", {"http_client": sentinel.http_client}),
|
||||
(BuiltinNodeTypes.QUESTION_CLASSIFIER, "QuestionClassifierNode", {"http_client": sentinel.http_client}),
|
||||
(
|
||||
BuiltinNodeTypes.LLM,
|
||||
"LLMNode",
|
||||
{
|
||||
"http_client": sentinel.http_client,
|
||||
"template_renderer": sentinel.llm_template_renderer,
|
||||
},
|
||||
),
|
||||
(
|
||||
BuiltinNodeTypes.QUESTION_CLASSIFIER,
|
||||
"QuestionClassifierNode",
|
||||
{
|
||||
"http_client": sentinel.http_client,
|
||||
"template_renderer": sentinel.llm_template_renderer,
|
||||
},
|
||||
),
|
||||
(BuiltinNodeTypes.PARAMETER_EXTRACTOR, "ParameterExtractorNode", {}),
|
||||
],
|
||||
)
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
"""Unit tests for enterprise service integrations.
|
||||
|
||||
This module covers the enterprise-only default workspace auto-join behavior:
|
||||
- Enterprise mode disabled: no external calls
|
||||
- Successful join / skipped join: no errors
|
||||
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
|
||||
Covers:
|
||||
- Default workspace auto-join behavior
|
||||
- License status caching (get_cached_license_status)
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
@@ -11,6 +10,9 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from services.enterprise.enterprise_service import (
|
||||
INVALID_LICENSE_CACHE_TTL,
|
||||
LICENSE_STATUS_CACHE_KEY,
|
||||
VALID_LICENSE_CACHE_TTL,
|
||||
DefaultWorkspaceJoinResult,
|
||||
EnterpriseService,
|
||||
try_join_default_workspace,
|
||||
@@ -37,7 +39,6 @@ class TestJoinDefaultWorkspace:
|
||||
"/default-workspace/members",
|
||||
json={"account_id": account_id},
|
||||
timeout=1.0,
|
||||
raise_for_status=True,
|
||||
)
|
||||
|
||||
def test_join_default_workspace_invalid_response_format_raises(self):
|
||||
@@ -139,3 +140,134 @@ class TestTryJoinDefaultWorkspace:
|
||||
|
||||
# Should not raise even though UUID parsing fails inside join_default_workspace
|
||||
try_join_default_workspace("not-a-uuid")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_cached_license_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_EE_SVC = "services.enterprise.enterprise_service"
|
||||
|
||||
|
||||
class TestGetCachedLicenseStatus:
|
||||
"""Tests for EnterpriseService.get_cached_license_status."""
|
||||
|
||||
def test_returns_none_when_enterprise_disabled(self):
|
||||
with patch(f"{_EE_SVC}.dify_config") as mock_config:
|
||||
mock_config.ENTERPRISE_ENABLED = False
|
||||
|
||||
assert EnterpriseService.get_cached_license_status() is None
|
||||
|
||||
def test_cache_hit_returns_license_status_enum(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = b"active"
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.ACTIVE
|
||||
assert isinstance(result, LicenseStatus)
|
||||
mock_get_info.assert_not_called()
|
||||
|
||||
def test_cache_miss_fetches_api_and_caches_valid_status(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_get_info.return_value = {"License": {"status": "active"}}
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.ACTIVE
|
||||
mock_redis.setex.assert_called_once_with(
|
||||
LICENSE_STATUS_CACHE_KEY, VALID_LICENSE_CACHE_TTL, LicenseStatus.ACTIVE
|
||||
)
|
||||
|
||||
def test_cache_miss_fetches_api_and_caches_invalid_status_with_short_ttl(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_get_info.return_value = {"License": {"status": "expired"}}
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.EXPIRED
|
||||
mock_redis.setex.assert_called_once_with(
|
||||
LICENSE_STATUS_CACHE_KEY, INVALID_LICENSE_CACHE_TTL, LicenseStatus.EXPIRED
|
||||
)
|
||||
|
||||
def test_redis_read_failure_falls_through_to_api(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.side_effect = ConnectionError("redis down")
|
||||
mock_get_info.return_value = {"License": {"status": "active"}}
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.ACTIVE
|
||||
mock_get_info.assert_called_once()
|
||||
|
||||
def test_redis_write_failure_still_returns_status(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.setex.side_effect = ConnectionError("redis down")
|
||||
mock_get_info.return_value = {"License": {"status": "expiring"}}
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.EXPIRING
|
||||
|
||||
def test_api_failure_returns_none(self):
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_get_info.side_effect = Exception("network failure")
|
||||
|
||||
assert EnterpriseService.get_cached_license_status() is None
|
||||
|
||||
def test_api_returns_no_license_info(self):
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_get_info.return_value = {} # no "License" key
|
||||
|
||||
assert EnterpriseService.get_cached_license_status() is None
|
||||
mock_redis.setex.assert_not_called()
|
||||
|
||||
@@ -34,7 +34,6 @@ class TestTryPreUninstallPlugin:
|
||||
"POST",
|
||||
"/pre-uninstall-plugin",
|
||||
json={"tenant_id": "tenant-123", "plugin_unique_identifier": "com.example.my_plugin"},
|
||||
raise_for_status=True,
|
||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||
)
|
||||
|
||||
@@ -62,7 +61,6 @@ class TestTryPreUninstallPlugin:
|
||||
"POST",
|
||||
"/pre-uninstall-plugin",
|
||||
json={"tenant_id": "tenant-456", "plugin_unique_identifier": "com.example.other_plugin"},
|
||||
raise_for_status=True,
|
||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||
)
|
||||
mock_logger.exception.assert_called_once()
|
||||
@@ -87,7 +85,6 @@ class TestTryPreUninstallPlugin:
|
||||
"POST",
|
||||
"/pre-uninstall-plugin",
|
||||
json={"tenant_id": "tenant-789", "plugin_unique_identifier": "com.example.failing_plugin"},
|
||||
raise_for_status=True,
|
||||
timeout=dify_config.ENTERPRISE_REQUEST_TIMEOUT,
|
||||
)
|
||||
mock_logger.exception.assert_called_once()
|
||||
|
||||
790
api/uv.lock
generated
790
api/uv.lock
generated
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user