Compare commits

...

15 Commits

Author SHA1 Message Date
Yanli 盐粒
8548498f25 fix(api): restore advanced chat refresh_model contract 2026-03-18 19:41:00 +08:00
Yanli 盐粒
d014f0b91a fix(api): address typing review feedback 2026-03-18 19:16:48 +08:00
Yanli 盐粒
cc5aac268a fix(api): support tool typed dicts on py311 2026-03-18 18:59:49 +08:00
Yanli 盐粒
4c1d27431b fix(api): restore workflow node compatibility 2026-03-18 18:43:35 +08:00
Yanli 盐粒
9a86f280eb fix(api): avoid recursive loop type adapters 2026-03-18 18:20:43 +08:00
Yanli 盐粒
c5920fb28a Merge remote-tracking branch 'origin/main' into yanli/phase3-code-scope 2026-03-18 17:52:03 +08:00
Yanli 盐粒
2f81d5dfdf fix(api): restore typedict py311 compatibility 2026-03-17 20:30:18 +08:00
Yanli 盐粒
7639d8e43f fix(api): reuse advanced chat refresh session 2026-03-17 20:18:21 +08:00
Yanli 盐粒
1dce81c604 refactor(api): type single node workflow helpers 2026-03-17 20:16:14 +08:00
Yanli 盐粒
f874ca183e chore(api): remove phase 3 pyrefly excludes 2026-03-17 20:04:55 +08:00
Yanli 盐粒
0d805e624e Type phase 3 loop values 2026-03-17 19:39:54 +08:00
Yanli 盐粒
61196180b8 Type phase 3 tool inputs 2026-03-17 19:31:00 +08:00
Yanli 盐粒
79433b0091 Refine phase 3 typing boundaries 2026-03-17 19:13:12 +08:00
Yanli 盐粒
c4aeaa35d4 Type phase 3 schema contracts 2026-03-17 18:56:22 +08:00
Yanli 盐粒
9f0d79b8b0 Tighten phase 3 runtime typing 2026-03-17 18:49:14 +08:00
34 changed files with 729 additions and 369 deletions

View File

@@ -5,7 +5,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
from typing import TYPE_CHECKING, Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@@ -47,7 +47,6 @@ from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.base import Base
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.workflow_draft_variable_service import (
@@ -522,8 +521,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# release database connection, because the following new thread operations may take a long time
with Session(bind=db.engine, expire_on_commit=False) as session:
workflow = _refresh_model(session, workflow)
message = _refresh_model(session, message)
workflow = _refresh_model(session=session, model=workflow)
message = _refresh_model(session=session, model=message)
assert message is not None
# workflow_ = session.get(Workflow, workflow.id)
# assert workflow_ is not None
# workflow = workflow_
@@ -690,11 +690,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
raise e
_T = TypeVar("_T", bound=Base)
@overload
def _refresh_model(*, session: Session | None = None, model: Workflow) -> Workflow: ...
def _refresh_model(session, model: _T) -> _T:
with Session(bind=db.engine, expire_on_commit=False) as session:
detach_model = session.get(type(model), model.id)
assert detach_model is not None
return detach_model
@overload
def _refresh_model(*, session: Session | None = None, model: Message) -> Message: ...
def _refresh_model(*, session: Session | None = None, model: Any) -> Any:
if session is not None:
detached_model = session.get(type(model), model.id)
assert detached_model is not None
return detached_model
with Session(bind=db.engine, expire_on_commit=False) as refresh_session:
detached_model = refresh_session.get(type(model), model.id)
assert detached_model is not None
return detached_model

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -56,8 +56,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, Any, None]:
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -87,8 +87,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, Any, None]:
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -55,7 +55,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -86,7 +86,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -1,7 +1,7 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping
from typing import Any, Union
from collections.abc import Generator, Iterator, Mapping
from typing import Any
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
@@ -16,24 +16,26 @@ class AppGenerateResponseConverter(ABC):
@classmethod
def convert(
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
cls, response: AppBlockingResponse | Iterator[AppStreamResponse], invoke_from: InvokeFrom
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
else:
stream_response = response
def _generate_full_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_full_response(response)
def _generate_full_response() -> Generator[dict[str, Any] | str, None, None]:
yield from cls.convert_stream_full_response(stream_response)
return _generate_full_response()
else:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response)
else:
stream_response = response
def _generate_simple_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_simple_response(response)
def _generate_simple_response() -> Generator[dict[str, Any] | str, None, None]:
yield from cls.convert_stream_simple_response(stream_response)
return _generate_simple_response()
@@ -50,14 +52,14 @@ class AppGenerateResponseConverter(ABC):
@classmethod
@abstractmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
raise NotImplementedError

View File

@@ -224,6 +224,7 @@ class BaseAppGenerator:
def _get_draft_var_saver_factory(invoke_from: InvokeFrom, account: Account | EndUser) -> DraftVariableSaverFactory:
if invoke_from == InvokeFrom.DEBUGGER:
assert isinstance(account, Account)
debug_account = account
def draft_var_saver_factory(
session: Session,
@@ -240,7 +241,7 @@ class BaseAppGenerator:
node_type=node_type,
node_execution_id=node_execution_id,
enclosing_node_id=enclosing_node_id,
user=account,
user=debug_account,
)
else:

View File

