Compare commits

...

7 Commits

28 changed files with 518 additions and 702 deletions

View File

@@ -29,6 +29,8 @@ ignore_imports =
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph_engine
core.workflow.nodes.iteration.iteration_node -> core.workflow.graph
@@ -107,14 +109,12 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> core.tools.tool_manager
core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
core.workflow.nodes.llm.llm_utils -> configs
core.workflow.nodes.llm.llm_utils -> core.model_manager
core.workflow.nodes.llm.protocols -> core.model_manager
core.workflow.nodes.llm.llm_utils -> core.model_runtime.model_providers.__base.large_language_model
core.workflow.nodes.llm.llm_utils -> models.model
core.workflow.nodes.llm.llm_utils -> models.provider
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
core.workflow.nodes.llm.node -> core.tools.signature
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
@@ -135,8 +135,8 @@ ignore_imports =
core.workflow.nodes.start.start_node -> core.app.app_config.entities
core.workflow.workflow_entry -> core.app.apps.exc
core.workflow.workflow_entry -> core.app.entities.app_invoke_entities
core.workflow.workflow_entry -> core.app.workflow.layers.llm_quota
core.workflow.workflow_entry -> core.app.workflow.node_factory
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
core.workflow.nodes.tool.tool_node -> core.tools.utils.message_transformer
@@ -180,7 +180,7 @@ ignore_imports =
core.workflow.workflow_entry -> extensions.otel.runtime
core.workflow.nodes.agent.agent_node -> models
core.workflow.nodes.base.node -> models.enums
core.workflow.nodes.llm.llm_utils -> models.provider_ids
core.workflow.nodes.loop.loop_node -> core.app.workflow.layers.llm_quota
core.workflow.nodes.llm.node -> models.model
core.workflow.workflow_entry -> models.enums
core.workflow.nodes.agent.agent_node -> services

View File

@@ -1 +1,5 @@
"""LLM-related application services."""
from .quota import deduct_llm_quota, ensure_llm_quota_available
__all__ = ["deduct_llm_quota", "ensure_llm_quota_available"]

93
api/core/app/llm/quota.py Normal file
View File

@@ -0,0 +1,93 @@
from sqlalchemy import update
from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.model_entities import ModelStatus
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
def ensure_llm_quota_available(*, model_instance: ModelInstance) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
provider_model = provider_configuration.get_provider_model(
model_type=model_instance.model_type_instance.model_type,
model=model_instance.model_name,
)
if provider_model and provider_model.status == ModelStatus.QUOTA_EXCEEDED:
raise QuotaExceededError(f"Model provider {model_instance.provider} quota exceeded.")
def deduct_llm_quota(*, tenant_id: str, model_instance: ModelInstance, usage: LLMUsage) -> None:
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model_name)
else:
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

View File

