Compare commits

..

12 Commits

Author SHA1 Message Date
Yanli 盐粒
2f81d5dfdf fix(api): restore typedict py311 compatibility 2026-03-17 20:30:18 +08:00
Yanli 盐粒
7639d8e43f fix(api): reuse advanced chat refresh session 2026-03-17 20:18:21 +08:00
Yanli 盐粒
1dce81c604 refactor(api): type single node workflow helpers 2026-03-17 20:16:14 +08:00
Yanli 盐粒
f874ca183e chore(api): remove phase 3 pyrefly excludes 2026-03-17 20:04:55 +08:00
Yanli 盐粒
0d805e624e Type phase 3 loop values 2026-03-17 19:39:54 +08:00
Yanli 盐粒
61196180b8 Type phase 3 tool inputs 2026-03-17 19:31:00 +08:00
Yanli 盐粒
79433b0091 Refine phase 3 typing boundaries 2026-03-17 19:13:12 +08:00
Yanli 盐粒
c4aeaa35d4 Type phase 3 schema contracts 2026-03-17 18:56:22 +08:00
Yanli 盐粒
9f0d79b8b0 Tighten phase 3 runtime typing 2026-03-17 18:49:14 +08:00
盐粒 Yanli
a717519822 refactor(api): tighten phase 1 shared type contracts (#33453) 2026-03-17 17:50:51 +08:00
zyssyz123
a592c53573 fix: auto-activate credential when provider record exists without act… (#33503) 2026-03-17 17:27:11 +08:00
-LAN-
239e09473e fix(web): preserve public workflow SSE reconnect after pause (#33552) 2026-03-17 16:41:08 +08:00
63 changed files with 1059 additions and 727 deletions

View File

@@ -78,7 +78,7 @@ class UserProfile(TypedDict):
nickname: NotRequired[str]
```
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
- For classes, declare all member variables explicitly with types at the top of the class body (before `__init__`), even when the class is not a dataclass or Pydantic model, so the class shape is obvious at a glance:
```python
from datetime import datetime

View File

@@ -1,4 +1,4 @@
from typing import Literal, Protocol
from typing import Literal, Protocol, cast
from urllib.parse import quote_plus, urlunparse
from pydantic import AliasChoices, Field
@@ -12,16 +12,13 @@ class RedisConfigDefaults(Protocol):
REDIS_PASSWORD: str | None
REDIS_DB: int
REDIS_USE_SSL: bool
REDIS_USE_SENTINEL: bool | None
REDIS_USE_CLUSTERS: bool
class RedisConfigDefaultsMixin:
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
return self
def _redis_defaults(config: object) -> RedisConfigDefaults:
return cast(RedisConfigDefaults, config)
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
class RedisPubSubConfig(BaseSettings):
"""
Configuration settings for event transport between API and workers.
@@ -74,7 +71,7 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
)
def _build_default_pubsub_url(self) -> str:
defaults = self._redis_defaults()
defaults = _redis_defaults(self)
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
@@ -91,11 +88,9 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
if userinfo:
userinfo = f"{userinfo}@"
host = defaults.REDIS_HOST
port = defaults.REDIS_PORT
db = defaults.REDIS_DB
netloc = f"{userinfo}{host}:{port}"
netloc = f"{userinfo}{defaults.REDIS_HOST}:{defaults.REDIS_PORT}"
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
@property

View File

@@ -5,7 +5,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
from typing import TYPE_CHECKING, Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@@ -47,7 +47,6 @@ from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.base import Base
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.workflow_draft_variable_service import (
@@ -524,6 +523,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
with Session(bind=db.engine, expire_on_commit=False) as session:
workflow = _refresh_model(session, workflow)
message = _refresh_model(session, message)
assert message is not None
# workflow_ = session.get(Workflow, workflow.id)
# assert workflow_ is not None
# workflow = workflow_
@@ -690,11 +690,20 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
raise e
_T = TypeVar("_T", bound=Base)
@overload
def _refresh_model(session: Session, model: Workflow) -> Workflow: ...
def _refresh_model(session, model: _T) -> _T:
with Session(bind=db.engine, expire_on_commit=False) as session:
detach_model = session.get(type(model), model.id)
assert detach_model is not None
return detach_model
@overload
def _refresh_model(session: Session, model: Message) -> Message: ...
def _refresh_model(session: Session, model: Workflow | Message) -> Workflow | Message:
if isinstance(model, Workflow):
detached_workflow = session.get(Workflow, model.id)
assert detached_workflow is not None
return detached_workflow
detached_message = session.get(Message, model.id)
assert detached_message is not None
return detached_message

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -56,7 +56,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, Any, None]:
"""
Convert stream full response.
@@ -87,7 +87,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, Any, None]:
"""
Convert stream simple response.

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -55,7 +55,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -86,7 +86,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -1,7 +1,7 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping
from typing import Any, Union
from collections.abc import Generator, Iterator, Mapping
from typing import Any
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
@@ -16,24 +16,26 @@ class AppGenerateResponseConverter(ABC):
@classmethod
def convert(
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
cls, response: AppBlockingResponse | Iterator[AppStreamResponse], invoke_from: InvokeFrom
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
else:
stream_response = response
def _generate_full_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_full_response(response)
def _generate_full_response() -> Generator[dict[str, Any] | str, None, None]:
yield from cls.convert_stream_full_response(stream_response)
return _generate_full_response()
else:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response)
else:
stream_response = response
def _generate_simple_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_simple_response(response)
def _generate_simple_response() -> Generator[dict[str, Any] | str, None, None]:
yield from cls.convert_stream_simple_response(stream_response)
return _generate_simple_response()
@@ -50,14 +52,14 @@ class AppGenerateResponseConverter(ABC):
@classmethod
@abstractmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
raise NotImplementedError

View File

@@ -224,6 +224,7 @@ class BaseAppGenerator:
def _get_draft_var_saver_factory(invoke_from: InvokeFrom, account: Account | EndUser) -> DraftVariableSaverFactory:
if invoke_from == InvokeFrom.DEBUGGER:
assert isinstance(account, Account)
debug_account = account
def draft_var_saver_factory(
session: Session,
@@ -240,7 +241,7 @@ class BaseAppGenerator:
node_type=node_type,
node_execution_id=node_execution_id,
enclosing_node_id=enclosing_node_id,
user=account,
user=debug_account,
)
else:

View File

@@ -166,15 +166,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
assert conversation is not None
assert message is not None
generated_conversation_id = str(conversation.id)
generated_message_id = str(message.id)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
conversation_id=generated_conversation_id,
app_mode=conversation.mode,
message_id=message.id,
message_id=generated_message_id,
)
# new thread with request context
@@ -184,8 +188,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation_id=conversation.id,
message_id=message.id,
conversation_id=generated_conversation_id,
message_id=generated_message_id,
)
worker_thread = threading.Thread(target=worker_with_context)

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -55,7 +55,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -86,7 +86,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -149,6 +149,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
assert conversation is not None
assert message is not None
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@@ -312,15 +314,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
assert conversation is not None
assert message is not None
conversation_id = str(conversation.id)
message_id = str(message.id)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
conversation_id=conversation_id,
app_mode=conversation.mode,
message_id=message.id,
message_id=message_id,
)
# new thread with request context
@@ -330,7 +336,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message_id=message.id,
message_id=message_id,
)
worker_thread = threading.Thread(target=worker_with_context)

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -54,7 +54,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -84,7 +84,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -1,13 +1,17 @@
import logging
import time
from collections.abc import Mapping, Sequence
from typing import Any, cast
from typing import Protocol, TypeAlias
from pydantic import ValidationError
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
from core.app.entities.app_invoke_entities import (
InvokeFrom,
UserFrom,
build_dify_run_context,
)
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
@@ -36,7 +40,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.graph import Graph
from dify_graph.graph_engine.layers.base import GraphEngineLayer
@@ -75,6 +79,14 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
logger = logging.getLogger(__name__)
GraphConfigObject: TypeAlias = dict[str, object]
GraphConfigMapping: TypeAlias = Mapping[str, object]
class SingleNodeRunEntity(Protocol):
node_id: str
inputs: Mapping[str, object]
class WorkflowBasedAppRunner:
def __init__(
@@ -98,7 +110,7 @@ class WorkflowBasedAppRunner:
def _init_graph(
self,
graph_config: Mapping[str, Any],
graph_config: GraphConfigMapping,
graph_runtime_state: GraphRuntimeState,
user_from: UserFrom,
invoke_from: InvokeFrom,
@@ -154,8 +166,8 @@ class WorkflowBasedAppRunner:
def _prepare_single_node_execution(
self,
workflow: Workflow,
single_iteration_run: Any | None = None,
single_loop_run: Any | None = None,
single_iteration_run: SingleNodeRunEntity | None = None,
single_loop_run: SingleNodeRunEntity | None = None,
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
"""
Prepare graph, variable pool, and runtime state for single node execution
@@ -208,11 +220,88 @@ class WorkflowBasedAppRunner:
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
return graph, variable_pool, graph_runtime_state
@staticmethod
def _get_graph_items(graph_config: GraphConfigMapping) -> tuple[list[GraphConfigMapping], list[GraphConfigMapping]]:
nodes = graph_config.get("nodes")
edges = graph_config.get("edges")
if not isinstance(nodes, list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(edges, list):
raise ValueError("edges in workflow graph must be a list")
validated_nodes: list[GraphConfigMapping] = []
for node in nodes:
if not isinstance(node, Mapping):
raise ValueError("nodes in workflow graph must be mappings")
validated_nodes.append(node)
validated_edges: list[GraphConfigMapping] = []
for edge in edges:
if not isinstance(edge, Mapping):
raise ValueError("edges in workflow graph must be mappings")
validated_edges.append(edge)
return validated_nodes, validated_edges
@staticmethod
def _extract_start_node_id(node_config: GraphConfigMapping | None) -> str | None:
if node_config is None:
return None
node_data = node_config.get("data")
if not isinstance(node_data, Mapping):
return None
start_node_id = node_data.get("start_node_id")
return start_node_id if isinstance(start_node_id, str) else None
@classmethod
def _build_single_node_graph_config(
cls,
*,
graph_config: GraphConfigMapping,
node_id: str,
node_type_filter_key: str,
) -> tuple[GraphConfigObject, NodeConfigDict]:
node_configs, edge_configs = cls._get_graph_items(graph_config)
main_node_config = next((node for node in node_configs if node.get("id") == node_id), None)
start_node_id = cls._extract_start_node_id(main_node_config)
filtered_node_configs = [
dict(node)
for node in node_configs
if node.get("id") == node_id
or (isinstance(node_data := node.get("data"), Mapping) and node_data.get(node_type_filter_key) == node_id)
or (start_node_id and node.get("id") == start_node_id)
]
if not filtered_node_configs:
raise ValueError(f"node id {node_id} not found in workflow graph")
filtered_node_ids = {
str(node_id_value) for node in filtered_node_configs if isinstance((node_id_value := node.get("id")), str)
}
filtered_edge_configs = [
dict(edge)
for edge in edge_configs
if (edge.get("source") is None or edge.get("source") in filtered_node_ids)
and (edge.get("target") is None or edge.get("target") in filtered_node_ids)
]
target_node_config = next((node for node in filtered_node_configs if node.get("id") == node_id), None)
if target_node_config is None:
raise ValueError(f"node id {node_id} not found in workflow graph")
return (
{
"nodes": filtered_node_configs,
"edges": filtered_edge_configs,
},
NodeConfigDictAdapter.validate_python(target_node_config),
)
def _get_graph_and_variable_pool_for_single_node_run(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
user_inputs: Mapping[str, object],
graph_runtime_state: GraphRuntimeState,
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
@@ -236,41 +325,14 @@ class WorkflowBasedAppRunner:
if not graph_config:
raise ValueError("workflow graph not found")
graph_config = cast(dict[str, Any], graph_config)
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# filter nodes only in the specified node type (iteration or loop)
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id
or node.get("data", {}).get(node_type_filter_key, "") == node_id
or (start_node_id and node.get("id") == start_node_id)
]
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs]
# filter edges only in the specified node type
edge_configs = [
edge
for edge in graph_config.get("edges", [])
if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
]
graph_config["edges"] = edge_configs
graph_config, target_node_config = self._build_single_node_graph_config(
graph_config=graph_config,
node_id=node_id,
node_type_filter_key=node_type_filter_key,
)
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
@@ -299,18 +361,6 @@ class WorkflowBasedAppRunner:
if not graph:
raise ValueError("graph not found in workflow")
# fetch node config from node id
target_node_config = None
for node in node_configs:
if node.get("id") == node_id:
target_node_config = node
break
if not target_node_config:
raise ValueError(f"{node_type_label} node id not found in workflow graph")
target_node_config = NodeConfigDictAdapter.validate_python(target_node_config)
# Get node class
node_type = target_node_config["data"].type
node_version = str(target_node_config["data"].version)

View File

@@ -213,7 +213,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""
node_id: str
inputs: Mapping
inputs: Mapping[str, object]
single_iteration_run: SingleIterationRunEntity | None = None
@@ -223,7 +223,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""
node_id: str
inputs: Mapping
inputs: Mapping[str, object]
single_loop_run: SingleLoopRunEntity | None = None
@@ -243,7 +243,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
node_id: str
inputs: dict
inputs: Mapping[str, object]
single_iteration_run: SingleIterationRunEntity | None = None
@@ -253,7 +253,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
node_id: str
inputs: dict
inputs: Mapping[str, object]
single_loop_run: SingleLoopRunEntity | None = None

View File

@@ -473,9 +473,21 @@ class ProviderConfiguration(BaseModel):
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
else:
# some historical data may have a provider record but not be set as valid
provider_record.is_valid = True
if provider_record.credential_id is None:
provider_record.credential_id = new_record.id
provider_record.updated_at = naive_utc_now()
provider_model_credentials_cache = ProviderCredentialsCache(
tenant_id=self.tenant_id,
identity_id=provider_record.id,
cache_type=ProviderCredentialsCacheType.PROVIDER,
)
provider_model_credentials_cache.delete()
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
session.commit()
except Exception:
session.rollback()

View File

@@ -196,6 +196,8 @@ class ProviderManager:
if preferred_provider_type_record:
preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type)
elif dify_config.EDITION == "CLOUD" and system_configuration.enabled:
preferred_provider_type = ProviderType.SYSTEM
elif custom_configuration.provider or custom_configuration.models:
preferred_provider_type = ProviderType.CUSTOM
elif system_configuration.enabled:

View File

@@ -1040,9 +1040,10 @@ class ToolManager:
continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
variable_selector = tool_input.require_variable_selector()
variable = variable_pool.get(variable_selector)
if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
raise ToolParameterError(f"Variable {variable_selector} does not exist")
parameter_value = variable.value
elif tool_input.type == "constant":
parameter_value = tool_input.value

View File

@@ -1,13 +1,24 @@
from enum import IntEnum, StrEnum, auto
from typing import Any, Literal, Union
from __future__ import annotations
from pydantic import BaseModel
from enum import IntEnum, StrEnum, auto
from typing import Literal, TypeAlias
from pydantic import BaseModel, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
AgentInputConstantValue: TypeAlias = (
list[ToolSelector] | str | int | float | bool | dict[str, object] | list[object] | None
)
VariableSelector: TypeAlias = list[str]
_AGENT_INPUT_VALUE_ADAPTER: TypeAdapter[AgentInputConstantValue] = TypeAdapter(AgentInputConstantValue)
_AGENT_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
class AgentNodeData(BaseNodeData):
type: NodeType = BuiltinNodeTypes.AGENT
@@ -21,8 +32,20 @@ class AgentNodeData(BaseNodeData):
tool_node_version: str | None = None
class AgentInput(BaseModel):
value: Union[list[str], list[ToolSelector], Any]
type: Literal["mixed", "variable", "constant"]
value: AgentInputConstantValue | VariableSelector
@field_validator("value", mode="before")
@classmethod
def validate_value(
cls, value: object, validation_info: ValidationInfo
) -> AgentInputConstantValue | VariableSelector:
input_type = validation_info.data.get("type")
if input_type == "variable":
return _AGENT_VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if input_type in {"mixed", "constant"}:
return _AGENT_INPUT_VALUE_ADAPTER.validate_python(value)
raise ValueError(f"Unknown agent input type: {input_type}")
agent_parameters: dict[str, AgentInput]

View File

@@ -1,16 +1,17 @@
from __future__ import annotations
import json
from collections.abc import Sequence
from typing import Any, cast
from collections.abc import Mapping, Sequence
from typing import TypeAlias
from packaging.version import Version
from pydantic import ValidationError
from pydantic import TypeAdapter, ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.app.entities.app_invoke_entities import InvokeFrom
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.plugin.entities.request import InvokeCredentials
@@ -28,6 +29,14 @@ from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGen
from .exceptions import AgentInputTypeError, AgentVariableNotFoundError
from .strategy_protocols import ResolvedAgentStrategy
JsonObject: TypeAlias = dict[str, object]
JsonObjectList: TypeAlias = list[JsonObject]
VariableSelector: TypeAlias = list[str]
_JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
_JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
_VARIABLE_SELECTOR_ADAPTER = TypeAdapter(VariableSelector)
class AgentRuntimeSupport:
def build_parameters(
@@ -39,12 +48,12 @@ class AgentRuntimeSupport:
strategy: ResolvedAgentStrategy,
tenant_id: str,
app_id: str,
invoke_from: Any,
invoke_from: InvokeFrom,
for_log: bool = False,
) -> dict[str, Any]:
) -> dict[str, object]:
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
result: dict[str, Any] = {}
result: dict[str, object] = {}
for parameter_name in node_data.agent_parameters:
parameter = agent_parameters_dictionary.get(parameter_name)
if not parameter:
@@ -54,9 +63,10 @@ class AgentRuntimeSupport:
agent_input = node_data.agent_parameters[parameter_name]
match agent_input.type:
case "variable":
variable = variable_pool.get(agent_input.value) # type: ignore[arg-type]
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(agent_input.value)
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
raise AgentVariableNotFoundError(str(variable_selector))
parameter_value = variable.value
case "mixed" | "constant":
try:
@@ -79,60 +89,38 @@ class AgentRuntimeSupport:
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
value = [tool for tool in value if tool.get("enabled", False)]
value = self._filter_mcp_type_tool(strategy, value)
for tool in value:
if "schemas" in tool:
tool.pop("schemas")
parameters = tool.get("parameters", {})
if all(isinstance(v, dict) for _, v in parameters.items()):
params = {}
for key, param in parameters.items():
if param.get("auto", ParamsAutoGenerated.OPEN) in (
ParamsAutoGenerated.CLOSE,
0,
):
value_param = param.get("value", {})
if value_param and value_param.get("type", "") == "variable":
variable_selector = value_param.get("value")
if not variable_selector:
raise ValueError("Variable selector is missing for a variable-type parameter.")
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
params[key] = variable.value
else:
params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
tool["parameters"] = parameters
tool_payloads = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
value = self._normalize_tool_payloads(
strategy=strategy,
tools=tool_payloads,
variable_pool=variable_pool,
)
if not for_log:
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
value = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
provider_type = self._coerce_tool_provider_type(tool.get("type"))
setting_params = self._coerce_json_object(tool.get("settings")) or {}
parameters = self._coerce_json_object(tool.get("parameters")) or {}
manual_input_params = [key for key, value in parameters.items() if value is not None]
parameters = {**parameters, **setting_params}
provider_id = self._coerce_optional_string(tool.get("provider_name")) or ""
tool_name = self._coerce_optional_string(tool.get("tool_name")) or ""
plugin_unique_identifier = self._coerce_optional_string(tool.get("plugin_unique_identifier"))
credential_id = self._coerce_optional_string(tool.get("credential_id"))
entity = AgentToolEntity(
provider_id=tool.get("provider_name", ""),
provider_id=provider_id,
provider_type=provider_type,
tool_name=tool.get("tool_name", ""),
tool_name=tool_name,
tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
credential_id=tool.get("credential_id", None),
plugin_unique_identifier=plugin_unique_identifier,
credential_id=credential_id,
)
extra = tool.get("extra", {})
extra = self._coerce_json_object(tool.get("extra")) or {}
runtime_variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None:
@@ -145,8 +133,9 @@ class AgentRuntimeSupport:
runtime_variable_pool,
)
if tool_runtime.entity.description:
description_override = self._coerce_optional_string(extra.get("description"))
tool_runtime.entity.description.llm = (
extra.get("description", "") or tool_runtime.entity.description.llm
description_override or tool_runtime.entity.description.llm
)
for tool_runtime_params in tool_runtime.entity.parameters:
tool_runtime_params.form = (
@@ -167,13 +156,13 @@ class AgentRuntimeSupport:
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
"credential_id": tool.get("credential_id", None),
"credential_id": credential_id,
"provider_type": provider_type.value,
}
)
value = tool_value
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value)
value = _JSON_OBJECT_ADAPTER.validate_python(value)
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
history_prompt_messages = []
if node_data.memory:
@@ -199,17 +188,27 @@ class AgentRuntimeSupport:
return result
def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials:
def build_credentials(self, *, parameters: Mapping[str, object]) -> InvokeCredentials:
credentials = InvokeCredentials()
credentials.tool_credentials = {}
for tool in parameters.get("tools", []):
tools = parameters.get("tools")
if not isinstance(tools, list):
return credentials
for raw_tool in tools:
tool = self._coerce_json_object(raw_tool)
if tool is None:
continue
if not tool.get("credential_id"):
continue
try:
identity = ToolIdentity.model_validate(tool.get("identity", {}))
except ValidationError:
continue
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
credential_id = self._coerce_optional_string(tool.get("credential_id"))
if credential_id is None:
continue
credentials.tool_credentials[identity.provider] = credential_id
return credentials
def fetch_memory(
@@ -232,14 +231,14 @@ class AgentRuntimeSupport:
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
def fetch_model(self, *, tenant_id: str, value: Mapping[str, object]) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=value.get("provider", ""),
provider=str(value.get("provider", "")),
model_type=ModelType.LLM,
)
model_name = value.get("model", "")
model_name = str(value.get("model", ""))
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM,
model=model_name,
@@ -249,7 +248,7 @@ class AgentRuntimeSupport:
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),
model_type=ModelType(str(value.get("model_type", ""))),
model=model_name,
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
@@ -268,9 +267,88 @@ class AgentRuntimeSupport:
@staticmethod
def _filter_mcp_type_tool(
strategy: ResolvedAgentStrategy,
tools: list[dict[str, Any]],
) -> list[dict[str, Any]]:
tools: JsonObjectList,
) -> JsonObjectList:
meta_version = strategy.meta_version
if meta_version and Version(meta_version) > Version("0.0.1"):
return tools
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
def _normalize_tool_payloads(
self,
*,
strategy: ResolvedAgentStrategy,
tools: JsonObjectList,
variable_pool: VariablePool,
) -> JsonObjectList:
enabled_tools = [dict(tool) for tool in tools if bool(tool.get("enabled", False))]
normalized_tools = self._filter_mcp_type_tool(strategy, enabled_tools)
for tool in normalized_tools:
tool.pop("schemas", None)
tool["parameters"] = self._resolve_tool_parameters(tool=tool, variable_pool=variable_pool)
tool["settings"] = self._resolve_tool_settings(tool)
return normalized_tools
def _resolve_tool_parameters(self, *, tool: Mapping[str, object], variable_pool: VariablePool) -> JsonObject:
parameter_configs = self._coerce_named_json_objects(tool.get("parameters"))
if parameter_configs is None:
raw_parameters = self._coerce_json_object(tool.get("parameters"))
return raw_parameters or {}
resolved_parameters: JsonObject = {}
for key, parameter_config in parameter_configs.items():
if parameter_config.get("auto", ParamsAutoGenerated.OPEN) in (ParamsAutoGenerated.CLOSE, 0):
value_param = self._coerce_json_object(parameter_config.get("value"))
if value_param and value_param.get("type") == "variable":
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(value_param.get("value"))
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
resolved_parameters[key] = variable.value
else:
resolved_parameters[key] = value_param.get("value", "") if value_param is not None else None
else:
resolved_parameters[key] = None
return resolved_parameters
@staticmethod
def _resolve_tool_settings(tool: Mapping[str, object]) -> JsonObject:
settings = AgentRuntimeSupport._coerce_named_json_objects(tool.get("settings"))
if settings is None:
return {}
return {key: setting.get("value") for key, setting in settings.items()}
@staticmethod
def _coerce_json_object(value: object) -> JsonObject | None:
try:
return _JSON_OBJECT_ADAPTER.validate_python(value)
except ValidationError:
return None
@staticmethod
def _coerce_optional_string(value: object) -> str | None:
return value if isinstance(value, str) else None
@staticmethod
def _coerce_tool_provider_type(value: object) -> ToolProviderType:
if isinstance(value, ToolProviderType):
return value
if isinstance(value, str):
return ToolProviderType(value)
return ToolProviderType.BUILT_IN
@classmethod
def _coerce_named_json_objects(cls, value: object) -> dict[str, JsonObject] | None:
if not isinstance(value, dict):
return None
coerced: dict[str, JsonObject] = {}
for key, item in value.items():
if not isinstance(key, str):
return None
json_object = cls._coerce_json_object(item)
if json_object is None:
return None
coerced[key] = json_object
return coerced