@@ -166,15 +166,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
assert conversation is not None
assert message is not None
generated_conversation_id = str(conversation.id)
generated_message_id = str(message.id)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
conversation_id=generated_conversation_id,
app_mode=conversation.mode,
message_id=message.id,
message_id=generated_message_id,
)
# new thread with request context
@@ -184,8 +188,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation_id=conversation.id,
message_id=message.id,
conversation_id=generated_conversation_id,
message_id=generated_message_id,
)
worker_thread = threading.Thread(target=worker_with_context)

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -55,7 +55,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -86,7 +86,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -149,6 +149,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
assert conversation is not None
assert message is not None
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@@ -312,15 +314,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
assert conversation is not None
assert message is not None
conversation_id = str(conversation.id)
message_id = str(message.id)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
conversation_id=conversation_id,
app_mode=conversation.mode,
message_id=message.id,
message_id=message_id,
)
# new thread with request context
@@ -330,7 +336,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message_id=message.id,
message_id=message_id,
)
worker_thread = threading.Thread(target=worker_with_context)

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -54,7 +54,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -84,7 +84,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -1,13 +1,17 @@
import logging
import time
from collections.abc import Mapping, Sequence
from typing import Any, cast
from typing import Protocol, TypeAlias
from pydantic import ValidationError
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
from core.app.entities.app_invoke_entities import (
InvokeFrom,
UserFrom,
build_dify_run_context,
)
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
@@ -36,7 +40,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.graph import Graph
from dify_graph.graph_engine.layers.base import GraphEngineLayer
@@ -75,6 +79,14 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
logger = logging.getLogger(__name__)
GraphConfigObject: TypeAlias = dict[str, object]
GraphConfigMapping: TypeAlias = Mapping[str, object]
class SingleNodeRunEntity(Protocol):
node_id: str
inputs: Mapping[str, object]
class WorkflowBasedAppRunner:
def __init__(
@@ -98,7 +110,7 @@ class WorkflowBasedAppRunner:
def _init_graph(
self,
graph_config: Mapping[str, Any],
graph_config: GraphConfigMapping,
graph_runtime_state: GraphRuntimeState,
user_from: UserFrom,
invoke_from: InvokeFrom,
@@ -154,8 +166,8 @@ class WorkflowBasedAppRunner:
def _prepare_single_node_execution(
self,
workflow: Workflow,
single_iteration_run: Any | None = None,
single_loop_run: Any | None = None,
single_iteration_run: SingleNodeRunEntity | None = None,
single_loop_run: SingleNodeRunEntity | None = None,
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
"""
Prepare graph, variable pool, and runtime state for single node execution
@@ -208,11 +220,88 @@ class WorkflowBasedAppRunner:
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
return graph, variable_pool, graph_runtime_state
@staticmethod
def _get_graph_items(graph_config: GraphConfigMapping) -> tuple[list[GraphConfigMapping], list[GraphConfigMapping]]:
nodes = graph_config.get("nodes")
edges = graph_config.get("edges")
if not isinstance(nodes, list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(edges, list):
raise ValueError("edges in workflow graph must be a list")
validated_nodes: list[GraphConfigMapping] = []
for node in nodes:
if not isinstance(node, Mapping):
raise ValueError("nodes in workflow graph must be mappings")
validated_nodes.append(node)
validated_edges: list[GraphConfigMapping] = []
for edge in edges:
if not isinstance(edge, Mapping):
raise ValueError("edges in workflow graph must be mappings")
validated_edges.append(edge)
return validated_nodes, validated_edges
@staticmethod
def _extract_start_node_id(node_config: GraphConfigMapping | None) -> str | None:
if node_config is None:
return None
node_data = node_config.get("data")
if not isinstance(node_data, Mapping):
return None
start_node_id = node_data.get("start_node_id")
return start_node_id if isinstance(start_node_id, str) else None
@classmethod
def _build_single_node_graph_config(
cls,
*,
graph_config: GraphConfigMapping,
node_id: str,
node_type_filter_key: str,
) -> tuple[GraphConfigObject, NodeConfigDict]:
node_configs, edge_configs = cls._get_graph_items(graph_config)
main_node_config = next((node for node in node_configs if node.get("id") == node_id), None)
start_node_id = cls._extract_start_node_id(main_node_config)
filtered_node_configs = [
dict(node)
for node in node_configs
if node.get("id") == node_id
or (isinstance(node_data := node.get("data"), Mapping) and node_data.get(node_type_filter_key) == node_id)
or (start_node_id and node.get("id") == start_node_id)
]
if not filtered_node_configs:
raise ValueError(f"node id {node_id} not found in workflow graph")
filtered_node_ids = {
str(node_id_value) for node in filtered_node_configs if isinstance((node_id_value := node.get("id")), str)
}
filtered_edge_configs = [
dict(edge)
for edge in edge_configs
if (edge.get("source") is None or edge.get("source") in filtered_node_ids)
and (edge.get("target") is None or edge.get("target") in filtered_node_ids)
]
target_node_config = next((node for node in filtered_node_configs if node.get("id") == node_id), None)
if target_node_config is None:
raise ValueError(f"node id {node_id} not found in workflow graph")
return (
{
"nodes": filtered_node_configs,
"edges": filtered_edge_configs,
},
NodeConfigDictAdapter.validate_python(target_node_config),
)
def _get_graph_and_variable_pool_for_single_node_run(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
user_inputs: Mapping[str, object],
graph_runtime_state: GraphRuntimeState,
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
@@ -236,41 +325,14 @@ class WorkflowBasedAppRunner:
if not graph_config:
raise ValueError("workflow graph not found")
graph_config = cast(dict[str, Any], graph_config)
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# filter nodes only in the specified node type (iteration or loop)
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id
or node.get("data", {}).get(node_type_filter_key, "") == node_id
or (start_node_id and node.get("id") == start_node_id)
]
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs]
# filter edges only in the specified node type
edge_configs = [
edge
for edge in graph_config.get("edges", [])
if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
]
graph_config["edges"] = edge_configs
graph_config, target_node_config = self._build_single_node_graph_config(
graph_config=graph_config,
node_id=node_id,
node_type_filter_key=node_type_filter_key,
)
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
@@ -299,18 +361,6 @@ class WorkflowBasedAppRunner:
if not graph:
raise ValueError("graph not found in workflow")
# fetch node config from node id
target_node_config = None
for node in node_configs:
if node.get("id") == node_id:
target_node_config = node
break
if not target_node_config:
raise ValueError(f"{node_type_label} node id not found in workflow graph")
target_node_config = NodeConfigDictAdapter.validate_python(target_node_config)
# Get node class
node_type = target_node_config["data"].type
node_version = str(target_node_config["data"].version)

View File

@@ -213,7 +213,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""
node_id: str
inputs: Mapping
inputs: Mapping[str, object]
single_iteration_run: SingleIterationRunEntity | None = None
@@ -223,7 +223,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""
node_id: str
inputs: Mapping
inputs: Mapping[str, object]
single_loop_run: SingleLoopRunEntity | None = None
@@ -243,7 +243,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
node_id: str
inputs: dict
inputs: Mapping[str, object]
single_iteration_run: SingleIterationRunEntity | None = None
@@ -253,7 +253,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
node_id: str
inputs: dict
inputs: Mapping[str, object]
single_loop_run: SingleLoopRunEntity | None = None

View File

@@ -1045,9 +1045,10 @@ class ToolManager:
continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
variable_selector = tool_input.require_variable_selector()
variable = variable_pool.get(variable_selector)
if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
raise ToolParameterError(f"Variable {variable_selector} does not exist")
parameter_value = variable.value
elif tool_input.type == "constant":
parameter_value = tool_input.value

View File

@@ -1,13 +1,24 @@
from enum import IntEnum, StrEnum, auto
from typing import Any, Literal, Union
from __future__ import annotations
from pydantic import BaseModel
from enum import IntEnum, StrEnum, auto
from typing import Literal, TypeAlias
from pydantic import BaseModel, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
AgentInputConstantValue: TypeAlias = (
list[ToolSelector] | str | int | float | bool | dict[str, object] | list[object] | None
)
VariableSelector: TypeAlias = list[str]
_AGENT_INPUT_VALUE_ADAPTER: TypeAdapter[AgentInputConstantValue] = TypeAdapter(AgentInputConstantValue)
_AGENT_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
class AgentNodeData(BaseNodeData):
type: NodeType = BuiltinNodeTypes.AGENT
@@ -21,8 +32,20 @@ class AgentNodeData(BaseNodeData):
tool_node_version: str | None = None
class AgentInput(BaseModel):
value: Union[list[str], list[ToolSelector], Any]
type: Literal["mixed", "variable", "constant"]
value: AgentInputConstantValue | VariableSelector
@field_validator("value", mode="before")
@classmethod
def validate_value(
cls, value: object, validation_info: ValidationInfo
) -> AgentInputConstantValue | VariableSelector:
input_type = validation_info.data.get("type")
if input_type == "variable":
return _AGENT_VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if input_type in {"mixed", "constant"}:
return _AGENT_INPUT_VALUE_ADAPTER.validate_python(value)
raise ValueError(f"Unknown agent input type: {input_type}")
agent_parameters: dict[str, AgentInput]

View File

@@ -1,16 +1,17 @@
from __future__ import annotations
import json
from collections.abc import Sequence
from typing import Any, cast
from collections.abc import Mapping, Sequence
from typing import TypeAlias
from packaging.version import Version
from pydantic import ValidationError
from pydantic import TypeAdapter, ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.app.entities.app_invoke_entities import InvokeFrom
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.plugin.entities.request import InvokeCredentials
@@ -28,6 +29,14 @@ from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGen
from .exceptions import AgentInputTypeError, AgentVariableNotFoundError
from .strategy_protocols import ResolvedAgentStrategy
JsonObject: TypeAlias = dict[str, object]
JsonObjectList: TypeAlias = list[JsonObject]
VariableSelector: TypeAlias = list[str]
_JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
_JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
_VARIABLE_SELECTOR_ADAPTER = TypeAdapter(VariableSelector)
class AgentRuntimeSupport:
def build_parameters(
@@ -39,12 +48,12 @@ class AgentRuntimeSupport:
strategy: ResolvedAgentStrategy,
tenant_id: str,
app_id: str,
invoke_from: Any,
invoke_from: InvokeFrom,
for_log: bool = False,
) -> dict[str, Any]:
) -> dict[str, object]:
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
result: dict[str, Any] = {}
result: dict[str, object] = {}
for parameter_name in node_data.agent_parameters:
parameter = agent_parameters_dictionary.get(parameter_name)
if not parameter:
@@ -54,9 +63,10 @@ class AgentRuntimeSupport:
agent_input = node_data.agent_parameters[parameter_name]
match agent_input.type:
case "variable":
variable = variable_pool.get(agent_input.value) # type: ignore[arg-type]
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(agent_input.value)
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
raise AgentVariableNotFoundError(str(variable_selector))
parameter_value = variable.value
case "mixed" | "constant":
try:
@@ -79,60 +89,38 @@ class AgentRuntimeSupport:
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
value = [tool for tool in value if tool.get("enabled", False)]
value = self._filter_mcp_type_tool(strategy, value)
for tool in value:
if "schemas" in tool:
tool.pop("schemas")
parameters = tool.get("parameters", {})
if all(isinstance(v, dict) for _, v in parameters.items()):
params = {}
for key, param in parameters.items():
if param.get("auto", ParamsAutoGenerated.OPEN) in (
ParamsAutoGenerated.CLOSE,
0,
):
value_param = param.get("value", {})
if value_param and value_param.get("type", "") == "variable":
variable_selector = value_param.get("value")
if not variable_selector:
raise ValueError("Variable selector is missing for a variable-type parameter.")
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
params[key] = variable.value
else:
params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
tool["parameters"] = parameters
tool_payloads = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
value = self._normalize_tool_payloads(
strategy=strategy,
tools=tool_payloads,
variable_pool=variable_pool,
)
if not for_log:
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
value = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
provider_type = self._coerce_tool_provider_type(tool.get("type"))
setting_params = self._coerce_json_object(tool.get("settings")) or {}
parameters = self._coerce_json_object(tool.get("parameters")) or {}
manual_input_params = [key for key, value in parameters.items() if value is not None]
parameters = {**parameters, **setting_params}
provider_id = self._coerce_optional_string(tool.get("provider_name")) or ""
tool_name = self._coerce_optional_string(tool.get("tool_name")) or ""
plugin_unique_identifier = self._coerce_optional_string(tool.get("plugin_unique_identifier"))
credential_id = self._coerce_optional_string(tool.get("credential_id"))
entity = AgentToolEntity(
provider_id=tool.get("provider_name", ""),
provider_id=provider_id,
provider_type=provider_type,
tool_name=tool.get("tool_name", ""),
tool_name=tool_name,
tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
credential_id=tool.get("credential_id", None),
plugin_unique_identifier=plugin_unique_identifier,
credential_id=credential_id,
)
extra = tool.get("extra", {})
extra = self._coerce_json_object(tool.get("extra")) or {}
runtime_variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None:
@@ -145,8 +133,9 @@ class AgentRuntimeSupport:
runtime_variable_pool,
)
if tool_runtime.entity.description:
description_override = self._coerce_optional_string(extra.get("description"))
tool_runtime.entity.description.llm = (
extra.get("description", "") or tool_runtime.entity.description.llm
description_override or tool_runtime.entity.description.llm
)
for tool_runtime_params in tool_runtime.entity.parameters:
tool_runtime_params.form = (
@@ -167,13 +156,13 @@ class AgentRuntimeSupport:
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
"credential_id": tool.get("credential_id", None),
"credential_id": credential_id,
"provider_type": provider_type.value,
}
)
value = tool_value
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value)
value = _JSON_OBJECT_ADAPTER.validate_python(value)
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
history_prompt_messages = []
if node_data.memory:
@@ -199,17 +188,27 @@ class AgentRuntimeSupport:
return result
def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials:
def build_credentials(self, *, parameters: Mapping[str, object]) -> InvokeCredentials:
credentials = InvokeCredentials()
credentials.tool_credentials = {}
for tool in parameters.get("tools", []):
tools = parameters.get("tools")
if not isinstance(tools, list):
return credentials
for raw_tool in tools:
tool = self._coerce_json_object(raw_tool)
if tool is None:
continue
if not tool.get("credential_id"):
continue
try:
identity = ToolIdentity.model_validate(tool.get("identity", {}))
except ValidationError:
continue
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
credential_id = self._coerce_optional_string(tool.get("credential_id"))
if credential_id is None:
continue
credentials.tool_credentials[identity.provider] = credential_id
return credentials
def fetch_memory(
@@ -232,14 +231,14 @@ class AgentRuntimeSupport:
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
def fetch_model(self, *, tenant_id: str, value: Mapping[str, object]) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=value.get("provider", ""),
provider=str(value.get("provider", "")),
model_type=ModelType.LLM,
)
model_name = value.get("model", "")
model_name = str(value.get("model", ""))
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM,
model=model_name,
@@ -249,7 +248,7 @@ class AgentRuntimeSupport:
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),
model_type=ModelType(str(value.get("model_type", ""))),
model=model_name,
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
@@ -268,9 +267,88 @@ class AgentRuntimeSupport:
@staticmethod
def _filter_mcp_type_tool(
strategy: ResolvedAgentStrategy,
tools: list[dict[str, Any]],
) -> list[dict[str, Any]]:
tools: JsonObjectList,
) -> JsonObjectList:
meta_version = strategy.meta_version
if meta_version and Version(meta_version) > Version("0.0.1"):
return tools
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
def _normalize_tool_payloads(
self,
*,
strategy: ResolvedAgentStrategy,
tools: JsonObjectList,
variable_pool: VariablePool,
) -> JsonObjectList:
enabled_tools = [dict(tool) for tool in tools if bool(tool.get("enabled", False))]
normalized_tools = self._filter_mcp_type_tool(strategy, enabled_tools)
for tool in normalized_tools:
tool.pop("schemas", None)
tool["parameters"] = self._resolve_tool_parameters(tool=tool, variable_pool=variable_pool)
tool["settings"] = self._resolve_tool_settings(tool)
return normalized_tools
def _resolve_tool_parameters(self, *, tool: Mapping[str, object], variable_pool: VariablePool) -> JsonObject:
parameter_configs = self._coerce_named_json_objects(tool.get("parameters"))
if parameter_configs is None:
raw_parameters = self._coerce_json_object(tool.get("parameters"))
return raw_parameters or {}
resolved_parameters: JsonObject = {}
for key, parameter_config in parameter_configs.items():
if parameter_config.get("auto", ParamsAutoGenerated.OPEN) in (ParamsAutoGenerated.CLOSE, 0):
value_param = self._coerce_json_object(parameter_config.get("value"))
if value_param and value_param.get("type") == "variable":
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(value_param.get("value"))
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
resolved_parameters[key] = variable.value
else:
resolved_parameters[key] = value_param.get("value", "") if value_param is not None else None
else:
resolved_parameters[key] = None
return resolved_parameters
@staticmethod
def _resolve_tool_settings(tool: Mapping[str, object]) -> JsonObject:
settings = AgentRuntimeSupport._coerce_named_json_objects(tool.get("settings"))
if settings is None:
return {}
return {key: setting.get("value") for key, setting in settings.items()}
@staticmethod
def _coerce_json_object(value: object) -> JsonObject | None:
try:
return _JSON_OBJECT_ADAPTER.validate_python(value)
except ValidationError:
return None
@staticmethod
def _coerce_optional_string(value: object) -> str | None:
return value if isinstance(value, str) else None
@staticmethod
def _coerce_tool_provider_type(value: object) -> ToolProviderType:
if isinstance(value, ToolProviderType):
return value
if isinstance(value, str):
return ToolProviderType(value)
return ToolProviderType.BUILT_IN
@classmethod
def _coerce_named_json_objects(cls, value: object) -> dict[str, JsonObject] | None:
if not isinstance(value, dict):
return None
coerced: dict[str, JsonObject] = {}
for key, item in value.items():
if not isinstance(key, str):
return None
json_object = cls._coerce_json_object(item)
if json_object is None:
return None
coerced[key] = json_object
return coerced

View File

@@ -1,7 +1,7 @@
import logging
import time
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from typing import Any, TypeAlias, cast
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
@@ -32,6 +32,13 @@ from models.workflow import Workflow
logger = logging.getLogger(__name__)
SpecialValueScalar: TypeAlias = str | int | float | bool | None
SpecialValue: TypeAlias = SpecialValueScalar | File | Mapping[str, "SpecialValue"] | list["SpecialValue"]
SerializedSpecialValue: TypeAlias = (
SpecialValueScalar | dict[str, "SerializedSpecialValue"] | list["SerializedSpecialValue"]
)
SingleNodeGraphConfig: TypeAlias = dict[str, list[dict[str, object]]]
class _WorkflowChildEngineBuilder:
@staticmethod
@@ -276,10 +283,10 @@ class WorkflowEntry:
@staticmethod
def _create_single_node_graph(
node_id: str,
node_data: dict[str, Any],
node_data: Mapping[str, object],
node_width: int = 114,
node_height: int = 514,
) -> dict[str, Any]:
) -> SingleNodeGraphConfig:
"""
Create a minimal graph structure for testing a single node in isolation.
@@ -289,14 +296,14 @@ class WorkflowEntry:
:param node_height: height for UI layout (default: 100)
:return: graph dictionary with start node and target node
"""
node_config = {
node_config: dict[str, object] = {
"id": node_id,
"width": node_width,
"height": node_height,
"type": "custom",
"data": node_data,
"data": dict(node_data),
}
start_node_config = {
start_node_config: dict[str, object] = {
"id": "start",
"width": node_width,
"height": node_height,
@@ -321,7 +328,12 @@ class WorkflowEntry:
@classmethod
def run_free_node(
cls, node_data: dict[str, Any], node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
cls,
node_data: Mapping[str, object],
node_id: str,
tenant_id: str,
user_id: str,
user_inputs: Mapping[str, object],
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
"""
Run free node
@@ -339,6 +351,8 @@ class WorkflowEntry:
graph_dict = cls._create_single_node_graph(node_id, node_data)
node_type = node_data.get("type", "")
if not isinstance(node_type, str):
raise ValueError("Node type must be a string")
if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
raise ValueError(f"Node type {node_type} not supported")
@@ -369,7 +383,7 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": dict(node_data)})
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
@@ -405,30 +419,34 @@ class WorkflowEntry:
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
@staticmethod
def handle_special_values(value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
def handle_special_values(value: Mapping[str, SpecialValue] | None) -> dict[str, SerializedSpecialValue] | None:
# NOTE(QuantumGhost): Avoid using this function in new code.
# Keep values structured as long as possible and only convert to dict
# immediately before serialization (e.g., JSON serialization) to maintain
# data integrity and type information.
result = WorkflowEntry._handle_special_values(value)
return result if isinstance(result, Mapping) or result is None else dict(result)
if result is None:
return None
if isinstance(result, dict):
return result
raise TypeError("handle_special_values expects a mapping input")
@staticmethod
def _handle_special_values(value: Any):
def _handle_special_values(value: SpecialValue) -> SerializedSpecialValue:
if value is None:
return value
if isinstance(value, dict):
res = {}
if isinstance(value, Mapping):
res: dict[str, SerializedSpecialValue] = {}
for k, v in value.items():
res[k] = WorkflowEntry._handle_special_values(v)
return res
if isinstance(value, list):
res_list = []
res_list: list[SerializedSpecialValue] = []
for item in value:
res_list.append(WorkflowEntry._handle_special_values(item))
return res_list
if isinstance(value, File):
return value.to_dict()
return dict(value.to_dict())
return value
@classmethod

View File

@@ -112,6 +112,8 @@ def _get_encoded_string(f: File, /) -> str:
data = _download_file_content(f.storage_key)
case FileTransferMethod.DATASOURCE_FILE:
data = _download_file_content(f.storage_key)
case _:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
return base64.b64encode(data).decode("utf-8")

View File

@@ -133,6 +133,8 @@ class ExecutionLimitsLayer(GraphEngineLayer):
elif limit_type == LimitType.TIME_LIMIT:
elapsed_time = time.time() - self.start_time if self.start_time else 0
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
else:
return
self.logger.warning("Execution limit exceeded: %s", reason)

View File

@@ -336,12 +336,7 @@ class Node(Generic[NodeDataT]):
def _restore_execution_id_from_runtime_state(self) -> str | None:
graph_execution = self.graph_runtime_state.graph_execution
try:
node_executions = graph_execution.node_executions
except AttributeError:
return None
if not isinstance(node_executions, dict):
return None
node_executions = graph_execution.node_executions
node_execution = node_executions.get(self._node_id)
if node_execution is None:
return None
@@ -395,8 +390,7 @@ class Node(Generic[NodeDataT]):
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
yield self._dispatch(event)
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
event.id = self.execution_id
yield event
yield event.model_copy(update={"id": self.execution_id})
else:
yield event
except Exception as e:

View File

@@ -443,7 +443,10 @@ def _extract_text_from_docx(file_content: bytes) -> str:
# Keep track of paragraph and table positions
content_items: list[tuple[int, str, Table | Paragraph]] = []
it = iter(doc.element.body)
doc_body = getattr(doc.element, "body", None)
if doc_body is None:
raise TextExtractionError("DOCX body not found")
it = iter(doc_body)
part = next(it, None)
i = 0
while part is not None:

View File

@@ -1,7 +1,8 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from typing import Literal, NotRequired
from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from dify_graph.entities.base_node_data import BaseNodeData
@@ -10,11 +11,17 @@ from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
from dify_graph.nodes.base.entities import VariableSelector
class StructuredOutputConfig(TypedDict):
schema: Mapping[str, object]
name: NotRequired[str]
description: NotRequired[str]
class ModelConfig(BaseModel):
provider: str
name: str
mode: LLMMode
completion_params: dict[str, Any] = Field(default_factory=dict)
completion_params: dict[str, object] = Field(default_factory=dict)
class ContextConfig(BaseModel):
@@ -33,7 +40,7 @@ class VisionConfig(BaseModel):
@field_validator("configs", mode="before")
@classmethod
def convert_none_configs(cls, v: Any):
def convert_none_configs(cls, v: object):
if v is None:
return VisionConfigOptions()
return v
@@ -44,7 +51,7 @@ class PromptConfig(BaseModel):
@field_validator("jinja2_variables", mode="before")
@classmethod
def convert_none_jinja2_variables(cls, v: Any):
def convert_none_jinja2_variables(cls, v: object):
if v is None:
return []
return v
@@ -67,7 +74,7 @@ class LLMNodeData(BaseNodeData):
memory: MemoryConfig | None = None
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: Mapping[str, Any] | None = None
structured_output: StructuredOutputConfig | None = None
# We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
reasoning_format: Literal["separated", "tagged"] = Field(
@@ -90,11 +97,30 @@ class LLMNodeData(BaseNodeData):
@field_validator("prompt_config", mode="before")
@classmethod
def convert_none_prompt_config(cls, v: Any):
def convert_none_prompt_config(cls, v: object):
if v is None:
return PromptConfig()
return v
@field_validator("structured_output", mode="before")
@classmethod
def convert_legacy_structured_output(cls, v: object) -> StructuredOutputConfig | None | object:
if not isinstance(v, Mapping):
return v
schema = v.get("schema")
if schema is None:
return None
normalized: StructuredOutputConfig = {"schema": schema}
name = v.get("name")
description = v.get("description")
if isinstance(name, str):
normalized["name"] = name
if isinstance(description, str):
normalized["description"] = description
return normalized
@property
def structured_output_enabled(self) -> bool:
return self.structured_output_switch_on and self.structured_output is not None

View File

@@ -9,6 +9,7 @@ import time
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from pydantic import TypeAdapter
from sqlalchemy import select
from core.llm_generator.output_parser.errors import OutputParserError
@@ -74,6 +75,7 @@ from .entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
StructuredOutputConfig,
)
from .exc import (
InvalidContextStructureError,
@@ -88,6 +90,7 @@ if TYPE_CHECKING:
from dify_graph.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)
_JSON_OBJECT_ADAPTER = TypeAdapter(dict[str, object])
class LLMNode(Node[LLMNodeData]):
@@ -354,7 +357,7 @@ class LLMNode(Node[LLMNodeData]):
stop: Sequence[str] | None = None,
user_id: str,
structured_output_enabled: bool,
structured_output: Mapping[str, Any] | None = None,
structured_output: StructuredOutputConfig | None = None,
file_saver: LLMFileSaver,
file_outputs: list[File],
node_id: str,
@@ -367,8 +370,10 @@ class LLMNode(Node[LLMNodeData]):
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
if structured_output_enabled:
if structured_output is None:
raise LLMNodeError("Please provide a valid structured output schema")
output_schema = LLMNode.fetch_structured_output_schema(
structured_output=structured_output or {},
structured_output=structured_output,
)
request_start_time = time.perf_counter()
@@ -920,6 +925,12 @@ class LLMNode(Node[LLMNodeData]):
# Extract clean text and reasoning from <think> tags
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
structured_output = (
dict(invoke_result.structured_output)
if isinstance(invoke_result, LLMResultWithStructuredOutput) and invoke_result.structured_output is not None
else None
)
event = ModelInvokeCompletedEvent(
# Use clean_text for separated mode, full_text for tagged mode
text=clean_text if reasoning_format == "separated" else full_text,
@@ -928,7 +939,7 @@ class LLMNode(Node[LLMNodeData]):
# Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content,
# Pass structured output if enabled
structured_output=getattr(invoke_result, "structured_output", None),
structured_output=structured_output,
)
if request_latency is not None:
event.usage.latency = round(request_latency, 3)
@@ -962,27 +973,18 @@ class LLMNode(Node[LLMNodeData]):
@staticmethod
def fetch_structured_output_schema(
*,
structured_output: Mapping[str, Any],
) -> dict[str, Any]:
structured_output: StructuredOutputConfig,
) -> dict[str, object]:
"""
Fetch the structured output schema from the node data.
Returns:
dict[str, Any]: The structured output schema
dict[str, object]: The structured output schema
"""
if not structured_output:
schema = structured_output.get("schema")
if not schema:
raise LLMNodeError("Please provide a valid structured output schema")
structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
if not structured_output_schema:
raise LLMNodeError("Please provide a valid structured output schema")
try:
schema = json.loads(structured_output_schema)
if not isinstance(schema, dict):
raise LLMNodeError("structured_output_schema must be a JSON object")
return schema
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
return _JSON_OBJECT_ADAPTER.validate_python(schema)
@staticmethod
def _save_multimodal_output_and_convert_result_to_markdown(

View File

@@ -1,7 +1,10 @@
from enum import StrEnum
from typing import Annotated, Any, Literal
from __future__ import annotations
from pydantic import AfterValidator, BaseModel, Field, field_validator
from enum import StrEnum
from typing import Annotated, Any, Literal, TypeAlias, cast
from pydantic import AfterValidator, BaseModel, Field, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
@@ -9,6 +12,12 @@ from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
from dify_graph.utils.condition.entities import Condition
from dify_graph.variables.types import SegmentType
LoopValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
LoopValueMapping: TypeAlias = dict[str, LoopValue]
VariableSelector: TypeAlias = list[str]
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
_VALID_VAR_TYPE = frozenset(
[
SegmentType.STRING,
@@ -29,6 +38,36 @@ def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
return seg_type
def _validate_loop_value(value: object) -> LoopValue:
if value is None or isinstance(value, (str, int, float, bool)):
return cast(LoopValue, value)
if isinstance(value, list):
return [_validate_loop_value(item) for item in value]
if isinstance(value, dict):
normalized: dict[str, LoopValue] = {}
for key, item in value.items():
if not isinstance(key, str):
raise TypeError("Loop values only support string object keys")
normalized[key] = _validate_loop_value(item)
return normalized
raise TypeError("Loop values must be JSON-like primitives, arrays, or objects")
def _validate_loop_value_mapping(value: object) -> LoopValueMapping:
if not isinstance(value, dict):
raise TypeError("Loop outputs must be an object")
normalized: LoopValueMapping = {}
for key, item in value.items():
if not isinstance(key, str):
raise TypeError("Loop output keys must be strings")
normalized[key] = _validate_loop_value(item)
return normalized
class LoopVariableData(BaseModel):
"""
Loop Variable Data.
@@ -37,7 +76,29 @@ class LoopVariableData(BaseModel):
label: str
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
value_type: Literal["variable", "constant"]
value: Any | list[str] | None = None
value: LoopValue | VariableSelector | None = None
@field_validator("value", mode="before")
@classmethod
def validate_value(cls, value: object, validation_info: ValidationInfo) -> LoopValue | VariableSelector | None:
value_type = validation_info.data.get("value_type")
if value_type == "variable":
if value is None:
raise ValueError("Variable loop inputs require a selector")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if value_type == "constant":
return _validate_loop_value(value)
raise ValueError(f"Unknown loop variable value type: {value_type}")
def require_variable_selector(self) -> VariableSelector:
if self.value_type != "variable":
raise ValueError(f"Expected variable loop input, got {self.value_type}")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
def require_constant_value(self) -> LoopValue:
if self.value_type != "constant":
raise ValueError(f"Expected constant loop input, got {self.value_type}")
return _validate_loop_value(self.value)
class LoopNodeData(BaseLoopNodeData):
@@ -46,14 +107,14 @@ class LoopNodeData(BaseLoopNodeData):
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
outputs: dict[str, Any] = Field(default_factory=dict)
outputs: LoopValueMapping = Field(default_factory=dict)
@field_validator("outputs", mode="before")
@classmethod
def validate_outputs(cls, v):
if v is None:
def validate_outputs(cls, value: object) -> LoopValueMapping:
if value is None:
return {}
return v
return _validate_loop_value_mapping(value)
class LoopStartNodeData(BaseNodeData):
@@ -77,8 +138,8 @@ class LoopState(BaseLoopState):
Loop State.
"""
outputs: list[Any] = Field(default_factory=list)
current_output: Any = None
outputs: list[LoopValue] = Field(default_factory=list)
current_output: LoopValue | None = None
class MetaData(BaseLoopState.MetaData):
"""
@@ -87,7 +148,7 @@ class LoopState(BaseLoopState):
loop_length: int
def get_last_output(self) -> Any:
def get_last_output(self) -> LoopValue | None:
"""
Get last output.
"""
@@ -95,7 +156,7 @@ class LoopState(BaseLoopState):
return self.outputs[-1]
return None
def get_current_output(self) -> Any:
def get_current_output(self) -> LoopValue | None:
"""
Get current output.
"""

View File

@@ -3,7 +3,7 @@ import json
import logging
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Literal, cast
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.enums import (
@@ -29,7 +29,7 @@ from dify_graph.node_events import (
)
from dify_graph.nodes.base import LLMUsageTrackingMixin
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopValue, LoopVariableData
from dify_graph.utils.condition.processor import ConditionProcessor
from dify_graph.variables import Segment, SegmentType
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
@@ -60,7 +60,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
break_conditions = self.node_data.break_conditions
logical_operator = self.node_data.logical_operator
inputs = {"loop_count": loop_count}
inputs: dict[str, object] = {"loop_count": loop_count}
if not self.node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
@@ -68,12 +68,14 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
root_node_id = self.node_data.start_node_id
# Initialize loop variables in the original variable pool
loop_variable_selectors = {}
loop_variable_selectors: dict[str, list[str]] = {}
if self.node_data.loop_variables:
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.require_constant_value()),
"variable": lambda var: (
self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None
self.graph_runtime_state.variable_pool.get(var.require_variable_selector())
if var.value is not None
else None
),
}
for loop_variable in self.node_data.loop_variables:
@@ -95,7 +97,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
condition_processor = ConditionProcessor()
loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
single_loop_variable_map: dict[str, dict[str, LoopValue]] = {} # single loop variable output
loop_usage = LLMUsage.empty_usage()
loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
@@ -146,7 +148,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
# Collect loop variable values after iteration
single_loop_variable = {}
single_loop_variable: dict[str, LoopValue] = {}
for key, selector in loop_variable_selectors.items():
segment = self.graph_runtime_state.variable_pool.get(selector)
single_loop_variable[key] = segment.value if segment else None
@@ -297,20 +299,29 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
graph_config: Mapping[str, object],
node_id: str,
node_data: LoopNodeData,
) -> Mapping[str, Sequence[str]]:
variable_mapping = {}
variable_mapping: dict[str, Sequence[str]] = {}
# Extract loop node IDs statically from graph_config
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
# Get node configs from graph_config
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
raw_nodes = graph_config.get("nodes")
node_configs: dict[str, Mapping[str, object]] = {}
if isinstance(raw_nodes, list):
for raw_node in raw_nodes:
if not isinstance(raw_node, dict):
continue
raw_node_id = raw_node.get("id")
if isinstance(raw_node_id, str):
node_configs[raw_node_id] = raw_node
for sub_node_id, sub_node_config in node_configs.items():
if sub_node_config.get("data", {}).get("loop_id") != node_id:
sub_node_data = sub_node_config.get("data")
if not isinstance(sub_node_data, dict) or sub_node_data.get("loop_id") != node_id:
continue
# variable selector to variable mapping
@@ -341,9 +352,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
for loop_variable in node_data.loop_variables or []:
if loop_variable.value_type == "variable":
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
# add loop variable to variable mapping
selector = loop_variable.value
selector = loop_variable.require_variable_selector()
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
# remove variable out from loop
@@ -352,7 +362,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
return variable_mapping
@classmethod
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, object], loop_node_id: str) -> set[str]:
"""
Extract node IDs that belong to a specific loop from graph configuration.
@@ -363,12 +373,19 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
:param loop_node_id: the ID of the loop node
:return: set of node IDs that belong to the loop
"""
loop_node_ids = set()
loop_node_ids: set[str] = set()
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
raw_nodes = graph_config.get("nodes")
if not isinstance(raw_nodes, list):
return loop_node_ids
for node in raw_nodes:
if not isinstance(node, dict):
continue
node_data = node.get("data")
if not isinstance(node_data, dict):
continue
if node_data.get("loop_id") == loop_node_id:
node_id = node.get("id")
if node_id:
@@ -377,7 +394,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
return loop_node_ids
@staticmethod
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
def _get_segment_for_constant(var_type: SegmentType, original_value: LoopValue | None) -> Segment:
"""Get the appropriate segment type for a constant value."""
# TODO: Refactor for maintainability:
# 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)

View File

@@ -1,4 +1,4 @@
from typing import Annotated, Any, Literal
from typing import Annotated, Literal
from pydantic import (
BaseModel,
@@ -6,6 +6,7 @@ from pydantic import (
Field,
field_validator,
)
from typing_extensions import TypedDict
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from dify_graph.entities.base_node_data import BaseNodeData
@@ -55,7 +56,7 @@ class ParameterConfig(BaseModel):
@field_validator("name", mode="before")
@classmethod
def validate_name(cls, value) -> str:
def validate_name(cls, value: object) -> str:
if not value:
raise ValueError("Parameter name is required")
if value in {"__reason", "__is_success"}:
@@ -79,6 +80,23 @@ class ParameterConfig(BaseModel):
return element_type
class JsonSchemaArrayItems(TypedDict):
type: str
class ParameterJsonSchemaProperty(TypedDict, total=False):
description: str
type: str
items: JsonSchemaArrayItems
enum: list[str]
class ParameterJsonSchema(TypedDict):
type: Literal["object"]
properties: dict[str, ParameterJsonSchemaProperty]
required: list[str]
class ParameterExtractorNodeData(BaseNodeData):
"""
Parameter Extractor Node Data.
@@ -95,19 +113,19 @@ class ParameterExtractorNodeData(BaseNodeData):
@field_validator("reasoning_mode", mode="before")
@classmethod
def set_reasoning_mode(cls, v) -> str:
return v or "function_call"
def set_reasoning_mode(cls, v: object) -> str:
return str(v) if v else "function_call"
def get_parameter_json_schema(self):
def get_parameter_json_schema(self) -> ParameterJsonSchema:
"""
Get parameter json schema.
:return: parameter json schema
"""
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
parameters: ParameterJsonSchema = {"type": "object", "properties": {}, "required": []}
for parameter in self.parameters:
parameter_schema: dict[str, Any] = {"description": parameter.description}
parameter_schema: ParameterJsonSchemaProperty = {"description": parameter.description}
if parameter.type == SegmentType.STRING:
parameter_schema["type"] = "string"
@@ -118,7 +136,7 @@ class ParameterExtractorNodeData(BaseNodeData):
raise AssertionError("element type should not be None.")
parameter_schema["items"] = {"type": element_type.value}
else:
parameter_schema["type"] = parameter.type
parameter_schema["type"] = parameter.type.value
if parameter.options:
parameter_schema["enum"] = parameter.options

View File

@@ -5,6 +5,8 @@ import uuid
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from pydantic import TypeAdapter
from core.model_manager import ModelInstance
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
@@ -63,6 +65,7 @@ from .prompts import (
)
logger = logging.getLogger(__name__)
_JSON_OBJECT_ADAPTER = TypeAdapter(dict[str, object])
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
@@ -70,7 +73,7 @@ if TYPE_CHECKING:
from dify_graph.runtime import GraphRuntimeState
def extract_json(text):
def extract_json(text: str) -> str | None:
"""
From a given JSON started from '{' or '[' extract the complete JSON object.
"""
@@ -392,10 +395,15 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
)
# generate tool
parameter_schema = node_data.get_parameter_json_schema()
tool = PromptMessageTool(
name=FUNCTION_CALLING_EXTRACTOR_NAME,
description="Extract parameters from the natural language text",
parameters=node_data.get_parameter_json_schema(),
parameters={
"type": parameter_schema["type"],
"properties": dict(parameter_schema["properties"]),
"required": list(parameter_schema["required"]),
},
)
return prompt_messages, [tool]
@@ -602,19 +610,21 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
else:
return None
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
def _transform_result(self, data: ParameterExtractorNodeData, result: Mapping[str, object]) -> dict[str, object]:
"""
Transform result into standard format.
"""
transformed_result: dict[str, Any] = {}
transformed_result: dict[str, object] = {}
for parameter in data.parameters:
if parameter.name in result:
param_value = result[parameter.name]
# transform value
if parameter.type == SegmentType.NUMBER:
transformed = self._transform_number(param_value)
if transformed is not None:
transformed_result[parameter.name] = transformed
if isinstance(param_value, (bool, int, float, str)):
numeric_value: bool | int | float | str = param_value
transformed = self._transform_number(numeric_value)
if transformed is not None:
transformed_result[parameter.name] = transformed
elif parameter.type == SegmentType.BOOLEAN:
if isinstance(result[parameter.name], (bool, int)):
transformed_result[parameter.name] = bool(result[parameter.name])
@@ -661,7 +671,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
return transformed_result
def _extract_complete_json_response(self, result: str) -> dict | None:
def _extract_complete_json_response(self, result: str) -> dict[str, object] | None:
"""
Extract complete json response.
"""
@@ -672,11 +682,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
json_str = extract_json(result[idx:])
if json_str:
with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str))
return _JSON_OBJECT_ADAPTER.validate_python(json.loads(json_str))
logger.info("extra error: %s", result)
return None
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict[str, object] | None:
"""
Extract json from tool call.
"""
@@ -690,16 +700,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
json_str = extract_json(result[idx:])
if json_str:
with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str))
return _JSON_OBJECT_ADAPTER.validate_python(json.loads(json_str))
logger.info("extra error: %s", result)
return None
def _generate_default_result(self, data: ParameterExtractorNodeData):
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict[str, object]:
"""
Generate default result.
"""
result: dict[str, Any] = {}
result: dict[str, object] = {}
for parameter in data.parameters:
if parameter.type == "number":
result[parameter.name] = 0

View File

@@ -1,12 +1,66 @@
from typing import Any, Literal, Union
from __future__ import annotations
from pydantic import BaseModel, field_validator
from typing import Literal, TypeAlias, cast
from pydantic import BaseModel, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
from typing_extensions import TypedDict
from core.tools.entities.tool_entities import ToolProviderType
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
ToolConfigurationValue: TypeAlias = str | int | float | bool
ToolInputConstantValue: TypeAlias = str | int | float | bool | dict[str, object] | list[object] | None
VariableSelector: TypeAlias = list[str]
_TOOL_INPUT_MIXED_ADAPTER: TypeAdapter[str] = TypeAdapter(str)
_TOOL_INPUT_CONSTANT_ADAPTER: TypeAdapter[ToolInputConstantValue] = TypeAdapter(ToolInputConstantValue)
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
class WorkflowToolInputValue(TypedDict):
type: Literal["mixed", "variable", "constant"]
value: ToolInputConstantValue | VariableSelector
ToolConfigurationEntry: TypeAlias = ToolConfigurationValue | WorkflowToolInputValue
ToolConfigurations: TypeAlias = dict[str, ToolConfigurationEntry]
class ToolInputPayload(BaseModel):
type: Literal["mixed", "variable", "constant"]
value: ToolInputConstantValue | VariableSelector
@field_validator("value", mode="before")
@classmethod
def validate_value(
cls, value: object, validation_info: ValidationInfo
) -> ToolInputConstantValue | VariableSelector:
input_type = validation_info.data.get("type")
if input_type == "mixed":
return _TOOL_INPUT_MIXED_ADAPTER.validate_python(value)
if input_type == "variable":
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if input_type == "constant":
return _TOOL_INPUT_CONSTANT_ADAPTER.validate_python(value)
raise ValueError(f"Unknown tool input type: {input_type}")
def require_variable_selector(self) -> VariableSelector:
if self.type != "variable":
raise ValueError(f"Expected variable tool input, got {self.type}")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
def _validate_tool_configuration_entry(value: object) -> ToolConfigurationEntry:
if isinstance(value, (str, int, float, bool)):
return cast(ToolConfigurationEntry, value)
if isinstance(value, dict):
return cast(ToolConfigurationEntry, ToolInputPayload.model_validate(value).model_dump())
raise TypeError("Tool configuration values must be primitives or workflow tool input objects")
class ToolEntity(BaseModel):
provider_id: str
@@ -14,52 +68,29 @@ class ToolEntity(BaseModel):
provider_name: str # redundancy
tool_name: str
tool_label: str # redundancy
tool_configurations: dict[str, Any]
tool_configurations: ToolConfigurations
credential_id: str | None = None
plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before")
@classmethod
def validate_tool_configurations(cls, value, values: ValidationInfo):
def validate_tool_configurations(cls, value: object, _validation_info: ValidationInfo) -> ToolConfigurations:
if not isinstance(value, dict):
raise ValueError("tool_configurations must be a dictionary")
raise TypeError("tool_configurations must be a dictionary")
for key in values.data.get("tool_configurations", {}):
value = values.data.get("tool_configurations", {}).get(key)
if not isinstance(value, str | int | float | bool):
raise ValueError(f"{key} must be a string")
return value
normalized: ToolConfigurations = {}
for key, item in value.items():
if not isinstance(key, str):
raise TypeError("tool_configurations keys must be strings")
normalized[key] = _validate_tool_configuration_entry(item)
return normalized
class ToolNodeData(BaseNodeData, ToolEntity):
type: NodeType = BuiltinNodeTypes.TOOL
class ToolInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"]
@field_validator("type", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
typ = value
value = validation_info.data.get("value")
if value is None:
return typ
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))):
raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}")
return typ
class ToolInput(ToolInputPayload):
pass
tool_parameters: dict[str, ToolInput]
# The version of the tool parameter.
@@ -69,7 +100,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
@field_validator("tool_parameters", mode="before")
@classmethod
def filter_none_tool_inputs(cls, value):
def filter_none_tool_inputs(cls, value: object) -> object:
if not isinstance(value, dict):
return value
@@ -80,8 +111,10 @@ class ToolNodeData(BaseNodeData, ToolEntity):
}
@staticmethod
def _has_valid_value(tool_input):
def _has_valid_value(tool_input: object) -> bool:
"""Check if the value is valid"""
if isinstance(tool_input, dict):
return tool_input.get("value") is not None
return getattr(tool_input, "value", None) is not None
if isinstance(tool_input, ToolNodeData.ToolInput):
return tool_input.value is not None
return False