@@ -1,9 +1,11 @@
"""Workflow-level GraphEngine layers that depend on outer infrastructure."""
from .llm_quota import LLMQuotaLayer
from .observability import ObservabilityLayer
from .persistence import PersistenceWorkflowInfo, WorkflowPersistenceLayer
__all__ = [
"LLMQuotaLayer",
"ObservabilityLayer",
"PersistenceWorkflowInfo",
"WorkflowPersistenceLayer",

View File

@@ -0,0 +1,128 @@
"""
LLM quota deduction layer for GraphEngine.
This layer centralizes model-quota deduction outside node implementations.
"""
import logging
from typing import TYPE_CHECKING, cast, final
from typing_extensions import override
from core.app.llm import deduct_llm_quota, ensure_llm_quota_available
from core.errors.error import QuotaExceededError
from core.model_manager import ModelInstance
from core.workflow.enums import NodeType
from core.workflow.graph_engine.entities.commands import AbortCommand, CommandType
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import GraphEngineEvent, GraphNodeEventBase
from core.workflow.graph_events.node import NodeRunSucceededEvent
from core.workflow.nodes.base.node import Node
if TYPE_CHECKING:
from core.workflow.nodes.llm.node import LLMNode
from core.workflow.nodes.parameter_extractor.parameter_extractor_node import ParameterExtractorNode
from core.workflow.nodes.question_classifier.question_classifier_node import QuestionClassifierNode
logger = logging.getLogger(__name__)
@final
class LLMQuotaLayer(GraphEngineLayer):
"""Graph layer that applies LLM quota deduction after node execution."""
def __init__(self) -> None:
super().__init__()
self._abort_sent = False
@override
def on_graph_start(self) -> None:
self._abort_sent = False
@override
def on_event(self, event: GraphEngineEvent) -> None:
_ = event
@override
def on_graph_end(self, error: Exception | None) -> None:
_ = error
@override
def on_node_run_start(self, node: Node) -> None:
if self._abort_sent:
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
return
try:
ensure_llm_quota_available(model_instance=model_instance)
except QuotaExceededError as exc:
self._set_stop_event(node)
self._send_abort_command(reason=str(exc))
logger.warning("LLM quota check failed, node_id=%s, error=%s", node.id, exc)
@override
def on_node_run_end(
self, node: Node, error: Exception | None, result_event: GraphNodeEventBase | None = None
) -> None:
if error is not None or not isinstance(result_event, NodeRunSucceededEvent):
return
model_instance = self._extract_model_instance(node)
if model_instance is None:
return
try:
deduct_llm_quota(
tenant_id=node.tenant_id,
model_instance=model_instance,
usage=result_event.node_run_result.llm_usage,
)
except QuotaExceededError as exc:
self._set_stop_event(node)
self._send_abort_command(reason=str(exc))
logger.warning("LLM quota deduction exceeded, node_id=%s, error=%s", node.id, exc)
except Exception:
logger.exception("LLM quota deduction failed, node_id=%s", node.id)
@staticmethod
def _set_stop_event(node: Node) -> None:
stop_event = getattr(node.graph_runtime_state, "stop_event", None)
if stop_event is not None:
stop_event.set()
def _send_abort_command(self, *, reason: str) -> None:
if not self.command_channel or self._abort_sent:
return
try:
self.command_channel.send_command(
AbortCommand(
command_type=CommandType.ABORT,
reason=reason,
)
)
self._abort_sent = True
except Exception:
logger.exception("Failed to send quota abort command")
@staticmethod
def _extract_model_instance(node: Node) -> ModelInstance | None:
try:
match node.node_type:
case NodeType.LLM:
return cast("LLMNode", node).model_instance
case NodeType.PARAMETER_EXTRACTOR:
return cast("ParameterExtractorNode", node).model_instance
case NodeType.QUESTION_CLASSIFIER:
return cast("QuestionClassifierNode", node).model_instance
case _:
return None
except AttributeError:
logger.warning(
"LLMQuotaLayer skipped quota deduction because node does not expose a model instance, node_id=%s",
node.id,
)
return None

View File

@@ -2,6 +2,7 @@ import tempfile
from binascii import hexlify, unhexlify
from collections.abc import Generator
from core.app.llm import deduct_llm_quota
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
from core.model_manager import ModelManager
from core.model_runtime.entities.llm_entities import (
@@ -29,7 +30,6 @@ from core.plugin.entities.request import (
)
from core.tools.entities.tool_entities import ToolProviderType
from core.tools.utils.model_invocation_utils import ModelInvocationUtils
from core.workflow.nodes.llm import llm_utils
from models.account import Tenant
@@ -63,16 +63,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
def handle() -> Generator[LLMResultChunk, None, None]:
for chunk in response:
if chunk.delta.usage:
llm_utils.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
chunk.prompt_messages = []
yield chunk
return handle()
else:
if response.usage:
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(response: LLMResult) -> Generator[LLMResultChunk, None, None]:
yield LLMResultChunk(
@@ -126,16 +124,14 @@ class PluginModelBackwardsInvocation(BaseBackwardsInvocation):
def handle() -> Generator[LLMResultChunkWithStructuredOutput, None, None]:
for chunk in response:
if chunk.delta.usage:
llm_utils.deduct_llm_quota(
tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage
)
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=chunk.delta.usage)
chunk.prompt_messages = []
yield chunk
return handle()
else:
if response.usage:
llm_utils.deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
deduct_llm_quota(tenant_id=tenant.id, model_instance=model_instance, usage=response.usage)
def handle_non_streaming(
response: LLMResultWithStructuredOutput,

View File

@@ -8,6 +8,7 @@ from typing import Any, cast
logger = logging.getLogger(__name__)
from core.app.llm import deduct_llm_quota
from core.entities.knowledge_entities import PreviewDetail
from core.llm_generator.prompts import DEFAULT_GENERATOR_SUMMARY_PROMPT
from core.model_manager import ModelInstance
@@ -35,7 +36,6 @@ from core.rag.models.document import AttachmentDocument, Document, MultimodalGen
from core.rag.retrieval.retrieval_methods import RetrievalMethod
from core.tools.utils.text_processing_utils import remove_leading_symbols
from core.workflow.file import File, FileTransferMethod, FileType, file_manager
from core.workflow.nodes.llm import llm_utils
from extensions.ext_database import db
from factories.file_factory import build_from_mapping
from libs import helper
@@ -474,7 +474,7 @@ class ParagraphIndexProcessor(BaseIndexProcessor):
# Deduct quota for summary generation (same as workflow nodes)
try:
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
except Exception as e:
# Log but don't fail summary generation if quota deduction fails
logger.warning("Failed to deduct quota for summary generation: %s", str(e))

View File

@@ -2,6 +2,7 @@ from collections.abc import Generator, Sequence
from typing import Union
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
from core.app.llm import deduct_llm_quota
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
@@ -9,7 +10,6 @@ from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
from core.rag.retrieval.output_parser.react_output import ReactAction
from core.rag.retrieval.output_parser.structured_chat import StructuredChatOutputParser
from core.workflow.nodes.llm import llm_utils
PREFIX = """Respond to the human as helpfully and accurately as possible. You have access to the following tools:"""
@@ -162,7 +162,7 @@ class ReactMultiDatasetRouter:
text, usage = self._handle_invoke_result(invoke_result=invoke_result)
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
deduct_llm_quota(tenant_id=tenant_id, model_instance=model_instance, usage=usage)
return text, usage

View File

@@ -9,7 +9,6 @@ from __future__ import annotations
import logging
import queue
import threading
from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
@@ -77,13 +76,10 @@ class GraphEngine:
config: GraphEngineConfig = _DEFAULT_CONFIG,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
# stop event
self._stop_event = threading.Event()
# Bind runtime state to current workflow context
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_runtime_state.stop_event = self._stop_event
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
self._command_channel = command_channel
self._config = config
@@ -163,7 +159,6 @@ class GraphEngine:
layers=self._layers,
execution_context=execution_context,
config=self._config,
stop_event=self._stop_event,
)
# === Orchestration ===
@@ -194,7 +189,6 @@ class GraphEngine:
event_handler=self._event_handler_registry,
execution_coordinator=self._execution_coordinator,
event_emitter=self._event_manager,
stop_event=self._stop_event,
)
# === Validation ===
@@ -314,7 +308,6 @@ class GraphEngine:
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
self._stop_event.clear()
paused_nodes: list[str] = []
deferred_nodes: list[str] = []
if resume:
@@ -348,7 +341,6 @@ class GraphEngine:
def _stop_execution(self) -> None:
"""Stop execution subsystems."""
self._stop_event.set()
self._dispatcher.stop()
self._worker_pool.stop()
# Don't mark complete here as the dispatcher already does it

View File

@@ -44,7 +44,6 @@ class Dispatcher:
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandler",
execution_coordinator: ExecutionCoordinator,
stop_event: threading.Event,
event_emitter: EventManager | None = None,
) -> None:
"""
@@ -62,7 +61,7 @@ class Dispatcher:
self._event_emitter = event_emitter
self._thread: threading.Thread | None = None
self._stop_event = stop_event
self._stop_event = threading.Event()
self._start_time: float | None = None
def start(self) -> None:
@@ -70,12 +69,14 @@ class Dispatcher:
if self._thread and self._thread.is_alive():
return
self._stop_event.clear()
self._start_time = time.time()
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
self._thread.start()
def stop(self) -> None:
"""Stop the dispatcher thread."""
self._stop_event.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=2.0)

View File

@@ -42,7 +42,6 @@ class Worker(threading.Thread):
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: Sequence[GraphEngineLayer],
stop_event: threading.Event,
worker_id: int = 0,
execution_context: IExecutionContext | None = None,
) -> None:
@@ -63,16 +62,13 @@ class Worker(threading.Thread):
self._graph = graph
self._worker_id = worker_id
self._execution_context = execution_context
self._stop_event = stop_event
self._stop_event = threading.Event()
self._layers = layers if layers is not None else []
self._last_task_time = time.time()
def stop(self) -> None:
"""Worker is controlled via shared stop_event from GraphEngine.
This method is a no-op retained for backward compatibility.
"""
pass
"""Signal the worker to stop processing."""
self._stop_event.set()
@property
def is_idle(self) -> bool:

View File

@@ -37,7 +37,6 @@ class WorkerPool:
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: list[GraphEngineLayer],
stop_event: threading.Event,
config: GraphEngineConfig,
execution_context: IExecutionContext | None = None,
) -> None:
@@ -64,7 +63,6 @@ class WorkerPool:
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False
self._stop_event = stop_event
# No longer tracking worker states with callbacks to avoid lock contention
@@ -135,7 +133,6 @@ class WorkerPool:
layers=self._layers,
worker_id=worker_id,
execution_context=self._execution_context,
stop_event=self._stop_event,
)
worker.start()

View File

@@ -302,10 +302,6 @@ class Node(Generic[NodeDataT]):
"""
raise NotImplementedError
def _should_stop(self) -> bool:
"""Check if execution should be stopped."""
return self.graph_runtime_state.stop_event.is_set()
def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
@@ -374,21 +370,6 @@ class Node(Generic[NodeDataT]):
yield event
else:
yield event
if self._should_stop():
error_message = "Execution cancelled"
yield NodeRunFailedEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=error_message,
),
error=error_message,
)
return
except Exception as e:
logger.exception("Node %s failed to run", self._node_id)
result = NodeRunResult(

View File

@@ -588,6 +588,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
def _create_graph_engine(self, index: int, item: object):
# Import dependencies
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
@@ -642,5 +643,6 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(),
)
graph_engine.layer(LLMQuotaLayer())
return graph_engine