View File

@@ -1,7 +1,7 @@
import logging
import time
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from typing import Any, TypeAlias, cast
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
@@ -32,6 +32,13 @@ from models.workflow import Workflow
logger = logging.getLogger(__name__)
SpecialValueScalar: TypeAlias = str | int | float | bool | None
SpecialValue: TypeAlias = SpecialValueScalar | File | Mapping[str, "SpecialValue"] | list["SpecialValue"]
SerializedSpecialValue: TypeAlias = (
SpecialValueScalar | dict[str, "SerializedSpecialValue"] | list["SerializedSpecialValue"]
)
SingleNodeGraphConfig: TypeAlias = dict[str, list[dict[str, object]]]
class _WorkflowChildEngineBuilder:
@staticmethod
@@ -276,10 +283,10 @@ class WorkflowEntry:
@staticmethod
def _create_single_node_graph(
node_id: str,
node_data: dict[str, Any],
node_data: Mapping[str, object],
node_width: int = 114,
node_height: int = 514,
) -> dict[str, Any]:
) -> SingleNodeGraphConfig:
"""
Create a minimal graph structure for testing a single node in isolation.
@@ -289,14 +296,14 @@ class WorkflowEntry:
:param node_height: height for UI layout (default: 100)
:return: graph dictionary with start node and target node
"""
node_config = {
node_config: dict[str, object] = {
"id": node_id,
"width": node_width,
"height": node_height,
"type": "custom",
"data": node_data,
"data": dict(node_data),
}
start_node_config = {
start_node_config: dict[str, object] = {
"id": "start",
"width": node_width,
"height": node_height,
@@ -321,7 +328,12 @@ class WorkflowEntry:
@classmethod
def run_free_node(
cls, node_data: dict[str, Any], node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
cls,
node_data: Mapping[str, object],
node_id: str,
tenant_id: str,
user_id: str,
user_inputs: Mapping[str, object],
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
"""
Run free node
@@ -339,6 +351,8 @@ class WorkflowEntry:
graph_dict = cls._create_single_node_graph(node_id, node_data)
node_type = node_data.get("type", "")
if not isinstance(node_type, str):
raise ValueError("Node type must be a string")
if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
raise ValueError(f"Node type {node_type} not supported")
@@ -369,7 +383,7 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": dict(node_data)})
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
@@ -405,30 +419,34 @@ class WorkflowEntry:
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
@staticmethod
def handle_special_values(value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
def handle_special_values(value: Mapping[str, SpecialValue] | None) -> dict[str, SerializedSpecialValue] | None:
# NOTE(QuantumGhost): Avoid using this function in new code.
# Keep values structured as long as possible and only convert to dict
# immediately before serialization (e.g., JSON serialization) to maintain
# data integrity and type information.
result = WorkflowEntry._handle_special_values(value)
return result if isinstance(result, Mapping) or result is None else dict(result)
if result is None:
return None
if isinstance(result, dict):
return result
raise TypeError("handle_special_values expects a mapping input")
@staticmethod
def _handle_special_values(value: Any):
def _handle_special_values(value: SpecialValue) -> SerializedSpecialValue:
if value is None:
return value
if isinstance(value, dict):
res = {}
if isinstance(value, Mapping):
res: dict[str, SerializedSpecialValue] = {}
for k, v in value.items():
res[k] = WorkflowEntry._handle_special_values(v)
return res
if isinstance(value, list):
res_list = []
res_list: list[SerializedSpecialValue] = []
for item in value:
res_list.append(WorkflowEntry._handle_special_values(item))
return res_list
if isinstance(value, File):
return value.to_dict()
return dict(value.to_dict())
return value
@classmethod

View File

@@ -112,6 +112,8 @@ def _get_encoded_string(f: File, /) -> str:
data = _download_file_content(f.storage_key)
case FileTransferMethod.DATASOURCE_FILE:
data = _download_file_content(f.storage_key)
case _:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
return base64.b64encode(data).decode("utf-8")

View File

@@ -133,6 +133,8 @@ class ExecutionLimitsLayer(GraphEngineLayer):
elif limit_type == LimitType.TIME_LIMIT:
elapsed_time = time.time() - self.start_time if self.start_time else 0
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
else:
return
self.logger.warning("Execution limit exceeded: %s", reason)

View File

@@ -336,12 +336,7 @@ class Node(Generic[NodeDataT]):
def _restore_execution_id_from_runtime_state(self) -> str | None:
graph_execution = self.graph_runtime_state.graph_execution
try:
node_executions = graph_execution.node_executions
except AttributeError:
return None
if not isinstance(node_executions, dict):
return None
node_executions = graph_execution.node_executions
node_execution = node_executions.get(self._node_id)
if node_execution is None:
return None
@@ -395,8 +390,7 @@ class Node(Generic[NodeDataT]):
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
yield self._dispatch(event)
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
event.id = self.execution_id
yield event
yield event.model_copy(update={"id": self.execution_id})
else:
yield event
except Exception as e:

View File

@@ -443,7 +443,10 @@ def _extract_text_from_docx(file_content: bytes) -> str:
# Keep track of paragraph and table positions
content_items: list[tuple[int, str, Table | Paragraph]] = []
it = iter(doc.element.body)
doc_body = getattr(doc.element, "body", None)
if doc_body is None:
raise TextExtractionError("DOCX body not found")
it = iter(doc_body)
part = next(it, None)
i = 0
while part is not None:

View File

@@ -1,7 +1,8 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from typing import Literal, NotRequired
from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from dify_graph.entities.base_node_data import BaseNodeData
@@ -10,11 +11,17 @@ from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
from dify_graph.nodes.base.entities import VariableSelector
class StructuredOutputConfig(TypedDict):
schema: Mapping[str, object]
name: NotRequired[str]
description: NotRequired[str]
class ModelConfig(BaseModel):
provider: str
name: str
mode: LLMMode
completion_params: dict[str, Any] = Field(default_factory=dict)
completion_params: dict[str, object] = Field(default_factory=dict)
class ContextConfig(BaseModel):
@@ -33,7 +40,7 @@ class VisionConfig(BaseModel):
@field_validator("configs", mode="before")
@classmethod
def convert_none_configs(cls, v: Any):
def convert_none_configs(cls, v: object):
if v is None:
return VisionConfigOptions()
return v
@@ -44,7 +51,7 @@ class PromptConfig(BaseModel):
@field_validator("jinja2_variables", mode="before")
@classmethod
def convert_none_jinja2_variables(cls, v: Any):
def convert_none_jinja2_variables(cls, v: object):
if v is None:
return []
return v
@@ -67,7 +74,7 @@ class LLMNodeData(BaseNodeData):
memory: MemoryConfig | None = None
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: Mapping[str, Any] | None = None
structured_output: StructuredOutputConfig | None = None
# We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
reasoning_format: Literal["separated", "tagged"] = Field(
@@ -90,7 +97,7 @@ class LLMNodeData(BaseNodeData):
@field_validator("prompt_config", mode="before")
@classmethod
def convert_none_prompt_config(cls, v: Any):
def convert_none_prompt_config(cls, v: object):
if v is None:
return PromptConfig()
return v

View File

@@ -9,6 +9,7 @@ import time
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from pydantic import TypeAdapter
from sqlalchemy import select
from core.llm_generator.output_parser.errors import OutputParserError
@@ -74,6 +75,7 @@ from .entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
StructuredOutputConfig,
)
from .exc import (
InvalidContextStructureError,
@@ -88,6 +90,7 @@ if TYPE_CHECKING:
from dify_graph.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)
_JSON_OBJECT_ADAPTER = TypeAdapter(dict[str, object])
class LLMNode(Node[LLMNodeData]):
@@ -354,7 +357,7 @@ class LLMNode(Node[LLMNodeData]):
stop: Sequence[str] | None = None,
user_id: str,
structured_output_enabled: bool,
structured_output: Mapping[str, Any] | None = None,
structured_output: StructuredOutputConfig | None = None,
file_saver: LLMFileSaver,
file_outputs: list[File],
node_id: str,
@@ -367,8 +370,10 @@ class LLMNode(Node[LLMNodeData]):
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
if structured_output_enabled:
if structured_output is None:
raise LLMNodeError("Please provide a valid structured output schema")
output_schema = LLMNode.fetch_structured_output_schema(
structured_output=structured_output or {},
structured_output=structured_output,
)
request_start_time = time.perf_counter()
@@ -920,6 +925,12 @@ class LLMNode(Node[LLMNodeData]):
# Extract clean text and reasoning from <think> tags
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
structured_output = (
dict(invoke_result.structured_output)
if isinstance(invoke_result, LLMResultWithStructuredOutput) and invoke_result.structured_output is not None
else None
)
event = ModelInvokeCompletedEvent(
# Use clean_text for separated mode, full_text for tagged mode
text=clean_text if reasoning_format == "separated" else full_text,
@@ -928,7 +939,7 @@ class LLMNode(Node[LLMNodeData]):
# Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content,
# Pass structured output if enabled
structured_output=getattr(invoke_result, "structured_output", None),
structured_output=structured_output,
)
if request_latency is not None:
event.usage.latency = round(request_latency, 3)
@@ -962,27 +973,18 @@ class LLMNode(Node[LLMNodeData]):
@staticmethod
def fetch_structured_output_schema(
*,
structured_output: Mapping[str, Any],
) -> dict[str, Any]:
structured_output: StructuredOutputConfig,
) -> dict[str, object]:
"""
Fetch the structured output schema from the node data.
Returns:
dict[str, Any]: The structured output schema
dict[str, object]: The structured output schema
"""
if not structured_output:
schema = structured_output.get("schema")
if not schema:
raise LLMNodeError("Please provide a valid structured output schema")
structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
if not structured_output_schema:
raise LLMNodeError("Please provide a valid structured output schema")
try:
schema = json.loads(structured_output_schema)
if not isinstance(schema, dict):
raise LLMNodeError("structured_output_schema must be a JSON object")
return schema
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
return _JSON_OBJECT_ADAPTER.validate_python(schema)
@staticmethod
def _save_multimodal_output_and_convert_result_to_markdown(

View File

@@ -1,7 +1,10 @@
from enum import StrEnum
from typing import Annotated, Any, Literal
from __future__ import annotations
from pydantic import AfterValidator, BaseModel, Field, field_validator
from enum import StrEnum
from typing import Annotated, Literal, TypeAlias
from pydantic import AfterValidator, BaseModel, Field, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
@@ -9,6 +12,14 @@ from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
from dify_graph.utils.condition.entities import Condition
from dify_graph.variables.types import SegmentType
LoopValue: TypeAlias = str | int | float | bool | None | dict[str, "LoopValue"] | list["LoopValue"]
LoopValueMapping: TypeAlias = dict[str, LoopValue]
VariableSelector: TypeAlias = list[str]
_LOOP_VALUE_ADAPTER: TypeAdapter[LoopValue] = TypeAdapter(LoopValue)
_LOOP_VALUE_MAPPING_ADAPTER: TypeAdapter[LoopValueMapping] = TypeAdapter(LoopValueMapping)
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
_VALID_VAR_TYPE = frozenset(
[
SegmentType.STRING,
@@ -37,7 +48,29 @@ class LoopVariableData(BaseModel):
label: str
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
value_type: Literal["variable", "constant"]
value: Any | list[str] | None = None
value: LoopValue | VariableSelector | None = None
@field_validator("value", mode="before")
@classmethod
def validate_value(cls, value: object, validation_info: ValidationInfo) -> LoopValue | VariableSelector | None:
value_type = validation_info.data.get("value_type")
if value_type == "variable":
if value is None:
return None
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if value_type == "constant":
return _LOOP_VALUE_ADAPTER.validate_python(value)
raise ValueError(f"Unknown loop variable value type: {value_type}")
def require_variable_selector(self) -> VariableSelector:
if self.value_type != "variable":
raise ValueError(f"Expected variable loop input, got {self.value_type}")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
def require_constant_value(self) -> LoopValue:
if self.value_type != "constant":
raise ValueError(f"Expected constant loop input, got {self.value_type}")
return _LOOP_VALUE_ADAPTER.validate_python(self.value)
class LoopNodeData(BaseLoopNodeData):
@@ -46,14 +79,14 @@ class LoopNodeData(BaseLoopNodeData):
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
outputs: dict[str, Any] = Field(default_factory=dict)
outputs: LoopValueMapping = Field(default_factory=dict)
@field_validator("outputs", mode="before")
@classmethod
def validate_outputs(cls, v):
if v is None:
def validate_outputs(cls, value: object) -> LoopValueMapping:
if value is None:
return {}
return v
return _LOOP_VALUE_MAPPING_ADAPTER.validate_python(value)
class LoopStartNodeData(BaseNodeData):
@@ -77,8 +110,8 @@ class LoopState(BaseLoopState):
Loop State.
"""
outputs: list[Any] = Field(default_factory=list)
current_output: Any = None
outputs: list[LoopValue] = Field(default_factory=list)
current_output: LoopValue | None = None
class MetaData(BaseLoopState.MetaData):
"""
@@ -87,7 +120,7 @@ class LoopState(BaseLoopState):
loop_length: int
def get_last_output(self) -> Any:
def get_last_output(self) -> LoopValue | None:
"""
Get last output.
"""
@@ -95,7 +128,7 @@ class LoopState(BaseLoopState):
return self.outputs[-1]
return None
def get_current_output(self) -> Any:
def get_current_output(self) -> LoopValue | None:
"""
Get current output.
"""

View File

@@ -3,7 +3,7 @@ import json
import logging
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Literal, cast
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.enums import (
@@ -29,7 +29,7 @@ from dify_graph.node_events import (
)
from dify_graph.nodes.base import LLMUsageTrackingMixin
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopValue, LoopVariableData
from dify_graph.utils.condition.processor import ConditionProcessor
from dify_graph.variables import Segment, SegmentType
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
@@ -60,7 +60,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
break_conditions = self.node_data.break_conditions
logical_operator = self.node_data.logical_operator
inputs = {"loop_count": loop_count}
inputs: dict[str, object] = {"loop_count": loop_count}
if not self.node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
@@ -68,12 +68,14 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
root_node_id = self.node_data.start_node_id
# Initialize loop variables in the original variable pool
loop_variable_selectors = {}
loop_variable_selectors: dict[str, list[str]] = {}
if self.node_data.loop_variables:
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.require_constant_value()),
"variable": lambda var: (
self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None
self.graph_runtime_state.variable_pool.get(var.require_variable_selector())
if var.value is not None
else None
),
}
for loop_variable in self.node_data.loop_variables:
@@ -95,7 +97,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
condition_processor = ConditionProcessor()
loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
single_loop_variable_map: dict[str, dict[str, LoopValue]] = {} # single loop variable output
loop_usage = LLMUsage.empty_usage()
loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
@@ -146,7 +148,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
# Collect loop variable values after iteration
single_loop_variable = {}
single_loop_variable: dict[str, LoopValue] = {}
for key, selector in loop_variable_selectors.items():
segment = self.graph_runtime_state.variable_pool.get(selector)
single_loop_variable[key] = segment.value if segment else None
@@ -297,20 +299,29 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
graph_config: Mapping[str, object],
node_id: str,
node_data: LoopNodeData,
) -> Mapping[str, Sequence[str]]:
variable_mapping = {}
variable_mapping: dict[str, Sequence[str]] = {}
# Extract loop node IDs statically from graph_config
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
# Get node configs from graph_config
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
raw_nodes = graph_config.get("nodes")
node_configs: dict[str, Mapping[str, object]] = {}
if isinstance(raw_nodes, list):
for raw_node in raw_nodes:
if not isinstance(raw_node, dict):
continue
raw_node_id = raw_node.get("id")
if isinstance(raw_node_id, str):
node_configs[raw_node_id] = raw_node
for sub_node_id, sub_node_config in node_configs.items():
if sub_node_config.get("data", {}).get("loop_id") != node_id:
sub_node_data = sub_node_config.get("data")
if not isinstance(sub_node_data, dict) or sub_node_data.get("loop_id") != node_id:
continue
# variable selector to variable mapping
@@ -341,9 +352,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
for loop_variable in node_data.loop_variables or []:
if loop_variable.value_type == "variable":
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
# add loop variable to variable mapping
selector = loop_variable.value
selector = loop_variable.require_variable_selector()
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
# remove variable out from loop
@@ -352,7 +362,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
return variable_mapping
@classmethod
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, object], loop_node_id: str) -> set[str]:
"""
Extract node IDs that belong to a specific loop from graph configuration.
@@ -363,12 +373,19 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
:param loop_node_id: the ID of the loop node
:return: set of node IDs that belong to the loop
"""
loop_node_ids = set()
loop_node_ids: set[str] = set()
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
raw_nodes = graph_config.get("nodes")
if not isinstance(raw_nodes, list):
return loop_node_ids
for node in raw_nodes:
if not isinstance(node, dict):
continue
node_data = node.get("data")
if not isinstance(node_data, dict):
continue
if node_data.get("loop_id") == loop_node_id:
node_id = node.get("id")
if node_id:
@@ -377,7 +394,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
return loop_node_ids
@staticmethod
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
def _get_segment_for_constant(var_type: SegmentType, original_value: LoopValue | None) -> Segment:
"""Get the appropriate segment type for a constant value."""
# TODO: Refactor for maintainability:
# 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)

View File

@@ -1,4 +1,4 @@
from typing import Annotated, Any, Literal
from typing import Annotated, Literal
from pydantic import (
BaseModel,
@@ -6,6 +6,7 @@ from pydantic import (
Field,
field_validator,
)
from typing_extensions import TypedDict
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from dify_graph.entities.base_node_data import BaseNodeData
@@ -55,7 +56,7 @@ class ParameterConfig(BaseModel):
@field_validator("name", mode="before")
@classmethod
def validate_name(cls, value) -> str:
def validate_name(cls, value: object) -> str:
if not value:
raise ValueError("Parameter name is required")
if value in {"__reason", "__is_success"}:
@@ -79,6 +80,23 @@ class ParameterConfig(BaseModel):
return element_type
class JsonSchemaArrayItems(TypedDict):
type: str
class ParameterJsonSchemaProperty(TypedDict, total=False):
description: str
type: str
items: JsonSchemaArrayItems
enum: list[str]
class ParameterJsonSchema(TypedDict):
type: Literal["object"]
properties: dict[str, ParameterJsonSchemaProperty]
required: list[str]
class ParameterExtractorNodeData(BaseNodeData):
"""
Parameter Extractor Node Data.
@@ -95,19 +113,19 @@ class ParameterExtractorNodeData(BaseNodeData):
@field_validator("reasoning_mode", mode="before")
@classmethod
def set_reasoning_mode(cls, v) -> str:
return v or "function_call"
def set_reasoning_mode(cls, v: object) -> str:
return str(v) if v else "function_call"
def get_parameter_json_schema(self):
def get_parameter_json_schema(self) -> ParameterJsonSchema:
"""
Get parameter json schema.
:return: parameter json schema
"""
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
parameters: ParameterJsonSchema = {"type": "object", "properties": {}, "required": []}
for parameter in self.parameters:
parameter_schema: dict[str, Any] = {"description": parameter.description}
parameter_schema: ParameterJsonSchemaProperty = {"description": parameter.description}
if parameter.type == SegmentType.STRING:
parameter_schema["type"] = "string"
@@ -118,7 +136,7 @@ class ParameterExtractorNodeData(BaseNodeData):
raise AssertionError("element type should not be None.")
parameter_schema["items"] = {"type": element_type.value}
else:
parameter_schema["type"] = parameter.type
parameter_schema["type"] = parameter.type.value
if parameter.options:
parameter_schema["enum"] = parameter.options

View File

@@ -5,6 +5,8 @@ import uuid
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from pydantic import TypeAdapter
from core.model_manager import ModelInstance
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
@@ -63,6 +65,7 @@ from .prompts import (
)
logger = logging.getLogger(__name__)
_JSON_OBJECT_ADAPTER = TypeAdapter(dict[str, object])
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
@@ -70,7 +73,7 @@ if TYPE_CHECKING:
from dify_graph.runtime import GraphRuntimeState
def extract_json(text):
def extract_json(text: str) -> str | None:
"""
From a given JSON started from '{' or '[' extract the complete JSON object.
"""
@@ -392,10 +395,15 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
)
# generate tool
parameter_schema = node_data.get_parameter_json_schema()
tool = PromptMessageTool(
name=FUNCTION_CALLING_EXTRACTOR_NAME,
description="Extract parameters from the natural language text",
parameters=node_data.get_parameter_json_schema(),
parameters={
"type": parameter_schema["type"],
"properties": dict(parameter_schema["properties"]),
"required": list(parameter_schema["required"]),
},
)
return prompt_messages, [tool]
@@ -602,19 +610,21 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
else:
return None
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
def _transform_result(self, data: ParameterExtractorNodeData, result: Mapping[str, object]) -> dict[str, object]:
"""
Transform result into standard format.
"""
transformed_result: dict[str, Any] = {}
transformed_result: dict[str, object] = {}
for parameter in data.parameters:
if parameter.name in result:
param_value = result[parameter.name]
# transform value
if parameter.type == SegmentType.NUMBER:
transformed = self._transform_number(param_value)
if transformed is not None:
transformed_result[parameter.name] = transformed
if isinstance(param_value, (bool, int, float, str)):
numeric_value: bool | int | float | str = param_value
transformed = self._transform_number(numeric_value)
if transformed is not None:
transformed_result[parameter.name] = transformed
elif parameter.type == SegmentType.BOOLEAN:
if isinstance(result[parameter.name], (bool, int)):
transformed_result[parameter.name] = bool(result[parameter.name])
@@ -661,7 +671,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
return transformed_result
def _extract_complete_json_response(self, result: str) -> dict | None:
def _extract_complete_json_response(self, result: str) -> dict[str, object] | None:
"""
Extract complete json response.
"""
@@ -672,11 +682,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
json_str = extract_json(result[idx:])
if json_str:
with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str))
return _JSON_OBJECT_ADAPTER.validate_python(json.loads(json_str))
logger.info("extra error: %s", result)
return None
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict[str, object] | None:
"""
Extract json from tool call.
"""
@@ -690,16 +700,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
json_str = extract_json(result[idx:])
if json_str:
with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str))
return _JSON_OBJECT_ADAPTER.validate_python(json.loads(json_str))
logger.info("extra error: %s", result)
return None
def _generate_default_result(self, data: ParameterExtractorNodeData):
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict[str, object]:
"""
Generate default result.
"""
result: dict[str, Any] = {}
result: dict[str, object] = {}
for parameter in data.parameters:
if parameter.type == "number":
result[parameter.name] = 0