View File

@@ -225,10 +225,11 @@ class ToolNode(Node[ToolNodeData]):
continue
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
variable_selector = tool_input.require_variable_selector()
variable = variable_pool.get(variable_selector)
if variable is None:
if parameter.required:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
raise ToolParameterError(f"Variable {variable_selector} does not exist")
continue
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
@@ -510,8 +511,9 @@ class ToolNode(Node[ToolNodeData]):
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
variable_selector = input.require_variable_selector()
selector_key = ".".join(variable_selector)
result[f"#{selector_key}#"] = variable_selector
case "constant":
pass

View File

@@ -9,7 +9,7 @@ from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from dify_graph.variables import SegmentType, VariableBase
from dify_graph.variables import Segment, SegmentType, VariableBase
from .node_data import VariableAssignerData, WriteMode
@@ -74,23 +74,29 @@ class VariableAssignerNode(Node[VariableAssignerData]):
if not isinstance(original_variable, VariableBase):
raise VariableOperatorNodeError("assigned variable not found")
income_value: Segment
updated_variable: VariableBase
match self.node_data.write_mode:
case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if input_value is None:
raise VariableOperatorNodeError("input value not found")
income_value = input_value
updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if input_value is None:
raise VariableOperatorNodeError("input value not found")
income_value = input_value
updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={"value": updated_value})
case WriteMode.CLEAR:
income_value = SegmentType.get_zero_value(original_variable.value_type)
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
# Over write the variable.
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)