View File

@@ -1,14 +1,11 @@
from collections.abc import Sequence
from typing import cast
from sqlalchemy import select, update
from sqlalchemy import select
from sqlalchemy.orm import Session
from configs import dify_config
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.entities.model_entities import AIModelEntity
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
@@ -17,10 +14,7 @@ from core.workflow.file.models import File
from core.workflow.runtime import VariablePool
from core.workflow.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import Conversation
from models.provider import Provider, ProviderType
from models.provider_ids import ModelProviderID
from .exc import InvalidVariableTypeError
@@ -68,68 +62,3 @@ def fetch_memory(
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
return memory
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
provider_model_bundle = model_instance.provider_model_bundle
provider_configuration = provider_model_bundle.configuration
if provider_configuration.using_provider_type != ProviderType.SYSTEM:
return
system_configuration = provider_configuration.system_configuration
quota_unit = None
for quota_configuration in system_configuration.quota_configurations:
if quota_configuration.quota_type == system_configuration.current_quota_type:
quota_unit = quota_configuration.quota_unit
if quota_configuration.quota_limit == -1:
return
break
used_quota = None
if quota_unit:
if quota_unit == QuotaUnit.TOKENS:
used_quota = usage.total_tokens
elif quota_unit == QuotaUnit.CREDITS:
used_quota = dify_config.get_model_credits(model_instance.model_name)
else:
used_quota = 1
if used_quota is not None and system_configuration.current_quota_type is not None:
if system_configuration.current_quota_type == ProviderQuotaType.TRIAL:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
)
elif system_configuration.current_quota_type == ProviderQuotaType.PAID:
from services.credit_pool_service import CreditPoolService
CreditPoolService.check_and_deduct_credits(
tenant_id=tenant_id,
credits_required=used_quota,
pool_type="paid",
)
else:
with Session(db.engine) as session:
stmt = (
update(Provider)
.where(
Provider.tenant_id == tenant_id,
# TODO: Use provider name with prefix after the data migration.
Provider.provider_name == ModelProviderID(model_instance.provider).provider_name,
Provider.provider_type == ProviderType.SYSTEM.value,
Provider.quota_type == system_configuration.current_quota_type.value,
Provider.quota_limit > Provider.quota_used,
)
.values(
quota_used=Provider.quota_used + used_quota,
last_used=naive_utc_now(),
)
)
session.execute(stmt)
session.commit()