View File

@@ -1,12 +1,24 @@
from typing import Any, Literal, Union
from __future__ import annotations
from pydantic import BaseModel, field_validator
from typing import Literal, TypeAlias
from pydantic import BaseModel, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.tools.entities.tool_entities import ToolProviderType
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
ToolConfigurationValue: TypeAlias = str | int | float | bool
ToolConfigurations: TypeAlias = dict[str, ToolConfigurationValue]
ToolInputConstantValue: TypeAlias = str | int | float | bool | dict[str, object] | list[object] | None
VariableSelector: TypeAlias = list[str]
_TOOL_CONFIGURATIONS_ADAPTER: TypeAdapter[ToolConfigurations] = TypeAdapter(ToolConfigurations)
_TOOL_INPUT_MIXED_ADAPTER: TypeAdapter[str] = TypeAdapter(str)
_TOOL_INPUT_CONSTANT_ADAPTER: TypeAdapter[ToolInputConstantValue] = TypeAdapter(ToolInputConstantValue)
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
class ToolEntity(BaseModel):
provider_id: str
@@ -14,52 +26,41 @@ class ToolEntity(BaseModel):
provider_name: str # redundancy
tool_name: str
tool_label: str # redundancy
tool_configurations: dict[str, Any]
tool_configurations: ToolConfigurations
credential_id: str | None = None
plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before")
@classmethod
def validate_tool_configurations(cls, value, values: ValidationInfo):
if not isinstance(value, dict):
raise ValueError("tool_configurations must be a dictionary")
for key in values.data.get("tool_configurations", {}):
value = values.data.get("tool_configurations", {}).get(key)
if not isinstance(value, str | int | float | bool):
raise ValueError(f"{key} must be a string")
return value
def validate_tool_configurations(cls, value: object, _validation_info: ValidationInfo) -> ToolConfigurations:
return _TOOL_CONFIGURATIONS_ADAPTER.validate_python(value)
class ToolNodeData(BaseNodeData, ToolEntity):
type: NodeType = BuiltinNodeTypes.TOOL
class ToolInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"]
value: ToolInputConstantValue | VariableSelector
@field_validator("type", mode="before")
@field_validator("value", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
typ = value
value = validation_info.data.get("value")
def validate_value(
cls, value: object, validation_info: ValidationInfo
) -> ToolInputConstantValue | VariableSelector:
input_type = validation_info.data.get("type")
if input_type == "mixed":
return _TOOL_INPUT_MIXED_ADAPTER.validate_python(value)
if input_type == "variable":
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if input_type == "constant":
return _TOOL_INPUT_CONSTANT_ADAPTER.validate_python(value)
raise ValueError(f"Unknown tool input type: {input_type}")
if value is None:
return typ
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))):
raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}")
return typ
def require_variable_selector(self) -> VariableSelector:
if self.type != "variable":
raise ValueError(f"Expected variable tool input, got {self.type}")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
tool_parameters: dict[str, ToolInput]
# The version of the tool parameter.
@@ -69,7 +70,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
@field_validator("tool_parameters", mode="before")
@classmethod
def filter_none_tool_inputs(cls, value):
def filter_none_tool_inputs(cls, value: object) -> object:
if not isinstance(value, dict):
return value
@@ -80,8 +81,10 @@ class ToolNodeData(BaseNodeData, ToolEntity):
}
@staticmethod
def _has_valid_value(tool_input):
def _has_valid_value(tool_input: object) -> bool:
"""Check if the value is valid"""
if isinstance(tool_input, dict):
return tool_input.get("value") is not None
return getattr(tool_input, "value", None) is not None
if isinstance(tool_input, ToolNodeData.ToolInput):
return tool_input.value is not None
return False