View File

@@ -66,6 +66,11 @@ class GraphExecutionProtocol(Protocol):
exceptions_count: int
pause_reasons: list[PauseReason]
@property
def node_executions(self) -> Mapping[str, NodeExecutionProtocol]:
"""Return node execution state keyed by node id for resume support."""
...
def start(self) -> None:
"""Transition execution into the running state."""
...
@@ -91,6 +96,12 @@ class GraphExecutionProtocol(Protocol):
...
class NodeExecutionProtocol(Protocol):
"""Structural interface for per-node execution state used during resume."""
execution_id: str | None
class ResponseStreamCoordinatorProtocol(Protocol):
"""Structural interface for response stream coordinator."""

View File

@@ -13,21 +13,6 @@ controllers/console/workspace/trigger_providers.py
controllers/service_api/app/annotation.py
controllers/web/workflow_events.py
core/agent/fc_agent_runner.py
core/app/apps/advanced_chat/app_generator.py
core/app/apps/advanced_chat/app_runner.py
core/app/apps/advanced_chat/generate_task_pipeline.py
core/app/apps/agent_chat/app_generator.py
core/app/apps/base_app_generate_response_converter.py
core/app/apps/base_app_generator.py
core/app/apps/chat/app_generator.py
core/app/apps/common/workflow_response_converter.py
core/app/apps/completion/app_generator.py
core/app/apps/pipeline/pipeline_generator.py
core/app/apps/pipeline/pipeline_runner.py
core/app/apps/workflow/app_generator.py
core/app/apps/workflow/app_runner.py
core/app/apps/workflow/generate_task_pipeline.py
core/app/apps/workflow_app_runner.py
core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
core/datasource/datasource_manager.py
core/external_data_tool/api/api.py
@@ -108,35 +93,6 @@ core/tools/workflow_as_tool/provider.py
core/trigger/debug/event_selectors.py
core/trigger/entities/entities.py
core/trigger/provider.py
core/workflow/workflow_entry.py
dify_graph/entities/workflow_execution.py
dify_graph/file/file_manager.py
dify_graph/graph_engine/error_handler.py
dify_graph/graph_engine/layers/execution_limits.py
dify_graph/nodes/agent/agent_node.py
dify_graph/nodes/base/node.py
dify_graph/nodes/code/code_node.py
dify_graph/nodes/datasource/datasource_node.py
dify_graph/nodes/document_extractor/node.py
dify_graph/nodes/human_input/human_input_node.py
dify_graph/nodes/if_else/if_else_node.py
dify_graph/nodes/iteration/iteration_node.py
dify_graph/nodes/knowledge_index/knowledge_index_node.py
core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
dify_graph/nodes/list_operator/node.py
dify_graph/nodes/llm/node.py
dify_graph/nodes/loop/loop_node.py
dify_graph/nodes/parameter_extractor/parameter_extractor_node.py
dify_graph/nodes/question_classifier/question_classifier_node.py
dify_graph/nodes/start/start_node.py
dify_graph/nodes/template_transform/template_transform_node.py
dify_graph/nodes/tool/tool_node.py
dify_graph/nodes/trigger_plugin/trigger_event_node.py
dify_graph/nodes/trigger_schedule/trigger_schedule_node.py
dify_graph/nodes/trigger_webhook/node.py
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
dify_graph/nodes/variable_assigner/v1/node.py
dify_graph/nodes/variable_assigner/v2/node.py
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
extensions/otel/instrumentation.py
extensions/otel/runtime.py