View File

@@ -278,8 +278,6 @@ class LLMNode(Node[LLMNodeData]):
else None
)
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
break
elif isinstance(event, LLMStructuredOutput):
structured_output = event
@@ -1234,6 +1232,10 @@ class LLMNode(Node[LLMNodeData]):
def retry(self) -> bool:
return self.node_data.retry_config.retry_enabled
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
def _combine_message_content_with_role(
*, contents: str | list[PromptMessageContentUnionTypes] | None = None, role: PromptMessageRole

View File

@@ -413,6 +413,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
def _create_graph_engine(self, start_at: datetime, root_node_id: str):
# Import dependencies
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.graph import Graph
@@ -454,5 +455,6 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
command_channel=InMemoryChannel(), # Use InMemoryChannel for sub-graphs
config=GraphEngineConfig(),
)
graph_engine.layer(LLMQuotaLayer())
return graph_engine

View File

@@ -308,9 +308,6 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
usage = invoke_result.usage
tool_call = invoke_result.message.tool_calls[0] if invoke_result.message.tool_calls else None
# deduct quota
llm_utils.deduct_llm_quota(tenant_id=self.tenant_id, model_instance=model_instance, usage=usage)
return text, usage, tool_call
def _generate_function_call_prompt(
@@ -828,6 +825,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
return rest_tokens
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@@ -240,6 +240,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
llm_usage=usage,
)
@property
def model_instance(self) -> ModelInstance:
return self._model_instance
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import importlib
import json
import threading
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
@@ -219,8 +218,6 @@ class GraphRuntimeState:
self._pending_graph_node_states: dict[str, NodeState] | None = None
self._pending_graph_edge_states: dict[str, NodeState] | None = None
self.stop_event: threading.Event = threading.Event()
if graph is not None:
self.attach_graph(graph)