View File

@@ -225,10 +225,11 @@ class ToolNode(Node[ToolNodeData]):
continue
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
variable_selector = tool_input.require_variable_selector()
variable = variable_pool.get(variable_selector)
if variable is None:
if parameter.required:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
raise ToolParameterError(f"Variable {variable_selector} does not exist")
continue
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
@@ -510,8 +511,9 @@ class ToolNode(Node[ToolNodeData]):
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
variable_selector = input.require_variable_selector()
selector_key = ".".join(variable_selector)
result[f"#{selector_key}#"] = variable_selector
case "constant":
pass

View File

@@ -9,7 +9,7 @@ from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from dify_graph.variables import SegmentType, VariableBase
from dify_graph.variables import Segment, SegmentType, VariableBase
from .node_data import VariableAssignerData, WriteMode
@@ -74,23 +74,29 @@ class VariableAssignerNode(Node[VariableAssignerData]):
if not isinstance(original_variable, VariableBase):
raise VariableOperatorNodeError("assigned variable not found")
income_value: Segment
updated_variable: VariableBase
match self.node_data.write_mode:
case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if input_value is None:
raise VariableOperatorNodeError("input value not found")
income_value = input_value
updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if input_value is None:
raise VariableOperatorNodeError("input value not found")
income_value = input_value
updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={"value": updated_value})
case WriteMode.CLEAR:
income_value = SegmentType.get_zero_value(original_variable.value_type)
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
# Over write the variable.
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)

View File

@@ -66,6 +66,11 @@ class GraphExecutionProtocol(Protocol):
exceptions_count: int
pause_reasons: list[PauseReason]
@property
def node_executions(self) -> Mapping[str, NodeExecutionProtocol]:
"""Return node execution state keyed by node id for resume support."""
...
def start(self) -> None:
"""Transition execution into the running state."""
...
@@ -91,6 +96,12 @@ class GraphExecutionProtocol(Protocol):
...
class NodeExecutionProtocol(Protocol):
"""Structural interface for per-node execution state used during resume."""
execution_id: str | None
class ResponseStreamCoordinatorProtocol(Protocol):
"""Structural interface for response stream coordinator."""

View File

@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
from dify_graph.file.models import File
if TYPE_CHECKING:
pass
from dify_graph.variables.segments import Segment
class ArrayValidation(StrEnum):
@@ -219,7 +219,7 @@ class SegmentType(StrEnum):
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
@staticmethod
def get_zero_value(t: SegmentType):
def get_zero_value(t: SegmentType) -> Segment:
# Lazy import to avoid circular dependency
from factories import variable_factory