View File

@@ -1013,7 +1013,7 @@ class TestAdvancedChatAppGeneratorInternals:
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session)
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object()))
refreshed = _refresh_model(session=SimpleNamespace(), model=source_model)
refreshed = _refresh_model(session=None, model=source_model)
assert refreshed is detached_model

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any
from copy import deepcopy
from unittest.mock import MagicMock, patch
import pytest
@@ -33,8 +33,8 @@ def _make_graph_state():
],
)
def test_run_uses_single_node_execution_branch(
single_iteration_run: Any,
single_loop_run: Any,
single_iteration_run: WorkflowAppGenerateEntity.SingleIterationRunEntity | None,
single_loop_run: WorkflowAppGenerateEntity.SingleLoopRunEntity | None,
) -> None:
app_config = MagicMock()
app_config.app_id = "app"
@@ -130,10 +130,23 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
"break_conditions": [],
"logical_operator": "and",
},
},
{
"id": "other-node",
"data": {
"type": "answer",
"title": "Answer",
},
},
],
"edges": [
{
"source": "other-node",
"target": "loop-node",
}
],
"edges": [],
}
original_graph_dict = deepcopy(workflow.graph_dict)
_, _, graph_runtime_state = _make_graph_state()
seen_configs: list[object] = []
@@ -143,13 +156,19 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
seen_configs.append(value)
return original_validate_python(value)
class FakeNodeClass:
@staticmethod
def extract_variable_selector_to_variable_mapping(**_kwargs):
return {}
monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python)
with (
patch("core.app.apps.workflow_app_runner.DifyNodeFactory"),
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()),
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()) as graph_init,
patch("core.app.apps.workflow_app_runner.load_into_variable_pool"),
patch("core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool"),
patch("core.app.apps.workflow_app_runner.resolve_workflow_node_class", return_value=FakeNodeClass),
):
runner._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
@@ -161,3 +180,8 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
)
assert seen_configs == [workflow.graph_dict["nodes"][0]]
assert workflow.graph_dict == original_graph_dict
graph_config = graph_init.call_args.kwargs["graph_config"]
assert graph_config is not workflow.graph_dict
assert graph_config["nodes"] == [workflow.graph_dict["nodes"][0]]
assert graph_config["edges"] == []