View File

@@ -6,6 +6,7 @@ from typing import Any, cast
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.app.workflow.layers.observability import ObservabilityLayer
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.constants import ENVIRONMENT_VARIABLE_NODE_ID
@@ -106,6 +107,7 @@ class WorkflowEntry:
max_steps=dify_config.WORKFLOW_MAX_EXECUTION_STEPS, max_time=dify_config.WORKFLOW_MAX_EXECUTION_TIME
)
self.graph_engine.layer(limits_layer)
self.graph_engine.layer(LLMQuotaLayer())
# Add observability layer when OTel is enabled
if dify_config.ENABLE_OTEL or is_instrument_flag_enabled():

View File

@@ -1,3 +1,4 @@
import json
import logging
import time
@@ -125,7 +126,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
data_source_info = document.data_source_info_dict
data_source_info["last_edited_time"] = last_edited_time
document.data_source_info = data_source_info
document.data_source_info = json.dumps(data_source_info)
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()

View File

@@ -12,8 +12,6 @@ from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from psycopg2.extensions import register_adapter
from psycopg2.extras import Json
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
@@ -21,12 +19,6 @@ from models.dataset import Dataset, Document, DocumentSegment
from tasks.document_indexing_sync_task import document_indexing_sync_task
@pytest.fixture(autouse=True)
def _register_dict_adapter_for_psycopg2():
"""Align test DB adapter behavior with dict payloads used in task update flow."""
register_adapter(dict, Json)
class DocumentIndexingSyncTaskTestDataFactory:
"""Create real DB entities for document indexing sync integration tests."""

View File