View File

@@ -1,3 +1,5 @@
from typing import Protocol, cast
from fastopenapi.routers import FlaskRouter
from flask_cors import CORS
@@ -9,6 +11,10 @@ from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS
DOCS_PREFIX = "/fastopenapi"
class SupportsIncludeRouter(Protocol):
def include_router(self, router: object, *, prefix: str = "") -> None: ...
def init_app(app: DifyApp) -> None:
docs_enabled = dify_config.SWAGGER_UI_ENABLED
docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None
@@ -36,7 +42,7 @@ def init_app(app: DifyApp) -> None:
_ = remote_files
_ = setup
router.include_router(console_router, prefix="/console/api")
cast(SupportsIncludeRouter, router).include_router(console_router, prefix="/console/api")
CORS(
app,
resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},

View File

@@ -55,7 +55,7 @@ class TypeMismatchError(Exception):
# Define the constant
SEGMENT_TO_VARIABLE_MAP = {
SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[VariableBase]] = {
ArrayAnySegment: ArrayAnyVariable,
ArrayBooleanSegment: ArrayBooleanVariable,
ArrayFileSegment: ArrayFileVariable,
@@ -296,13 +296,11 @@ def segment_to_variable(
raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
return cast(
VariableBase,
variable_class(
id=id,
name=name,
description=description,
value=segment.value,
selector=list(selector),
),
return variable_class(
id=id,
name=name,
description=description,
value_type=segment.value_type,
value=segment.value,
selector=list(selector),
)

View File

@@ -32,6 +32,11 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _stream_with_request_context(response: object) -> Any:
"""Bridge Flask's loosely-typed streaming helper without leaking casts into callers."""
return cast(Any, stream_with_context)(response)
def escape_like_pattern(pattern: str) -> str:
"""
Escape special characters in a string for safe use in SQL LIKE patterns.
@@ -286,22 +291,32 @@ def generate_text_hash(text: str) -> str:
return sha256(hash_text.encode()).hexdigest()
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
if isinstance(response, dict):
def compact_generate_response(
response: Mapping[str, Any] | Generator[str, None, None] | RateLimitGenerator,
) -> Response:
if isinstance(response, Mapping):
return Response(
response=json.dumps(jsonable_encoder(response)),
status=200,
content_type="application/json; charset=utf-8",
)
else:
stream_response = response
def generate() -> Generator:
yield from response
def generate() -> Generator[str, None, None]:
yield from stream_response
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
return Response(
_stream_with_request_context(generate()),
status=200,
mimetype="text/event-stream",
)
def length_prefixed_response(magic_number: int, response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
def length_prefixed_response(
magic_number: int,
response: Mapping[str, Any] | BaseModel | Generator[str | bytes, None, None] | RateLimitGenerator,
) -> Response:
"""
This function is used to return a response with a length prefix.
Magic number is a one byte number that indicates the type of the response.
@@ -332,7 +347,7 @@ def length_prefixed_response(magic_number: int, response: Union[Mapping, Generat
# | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data
return struct.pack("<BBHI", magic_number, 0, header_length, data_length) + b"\x00" * 6 + response
if isinstance(response, dict):
if isinstance(response, Mapping):
return Response(
response=pack_response_with_length_prefix(json.dumps(jsonable_encoder(response)).encode("utf-8")),
status=200,
@@ -345,14 +360,20 @@ def length_prefixed_response(magic_number: int, response: Union[Mapping, Generat
mimetype="application/json",
)
def generate() -> Generator:
for chunk in response:
stream_response = response
def generate() -> Generator[bytes, None, None]:
for chunk in stream_response:
if isinstance(chunk, str):
yield pack_response_with_length_prefix(chunk.encode("utf-8"))
else:
yield pack_response_with_length_prefix(chunk)
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
return Response(
_stream_with_request_context(generate()),
status=200,
mimetype="text/event-stream",
)
class TokenManager:

View File

@@ -77,12 +77,14 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]
@wraps(func)
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
pass
elif current_user is not None and not current_user.is_authenticated:
return current_app.ensure_sync(func)(*args, **kwargs)
user = _get_user()
if user is None or not user.is_authenticated:
return current_app.login_manager.unauthorized() # type: ignore
# we put csrf validation here for less conflicts
# TODO: maybe find a better place for it.
check_csrf_token(request, current_user.id)
check_csrf_token(request, user.id)
return current_app.ensure_sync(func)(*args, **kwargs)
return decorated_view

View File

@@ -7,9 +7,10 @@ https://github.com/django/django/blob/main/django/utils/module_loading.py
import sys
from importlib import import_module
from typing import Any
def cached_import(module_path: str, class_name: str):
def cached_import(module_path: str, class_name: str) -> Any:
"""
Import a module and return the named attribute/class from it, with caching.
@@ -20,16 +21,14 @@ def cached_import(module_path: str, class_name: str):
Returns:
The imported attribute/class
"""
if not (
(module := sys.modules.get(module_path))
and (spec := getattr(module, "__spec__", None))
and getattr(spec, "_initializing", False) is False
):
module = sys.modules.get(module_path)
spec = getattr(module, "__spec__", None) if module is not None else None
if module is None or getattr(spec, "_initializing", False):
module = import_module(module_path)
return getattr(module, class_name)
def import_string(dotted_path: str):
def import_string(dotted_path: str) -> Any:
"""
Import a dotted module path and return the attribute/class designated by
the last name in the path. Raise ImportError if the import failed.

View File

@@ -1,7 +1,48 @@
import sys
import urllib.parse
from dataclasses import dataclass
from typing import NotRequired
import httpx
from pydantic import TypeAdapter
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
JsonObject = dict[str, object]
JsonObjectList = list[JsonObject]
JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
class AccessTokenResponse(TypedDict, total=False):
access_token: str
class GitHubEmailRecord(TypedDict, total=False):
email: str
primary: bool
class GitHubRawUserInfo(TypedDict):
id: int | str
login: str
name: NotRequired[str]
email: NotRequired[str]
class GoogleRawUserInfo(TypedDict):
sub: str
email: str
ACCESS_TOKEN_RESPONSE_ADAPTER = TypeAdapter(AccessTokenResponse)
GITHUB_RAW_USER_INFO_ADAPTER = TypeAdapter(GitHubRawUserInfo)
GITHUB_EMAIL_RECORDS_ADAPTER = TypeAdapter(list[GitHubEmailRecord])
GOOGLE_RAW_USER_INFO_ADAPTER = TypeAdapter(GoogleRawUserInfo)
@dataclass
@@ -11,26 +52,38 @@ class OAuthUserInfo:
email: str
def _json_object(response: httpx.Response) -> JsonObject:
return JSON_OBJECT_ADAPTER.validate_python(response.json())
def _json_list(response: httpx.Response) -> JsonObjectList:
return JSON_OBJECT_LIST_ADAPTER.validate_python(response.json())
class OAuth:
client_id: str
client_secret: str
redirect_uri: str
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_authorization_url(self):
def get_authorization_url(self, invite_token: str | None = None) -> str:
raise NotImplementedError()
def get_access_token(self, code: str):
def get_access_token(self, code: str) -> str:
raise NotImplementedError()
def get_raw_user_info(self, token: str):
def get_raw_user_info(self, token: str) -> JsonObject:
raise NotImplementedError()
def get_user_info(self, token: str) -> OAuthUserInfo:
raw_info = self.get_raw_user_info(token)
return self._transform_user_info(raw_info)
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
raise NotImplementedError()
@@ -40,7 +93,7 @@ class GitHubOAuth(OAuth):
_USER_INFO_URL = "https://api.github.com/user"
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
def get_authorization_url(self, invite_token: str | None = None):
def get_authorization_url(self, invite_token: str | None = None) -> str:
params = {
"client_id": self.client_id,
"redirect_uri": self.redirect_uri,
@@ -50,7 +103,7 @@ class GitHubOAuth(OAuth):
params["state"] = invite_token
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):
def get_access_token(self, code: str) -> str:
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
@@ -60,7 +113,7 @@ class GitHubOAuth(OAuth):
headers = {"Accept": "application/json"}
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json()
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
access_token = response_json.get("access_token")
if not access_token:
@@ -68,23 +121,24 @@ class GitHubOAuth(OAuth):
return access_token
def get_raw_user_info(self, token: str):
def get_raw_user_info(self, token: str) -> JsonObject:
headers = {"Authorization": f"token {token}"}
response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
user_info = response.json()
user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response))
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
email_info = email_response.json()
primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
primary_email = next((email for email in email_info if email.get("primary") is True), None)
return {**user_info, "email": primary_email.get("email", "")}
return {**user_info, "email": primary_email.get("email", "") if primary_email else ""}
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
email = raw_info.get("email")
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
email = payload.get("email")
if not email:
email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com"
return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email)
email = f"{payload['id']}+{payload['login']}@users.noreply.github.com"
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email)
class GoogleOAuth(OAuth):
@@ -92,7 +146,7 @@ class GoogleOAuth(OAuth):
_TOKEN_URL = "https://oauth2.googleapis.com/token"
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
def get_authorization_url(self, invite_token: str | None = None):
def get_authorization_url(self, invite_token: str | None = None) -> str:
params = {
"client_id": self.client_id,
"response_type": "code",
@@ -103,7 +157,7 @@ class GoogleOAuth(OAuth):
params["state"] = invite_token
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):
def get_access_token(self, code: str) -> str:
data = {
"client_id": self.client_id,
"client_secret": self.client_secret,
@@ -114,7 +168,7 @@ class GoogleOAuth(OAuth):
headers = {"Accept": "application/json"}
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
response_json = response.json()
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
access_token = response_json.get("access_token")
if not access_token:
@@ -122,11 +176,12 @@ class GoogleOAuth(OAuth):
return access_token
def get_raw_user_info(self, token: str):
def get_raw_user_info(self, token: str) -> JsonObject:
headers = {"Authorization": f"Bearer {token}"}
response = httpx.get(self._USER_INFO_URL, headers=headers)
response.raise_for_status()
return response.json()
return _json_object(response)
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"])
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
payload = GOOGLE_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
return OAuthUserInfo(id=str(payload["sub"]), name="", email=payload["email"])

View File

@@ -1,25 +1,57 @@
import sys
import urllib.parse
from typing import Any
from typing import Any, Literal
import httpx
from flask_login import current_user
from pydantic import TypeAdapter
from sqlalchemy import select
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.source import DataSourceOauthBinding
if sys.version_info >= (3, 12):
from typing import TypedDict
else:
from typing_extensions import TypedDict
class NotionPageSummary(TypedDict):
page_id: str
page_name: str
page_icon: dict[str, str] | None
parent_id: str
type: Literal["page", "database"]
class NotionSourceInfo(TypedDict):
workspace_name: str | None
workspace_icon: str | None
workspace_id: str | None
pages: list[NotionPageSummary]
total: int
SOURCE_INFO_STORAGE_ADAPTER = TypeAdapter(dict[str, object])
NOTION_SOURCE_INFO_ADAPTER = TypeAdapter(NotionSourceInfo)
NOTION_PAGE_SUMMARY_ADAPTER = TypeAdapter(NotionPageSummary)
class OAuthDataSource:
client_id: str
client_secret: str
redirect_uri: str
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
self.client_id = client_id
self.client_secret = client_secret
self.redirect_uri = redirect_uri
def get_authorization_url(self):
def get_authorization_url(self) -> str:
raise NotImplementedError()
def get_access_token(self, code: str):
def get_access_token(self, code: str) -> None:
raise NotImplementedError()
@@ -30,7 +62,7 @@ class NotionOAuth(OAuthDataSource):
_NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks"
_NOTION_BOT_USER = "https://api.notion.com/v1/users/me"
def get_authorization_url(self):
def get_authorization_url(self) -> str:
params = {
"client_id": self.client_id,
"response_type": "code",
@@ -39,7 +71,7 @@ class NotionOAuth(OAuthDataSource):
}
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
def get_access_token(self, code: str):
def get_access_token(self, code: str) -> None:
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
headers = {"Accept": "application/json"}
auth = (self.client_id, self.client_secret)
@@ -54,13 +86,12 @@ class NotionOAuth(OAuthDataSource):
workspace_id = response_json.get("workspace_id")
# get all authorized pages
pages = self.get_authorized_pages(access_token)
source_info = {
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
"workspace_id": workspace_id,
"pages": pages,
"total": len(pages),
}
source_info = self._build_source_info(
workspace_name=workspace_name,
workspace_icon=workspace_icon,
workspace_id=workspace_id,
pages=pages,
)
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
@@ -70,7 +101,7 @@ class NotionOAuth(OAuthDataSource):
)
)
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
@@ -78,25 +109,24 @@ class NotionOAuth(OAuthDataSource):
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=source_info,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
db.session.add(new_data_source_binding)
db.session.commit()
def save_internal_access_token(self, access_token: str):
def save_internal_access_token(self, access_token: str) -> None:
workspace_name = self.notion_workspace_name(access_token)
workspace_icon = None
workspace_id = current_user.current_tenant_id
# get all authorized pages
pages = self.get_authorized_pages(access_token)
source_info = {
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
"workspace_id": workspace_id,
"pages": pages,
"total": len(pages),
}
source_info = self._build_source_info(
workspace_name=workspace_name,
workspace_icon=workspace_icon,
workspace_id=workspace_id,
pages=pages,
)
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
@@ -106,7 +136,7 @@ class NotionOAuth(OAuthDataSource):
)
)
if data_source_binding:
data_source_binding.source_info = source_info
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
@@ -114,13 +144,13 @@ class NotionOAuth(OAuthDataSource):
new_data_source_binding = DataSourceOauthBinding(
tenant_id=current_user.current_tenant_id,
access_token=access_token,
source_info=source_info,
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
provider="notion",
)
db.session.add(new_data_source_binding)
db.session.commit()
def sync_data_source(self, binding_id: str):
def sync_data_source(self, binding_id: str) -> None:
# save data source binding
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding).where(
@@ -134,23 +164,22 @@ class NotionOAuth(OAuthDataSource):
if data_source_binding:
# get all authorized pages
pages = self.get_authorized_pages(data_source_binding.access_token)
source_info = data_source_binding.source_info
new_source_info = {
"workspace_name": source_info["workspace_name"],
"workspace_icon": source_info["workspace_icon"],
"workspace_id": source_info["workspace_id"],
"pages": pages,
"total": len(pages),
}
data_source_binding.source_info = new_source_info
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
new_source_info = self._build_source_info(
workspace_name=source_info["workspace_name"],
workspace_icon=source_info["workspace_icon"],
workspace_id=source_info["workspace_id"],
pages=pages,
)
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
data_source_binding.disabled = False
data_source_binding.updated_at = naive_utc_now()
db.session.commit()
else:
raise ValueError("Data source binding not found")
def get_authorized_pages(self, access_token: str):
pages = []
def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]:
pages: list[NotionPageSummary] = []
page_results = self.notion_page_search(access_token)
database_results = self.notion_database_search(access_token)
# get page detail
@@ -187,7 +216,7 @@ class NotionOAuth(OAuthDataSource):
"parent_id": parent_id,
"type": "page",
}
pages.append(page)
pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page))
# get database detail
for database_result in database_results:
page_id = database_result["id"]
@@ -220,11 +249,11 @@ class NotionOAuth(OAuthDataSource):
"parent_id": parent_id,
"type": "database",
}
pages.append(page)
pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page))
return pages
def notion_page_search(self, access_token: str):
results = []
def notion_page_search(self, access_token: str) -> list[dict[str, Any]]:
results: list[dict[str, Any]] = []
next_cursor = None
has_more = True
@@ -249,7 +278,7 @@ class NotionOAuth(OAuthDataSource):
return results
def notion_block_parent_page_id(self, access_token: str, block_id: str):
def notion_block_parent_page_id(self, access_token: str, block_id: str) -> str:
headers = {
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
@@ -265,7 +294,7 @@ class NotionOAuth(OAuthDataSource):
return self.notion_block_parent_page_id(access_token, parent[parent_type])
return parent[parent_type]
def notion_workspace_name(self, access_token: str):
def notion_workspace_name(self, access_token: str) -> str:
headers = {
"Authorization": f"Bearer {access_token}",
"Notion-Version": "2022-06-28",
@@ -279,8 +308,8 @@ class NotionOAuth(OAuthDataSource):
return user_info["workspace_name"]
return "workspace"
def notion_database_search(self, access_token: str):
results = []
def notion_database_search(self, access_token: str) -> list[dict[str, Any]]:
results: list[dict[str, Any]] = []
next_cursor = None
has_more = True
@@ -303,3 +332,19 @@ class NotionOAuth(OAuthDataSource):
next_cursor = response_json.get("next_cursor", None)
return results
@staticmethod
def _build_source_info(
*,
workspace_name: str | None,
workspace_icon: str | None,
workspace_id: str | None,
pages: list[NotionPageSummary],
) -> NotionSourceInfo:
return {
"workspace_name": workspace_name,
"workspace_icon": workspace_icon,
"workspace_id": workspace_id,
"pages": pages,
"total": len(pages),
}

View File

@@ -23,6 +23,9 @@ from .enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTr
from .model import Account
from .types import EnumText, LongText, StringUUID
TriggerJsonObject = dict[str, object]
TriggerCredentials = dict[str, str]
class WorkflowTriggerLogDict(TypedDict):
id: str
@@ -89,10 +92,14 @@ class TriggerSubscription(TypeBase):
String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)"
)
endpoint_id: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint")
parameters: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON")
properties: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON")
parameters: Mapped[TriggerJsonObject] = mapped_column(
sa.JSON, nullable=False, comment="Subscription parameters JSON"
)
properties: Mapped[TriggerJsonObject] = mapped_column(
sa.JSON, nullable=False, comment="Subscription properties JSON"
)
credentials: Mapped[dict[str, Any]] = mapped_column(
credentials: Mapped[TriggerCredentials] = mapped_column(
sa.JSON, nullable=False, comment="Subscription credentials JSON"
)
credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key")
@@ -200,8 +207,8 @@ class TriggerOAuthTenantClient(TypeBase):
)
@property
def oauth_params(self) -> Mapping[str, Any]:
return cast(Mapping[str, Any], json.loads(self.encrypted_oauth_params or "{}"))
def oauth_params(self) -> Mapping[str, object]:
return cast(TriggerJsonObject, json.loads(self.encrypted_oauth_params or "{}"))
class WorkflowTriggerLog(TypeBase):

View File

@@ -19,7 +19,7 @@ from sqlalchemy import (
orm,
select,
)
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
from sqlalchemy.orm import Mapped, mapped_column
from typing_extensions import deprecated
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
@@ -33,7 +33,7 @@ from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus
from dify_graph.file.constants import maybe_file_object
from dify_graph.file.models import File
from dify_graph.variables import utils as variable_utils
from dify_graph.variables.variables import FloatVariable, IntegerVariable, StringVariable
from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable
from extensions.ext_storage import Storage
from factories.variable_factory import TypeMismatchError, build_segment_with_type
from libs.datetime_utils import naive_utc_now
@@ -59,6 +59,9 @@ from .types import EnumText, LongText, StringUUID
logger = logging.getLogger(__name__)
SerializedWorkflowValue = dict[str, Any]
SerializedWorkflowVariables = dict[str, SerializedWorkflowValue]
class WorkflowContentDict(TypedDict):
graph: Mapping[str, Any]
@@ -405,7 +408,7 @@ class Workflow(Base): # bug
def rag_pipeline_user_input_form(self) -> list:
# get user_input_form from start node
variables: list[Any] = self.rag_pipeline_variables
variables: list[SerializedWorkflowValue] = self.rag_pipeline_variables
return variables
@@ -448,17 +451,13 @@ class Workflow(Base): # bug
def environment_variables(
self,
) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
# TODO: find some way to init `self._environment_variables` when instance created.
if self._environment_variables is None:
self._environment_variables = "{}"
# Use workflow.tenant_id to avoid relying on request user in background threads
tenant_id = self.tenant_id
if not tenant_id:
return []
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables or "{}")
environment_variables_dict = cast(SerializedWorkflowVariables, json.loads(self._environment_variables or "{}"))
results = [
variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values()
]
@@ -536,11 +535,7 @@ class Workflow(Base): # bug
@property
def conversation_variables(self) -> Sequence[VariableBase]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._conversation_variables or "{}"))
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
return results
@@ -552,19 +547,20 @@ class Workflow(Base): # bug
)
@property
def rag_pipeline_variables(self) -> list[dict]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._rag_pipeline_variables is None:
self._rag_pipeline_variables = "{}"
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
results = list(variables_dict.values())
return results
def rag_pipeline_variables(self) -> list[SerializedWorkflowValue]:
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._rag_pipeline_variables or "{}"))
return [RAGPipelineVariable.model_validate(item).model_dump(mode="json") for item in variables_dict.values()]
@rag_pipeline_variables.setter
def rag_pipeline_variables(self, values: list[dict]) -> None:
def rag_pipeline_variables(self, values: Sequence[Mapping[str, Any] | RAGPipelineVariable]) -> None:
self._rag_pipeline_variables = json.dumps(
{item["variable"]: item for item in values},
{
rag_pipeline_variable.variable: rag_pipeline_variable.model_dump(mode="json")
for rag_pipeline_variable in (
item if isinstance(item, RAGPipelineVariable) else RAGPipelineVariable.model_validate(item)
for item in values
)
},
ensure_ascii=False,
)
@@ -802,44 +798,36 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
__tablename__ = "workflow_node_executions"
@declared_attr.directive
@classmethod
def __table_args__(cls) -> Any:
return (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index(
"workflow_node_execution_workflow_run_id_idx",
"workflow_run_id",
),
Index(
"workflow_node_execution_node_run_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_id",
),
Index(
"workflow_node_execution_id_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_execution_id",
),
Index(
# The first argument is the index name,
# which we leave as `None`` to allow auto-generation by the ORM.
None,
cls.tenant_id,
cls.workflow_id,
cls.node_id,
# MyPy may flag the following line because it doesn't recognize that
# the `declared_attr` decorator passes the receiving class as the first
# argument to this method, allowing us to reference class attributes.
cls.created_at.desc(),
),
)
__table_args__ = (
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
Index(
"workflow_node_execution_workflow_run_id_idx",
"workflow_run_id",
),
Index(
"workflow_node_execution_node_run_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_id",
),
Index(
"workflow_node_execution_id_idx",
"tenant_id",
"app_id",
"workflow_id",
"triggered_from",
"node_execution_id",
),
Index(
None,
"tenant_id",
"workflow_id",
"node_id",
sa.desc("created_at"),
),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
tenant_id: Mapped[str] = mapped_column(StringUUID)

View File

@@ -1,4 +1,3 @@
configs/middleware/cache/redis_pubsub_config.py
controllers/console/app/annotation.py
controllers/console/app/app.py
controllers/console/app/app_import.py
@@ -14,21 +13,6 @@ controllers/console/workspace/trigger_providers.py
controllers/service_api/app/annotation.py
controllers/web/workflow_events.py
core/agent/fc_agent_runner.py
core/app/apps/advanced_chat/app_generator.py
core/app/apps/advanced_chat/app_runner.py
core/app/apps/advanced_chat/generate_task_pipeline.py
core/app/apps/agent_chat/app_generator.py
core/app/apps/base_app_generate_response_converter.py
core/app/apps/base_app_generator.py
core/app/apps/chat/app_generator.py
core/app/apps/common/workflow_response_converter.py
core/app/apps/completion/app_generator.py
core/app/apps/pipeline/pipeline_generator.py
core/app/apps/pipeline/pipeline_runner.py
core/app/apps/workflow/app_generator.py
core/app/apps/workflow/app_runner.py
core/app/apps/workflow/generate_task_pipeline.py
core/app/apps/workflow_app_runner.py
core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
core/datasource/datasource_manager.py
core/external_data_tool/api/api.py
@@ -109,37 +93,6 @@ core/tools/workflow_as_tool/provider.py
core/trigger/debug/event_selectors.py
core/trigger/entities/entities.py
core/trigger/provider.py
core/workflow/workflow_entry.py
dify_graph/entities/workflow_execution.py
dify_graph/file/file_manager.py
dify_graph/graph_engine/error_handler.py
dify_graph/graph_engine/layers/execution_limits.py
dify_graph/nodes/agent/agent_node.py
dify_graph/nodes/base/node.py
dify_graph/nodes/code/code_node.py
dify_graph/nodes/datasource/datasource_node.py
dify_graph/nodes/document_extractor/node.py
dify_graph/nodes/human_input/human_input_node.py
dify_graph/nodes/if_else/if_else_node.py
dify_graph/nodes/iteration/iteration_node.py
dify_graph/nodes/knowledge_index/knowledge_index_node.py
core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
dify_graph/nodes/list_operator/node.py
dify_graph/nodes/llm/node.py
dify_graph/nodes/loop/loop_node.py
dify_graph/nodes/parameter_extractor/parameter_extractor_node.py
dify_graph/nodes/question_classifier/question_classifier_node.py
dify_graph/nodes/start/start_node.py
dify_graph/nodes/template_transform/template_transform_node.py
dify_graph/nodes/tool/tool_node.py
dify_graph/nodes/trigger_plugin/trigger_event_node.py
dify_graph/nodes/trigger_schedule/trigger_schedule_node.py
dify_graph/nodes/trigger_webhook/node.py
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
dify_graph/nodes/variable_assigner/v1/node.py
dify_graph/nodes/variable_assigner/v2/node.py
dify_graph/variables/types.py
extensions/ext_fastopenapi.py
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
extensions/otel/instrumentation.py
extensions/otel/runtime.py
@@ -156,19 +109,7 @@ extensions/storage/oracle_oci_storage.py
extensions/storage/supabase_storage.py
extensions/storage/tencent_cos_storage.py
extensions/storage/volcengine_tos_storage.py
factories/variable_factory.py
libs/external_api.py
libs/gmpy2_pkcs10aep_cipher.py
libs/helper.py
libs/login.py
libs/module_loading.py
libs/oauth.py
libs/oauth_data_source.py
models/trigger.py
models/workflow.py
repositories/sqlalchemy_api_workflow_node_execution_repository.py
repositories/sqlalchemy_api_workflow_run_repository.py
repositories/sqlalchemy_execution_extra_content_repository.py
schedule/queue_monitor_task.py
services/account_service.py
services/audio_service.py

View File

@@ -8,7 +8,7 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
import json
from collections.abc import Sequence
from datetime import datetime
from typing import cast
from typing import Protocol, cast
from sqlalchemy import asc, delete, desc, func, select
from sqlalchemy.engine import CursorResult
@@ -22,6 +22,20 @@ from repositories.api_workflow_node_execution_repository import (
)
class _WorkflowNodeExecutionSnapshotRow(Protocol):
id: str
node_execution_id: str | None
node_id: str
node_type: str
title: str
index: int
status: WorkflowNodeExecutionStatus
elapsed_time: float | None
created_at: datetime
finished_at: datetime | None
execution_metadata: str | None
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
"""
SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository.
@@ -40,6 +54,8 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
- Thread-safe database operations using session-per-request pattern
"""
_session_maker: sessionmaker[Session]
def __init__(self, session_maker: sessionmaker[Session]):
"""
Initialize the repository with a sessionmaker.
@@ -156,12 +172,12 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
)
with self._session_maker() as session:
rows = session.execute(stmt).all()
rows = cast(Sequence[_WorkflowNodeExecutionSnapshotRow], session.execute(stmt).all())
return [self._row_to_snapshot(row) for row in rows]
@staticmethod
def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot:
def _row_to_snapshot(row: _WorkflowNodeExecutionSnapshotRow) -> WorkflowNodeExecutionSnapshot:
metadata: dict[str, object] = {}
execution_metadata = getattr(row, "execution_metadata", None)
if execution_metadata:

View File

@@ -30,7 +30,7 @@ from core.plugin.impl.debugging import PluginDebuggingClient
from core.plugin.impl.plugin import PluginInstaller
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.provider import Provider, ProviderCredential
from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider
from models.provider_ids import GenericProviderID
from services.enterprise.plugin_manager_service import (
PluginManagerService,
@@ -534,6 +534,13 @@ class PluginService:
plugin_id = plugin.plugin_id
logger.info("Deleting credentials for plugin: %s", plugin_id)
session.execute(
delete(TenantPreferredModelProvider).where(
TenantPreferredModelProvider.tenant_id == tenant_id,
TenantPreferredModelProvider.provider_name.like(f"{plugin_id}/%"),
)
)
# Delete provider credentials that match this plugin
credential_ids = session.scalars(
select(ProviderCredential.id).where(

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any
from copy import deepcopy
from unittest.mock import MagicMock, patch
import pytest
@@ -33,8 +33,8 @@ def _make_graph_state():
],
)
def test_run_uses_single_node_execution_branch(
single_iteration_run: Any,
single_loop_run: Any,
single_iteration_run: WorkflowAppGenerateEntity.SingleIterationRunEntity | None,
single_loop_run: WorkflowAppGenerateEntity.SingleLoopRunEntity | None,
) -> None:
app_config = MagicMock()
app_config.app_id = "app"
@@ -130,10 +130,23 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
"break_conditions": [],
"logical_operator": "and",
},
},
{
"id": "other-node",
"data": {
"type": "answer",
"title": "Answer",
},
},
],
"edges": [
{
"source": "other-node",
"target": "loop-node",
}
],
"edges": [],
}
original_graph_dict = deepcopy(workflow.graph_dict)
_, _, graph_runtime_state = _make_graph_state()
seen_configs: list[object] = []
@@ -143,13 +156,19 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
seen_configs.append(value)
return original_validate_python(value)
class FakeNodeClass:
@staticmethod
def extract_variable_selector_to_variable_mapping(**_kwargs):
return {}
monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python)
with (
patch("core.app.apps.workflow_app_runner.DifyNodeFactory"),
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()),
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()) as graph_init,
patch("core.app.apps.workflow_app_runner.load_into_variable_pool"),
patch("core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool"),
patch("core.app.apps.workflow_app_runner.resolve_workflow_node_class", return_value=FakeNodeClass),
):
runner._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
@@ -161,3 +180,8 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
)
assert seen_configs == [workflow.graph_dict["nodes"][0]]
assert workflow.graph_dict == original_graph_dict
graph_config = graph_init.call_args.kwargs["graph_config"]
assert graph_config is not workflow.graph_dict
assert graph_config["nodes"] == [workflow.graph_dict["nodes"][0]]
assert graph_config["edges"] == []