@@ -0,0 +1,174 @@
import threading
from datetime import datetime
from unittest.mock import MagicMock, patch
from core.app.workflow.layers.llm_quota import LLMQuotaLayer
from core.errors.error import QuotaExceededError
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.graph_engine.entities.commands import CommandType
from core.workflow.graph_events.node import NodeRunSucceededEvent
from core.workflow.node_events import NodeRunResult
def _build_succeeded_event() -> NodeRunSucceededEvent:
return NodeRunSucceededEvent(
id="execution-id",
node_id="llm-node-id",
node_type=NodeType.LLM,
start_at=datetime.now(),
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
inputs={"question": "hello"},
llm_usage=LLMUsage.empty_usage(),
),
)
def test_deduct_quota_called_for_successful_llm_node() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.LLM
node.tenant_id = "tenant-id"
node.model_instance = object()
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=node.model_instance,
usage=result_event.node_run_result.llm_usage,
)
def test_deduct_quota_called_for_question_classifier_node() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "question-classifier-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.QUESTION_CLASSIFIER
node.tenant_id = "tenant-id"
node.model_instance = object()
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_called_once_with(
tenant_id="tenant-id",
model_instance=node.model_instance,
usage=result_event.node_run_result.llm_usage,
)
def test_non_llm_node_is_ignored() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "start-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.START
node.tenant_id = "tenant-id"
node._model_instance = object()
result_event = _build_succeeded_event()
with patch("core.app.workflow.layers.llm_quota.deduct_llm_quota", autospec=True) as mock_deduct:
layer.on_node_run_end(node=node, error=None, result_event=result_event)
mock_deduct.assert_not_called()
def test_quota_error_is_handled_in_layer() -> None:
layer = LLMQuotaLayer()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.LLM
node.tenant_id = "tenant-id"
node.model_instance = object()
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=ValueError("quota exceeded"),
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
def test_quota_deduction_exceeded_aborts_workflow_immediately() -> None:
layer = LLMQuotaLayer()
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.execution_id = "execution-id"
node.node_type = NodeType.LLM
node.tenant_id = "tenant-id"
node.model_instance = object()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
result_event = _build_succeeded_event()
with patch(
"core.app.workflow.layers.llm_quota.deduct_llm_quota",
autospec=True,
side_effect=QuotaExceededError("No credits remaining"),
):
layer.on_node_run_end(node=node, error=None, result_event=result_event)
assert stop_event.is_set()
layer.command_channel.send_command.assert_called_once()
abort_command = layer.command_channel.send_command.call_args.args[0]
assert abort_command.command_type == CommandType.ABORT
assert abort_command.reason == "No credits remaining"
def test_quota_precheck_failure_aborts_workflow_immediately() -> None:
layer = LLMQuotaLayer()
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = NodeType.LLM
node.model_instance = object()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch(
"core.app.workflow.layers.llm_quota.ensure_llm_quota_available",
autospec=True,
side_effect=QuotaExceededError("Model provider openai quota exceeded."),
):
layer.on_node_run_start(node)
assert stop_event.is_set()
layer.command_channel.send_command.assert_called_once()
abort_command = layer.command_channel.send_command.call_args.args[0]
assert abort_command.command_type == CommandType.ABORT
assert abort_command.reason == "Model provider openai quota exceeded."
def test_quota_precheck_passes_without_abort() -> None:
layer = LLMQuotaLayer()
stop_event = threading.Event()
layer.command_channel = MagicMock()
node = MagicMock()
node.id = "llm-node-id"
node.node_type = NodeType.LLM
node.model_instance = object()
node.graph_runtime_state = MagicMock()
node.graph_runtime_state.stop_event = stop_event
with patch("core.app.workflow.layers.llm_quota.ensure_llm_quota_available", autospec=True) as mock_check:
layer.on_node_run_start(node)
assert not stop_event.is_set()
mock_check.assert_called_once_with(model_instance=node.model_instance)
layer.command_channel.send_command.assert_not_called()

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import queue
import threading
from unittest import mock
from core.workflow.entities.pause_reason import SchedulingPause
@@ -37,7 +36,6 @@ def test_dispatcher_should_consume_remains_events_after_pause():
event_queue=event_queue,
event_handler=event_handler,
execution_coordinator=execution_coordinator,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()
assert event_queue.empty()
@@ -98,7 +96,6 @@ def _run_dispatcher_for_event(event) -> int:
event_queue=event_queue,
event_handler=event_handler,
execution_coordinator=coordinator,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()
@@ -184,7 +181,6 @@ def test_dispatcher_drain_event_queue():
event_queue=event_queue,
event_handler=event_handler,
execution_coordinator=coordinator,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()

View File

@@ -1,5 +1,4 @@
import queue
import threading
from datetime import datetime
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@@ -65,7 +64,6 @@ def test_dispatcher_drains_events_when_paused() -> None:
event_handler=handler,
execution_coordinator=coordinator,
event_emitter=None,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()

View File

@@ -1,550 +0,0 @@
"""
Unit tests for stop_event functionality in GraphEngine.
Tests the unified stop_event management by GraphEngine and its propagation
to WorkerPool, Worker, Dispatcher, and Nodes.
"""
import threading
import time
from unittest.mock import MagicMock, Mock, patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import (
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunStartedEvent,
)
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
class TestStopEventPropagation:
"""Test suite for stop_event propagation through GraphEngine components."""
def test_graph_engine_creates_stop_event(self):
"""Test that GraphEngine creates a stop_event on initialization."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Verify stop_event was created
assert engine._stop_event is not None
assert isinstance(engine._stop_event, threading.Event)
# Verify it was set in graph_runtime_state
assert runtime_state.stop_event is not None
assert runtime_state.stop_event is engine._stop_event
def test_stop_event_cleared_on_start(self):
"""Test that stop_event is cleared when execution starts."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start" # Set proper id
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Set the stop_event before running
engine._stop_event.set()
assert engine._stop_event.is_set()
# Run the engine (should clear the stop_event)
events = list(engine.run())
# After running, stop_event should be set again (by _stop_execution)
# But during start it was cleared
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
def test_stop_event_set_on_stop(self):
"""Test that stop_event is set when execution stops."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start" # Set proper id
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Initially not set
assert not engine._stop_event.is_set()
# Run the engine
list(engine.run())
# After execution completes, stop_event should be set
assert engine._stop_event.is_set()
def test_stop_event_passed_to_worker_pool(self):
"""Test that stop_event is passed to WorkerPool."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Verify WorkerPool has the stop_event
assert engine._worker_pool._stop_event is not None
assert engine._worker_pool._stop_event is engine._stop_event
def test_stop_event_passed_to_dispatcher(self):
"""Test that stop_event is passed to Dispatcher."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Verify Dispatcher has the stop_event
assert engine._dispatcher._stop_event is not None
assert engine._dispatcher._stop_event is engine._stop_event
class TestNodeStopCheck:
"""Test suite for Node._should_stop() functionality."""
def test_node_should_stop_checks_runtime_state(self):
"""Test that Node._should_stop() checks GraphRuntimeState.stop_event."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
answer_node = AnswerNode(
id="answer",
config={"id": "answer", "data": {"title": "answer", "answer": "{{#start.result#}}"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
# Initially stop_event is not set
assert not answer_node._should_stop()
# Set the stop_event
runtime_state.stop_event.set()
# Now _should_stop should return True
assert answer_node._should_stop()
def test_node_run_checks_stop_event_between_yields(self):
"""Test that Node.run() checks stop_event between yielding events."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
# Create a simple node
answer_node = AnswerNode(
id="answer",
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
# Set stop_event BEFORE running the node
runtime_state.stop_event.set()
# Run the node - should yield start event then detect stop
# The node should check stop_event before processing
assert answer_node._should_stop(), "stop_event should be set"
# Run and collect events
events = list(answer_node.run())
# Since stop_event is set at the start, we should get:
# 1. NodeRunStartedEvent (always yielded first)
# 2. Either NodeRunFailedEvent (if detected early) or NodeRunSucceededEvent (if too fast)
assert len(events) >= 2
assert isinstance(events[0], NodeRunStartedEvent)
# Note: AnswerNode is very simple and might complete before stop check
# The important thing is that _should_stop() returns True when stop_event is set
assert answer_node._should_stop()
class TestStopEventIntegration:
"""Integration tests for stop_event in workflow execution."""
def test_simple_workflow_respects_stop_event(self):
"""Test that a simple workflow respects stop_event."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start"
# Create start and answer nodes
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
answer_node = AnswerNode(
id="answer",
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.nodes["answer"] = answer_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Set stop_event before running
runtime_state.stop_event.set()
# Run the engine
events = list(engine.run())
# Should get started event but not succeeded (due to stop)
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
# The workflow should still complete (start node runs quickly)
# but answer node might be cancelled depending on timing
def test_stop_event_with_concurrent_nodes(self):
"""Test stop_event behavior with multiple concurrent nodes."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
# Create multiple nodes
for i in range(3):
answer_node = AnswerNode(
id=f"answer_{i}",
config={"id": f"answer_{i}", "data": {"title": f"answer_{i}", "answer": f"test{i}"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes[f"answer_{i}"] = answer_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# All nodes should share the same stop_event
for node in mock_graph.nodes.values():
assert node.graph_runtime_state.stop_event is runtime_state.stop_event
assert node.graph_runtime_state.stop_event is engine._stop_event
class TestStopEventTimeoutBehavior:
"""Test stop_event behavior with join timeouts."""
@patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread", autospec=True)
def test_dispatcher_uses_shorter_timeout(self, mock_thread_cls: MagicMock):
"""Test that Dispatcher uses 2s timeout instead of 10s."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
dispatcher = engine._dispatcher
dispatcher.start() # This will create and start the mocked thread
mock_thread_instance = mock_thread_cls.return_value
mock_thread_instance.is_alive.return_value = True
dispatcher.stop()
mock_thread_instance.join.assert_called_once_with(timeout=2.0)
@patch("core.workflow.graph_engine.worker_management.worker_pool.Worker", autospec=True)
def test_worker_pool_uses_shorter_timeout(self, mock_worker_cls: MagicMock):
"""Test that WorkerPool uses 2s timeout instead of 10s."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
worker_pool = engine._worker_pool
worker_pool.start(initial_count=1) # Start with one worker
mock_worker_instance = mock_worker_cls.return_value
mock_worker_instance.is_alive.return_value = True
worker_pool.stop()
mock_worker_instance.join.assert_called_once_with(timeout=2.0)
class TestStopEventResumeBehavior:
"""Test stop_event behavior during workflow resume."""
def test_stop_event_cleared_on_resume(self):
"""Test that stop_event is cleared when resuming a paused workflow."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start" # Set proper id
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Simulate a previous execution that set stop_event
engine._stop_event.set()
assert engine._stop_event.is_set()
# Run the engine (should clear stop_event in _start_execution)
events = list(engine.run())
# Execution should complete successfully
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
class TestWorkerStopBehavior:
"""Test Worker behavior with shared stop_event."""
def test_worker_uses_shared_stop_event(self):
"""Test that Worker uses shared stop_event from GraphEngine."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Get the worker pool and check workers
worker_pool = engine._worker_pool
# Start the worker pool to create workers
worker_pool.start()
# Check that at least one worker was created
assert len(worker_pool._workers) > 0
# Verify workers use the shared stop_event
for worker in worker_pool._workers:
assert worker._stop_event is engine._stop_event
# Clean up
worker_pool.stop()
def test_worker_stop_is_noop(self):
"""Test that Worker.stop() is now a no-op."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
# Create a mock worker
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
from core.workflow.graph_engine.worker import Worker
ready_queue = InMemoryReadyQueue()
event_queue = MagicMock()
# Create a proper mock graph with real dict
mock_graph = Mock(spec=Graph)
mock_graph.nodes = {} # Use real dict
stop_event = threading.Event()
worker = Worker(
ready_queue=ready_queue,
event_queue=event_queue,
graph=mock_graph,
layers=[],
stop_event=stop_event,
)
# Calling stop() should do nothing (no-op)
# and should NOT set the stop_event
worker.stop()
assert not stop_event.is_set()

View File

@@ -5,6 +5,7 @@ These tests intentionally stay in unit scope because they validate call argument
for external collaborators rather than SQL-backed state transitions.
"""
import json
import uuid
from unittest.mock import MagicMock, Mock, patch
@@ -196,3 +197,78 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
class TestDataSourceInfoSerialization:
"""Regression test: data_source_info must be written as a JSON string, not a raw dict.
See https://github.com/langgenius/dify/issues/32705
psycopg2 raises ``ProgrammingError: can't adapt type 'dict'`` when a Python
dict is passed directly to a text/LongText column.
"""
def test_data_source_info_serialized_as_json_string(
self,
mock_document,
mock_dataset,
dataset_id,
document_id,
):
"""data_source_info must be serialized with json.dumps before DB write."""
with (
patch("tasks.document_indexing_sync_task.session_factory") as mock_session_factory,
patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class,
patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class,
patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_ipf,
patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_runner_class,
):
# External collaborators
mock_service = MagicMock()
mock_service.get_datasource_credentials.return_value = {"integration_secret": "token"}
mock_service_class.return_value = mock_service
mock_extractor = MagicMock()
# Return a *different* timestamp so the task enters the sync/update branch
mock_extractor.get_notion_last_edited_time.return_value = "2024-02-01T00:00:00Z"
mock_extractor_class.return_value = mock_extractor
mock_ip = MagicMock()
mock_ipf.return_value.init_index_processor.return_value = mock_ip
mock_runner = MagicMock()
mock_runner_class.return_value = mock_runner
# DB session mock — shared across all ``session_factory.create_session()`` calls
session = MagicMock()
session.scalars.return_value.all.return_value = []
# .where() path: session 1 reads document + dataset, session 2 reads dataset
session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
mock_dataset,
]
# .filter_by() path: session 3 (update), session 4 (indexing)
session.query.return_value.filter_by.return_value.first.side_effect = [
mock_document,
mock_document,
]
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
begin_cm.__exit__.return_value = False
session.begin.return_value = begin_cm
session_cm = MagicMock()
session_cm.__enter__.return_value = session
session_cm.__exit__.return_value = False
mock_session_factory.create_session.return_value = session_cm
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert: data_source_info must be a JSON *string*, not a dict
assert isinstance(mock_document.data_source_info, str), (
f"data_source_info should be a JSON string, got {type(mock_document.data_source_info).__name__}"
)
parsed = json.loads(mock_document.data_source_info)
assert parsed["last_edited_time"] == "2024-02-01T00:00:00Z"