View File

@@ -734,7 +734,7 @@ def test_create_provider_credential_creates_provider_record_when_missing() -> No
def test_create_provider_credential_marks_existing_provider_as_valid() -> None:
configuration = _build_provider_configuration()
session = Mock()
provider_record = SimpleNamespace(is_valid=False)
provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id="existing-cred")
with _patched_session(session):
with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False):
@@ -743,6 +743,25 @@ def test_create_provider_credential_marks_existing_provider_as_valid() -> None:
configuration.create_provider_credential({"api_key": "raw"}, "Main")
assert provider_record.is_valid is True
assert provider_record.credential_id == "existing-cred"
session.commit.assert_called_once()
def test_create_provider_credential_auto_activates_when_no_active_credential() -> None:
configuration = _build_provider_configuration()
session = Mock()
provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id=None, updated_at=None)
with _patched_session(session):
with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False):
with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}):
with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record):
with patch("core.entities.provider_configuration.ProviderCredentialsCache"):
with patch.object(ProviderConfiguration, "switch_preferred_provider_type"):
configuration.create_provider_credential({"api_key": "raw"}, "Main")
assert provider_record.is_valid is True
assert provider_record.credential_id is not None
session.commit.assert_called_once()

View File

@@ -12,7 +12,7 @@ NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api
# console or api domain.
# example: http://udify.app/api
NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api
# Dev-only Hono proxy targets. Set the api prefixes above to https://localhost:5001/... to start the proxy with HTTPS.
# Dev-only Hono proxy targets. The frontend keeps requesting http://localhost:5001 directly.
HONO_PROXY_HOST=127.0.0.1
HONO_PROXY_PORT=5001
HONO_CONSOLE_API_PROXY_TARGET=

View File

@@ -328,7 +328,7 @@ describe('createWorkflowStreamHandlers', () => {
vi.clearAllMocks()
})
const setupHandlers = (overrides: { isTimedOut?: () => boolean } = {}) => {
const setupHandlers = (overrides: { isPublicAPI?: boolean, isTimedOut?: () => boolean } = {}) => {
let completionRes = ''
let currentTaskId: string | null = null
let isStopping = false
@@ -359,6 +359,7 @@ describe('createWorkflowStreamHandlers', () => {
const handlers = createWorkflowStreamHandlers({
getCompletionRes: () => completionRes,
getWorkflowProcessData: () => workflowProcessData,
isPublicAPI: overrides.isPublicAPI ?? false,
isTimedOut: overrides.isTimedOut ?? (() => false),
markEnded,
notify,
@@ -391,7 +392,7 @@ describe('createWorkflowStreamHandlers', () => {
}
it('should process workflow success and paused events', () => {
const setup = setupHandlers()
const setup = setupHandlers({ isPublicAPI: true })
const handlers = setup.handlers as Required<Pick<IOtherOptions, 'onWorkflowStarted' | 'onTextChunk' | 'onHumanInputRequired' | 'onHumanInputFormFilled' | 'onHumanInputFormTimeout' | 'onWorkflowPaused' | 'onWorkflowFinished' | 'onNodeStarted' | 'onNodeFinished' | 'onIterationStart' | 'onIterationNext' | 'onIterationFinish' | 'onLoopStart' | 'onLoopNext' | 'onLoopFinish'>>
act(() => {
@@ -546,7 +547,11 @@ describe('createWorkflowStreamHandlers', () => {
resultText: 'Hello',
status: WorkflowRunningStatus.Succeeded,
}))
expect(sseGetMock).toHaveBeenCalledWith('/workflow/run-1/events', {}, expect.any(Object))
expect(sseGetMock).toHaveBeenCalledWith(
'/workflow/run-1/events',
{},
expect.objectContaining({ isPublicAPI: true }),
)
expect(setup.messageId()).toBe('run-1')
expect(setup.onCompleted).toHaveBeenCalledWith('{"answer":"Hello"}', 3, true)
expect(setup.setRespondingFalse).toHaveBeenCalled()
@@ -647,6 +652,7 @@ describe('createWorkflowStreamHandlers', () => {
const handlers = createWorkflowStreamHandlers({
getCompletionRes: () => '',
getWorkflowProcessData: () => existingProcess,
isPublicAPI: false,
isTimedOut: () => false,
markEnded: vi.fn(),
notify: setup.notify,

View File

@@ -351,6 +351,7 @@ describe('useResultSender', () => {
await waitFor(() => {
expect(createWorkflowStreamHandlersMock).toHaveBeenCalledWith(expect.objectContaining({
getCompletionRes: harness.runState.getCompletionRes,
isPublicAPI: true,
resetRunState: harness.runState.resetRunState,
setWorkflowProcessData: harness.runState.setWorkflowProcessData,
}))
@@ -373,6 +374,30 @@ describe('useResultSender', () => {
expect(harness.runState.clearMoreLikeThis).not.toHaveBeenCalled()
})
it('should configure workflow handlers for installed apps as non-public', async () => {
const harness = createRunStateHarness()
const { result } = renderSender({
appSourceType: AppSourceTypeEnum.installedApp,
isWorkflow: true,
runState: harness.runState,
})
await act(async () => {
expect(await result.current.handleSend()).toBe(true)
})
expect(createWorkflowStreamHandlersMock).toHaveBeenCalledWith(expect.objectContaining({
isPublicAPI: false,
}))
expect(sendWorkflowMessageMock).toHaveBeenCalledWith(
{ inputs: { name: 'Alice' } },
expect.any(Object),
AppSourceTypeEnum.installedApp,
'app-1',
)
})
it('should stringify non-Error workflow failures', async () => {
const harness = createRunStateHarness()
sendWorkflowMessageMock.mockRejectedValue('workflow failed')

View File

@@ -1,11 +1,11 @@
import type { ResultInputValue } from '../result-request'
import type { ResultRunStateController } from './use-result-run-state'
import type { PromptConfig } from '@/models/debug'
import type { AppSourceType } from '@/service/share'
import type { VisionFile, VisionSettings } from '@/types/app'
import { useCallback, useEffect, useRef } from 'react'
import { TEXT_GENERATION_TIMEOUT_MS } from '@/config'
import {
AppSourceType,
sendCompletionMessage,
sendWorkflowMessage,
} from '@/service/share'
@@ -117,6 +117,7 @@ export const useResultSender = ({
const otherOptions = createWorkflowStreamHandlers({
getCompletionRes: runState.getCompletionRes,
getWorkflowProcessData: runState.getWorkflowProcessData,
isPublicAPI: appSourceType === AppSourceType.webApp,
isTimedOut: () => isTimeout,
markEnded: () => {
isEnd = true

View File

@@ -13,6 +13,7 @@ type Translate = (key: string, options?: Record<string, unknown>) => string
type CreateWorkflowStreamHandlersParams = {
getCompletionRes: () => string
getWorkflowProcessData: () => WorkflowProcess | undefined
isPublicAPI: boolean
isTimedOut: () => boolean
markEnded: () => void
notify: Notify
@@ -255,6 +256,7 @@ const serializeWorkflowOutputs = (outputs: WorkflowFinishedResponse['data']['out
export const createWorkflowStreamHandlers = ({
getCompletionRes,
getWorkflowProcessData,
isPublicAPI,
isTimedOut,
markEnded,
notify,
@@ -287,6 +289,7 @@ export const createWorkflowStreamHandlers = ({
}
const otherOptions: IOtherOptions = {
isPublicAPI,
onWorkflowStarted: ({ workflow_run_id, task_id }) => {
const workflowProcessData = getWorkflowProcessData()
if (workflowProcessData?.tracing.length) {
@@ -378,6 +381,7 @@ export const createWorkflowStreamHandlers = ({
},
onWorkflowPaused: ({ data }) => {
tempMessageId = data.workflow_run_id
// WebApp workflows must keep using the public API namespace after pause/resume.
void sseGet(`/workflow/${data.workflow_run_id}/events`, {}, otherOptions)
setWorkflowProcessData(applyWorkflowPaused(getWorkflowProcessData()))
},

View File

@@ -210,7 +210,6 @@
"@types/sortablejs": "1.15.9",
"@typescript-eslint/parser": "8.57.0",
"@typescript/native-preview": "7.0.0-dev.20260312.1",
"@vitejs/plugin-basic-ssl": "2.2.0",
"@vitejs/plugin-react": "6.0.0",
"@vitejs/plugin-rsc": "0.5.21",
"@vitest/coverage-v8": "4.1.0",

View File

@@ -34,16 +34,7 @@ const toUpstreamCookieName = (cookieName: string) => {
return `__Host-${cookieName}`
}
const toLocalCookieName = (cookieName: string, options: LocalCookieRewriteOptions) => {
if (options.localSecure)
return cookieName
return cookieName.replace(SECURE_COOKIE_PREFIX_PATTERN, '')
}
type LocalCookieRewriteOptions = {
localSecure: boolean
}
const toLocalCookieName = (cookieName: string) => cookieName.replace(SECURE_COOKIE_PREFIX_PATTERN, '')
export const rewriteCookieHeaderForUpstream = (cookieHeader?: string) => {
if (!cookieHeader)
@@ -64,10 +55,7 @@ export const rewriteCookieHeaderForUpstream = (cookieHeader?: string) => {
.join('; ')
}
const rewriteSetCookieValueForLocal = (
setCookieValue: string,
options: LocalCookieRewriteOptions,
) => {
const rewriteSetCookieValueForLocal = (setCookieValue: string) => {
const [rawCookiePair, ...rawAttributes] = setCookieValue.split(';')
const separatorIndex = rawCookiePair.indexOf('=')
@@ -80,11 +68,11 @@ const rewriteSetCookieValueForLocal = (
.map(attribute => attribute.trim())
.filter(attribute =>
!COOKIE_DOMAIN_PATTERN.test(attribute)
&& (options.localSecure || !COOKIE_SECURE_PATTERN.test(attribute))
&& (options.localSecure || !COOKIE_PARTITIONED_PATTERN.test(attribute)),
&& !COOKIE_SECURE_PATTERN.test(attribute)
&& !COOKIE_PARTITIONED_PATTERN.test(attribute),
)
.map((attribute) => {
if (!options.localSecure && SAME_SITE_NONE_PATTERN.test(attribute))
if (SAME_SITE_NONE_PATTERN.test(attribute))
return 'SameSite=Lax'
if (COOKIE_PATH_PATTERN.test(attribute))
@@ -93,13 +81,10 @@ const rewriteSetCookieValueForLocal = (
return attribute
})
return [`${toLocalCookieName(cookieName, options)}=${cookieValue}`, ...rewrittenAttributes].join('; ')
return [`${toLocalCookieName(cookieName)}=${cookieValue}`, ...rewrittenAttributes].join('; ')
}
export const rewriteSetCookieHeadersForLocal = (
setCookieHeaders: string | string[] | undefined,
options: LocalCookieRewriteOptions,
): string[] | undefined => {
export const rewriteSetCookieHeadersForLocal = (setCookieHeaders?: string | string[]): string[] | undefined => {
if (!setCookieHeaders)
return undefined
@@ -107,7 +92,7 @@ export const rewriteSetCookieHeadersForLocal = (
? setCookieHeaders
: [setCookieHeaders]
return normalizedHeaders.map(setCookieValue => rewriteSetCookieValueForLocal(setCookieValue, options))
return normalizedHeaders.map(rewriteSetCookieValueForLocal)
}
export { DEFAULT_PROXY_TARGET }

View File

@@ -1,21 +0,0 @@
export type DevProxyProtocolEnv = Partial<Record<
| 'NEXT_PUBLIC_API_PREFIX'
| 'NEXT_PUBLIC_PUBLIC_API_PREFIX',
string
>>
const isHttpsUrl = (value?: string) => {
if (!value)
return false
try {
return new URL(value).protocol === 'https:'
}
catch {
return false
}
}
export const shouldUseHttpsForDevProxy = (env: DevProxyProtocolEnv = {}) => {
return isHttpsUrl(env.NEXT_PUBLIC_API_PREFIX) || isHttpsUrl(env.NEXT_PUBLIC_PUBLIC_API_PREFIX)
}

View File

@@ -1,5 +1,5 @@
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { buildUpstreamUrl, createDevProxyApp, isAllowedDevOrigin, resolveDevProxyTargets, shouldUseHttpsForDevProxy } from './server'
import { buildUpstreamUrl, createDevProxyApp, isAllowedDevOrigin, resolveDevProxyTargets } from './server'
describe('dev proxy server', () => {
beforeEach(() => {
@@ -19,21 +19,6 @@ describe('dev proxy server', () => {
expect(targets.publicApiTarget).toBe('https://public.example.com')
})
// Scenario: the local dev proxy should switch to https when api prefixes are configured with https.
it('should enable https for the local dev proxy when api prefixes use https', () => {
// Assert
expect(shouldUseHttpsForDevProxy({
NEXT_PUBLIC_API_PREFIX: 'https://localhost:5001/console/api',
})).toBe(true)
expect(shouldUseHttpsForDevProxy({
NEXT_PUBLIC_PUBLIC_API_PREFIX: 'https://localhost:5001/api',
})).toBe(true)
expect(shouldUseHttpsForDevProxy({
NEXT_PUBLIC_API_PREFIX: 'http://localhost:5001/console/api',
NEXT_PUBLIC_PUBLIC_API_PREFIX: 'http://localhost:5001/api',
})).toBe(false)
})
// Scenario: target paths should not be duplicated when the incoming route already includes them.
it('should preserve prefixed targets when building upstream URLs', () => {
// Act
@@ -47,7 +32,6 @@ describe('dev proxy server', () => {
it('should only allow local development origins', () => {
// Assert
expect(isAllowedDevOrigin('http://localhost:3000')).toBe(true)
expect(isAllowedDevOrigin('https://localhost:3000')).toBe(true)
expect(isAllowedDevOrigin('http://127.0.0.1:3000')).toBe(true)
expect(isAllowedDevOrigin('https://example.com')).toBe(false)
})
@@ -102,39 +86,6 @@ describe('dev proxy server', () => {
])
})
// Scenario: secure local proxy responses should keep secure cross-site cookie attributes intact.
it('should preserve secure cookie attributes when the local proxy is https', async () => {
// Arrange
const fetchImpl = vi.fn<typeof fetch>().mockResolvedValue(new Response('ok', {
status: 200,
headers: [
['set-cookie', '__Host-access_token=abc; Path=/console/api; Domain=cloud.dify.ai; Secure; SameSite=None; Partitioned'],
['set-cookie', '__Host-csrf_token=csrf; Path=/console/api; Domain=cloud.dify.ai; Secure; SameSite=None'],
],
}))
const app = createDevProxyApp({
consoleApiTarget: 'https://cloud.dify.ai',
publicApiTarget: 'https://public.dify.ai',
fetchImpl,
})
// Act
const response = await app.request('https://127.0.0.1:5001/console/api/apps?page=1', {
headers: {
Origin: 'https://localhost:3000',
Cookie: 'access_token=abc',
},
})
// Assert
expect(response.headers.getSetCookie()).toEqual([
'__Host-access_token=abc; Path=/; Secure; SameSite=None; Partitioned',
'__Host-csrf_token=csrf; Path=/; Secure; SameSite=None',
])
expect(response.headers.get('access-control-allow-origin')).toBe('https://localhost:3000')
expect(response.headers.get('access-control-allow-credentials')).toBe('true')
})
// Scenario: preflight requests should advertise allowed headers for credentialed cross-origin calls.
it('should answer CORS preflight requests', async () => {
// Arrange

View File

@@ -2,16 +2,10 @@ import type { Context, Hono } from 'hono'
import { Hono as HonoApp } from 'hono'
import { DEFAULT_PROXY_TARGET, rewriteCookieHeaderForUpstream, rewriteSetCookieHeadersForLocal } from './cookies'
export { shouldUseHttpsForDevProxy } from './protocol'
type DevProxyEnv = Partial<Record<
| 'HONO_CONSOLE_API_PROXY_TARGET'
| 'HONO_PUBLIC_API_PROXY_TARGET',
string
> & Record<
| 'NEXT_PUBLIC_API_PREFIX'
| 'NEXT_PUBLIC_PUBLIC_API_PREFIX',
string | undefined
>>
export type DevProxyTargets = {
@@ -99,15 +93,11 @@ const createProxyRequestHeaders = (request: Request, targetUrl: URL) => {
return headers
}
const createUpstreamResponseHeaders = (
response: Response,
requestOrigin: string | null | undefined,
localSecure: boolean,
) => {
const createUpstreamResponseHeaders = (response: Response, requestOrigin?: string | null) => {
const headers = new Headers(response.headers)
RESPONSE_HEADERS_TO_DROP.forEach(header => headers.delete(header))
const rewrittenSetCookies = rewriteSetCookieHeadersForLocal(response.headers.getSetCookie(), { localSecure })
const rewrittenSetCookies = rewriteSetCookieHeadersForLocal(response.headers.getSetCookie())
rewrittenSetCookies?.forEach((cookie) => {
headers.append('set-cookie', cookie)
})
@@ -136,11 +126,7 @@ const proxyRequest = async (
}
const upstreamResponse = await fetchImpl(targetUrl, requestInit)
const responseHeaders = createUpstreamResponseHeaders(
upstreamResponse,
context.req.header('origin'),
requestUrl.protocol === 'https:',
)
const responseHeaders = createUpstreamResponseHeaders(upstreamResponse, context.req.header('origin'))
return new Response(upstreamResponse.body, {
status: upstreamResponse.status,

13
web/pnpm-lock.yaml generated
View File

@@ -512,9 +512,6 @@ importers:
'@typescript/native-preview':
specifier: 7.0.0-dev.20260312.1
version: 7.0.0-dev.20260312.1
'@vitejs/plugin-basic-ssl':
specifier: 2.2.0
version: 2.2.0(@voidzero-dev/vite-plus-core@0.1.11(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.0)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))
'@vitejs/plugin-react':
specifier: 6.0.0
version: 6.0.0(@voidzero-dev/vite-plus-core@0.1.11(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.0)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))
@@ -3606,12 +3603,6 @@ packages:
resolution: {integrity: sha512-hBcWIOppZV14bi+eAmCZj8Elj8hVSUZJTpf1lgGBhVD85pervzQ1poM/qYfFUlPraYSZYP+ASg6To5BwYmUSGQ==}
engines: {node: '>=16'}
'@vitejs/plugin-basic-ssl@2.2.0':
resolution: {integrity: sha512-nmyQ1HGRkfUxjsv3jw0+hMhEdZdrtkvMTdkzRUaRWfiO6PCWw2V2Pz3gldCq96Tn9S8htcgdTxw/gmbLLEbfYw==}
engines: {node: ^18.0.0 || ^20.0.0 || >=22.0.0}
peerDependencies:
vite: ^6.0.0 || ^7.0.0 || ^8.0.0
'@vitejs/plugin-react@5.2.0':
resolution: {integrity: sha512-YmKkfhOAi3wsB1PhJq5Scj3GXMn3WvtQ/JC0xoopuHoXSdmtdStOpFrYaT1kie2YgFBcIe64ROzMYRjCrYOdYw==}
engines: {node: ^20.19.0 || >=22.12.0}
@@ -11039,10 +11030,6 @@ snapshots:
'@resvg/resvg-wasm': 2.4.0
satori: 0.16.0
'@vitejs/plugin-basic-ssl@2.2.0(@voidzero-dev/vite-plus-core@0.1.11(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.0)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))':
dependencies:
vite: '@voidzero-dev/vite-plus-core@0.1.11(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.0)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)'
'@vitejs/plugin-react@5.2.0(@voidzero-dev/vite-plus-core@0.1.11(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.0)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))':
dependencies:
'@babel/core': 7.29.0

View File

@@ -1,10 +1,8 @@
import { createSecureServer } from 'node:http2'
import path from 'node:path'
import { fileURLToPath } from 'node:url'
import { serve } from '@hono/node-server'
import { getCertificate } from '@vitejs/plugin-basic-ssl'
import { loadEnv } from 'vite'
import { createDevProxyApp, resolveDevProxyTargets, shouldUseHttpsForDevProxy } from '../plugins/dev-proxy/server'
import { createDevProxyApp, resolveDevProxyTargets } from '../plugins/dev-proxy/server'
const projectRoot = path.resolve(path.dirname(fileURLToPath(import.meta.url)), '..')
const mode = process.env.MODE || process.env.NODE_ENV || 'development'
@@ -13,33 +11,11 @@ const env = loadEnv(mode, projectRoot, '')
const host = env.HONO_PROXY_HOST || '127.0.0.1'
const port = Number(env.HONO_PROXY_PORT || 5001)
const app = createDevProxyApp(resolveDevProxyTargets(env))
const useHttps = shouldUseHttpsForDevProxy(env)
if (useHttps) {
const certificate = await getCertificate(
path.join(projectRoot, 'node_modules/.vite/basic-ssl'),
'localhost',
Array.from(new Set(['localhost', '127.0.0.1', host])),
)
serve({
fetch: app.fetch,
hostname: host,
port,
})
serve({
fetch: app.fetch,
hostname: host,
port,
createServer: createSecureServer,
serverOptions: {
allowHTTP1: true,
cert: certificate,
key: certificate,
},
})
}
else {
serve({
fetch: app.fetch,
hostname: host,
port,
})
}
console.log(`[dev-hono-proxy] listening on ${useHttps ? 'https' : 'http'}://${host}:${port}`)
console.log(`[dev-hono-proxy] listening on http://${host}:${port}`)

View File

@@ -1,12 +1,9 @@
import path from 'node:path'
import { fileURLToPath } from 'node:url'
import basicSsl from '@vitejs/plugin-basic-ssl'
import react from '@vitejs/plugin-react'
import vinext from 'vinext'
import { loadEnv } from 'vite'
import Inspect from 'vite-plugin-inspect'
import { defineConfig } from 'vite-plus'
import { shouldUseHttpsForDevProxy } from './plugins/dev-proxy/protocol'
import { createCodeInspectorPlugin, createForceInspectorClientInjectionPlugin } from './plugins/vite/code-inspector'
import { customI18nHmrPlugin } from './plugins/vite/custom-i18n-hmr'
import { nextStaticImageTestPlugin } from './plugins/vite/next-static-image-test'
@@ -24,8 +21,6 @@ export default defineConfig(({ mode }) => {
const isTest = mode === 'test'
const isStorybook = process.env.STORYBOOK === 'true'
|| process.argv.some(arg => arg.toLowerCase().includes('storybook'))
const env = loadEnv(mode, projectRoot, '')
const useHttpsForDevServer = shouldUseHttpsForDevProxy(env)
const isAppComponentsCoverage = coverageScope === 'app-components'
const excludedComponentCoverageFiles = isAppComponentsCoverage
? collectComponentCoverageExcludedFiles(path.join(projectRoot, 'app/components'), { pathPrefix: 'app/components' })
@@ -62,7 +57,6 @@ export default defineConfig(({ mode }) => {
react(),
vinext({ react: false }),
customI18nHmrPlugin({ injectTarget: browserInitializerInjectTarget }),
...(useHttpsForDevServer ? [basicSsl()] : []),
// reactGrabOpenFilePlugin({
// injectTarget: browserInitializerInjectTarget,
// projectRoot,