Compare commits

...

1 Commits

Author SHA1 Message Date
Byron Wang
e8668782d6 gemini 3 pro dify workflow-engine test 2025-11-19 16:17:47 +08:00
323 changed files with 27727 additions and 2752 deletions

View File

@@ -0,0 +1,30 @@
# Security Penetration Test Report
**Generated:** 2025-11-16 14:02:56 UTC
Executive Summary:
Conducted a thorough white-box security assessment of the API located in /workspace/api, focusing on authentication, authorization, business logic vulnerabilities, and IDOR in key endpoints such as /installed-apps.
Methodology:
- Full recursive file listing and static code analysis to identify HTTP routes and sensitive endpoint implementations.
- Focused static analysis on endpoints handling sensitive actions, authentication, and role-based authorization.
- Created specialized agents for authentication and business logic vulnerability testing.
- Dynamic testing attempted for IDOR and authorization bypass, limited by local API server unavailability.
- All findings documented with recommended next steps.
Findings:
- Discovered multiple /installed-apps endpoints with solid authentication and multi-layered authorization checks enforcing tenant and role ownership.
- No exploitable access control bypass or privilege escalation vulnerabilities confirmed.
- Dynamic vulnerability testing for IDOR hampered due to connection refusals, preventing full validation.
- Created a high-priority note recommending environment verification and retesting of dynamic attacks once the API server is accessible.
Recommendations:
- Verify and restore access to the local API server to enable full dynamic testing.
- Retry dynamic testing for IDOR and authorization bypass attacks to confirm security.
- Continue layered security reviews focusing on evolving business logic and role enforcement.
- Consider adding automated integration tests validating authorization policies.
Conclusion:
The static analysis phase confirmed robust authentication and authorization controls in key sensitive endpoints; however, dynamic testing limitations prevent final validation. Once dynamic testing is possible, verify no IDOR or broken function-level authorization issues remain. This assessment provides a strong foundation for secure API usage and further iterative validation.
Severity: Medium (due to testing environment constraints limiting dynamic verification)

View File

@@ -161,7 +161,7 @@ WEB_API_CORS_ALLOW_ORIGINS=http://localhost:3000,*
CONSOLE_CORS_ALLOW_ORIGINS=http://localhost:3000,*
# Set COOKIE_DOMAIN when the console frontend and API are on different subdomains.
# Provide the registrable domain (e.g. example.com); leading dots are optional.
COOKIE_DOMAIN=
COOKIE_DOMAIN=localhost:5001
# Vector database configuration
# Supported values are `weaviate`, `qdrant`, `milvus`, `myscale`, `relyt`, `pgvector`, `pgvecto-rs`, `chroma`, `opensearch`, `oracle`, `tencent`, `elasticsearch`, `elasticsearch-ja`, `analyticdb`, `couchbase`, `vikingdb`, `oceanbase`, `opengauss`, `tablestore`,`vastbase`,`tidb`,`tidb_on_qdrant`,`baidu`,`lindorm`,`huawei_cloud`,`upstash`, `matrixone`.

11
api/bin/env Normal file
View File

@@ -0,0 +1,11 @@
#!/bin/sh
# add binaries to PATH if they aren't added yet
# affix colons on either side of $PATH to simplify matching
case ":${PATH}:" in
*:"$HOME/.local/bin":*)
;;
*)
# Prepending path in case a system-installed binary needs to be overridden
export PATH="$HOME/.local/bin:$PATH"
;;
esac

4
api/bin/env.fish Normal file
View File

@@ -0,0 +1,4 @@
if not contains "$HOME/.local/bin" $PATH
# Prepending path in case a system-installed binary needs to be overridden
set -x PATH "$HOME/.local/bin" $PATH
end

BIN
api/bin/uv Executable file

Binary file not shown.

BIN
api/bin/uvx Executable file

Binary file not shown.

View File

@@ -216,7 +216,6 @@ def setup_required(view: Callable[P, R]):
raise NotInitValidateError()
elif dify_config.EDITION == "SELF_HOSTED" and not db.session.query(DifySetup).first():
raise NotSetupError()
return view(*args, **kwargs)
return decorated

BIN
api/uv Executable file

Binary file not shown.

5488
api/uv.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,16 @@
# Dify Workflow Engine
A standalone SDK for executing Dify workflows.
## Installation
```bash
pip install -r requirements.txt
```
## Usage
```python
from core.workflow.workflow_entry import WorkflowEntry
# ...
```

View File

@@ -0,0 +1,7 @@
class DifyConfig:
WORKFLOW_CALL_MAX_DEPTH = 5
DEBUG = True
WORKFLOW_MAX_EXECUTION_STEPS = 100
WORKFLOW_MAX_EXECUTION_TIME = 600
dify_config = DifyConfig()

View File

@@ -0,0 +1,2 @@
class Agent:
pass

View File

@@ -0,0 +1,10 @@
from pydantic import BaseModel
class AgentEntity(BaseModel):
pass
class AgentNodeData(BaseModel):
agent_strategy_name: str
class AgentToolEntity(BaseModel):
pass

View File

@@ -0,0 +1,7 @@
from pydantic import BaseModel
class AgentPluginEntity(BaseModel):
pass
class AgentStrategyParameter(BaseModel):
pass

View File

@@ -0,0 +1,2 @@
class GenerateTaskStoppedError(Exception):
pass

View File

@@ -0,0 +1,6 @@
from enum import Enum
class InvokeFrom(Enum):
DEBUGGER = "debugger"
SERVICE_API = "service_api"
WEB_APP = "web_app"

View File

@@ -0,0 +1,2 @@
class CallbackHandler:
pass

View File

@@ -0,0 +1,2 @@
class DifyWorkflowCallbackHandler:
pass

View File

@@ -0,0 +1,2 @@
class Connection:
pass

View File

@@ -0,0 +1,2 @@
class DatasourceManager:
pass

View File

@@ -0,0 +1,4 @@
from pydantic import BaseModel
class DatasourceEntity(BaseModel):
pass

View File

@@ -0,0 +1,22 @@
from pydantic import BaseModel
class DatasourceEntity(BaseModel):
pass
class DatasourceType(BaseModel):
pass
class DatasourceMessage(BaseModel):
pass
class DatasourceParameter(BaseModel):
pass
class DatasourceProviderType(BaseModel):
pass
class GetOnlineDocumentPageContentRequest(BaseModel):
pass
class OnlineDriveDownloadFileRequest(BaseModel):
pass

View File

@@ -0,0 +1,2 @@
class OnlineDocumentDatasourcePlugin:
pass

View File

@@ -0,0 +1,2 @@
class OnlineDriveDatasourcePlugin:
pass

View File

@@ -0,0 +1,2 @@
class DatasourceFileMessageTransformer:
pass

View File

@@ -0,0 +1,6 @@
from .models import File, FileAttribute, FileTransferMethod, FileType
class FileManager:
pass
file_manager = FileManager()

View File

@@ -0,0 +1,13 @@
from enum import StrEnum
class FileType(StrEnum):
IMAGE = "image"
AUDIO = "audio"
VIDEO = "video"
DOCUMENT = "document"
CUSTOM = "custom"
class FileTransferMethod(StrEnum):
REMOTE_URL = "remote_url"
LOCAL_FILE = "local_file"
TOOL_FILE = "tool_file"

View File

@@ -0,0 +1,14 @@
from pydantic import BaseModel
class File(BaseModel):
def to_dict(self):
return {}
class FileAttribute(BaseModel):
pass
class FileTransferMethod(BaseModel):
pass
class FileType(BaseModel):
pass

View File

@@ -0,0 +1,2 @@
class SSKey:
pass

View File

@@ -0,0 +1,5 @@
class CodeExecutor:
pass
class CodeLanguage:
pass

View File

@@ -0,0 +1,14 @@
class CodeExecutor:
pass
class CodeLanguage:
PYTHON3 = "python3"
JAVASCRIPT = "javascript"
JSON = "json"
STRING = "string"
NUMBER = "number"
OBJECT = "object"
ARRAY = "array"
class CodeExecutionError(Exception):
pass

View File

@@ -0,0 +1,2 @@
class CodeNodeProvider:
pass

View File

@@ -0,0 +1,2 @@
class JavascriptCodeProvider:
pass

View File

@@ -0,0 +1,2 @@
class Python3CodeProvider:
pass

View File

@@ -0,0 +1,2 @@
class SSRFProxy:
pass

View File

@@ -0,0 +1,2 @@
class TokenBufferMemory:
pass

View File

@@ -0,0 +1,2 @@
class TokenBufferMemory:
pass

View File

@@ -0,0 +1,5 @@
class ModelManager:
pass
class ModelInstance:
pass

View File

@@ -0,0 +1,2 @@
class ModelPropertyKey:
pass

View File

@@ -0,0 +1,12 @@
from pydantic import BaseModel
class LLMResult(BaseModel):
pass
class LLMUsage(BaseModel):
@classmethod
def empty_usage(cls):
return cls()
class LLMUsageMetadata(BaseModel):
pass

View File

@@ -0,0 +1,7 @@
from pydantic import BaseModel
class ModelType(BaseModel):
pass
class AIModelEntity(BaseModel):
pass

View File

@@ -0,0 +1 @@
from .encoders import jsonable_encoder

View File

@@ -0,0 +1,2 @@
def jsonable_encoder(obj):
return obj

View File

@@ -0,0 +1,2 @@
class PluginManager:
pass

View File

@@ -0,0 +1,2 @@
class InvokeCredentials:
pass

View File

@@ -0,0 +1,5 @@
class PluginDaemonClientSideError(Exception):
pass
class PluginInvokeError(Exception):
pass

View File

@@ -0,0 +1,2 @@
class PluginInstaller:
pass

View File

@@ -0,0 +1,2 @@
class PromptTemplate:
pass

View File

@@ -0,0 +1,4 @@
from pydantic import BaseModel
class PromptEntity(BaseModel):
pass

View File

@@ -0,0 +1,7 @@
from pydantic import BaseModel
class AdvancedPromptEntity(BaseModel):
pass
class MemoryConfig(BaseModel):
pass

View File

@@ -0,0 +1,2 @@
class ProviderManager:
pass

View File

@@ -0,0 +1 @@
# Mock core.rag

View File

@@ -0,0 +1,4 @@
from pydantic import BaseModel
class RetrievalResource(BaseModel):
pass

View File

@@ -0,0 +1,7 @@
from pydantic import BaseModel
class CitationMetadata(BaseModel):
pass
class RetrievalSourceMetadata(BaseModel):
pass

View File

@@ -0,0 +1,4 @@
from .entities import index_processor_entities
class IndexProcessorBase:
pass

View File

@@ -0,0 +1,2 @@
class IndexProcessorBase:
pass

View File

@@ -0,0 +1,2 @@
class IndexProcessorFactory:
pass

View File

@@ -0,0 +1 @@
from .retrieval_service import RetrievalService

View File

@@ -0,0 +1,4 @@
from pydantic import BaseModel
class DatasetRetrieval(BaseModel):
pass

View File

@@ -0,0 +1,7 @@
from pydantic import BaseModel
from typing import ClassVar
class RetrievalMethod(BaseModel):
SEMANTIC_SEARCH: ClassVar[str] = "SEMANTIC_SEARCH"
KEYWORD_SEARCH: ClassVar[str] = "KEYWORD_SEARCH"
HYBRID_SEARCH: ClassVar[str] = "HYBRID_SEARCH"

View File

@@ -0,0 +1,2 @@
class RetrievalService:
pass

View File

@@ -0,0 +1,2 @@
class Tool:
pass

View File

@@ -0,0 +1,2 @@
class ToolManager:
pass

View File

@@ -0,0 +1,4 @@
from pydantic import BaseModel
class ToolEntity(BaseModel):
pass

View File

@@ -0,0 +1,19 @@
from pydantic import BaseModel
class ToolEntity(BaseModel):
pass
class ToolIdentity(BaseModel):
pass
class ToolInvokeMessage(BaseModel):
pass
class ToolParameter(BaseModel):
pass
class ToolProviderType(BaseModel):
pass
class ToolSelector(BaseModel):
pass

View File

@@ -0,0 +1,14 @@
class ToolProviderCredentialValidationError(Exception):
pass
class ToolNotFoundError(Exception):
pass
class ToolParameterValidationError(Exception):
pass
class ToolInvokeError(Exception):
pass
class ToolEngineInvokeError(Exception):
pass

View File

@@ -0,0 +1,2 @@
class ToolEngine:
pass

View File

@@ -0,0 +1,2 @@
class ToolManager:
pass

View File

@@ -0,0 +1,2 @@
class ToolUtils:
pass

View File

@@ -0,0 +1,5 @@
class ToolMessageTransformer:
pass
class ToolFileMessageTransformer:
pass

View File

@@ -0,0 +1,2 @@
class WorkflowAsTool:
pass

View File

@@ -0,0 +1,2 @@
class WorkflowTool:
pass

View File

@@ -0,0 +1,12 @@
from .variables import Variable, IntegerVariable, ArrayAnyVariable
from .segments import (
Segment,
ArrayFileSegment,
ArrayNumberSegment,
ArrayStringSegment,
NoneSegment,
FileSegment,
ArrayObjectSegment,
SegmentGroup
)
from .types import SegmentType, ArrayValidation

View File

@@ -0,0 +1 @@
SELECTORS_LENGTH = 10

View File

@@ -0,0 +1,40 @@
from pydantic import BaseModel
class Segment(BaseModel):
pass
class ArrayFileSegment(Segment):
pass
class ArrayNumberSegment(Segment):
pass
class ArrayStringSegment(Segment):
pass
class NoneSegment(Segment):
pass
class FileSegment(Segment):
pass
class ArrayObjectSegment(Segment):
pass
class ArrayBooleanSegment(Segment):
pass
class BooleanSegment(Segment):
pass
class ObjectSegment(Segment):
pass
class ArrayAnySegment(Segment):
pass
class StringSegment(Segment):
pass
class SegmentGroup(Segment):
pass

View File

@@ -0,0 +1,19 @@
from pydantic import BaseModel
from typing import ClassVar
class SegmentType(BaseModel):
STRING: ClassVar[str] = "string"
NUMBER: ClassVar[str] = "number"
OBJECT: ClassVar[str] = "object"
ARRAY_STRING: ClassVar[str] = "array[string]"
ARRAY_NUMBER: ClassVar[str] = "array[number]"
ARRAY_OBJECT: ClassVar[str] = "array[object]"
BOOLEAN: ClassVar[str] = "boolean"
ARRAY_BOOLEAN: ClassVar[str] = "array[boolean]"
SECRET: ClassVar[str] = "secret"
FILE: ClassVar[str] = "file"
ARRAY_FILE: ClassVar[str] = "array[file]"
GROUP: ClassVar[str] = "group"
class ArrayValidation(BaseModel):
pass

View File

@@ -0,0 +1,16 @@
from pydantic import BaseModel
class Variable(BaseModel):
pass
class IntegerVariable(Variable):
pass
class ArrayAnyVariable(Variable):
pass
class RAGPipelineVariableInput(BaseModel):
pass
class VariableUnion(BaseModel):
pass

View File

@@ -0,0 +1,132 @@
# Workflow
## Project Overview
This is the workflow graph engine module of Dify, implementing a queue-based distributed workflow execution system. The engine handles agentic AI workflows with support for parallel execution, node iteration, conditional logic, and external command control.
## Architecture
### Core Components
The graph engine follows a layered architecture with strict dependency rules:
1. **Graph Engine** (`graph_engine/`) - Orchestrates workflow execution
- **Manager** - External control interface for stop/pause/resume commands
- **Worker** - Node execution runtime
- **Command Processing** - Handles control commands (abort, pause, resume)
- **Event Management** - Event propagation and layer notifications
- **Graph Traversal** - Edge processing and skip propagation
- **Response Coordinator** - Path tracking and session management
- **Layers** - Pluggable middleware (debug logging, execution limits)
- **Command Channels** - Communication channels (InMemory, Redis)
1. **Graph** (`graph/`) - Graph structure and runtime state
- **Graph Template** - Workflow definition
- **Edge** - Node connections with conditions
- **Runtime State Protocol** - State management interface
1. **Nodes** (`nodes/`) - Node implementations
- **Base** - Abstract node classes and variable parsing
- **Specific Nodes** - LLM, Agent, Code, HTTP Request, Iteration, Loop, etc.
1. **Events** (`node_events/`) - Event system
- **Base** - Event protocols
- **Node Events** - Node lifecycle events
1. **Entities** (`entities/`) - Domain models
- **Variable Pool** - Variable storage
- **Graph Init Params** - Initialization configuration
## Key Design Patterns
### Command Channel Pattern
External workflow control via Redis or in-memory channels:
```python
# Send stop command to running workflow
channel = RedisChannel(redis_client, f"workflow:{task_id}:commands")
channel.send_command(AbortCommand(reason="User requested"))
```
### Layer System
Extensible middleware for cross-cutting concerns:
```python
engine = GraphEngine(graph)
engine.layer(DebugLoggingLayer(level="INFO"))
engine.layer(ExecutionLimitsLayer(max_nodes=100))
```
### Event-Driven Architecture
All node executions emit events for monitoring and integration:
- `NodeRunStartedEvent` - Node execution begins
- `NodeRunSucceededEvent` - Node completes successfully
- `NodeRunFailedEvent` - Node encounters error
- `GraphRunStartedEvent/GraphRunCompletedEvent` - Workflow lifecycle
### Variable Pool
Centralized variable storage with namespace isolation:
```python
# Variables scoped by node_id
pool.add(["node1", "output"], value)
result = pool.get(["node1", "output"])
```
## Import Architecture Rules
The codebase enforces strict layering via import-linter:
1. **Workflow Layers** (top to bottom):
- graph_engine → graph_events → graph → nodes → node_events → entities
1. **Graph Engine Internal Layers**:
- orchestration → command_processing → event_management → graph_traversal → domain
1. **Domain Isolation**:
- Domain models cannot import from infrastructure layers
1. **Command Channel Independence**:
- InMemory and Redis channels must remain independent
## Common Tasks
### Adding a New Node Type
1. Create node class in `nodes/<node_type>/`
1. Inherit from `BaseNode` or appropriate base class
1. Implement `_run()` method
1. Register in `nodes/node_mapping.py`
1. Add tests in `tests/unit_tests/core/workflow/nodes/`
### Implementing a Custom Layer
1. Create class inheriting from `Layer` base
1. Override lifecycle methods: `on_graph_start()`, `on_event()`, `on_graph_end()`
1. Add to engine via `engine.layer()`
### Debugging Workflow Execution
Enable debug logging layer:
```python
debug_layer = DebugLoggingLayer(
level="DEBUG",
include_inputs=True,
include_outputs=True
)
```

View File

@@ -0,0 +1,4 @@
SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
RAG_PIPELINE_VARIABLE_NODE_ID = "rag"

View File

@@ -0,0 +1,39 @@
import abc
from typing import Protocol
from core.variables import Variable
class ConversationVariableUpdater(Protocol):
"""
ConversationVariableUpdater defines an abstraction for updating conversation variable values.
It is intended for use by `v1.VariableAssignerNode` and `v2.VariableAssignerNode` when updating
conversation variables.
Implementations may choose to batch updates. If batching is used, the `flush` method
should be implemented to persist buffered changes, and `update`
should handle buffering accordingly.
Note: Since implementations may buffer updates, instances of ConversationVariableUpdater
are not thread-safe. Each VariableAssignerNode should create its own instance during execution.
"""
@abc.abstractmethod
def update(self, conversation_id: str, variable: "Variable"):
"""
Updates the value of the specified conversation variable in the underlying storage.
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
:param variable: The `Variable` instance containing the updated value.
"""
pass
@abc.abstractmethod
def flush(self):
"""
Flushes all pending updates to the underlying storage system.
If the implementation does not buffer updates, this method can be a no-op.
"""
pass

View File

@@ -0,0 +1,17 @@
from ..runtime.graph_runtime_state import GraphRuntimeState
from ..runtime.variable_pool import VariablePool
from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
from .workflow_pause import WorkflowPauseEntity
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"GraphRuntimeState",
"VariablePool",
"WorkflowExecution",
"WorkflowNodeExecution",
"WorkflowPauseEntity",
]

View File

@@ -0,0 +1,8 @@
from pydantic import BaseModel
class AgentNodeStrategyInit(BaseModel):
"""Agent node strategy initialization data."""
name: str
icon: str | None = None

View File

@@ -0,0 +1,20 @@
from collections.abc import Mapping
from typing import Any
from pydantic import BaseModel, Field
class GraphInitParams(BaseModel):
# init params
tenant_id: str = Field(..., description="tenant / workspace id")
app_id: str = Field(..., description="app id")
workflow_id: str = Field(..., description="workflow id")
graph_config: Mapping[str, Any] = Field(..., description="graph config")
user_id: str = Field(..., description="user id")
user_from: str = Field(
..., description="user from, account or end-user"
) # Should be UserFrom enum: 'account' | 'end-user'
invoke_from: str = Field(
..., description="invoke from, service-api, web-app, explore or debugger"
) # Should be InvokeFrom enum: 'service-api' | 'web-app' | 'explore' | 'debugger'
call_depth: int = Field(..., description="call depth")

View File

@@ -0,0 +1,49 @@
from enum import StrEnum, auto
from typing import Annotated, Any, ClassVar, TypeAlias
from pydantic import BaseModel, Discriminator, Tag
class _PauseReasonType(StrEnum):
HUMAN_INPUT_REQUIRED = auto()
SCHEDULED_PAUSE = auto()
class _PauseReasonBase(BaseModel):
TYPE: ClassVar[_PauseReasonType]
class HumanInputRequired(_PauseReasonBase):
TYPE = _PauseReasonType.HUMAN_INPUT_REQUIRED
class SchedulingPause(_PauseReasonBase):
TYPE = _PauseReasonType.SCHEDULED_PAUSE
message: str
def _get_pause_reason_discriminator(v: Any) -> _PauseReasonType | None:
if isinstance(v, _PauseReasonBase):
return v.TYPE
elif isinstance(v, dict):
reason_type_str = v.get("TYPE")
if reason_type_str is None:
return None
try:
reason_type = _PauseReasonType(reason_type_str)
except ValueError:
return None
return reason_type
else:
# return None if the discriminator value isn't found
return None
PauseReason: TypeAlias = Annotated[
(
Annotated[HumanInputRequired, Tag(_PauseReasonType.HUMAN_INPUT_REQUIRED)]
| Annotated[SchedulingPause, Tag(_PauseReasonType.SCHEDULED_PAUSE)]
),
Discriminator(_get_pause_reason_discriminator),
]

View File

@@ -0,0 +1,72 @@
"""
Domain entities for workflow execution.
Models are independent of the storage mechanism and don't contain
implementation details like tenant_id, app_id, etc.
"""
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
from libs.datetime_utils import naive_utc_now
class WorkflowExecution(BaseModel):
"""
Domain model for workflow execution based on WorkflowRun but without
user, tenant, and app attributes.
"""
id_: str = Field(...)
workflow_id: str = Field(...)
workflow_version: str = Field(...)
workflow_type: WorkflowType = Field(...)
graph: Mapping[str, Any] = Field(...)
inputs: Mapping[str, Any] = Field(...)
outputs: Mapping[str, Any] | None = None
status: WorkflowExecutionStatus = WorkflowExecutionStatus.RUNNING
error_message: str = Field(default="")
total_tokens: int = Field(default=0)
total_steps: int = Field(default=0)
exceptions_count: int = Field(default=0)
started_at: datetime = Field(...)
finished_at: datetime | None = None
@property
def elapsed_time(self) -> float:
"""
Calculate elapsed time in seconds.
If workflow is not finished, use current time.
"""
end_time = self.finished_at or naive_utc_now()
return (end_time - self.started_at).total_seconds()
@classmethod
def new(
cls,
*,
id_: str,
workflow_id: str,
workflow_type: WorkflowType,
workflow_version: str,
graph: Mapping[str, Any],
inputs: Mapping[str, Any],
started_at: datetime,
) -> "WorkflowExecution":
return WorkflowExecution(
id_=id_,
workflow_id=workflow_id,
workflow_type=workflow_type,
workflow_version=workflow_version,
graph=graph,
inputs=inputs,
status=WorkflowExecutionStatus.RUNNING,
started_at=started_at,
)

View File

@@ -0,0 +1,147 @@
"""
Domain entities for workflow node execution.
This module contains the domain model for workflow node execution, which is used
by the core workflow module. These models are independent of the storage mechanism
and don't contain implementation details like tenant_id, app_id, etc.
"""
from collections.abc import Mapping
from datetime import datetime
from typing import Any
from pydantic import BaseModel, Field, PrivateAttr
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
class WorkflowNodeExecution(BaseModel):
"""
Domain model for workflow node execution.
This model represents the core business entity of a node execution,
without implementation details like tenant_id, app_id, etc.
Note: User/context-specific fields (triggered_from, created_by, created_by_role)
have been moved to the repository implementation to keep the domain model clean.
These fields are still accepted in the constructor for backward compatibility,
but they are not stored in the model.
"""
# --------- Core identification fields ---------
# Unique identifier for this execution record, used when persisting to storage.
# Value is a UUID string (e.g., '09b3e04c-f9ae-404c-ad82-290b8d7bd382').
id: str
# Optional secondary ID for cross-referencing purposes.
#
# NOTE: For referencing the persisted record, use `id` rather than `node_execution_id`.
# While `node_execution_id` may sometimes be a UUID string, this is not guaranteed.
# In most scenarios, `id` should be used as the primary identifier.
node_execution_id: str | None = None
workflow_id: str # ID of the workflow this node belongs to
workflow_execution_id: str | None = None # ID of the specific workflow run (null for single-step debugging)
# --------- Core identification fields ends ---------
# Execution positioning and flow
index: int # Sequence number for ordering in trace visualization
predecessor_node_id: str | None = None # ID of the node that executed before this one
node_id: str # ID of the node being executed
node_type: NodeType # Type of node (e.g., start, llm, knowledge)
title: str # Display title of the node
# Execution data
# The `inputs` and `outputs` fields hold the full content
inputs: Mapping[str, Any] | None = None # Input variables used by this node
process_data: Mapping[str, Any] | None = None # Intermediate processing data
outputs: Mapping[str, Any] | None = None # Output variables produced by this node
# Execution state
status: WorkflowNodeExecutionStatus = WorkflowNodeExecutionStatus.RUNNING # Current execution status
error: str | None = None # Error message if execution failed
elapsed_time: float = Field(default=0.0) # Time taken for execution in seconds
# Additional metadata
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None # Execution metadata (tokens, cost, etc.)
# Timing information
created_at: datetime # When execution started
finished_at: datetime | None = None # When execution completed
_truncated_inputs: Mapping[str, Any] | None = PrivateAttr(None)
_truncated_outputs: Mapping[str, Any] | None = PrivateAttr(None)
_truncated_process_data: Mapping[str, Any] | None = PrivateAttr(None)
def get_truncated_inputs(self) -> Mapping[str, Any] | None:
return self._truncated_inputs
def get_truncated_outputs(self) -> Mapping[str, Any] | None:
return self._truncated_outputs
def get_truncated_process_data(self) -> Mapping[str, Any] | None:
return self._truncated_process_data
def set_truncated_inputs(self, truncated_inputs: Mapping[str, Any] | None):
self._truncated_inputs = truncated_inputs
def set_truncated_outputs(self, truncated_outputs: Mapping[str, Any] | None):
self._truncated_outputs = truncated_outputs
def set_truncated_process_data(self, truncated_process_data: Mapping[str, Any] | None):
self._truncated_process_data = truncated_process_data
def get_response_inputs(self) -> Mapping[str, Any] | None:
inputs = self.get_truncated_inputs()
if inputs:
return inputs
return self.inputs
@property
def inputs_truncated(self):
return self._truncated_inputs is not None
@property
def outputs_truncated(self):
return self._truncated_outputs is not None
@property
def process_data_truncated(self):
return self._truncated_process_data is not None
def get_response_outputs(self) -> Mapping[str, Any] | None:
outputs = self.get_truncated_outputs()
if outputs is not None:
return outputs
return self.outputs
def get_response_process_data(self) -> Mapping[str, Any] | None:
process_data = self.get_truncated_process_data()
if process_data is not None:
return process_data
return self.process_data
def update_from_mapping(
self,
inputs: Mapping[str, Any] | None = None,
process_data: Mapping[str, Any] | None = None,
outputs: Mapping[str, Any] | None = None,
metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] | None = None,
):
"""
Update the model from mappings.
Args:
inputs: The inputs to update
process_data: The process data to update
outputs: The outputs to update
metadata: The metadata to update
"""
if inputs is not None:
self.inputs = dict(inputs)
if process_data is not None:
self.process_data = dict(process_data)
if outputs is not None:
self.outputs = dict(outputs)
if metadata is not None:
self.metadata = dict(metadata)

View File

@@ -0,0 +1,61 @@
"""
Domain entities for workflow pause management.
This module contains the domain model for workflow pause, which is used
by the core workflow module. These models are independent of the storage mechanism
and don't contain implementation details like tenant_id, app_id, etc.
"""
from abc import ABC, abstractmethod
from datetime import datetime
class WorkflowPauseEntity(ABC):
"""
Abstract base class for workflow pause entities.
This domain model represents a paused workflow execution state,
without implementation details like tenant_id, app_id, etc.
It provides the interface for managing workflow pause/resume operations
and state persistence through file storage.
The `WorkflowPauseEntity` is never reused. If a workflow execution pauses multiple times,
it will generate multiple `WorkflowPauseEntity` records.
"""
@property
@abstractmethod
def id(self) -> str:
"""The identifier of current WorkflowPauseEntity"""
pass
@property
@abstractmethod
def workflow_execution_id(self) -> str:
"""The identifier of the workflow execution record the pause associated with.
Correspond to `WorkflowExecution.id`.
"""
@abstractmethod
def get_state(self) -> bytes:
"""
Retrieve the serialized workflow state from storage.
This method should load and return the workflow execution state
that was saved when the workflow was paused. The state contains
all necessary information to resume the workflow execution.
Returns:
bytes: The serialized workflow state containing
execution context, variable values, node states, etc.
"""
...
@property
@abstractmethod
def resumed_at(self) -> datetime | None:
"""`resumed_at` return the resumption time of the current pause, or `None` if
the pause is not resumed yet.
"""
pass

View File

@@ -0,0 +1,262 @@
from enum import StrEnum
class NodeState(StrEnum):
"""State of a node or edge during workflow execution."""
UNKNOWN = "unknown"
TAKEN = "taken"
SKIPPED = "skipped"
class SystemVariableKey(StrEnum):
"""
System Variables.
"""
QUERY = "query"
FILES = "files"
CONVERSATION_ID = "conversation_id"
USER_ID = "user_id"
DIALOGUE_COUNT = "dialogue_count"
APP_ID = "app_id"
WORKFLOW_ID = "workflow_id"
WORKFLOW_EXECUTION_ID = "workflow_run_id"
TIMESTAMP = "timestamp"
# RAG Pipeline
DOCUMENT_ID = "document_id"
ORIGINAL_DOCUMENT_ID = "original_document_id"
BATCH = "batch"
DATASET_ID = "dataset_id"
DATASOURCE_TYPE = "datasource_type"
DATASOURCE_INFO = "datasource_info"
INVOKE_FROM = "invoke_from"
class NodeType(StrEnum):
START = "start"
END = "end"
ANSWER = "answer"
LLM = "llm"
KNOWLEDGE_RETRIEVAL = "knowledge-retrieval"
KNOWLEDGE_INDEX = "knowledge-index"
IF_ELSE = "if-else"
CODE = "code"
TEMPLATE_TRANSFORM = "template-transform"
QUESTION_CLASSIFIER = "question-classifier"
HTTP_REQUEST = "http-request"
TOOL = "tool"
DATASOURCE = "datasource"
VARIABLE_AGGREGATOR = "variable-aggregator"
LEGACY_VARIABLE_AGGREGATOR = "variable-assigner" # TODO: Merge this into VARIABLE_AGGREGATOR in the database.
LOOP = "loop"
LOOP_START = "loop-start"
LOOP_END = "loop-end"
ITERATION = "iteration"
ITERATION_START = "iteration-start" # Fake start node for iteration.
PARAMETER_EXTRACTOR = "parameter-extractor"
VARIABLE_ASSIGNER = "assigner"
DOCUMENT_EXTRACTOR = "document-extractor"
LIST_OPERATOR = "list-operator"
AGENT = "agent"
TRIGGER_WEBHOOK = "trigger-webhook"
TRIGGER_SCHEDULE = "trigger-schedule"
TRIGGER_PLUGIN = "trigger-plugin"
HUMAN_INPUT = "human-input"
@property
def is_trigger_node(self) -> bool:
"""Check if this node type is a trigger node."""
return self in [
NodeType.TRIGGER_WEBHOOK,
NodeType.TRIGGER_SCHEDULE,
NodeType.TRIGGER_PLUGIN,
]
@property
def is_start_node(self) -> bool:
"""Check if this node type can serve as a workflow entry point."""
return self in [
NodeType.START,
NodeType.DATASOURCE,
NodeType.TRIGGER_WEBHOOK,
NodeType.TRIGGER_SCHEDULE,
NodeType.TRIGGER_PLUGIN,
]
class NodeExecutionType(StrEnum):
"""Node execution type classification."""
EXECUTABLE = "executable" # Regular nodes that execute and produce outputs
RESPONSE = "response" # Response nodes that stream outputs (Answer, End)
BRANCH = "branch" # Nodes that can choose different branches (if-else, question-classifier)
CONTAINER = "container" # Container nodes that manage subgraphs (iteration, loop, graph)
ROOT = "root" # Nodes that can serve as execution entry points
class ErrorStrategy(StrEnum):
FAIL_BRANCH = "fail-branch"
DEFAULT_VALUE = "default-value"
class FailBranchSourceHandle(StrEnum):
FAILED = "fail-branch"
SUCCESS = "success-branch"
class WorkflowType(StrEnum):
"""
Workflow Type Enum for domain layer
"""
WORKFLOW = "workflow"
CHAT = "chat"
RAG_PIPELINE = "rag-pipeline"
class WorkflowExecutionStatus(StrEnum):
# State diagram for the workflw status:
# (@) means start, (*) means end
#
# ┌------------------>------------------------->------------------->--------------┐
# | |
# | ┌-----------------------<--------------------┐ |
# ^ | | |
# | | ^ |
# | V | |
# ┌-----------┐ ┌-----------------------┐ ┌-----------┐ V
# | Scheduled |------->| Running |---------------------->| paused | |
# └-----------┘ └-----------------------┘ └-----------┘ |
# | | | | | | |
# | | | | | | |
# ^ | | | V V |
# | | | | | ┌---------┐ |
# (@) | | | └------------------------>| Stopped |<----┘
# | | | └---------┘
# | | | |
# | | V V
# | | ┌-----------┐ |
# | | | Succeeded |------------->--------------┤
# | | └-----------┘ |
# | V V
# | +--------┐ |
# | | Failed |---------------------->----------------┤
# | └--------┘ |
# V V
# ┌---------------------┐ |
# | Partially Succeeded |---------------------->-----------------┘--------> (*)
# └---------------------┘
#
# Mermaid diagram:
#
# ---
# title: State diagram for Workflow run state
# ---
# stateDiagram-v2
# scheduled: Scheduled
# running: Running
# succeeded: Succeeded
# failed: Failed
# partial_succeeded: Partial Succeeded
# paused: Paused
# stopped: Stopped
#
# [*] --> scheduled:
# scheduled --> running: Start Execution
# running --> paused: Human input required
# paused --> running: human input added
# paused --> stopped: User stops execution
# running --> succeeded: Execution finishes without any error
# running --> failed: Execution finishes with errors
# running --> stopped: User stops execution
# running --> partial_succeeded: some execution occurred and handled during execution
#
# scheduled --> stopped: User stops execution
#
# succeeded --> [*]
# failed --> [*]
# partial_succeeded --> [*]
# stopped --> [*]
# `SCHEDULED` means that the workflow is scheduled to run, but has not
# started running yet. (maybe due to possible worker saturation.)
#
# This enum value is currently unused.
SCHEDULED = "scheduled"
# `RUNNING` means the workflow is exeuting.
RUNNING = "running"
# `SUCCEEDED` means the execution of workflow succeed without any error.
SUCCEEDED = "succeeded"
# `FAILED` means the execution of workflow failed without some errors.
FAILED = "failed"
# `STOPPED` means the execution of workflow was stopped, either manually
# by the user, or automatically by the Dify application (E.G. the moderation
# mechanism.)
STOPPED = "stopped"
# `PARTIAL_SUCCEEDED` indicates that some errors occurred during the workflow
# execution, but they were successfully handled (e.g., by using an error
# strategy such as "fail branch" or "default value").
PARTIAL_SUCCEEDED = "partial-succeeded"
# `PAUSED` indicates that the workflow execution is temporarily paused
# (e.g., awaiting human input) and is expected to resume later.
PAUSED = "paused"
def is_ended(self) -> bool:
return self in _END_STATE
_END_STATE = frozenset(
[
WorkflowExecutionStatus.SUCCEEDED,
WorkflowExecutionStatus.FAILED,
WorkflowExecutionStatus.PARTIAL_SUCCEEDED,
WorkflowExecutionStatus.STOPPED,
]
)
class WorkflowNodeExecutionMetadataKey(StrEnum):
"""
Node Run Metadata Key.
"""
TOTAL_TOKENS = "total_tokens"
TOTAL_PRICE = "total_price"
CURRENCY = "currency"
TOOL_INFO = "tool_info"
AGENT_LOG = "agent_log"
TRIGGER_INFO = "trigger_info"
ITERATION_ID = "iteration_id"
ITERATION_INDEX = "iteration_index"
LOOP_ID = "loop_id"
LOOP_INDEX = "loop_index"
PARALLEL_ID = "parallel_id"
PARALLEL_START_NODE_ID = "parallel_start_node_id"
PARENT_PARALLEL_ID = "parent_parallel_id"
PARENT_PARALLEL_START_NODE_ID = "parent_parallel_start_node_id"
PARALLEL_MODE_RUN_ID = "parallel_mode_run_id"
ITERATION_DURATION_MAP = "iteration_duration_map" # single iteration duration if iteration node runs
LOOP_DURATION_MAP = "loop_duration_map" # single loop duration if loop node runs
ERROR_STRATEGY = "error_strategy" # node in continue on error mode return the field
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
DATASOURCE_INFO = "datasource_info"
class WorkflowNodeExecutionStatus(StrEnum):
PENDING = "pending" # Node is scheduled but not yet executing
RUNNING = "running"
SUCCEEDED = "succeeded"
FAILED = "failed"
EXCEPTION = "exception"
STOPPED = "stopped"
PAUSED = "paused"
# Legacy statuses - kept for backward compatibility
RETRY = "retry" # Legacy: replaced by retry mechanism in error handling

View File

@@ -0,0 +1,16 @@
from core.workflow.nodes.base.node import Node
class WorkflowNodeRunFailedError(Exception):
def __init__(self, node: Node, err_msg: str):
self._node = node
self._error = err_msg
super().__init__(f"Node {node.title} run failed: {err_msg}")
@property
def node(self) -> Node:
return self._node
@property
def error(self) -> str:
return self._error

View File

@@ -0,0 +1,11 @@
from .edge import Edge
from .graph import Graph, GraphBuilder, NodeFactory
from .graph_template import GraphTemplate
__all__ = [
"Edge",
"Graph",
"GraphBuilder",
"GraphTemplate",
"NodeFactory",
]

View File

@@ -0,0 +1,15 @@
import uuid
from dataclasses import dataclass, field
from core.workflow.enums import NodeState
@dataclass
class Edge:
"""Edge connecting two nodes in a workflow graph."""
id: str = field(default_factory=lambda: str(uuid.uuid4()))
tail: str = "" # tail node id (source)
head: str = "" # head node id (target)
source_handle: str = "source" # source handle for conditional branching
state: NodeState = field(default=NodeState.UNKNOWN) # edge execution state

View File

@@ -0,0 +1,465 @@
import logging
from collections import defaultdict
from collections.abc import Mapping, Sequence
from typing import Protocol, cast, final
from core.workflow.enums import ErrorStrategy, NodeExecutionType, NodeState, NodeType
from core.workflow.nodes.base.node import Node
from libs.typing import is_str, is_str_dict
from .edge import Edge
from .validation import get_graph_validator
logger = logging.getLogger(__name__)
class NodeFactory(Protocol):
"""
Protocol for creating Node instances from node data dictionaries.
This protocol decouples the Graph class from specific node mapping implementations,
allowing for different node creation strategies while maintaining type safety.
"""
def create_node(self, node_config: dict[str, object]) -> Node:
"""
Create a Node instance from node configuration data.
:param node_config: node configuration dictionary containing type and other data
:return: initialized Node instance
:raises ValueError: if node type is unknown or configuration is invalid
"""
...
@final
class Graph:
"""Graph representation with nodes and edges for workflow execution."""
def __init__(
self,
*,
nodes: dict[str, Node] | None = None,
edges: dict[str, Edge] | None = None,
in_edges: dict[str, list[str]] | None = None,
out_edges: dict[str, list[str]] | None = None,
root_node: Node,
):
"""
Initialize Graph instance.
:param nodes: graph nodes mapping (node id: node object)
:param edges: graph edges mapping (edge id: edge object)
:param in_edges: incoming edges mapping (node id: list of edge ids)
:param out_edges: outgoing edges mapping (node id: list of edge ids)
:param root_node: root node object
"""
self.nodes = nodes or {}
self.edges = edges or {}
self.in_edges = in_edges or {}
self.out_edges = out_edges or {}
self.root_node = root_node
@classmethod
def _parse_node_configs(cls, node_configs: list[dict[str, object]]) -> dict[str, dict[str, object]]:
"""
Parse node configurations and build a mapping of node IDs to configs.
:param node_configs: list of node configuration dictionaries
:return: mapping of node ID to node config
"""
node_configs_map: dict[str, dict[str, object]] = {}
for node_config in node_configs:
node_id = node_config.get("id")
if not node_id or not isinstance(node_id, str):
continue
node_configs_map[node_id] = node_config
return node_configs_map
@classmethod
def _find_root_node_id(
cls,
node_configs_map: Mapping[str, Mapping[str, object]],
edge_configs: Sequence[Mapping[str, object]],
root_node_id: str | None = None,
) -> str:
"""
Find the root node ID if not specified.
:param node_configs_map: mapping of node ID to node config
:param edge_configs: list of edge configurations
:param root_node_id: explicitly specified root node ID
:return: determined root node ID
"""
if root_node_id:
if root_node_id not in node_configs_map:
raise ValueError(f"Root node id {root_node_id} not found in the graph")
return root_node_id
# Find nodes with no incoming edges
nodes_with_incoming: set[str] = set()
for edge_config in edge_configs:
target = edge_config.get("target")
if isinstance(target, str):
nodes_with_incoming.add(target)
root_candidates = [nid for nid in node_configs_map if nid not in nodes_with_incoming]
# Prefer START node if available
start_node_id = None
for nid in root_candidates:
node_data = node_configs_map[nid].get("data")
if not is_str_dict(node_data):
continue
node_type = node_data.get("type")
if not isinstance(node_type, str):
continue
if NodeType(node_type).is_start_node:
start_node_id = nid
break
root_node_id = start_node_id or (root_candidates[0] if root_candidates else None)
if not root_node_id:
raise ValueError("Unable to determine root node ID")
return root_node_id
@classmethod
def _build_edges(
cls, edge_configs: list[dict[str, object]]
) -> tuple[dict[str, Edge], dict[str, list[str]], dict[str, list[str]]]:
"""
Build edge objects and mappings from edge configurations.
:param edge_configs: list of edge configurations
:return: tuple of (edges dict, in_edges dict, out_edges dict)
"""
edges: dict[str, Edge] = {}
in_edges: dict[str, list[str]] = defaultdict(list)
out_edges: dict[str, list[str]] = defaultdict(list)
edge_counter = 0
for edge_config in edge_configs:
source = edge_config.get("source")
target = edge_config.get("target")
if not is_str(source) or not is_str(target):
continue
# Create edge
edge_id = f"edge_{edge_counter}"
edge_counter += 1
source_handle = edge_config.get("sourceHandle", "source")
if not is_str(source_handle):
continue
edge = Edge(
id=edge_id,
tail=source,
head=target,
source_handle=source_handle,
)
edges[edge_id] = edge
out_edges[source].append(edge_id)
in_edges[target].append(edge_id)
return edges, dict(in_edges), dict(out_edges)
@classmethod
def _create_node_instances(
cls,
node_configs_map: dict[str, dict[str, object]],
node_factory: "NodeFactory",
) -> dict[str, Node]:
"""
Create node instances from configurations using the node factory.
:param node_configs_map: mapping of node ID to node config
:param node_factory: factory for creating node instances
:return: mapping of node ID to node instance
"""
nodes: dict[str, Node] = {}
for node_id, node_config in node_configs_map.items():
try:
node_instance = node_factory.create_node(node_config)
except Exception:
logger.exception("Failed to create node instance for node_id %s", node_id)
raise
nodes[node_id] = node_instance
return nodes
@classmethod
def new(cls) -> "GraphBuilder":
"""Create a fluent builder for assembling a graph programmatically."""
return GraphBuilder(graph_cls=cls)
@classmethod
def _promote_fail_branch_nodes(cls, nodes: dict[str, Node]) -> None:
"""
Promote nodes configured with FAIL_BRANCH error strategy to branch execution type.
:param nodes: mapping of node ID to node instance
"""
for node in nodes.values():
if node.error_strategy == ErrorStrategy.FAIL_BRANCH:
node.execution_type = NodeExecutionType.BRANCH
@classmethod
def _mark_inactive_root_branches(
cls,
nodes: dict[str, Node],
edges: dict[str, Edge],
in_edges: dict[str, list[str]],
out_edges: dict[str, list[str]],
active_root_id: str,
) -> None:
"""
Mark nodes and edges from inactive root branches as skipped.
Algorithm:
1. Mark inactive root nodes as skipped
2. For skipped nodes, mark all their outgoing edges as skipped
3. For each edge marked as skipped, check its target node:
- If ALL incoming edges are skipped, mark the node as skipped
- Otherwise, leave the node state unchanged
:param nodes: mapping of node ID to node instance
:param edges: mapping of edge ID to edge instance
:param in_edges: mapping of node ID to incoming edge IDs
:param out_edges: mapping of node ID to outgoing edge IDs
:param active_root_id: ID of the active root node
"""
# Find all top-level root nodes (nodes with ROOT execution type and no incoming edges)
top_level_roots: list[str] = [
node.id for node in nodes.values() if node.execution_type == NodeExecutionType.ROOT
]
# If there's only one root or the active root is not a top-level root, no marking needed
if len(top_level_roots) <= 1 or active_root_id not in top_level_roots:
return
# Mark inactive root nodes as skipped
inactive_roots: list[str] = [root_id for root_id in top_level_roots if root_id != active_root_id]
for root_id in inactive_roots:
if root_id in nodes:
nodes[root_id].state = NodeState.SKIPPED
# Recursively mark downstream nodes and edges
def mark_downstream(node_id: str) -> None:
"""Recursively mark downstream nodes and edges as skipped."""
if nodes[node_id].state != NodeState.SKIPPED:
return
# If this node is skipped, mark all its outgoing edges as skipped
out_edge_ids = out_edges.get(node_id, [])
for edge_id in out_edge_ids:
edge = edges[edge_id]
edge.state = NodeState.SKIPPED
# Check the target node of this edge
target_node = nodes[edge.head]
in_edge_ids = in_edges.get(target_node.id, [])
in_edge_states = [edges[eid].state for eid in in_edge_ids]
# If all incoming edges are skipped, mark the node as skipped
if all(state == NodeState.SKIPPED for state in in_edge_states):
target_node.state = NodeState.SKIPPED
# Recursively process downstream nodes
mark_downstream(target_node.id)
# Process each inactive root and its downstream nodes
for root_id in inactive_roots:
mark_downstream(root_id)
@classmethod
def init(
cls,
*,
graph_config: Mapping[str, object],
node_factory: "NodeFactory",
root_node_id: str | None = None,
) -> "Graph":
"""
Initialize graph
:param graph_config: graph config containing nodes and edges
:param node_factory: factory for creating node instances from config data
:param root_node_id: root node id
:return: graph instance
"""
# Parse configs
edge_configs = graph_config.get("edges", [])
node_configs = graph_config.get("nodes", [])
edge_configs = cast(list[dict[str, object]], edge_configs)
node_configs = cast(list[dict[str, object]], node_configs)
if not node_configs:
raise ValueError("Graph must have at least one node")
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
# Parse node configurations
node_configs_map = cls._parse_node_configs(node_configs)
# Find root node
root_node_id = cls._find_root_node_id(node_configs_map, edge_configs, root_node_id)
# Build edges
edges, in_edges, out_edges = cls._build_edges(edge_configs)
# Create node instances
nodes = cls._create_node_instances(node_configs_map, node_factory)
# Promote fail-branch nodes to branch execution type at graph level
cls._promote_fail_branch_nodes(nodes)
# Get root node instance
root_node = nodes[root_node_id]
# Mark inactive root branches as skipped
cls._mark_inactive_root_branches(nodes, edges, in_edges, out_edges, root_node_id)
# Create and return the graph
graph = cls(
nodes=nodes,
edges=edges,
in_edges=in_edges,
out_edges=out_edges,
root_node=root_node,
)
# Validate the graph structure using built-in validators
get_graph_validator().validate(graph)
return graph
@property
def node_ids(self) -> list[str]:
"""
Get list of node IDs (compatibility property for existing code)
:return: list of node IDs
"""
return list(self.nodes.keys())
def get_outgoing_edges(self, node_id: str) -> list[Edge]:
"""
Get all outgoing edges from a node (V2 method)
:param node_id: node id
:return: list of outgoing edges
"""
edge_ids = self.out_edges.get(node_id, [])
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
def get_incoming_edges(self, node_id: str) -> list[Edge]:
"""
Get all incoming edges to a node (V2 method)
:param node_id: node id
:return: list of incoming edges
"""
edge_ids = self.in_edges.get(node_id, [])
return [self.edges[eid] for eid in edge_ids if eid in self.edges]
@final
class GraphBuilder:
"""Fluent helper for constructing simple graphs, primarily for tests."""
def __init__(self, *, graph_cls: type[Graph]):
self._graph_cls = graph_cls
self._nodes: list[Node] = []
self._nodes_by_id: dict[str, Node] = {}
self._edges: list[Edge] = []
self._edge_counter = 0
def add_root(self, node: Node) -> "GraphBuilder":
"""Register the root node. Must be called exactly once."""
if self._nodes:
raise ValueError("Root node has already been added")
self._register_node(node)
self._nodes.append(node)
return self
def add_node(
self,
node: Node,
*,
from_node_id: str | None = None,
source_handle: str = "source",
) -> "GraphBuilder":
"""Append a node and connect it from the specified predecessor."""
if not self._nodes:
raise ValueError("Root node must be added before adding other nodes")
predecessor_id = from_node_id or self._nodes[-1].id
if predecessor_id not in self._nodes_by_id:
raise ValueError(f"Predecessor node '{predecessor_id}' not found")
predecessor = self._nodes_by_id[predecessor_id]
self._register_node(node)
self._nodes.append(node)
edge_id = f"edge_{self._edge_counter}"
self._edge_counter += 1
edge = Edge(id=edge_id, tail=predecessor.id, head=node.id, source_handle=source_handle)
self._edges.append(edge)
return self
def connect(self, *, tail: str, head: str, source_handle: str = "source") -> "GraphBuilder":
"""Connect two existing nodes without adding a new node."""
if tail not in self._nodes_by_id:
raise ValueError(f"Tail node '{tail}' not found")
if head not in self._nodes_by_id:
raise ValueError(f"Head node '{head}' not found")
edge_id = f"edge_{self._edge_counter}"
self._edge_counter += 1
edge = Edge(id=edge_id, tail=tail, head=head, source_handle=source_handle)
self._edges.append(edge)
return self
def build(self) -> Graph:
"""Materialize the graph instance from the accumulated nodes and edges."""
if not self._nodes:
raise ValueError("Cannot build an empty graph")
nodes = {node.id: node for node in self._nodes}
edges = {edge.id: edge for edge in self._edges}
in_edges: dict[str, list[str]] = defaultdict(list)
out_edges: dict[str, list[str]] = defaultdict(list)
for edge in self._edges:
out_edges[edge.tail].append(edge.id)
in_edges[edge.head].append(edge.id)
return self._graph_cls(
nodes=nodes,
edges=edges,
in_edges=dict(in_edges),
out_edges=dict(out_edges),
root_node=self._nodes[0],
)
def _register_node(self, node: Node) -> None:
if not node.id:
raise ValueError("Node must have a non-empty id")
if node.id in self._nodes_by_id:
raise ValueError(f"Duplicate node id detected: {node.id}")
self._nodes_by_id[node.id] = node

View File

@@ -0,0 +1,20 @@
from typing import Any
from pydantic import BaseModel, Field
class GraphTemplate(BaseModel):
"""
Graph Template for container nodes and subgraph expansion
According to GraphEngine V2 spec, GraphTemplate contains:
- nodes: mapping of node definitions
- edges: mapping of edge definitions
- root_ids: list of root node IDs
- output_selectors: list of output selectors for the template
"""
nodes: dict[str, dict[str, Any]] = Field(default_factory=dict, description="node definitions mapping")
edges: dict[str, dict[str, Any]] = Field(default_factory=dict, description="edge definitions mapping")
root_ids: list[str] = Field(default_factory=list, description="root node IDs")
output_selectors: list[str] = Field(default_factory=list, description="output selectors")

View File

@@ -0,0 +1,161 @@
from __future__ import annotations
from collections.abc import Sequence
from dataclasses import dataclass
from typing import TYPE_CHECKING, Protocol
from core.workflow.enums import NodeExecutionType, NodeType
if TYPE_CHECKING:
from .graph import Graph
@dataclass(frozen=True, slots=True)
class GraphValidationIssue:
"""Immutable value object describing a single validation issue."""
code: str
message: str
node_id: str | None = None
class GraphValidationError(ValueError):
"""Raised when graph validation fails."""
def __init__(self, issues: Sequence[GraphValidationIssue]) -> None:
if not issues:
raise ValueError("GraphValidationError requires at least one issue.")
self.issues: tuple[GraphValidationIssue, ...] = tuple(issues)
message = "; ".join(f"[{issue.code}] {issue.message}" for issue in self.issues)
super().__init__(message)
class GraphValidationRule(Protocol):
"""Protocol that individual validation rules must satisfy."""
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
"""Validate the provided graph and return any discovered issues."""
...
@dataclass(frozen=True, slots=True)
class _EdgeEndpointValidator:
"""Ensures all edges reference existing nodes."""
missing_node_code: str = "MISSING_NODE"
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
issues: list[GraphValidationIssue] = []
for edge in graph.edges.values():
if edge.tail not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.missing_node_code,
message=f"Edge {edge.id} references unknown source node '{edge.tail}'.",
node_id=edge.tail,
)
)
if edge.head not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.missing_node_code,
message=f"Edge {edge.id} references unknown target node '{edge.head}'.",
node_id=edge.head,
)
)
return issues
@dataclass(frozen=True, slots=True)
class _RootNodeValidator:
"""Validates root node invariants."""
invalid_root_code: str = "INVALID_ROOT"
container_entry_types: tuple[NodeType, ...] = (NodeType.ITERATION_START, NodeType.LOOP_START)
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
root_node = graph.root_node
issues: list[GraphValidationIssue] = []
if root_node.id not in graph.nodes:
issues.append(
GraphValidationIssue(
code=self.invalid_root_code,
message=f"Root node '{root_node.id}' is missing from the node registry.",
node_id=root_node.id,
)
)
return issues
node_type = getattr(root_node, "node_type", None)
if root_node.execution_type != NodeExecutionType.ROOT and node_type not in self.container_entry_types:
issues.append(
GraphValidationIssue(
code=self.invalid_root_code,
message=f"Root node '{root_node.id}' must declare execution type 'root'.",
node_id=root_node.id,
)
)
return issues
@dataclass(frozen=True, slots=True)
class GraphValidator:
"""Coordinates execution of graph validation rules."""
rules: tuple[GraphValidationRule, ...]
def validate(self, graph: Graph) -> None:
"""Validate the graph against all configured rules."""
issues: list[GraphValidationIssue] = []
for rule in self.rules:
issues.extend(rule.validate(graph))
if issues:
raise GraphValidationError(issues)
@dataclass(frozen=True, slots=True)
class _TriggerStartExclusivityValidator:
"""Ensures trigger nodes do not coexist with UserInput (start) nodes."""
conflict_code: str = "TRIGGER_START_NODE_CONFLICT"
def validate(self, graph: Graph) -> Sequence[GraphValidationIssue]:
start_node_id: str | None = None
trigger_node_ids: list[str] = []
for node in graph.nodes.values():
node_type = getattr(node, "node_type", None)
if not isinstance(node_type, NodeType):
continue
if node_type == NodeType.START:
start_node_id = node.id
elif node_type.is_trigger_node:
trigger_node_ids.append(node.id)
if start_node_id and trigger_node_ids:
trigger_list = ", ".join(trigger_node_ids)
return [
GraphValidationIssue(
code=self.conflict_code,
message=(
f"UserInput (start) node '{start_node_id}' cannot coexist with trigger nodes: {trigger_list}."
),
node_id=start_node_id,
)
]
return []
_DEFAULT_RULES: tuple[GraphValidationRule, ...] = (
_EdgeEndpointValidator(),
_RootNodeValidator(),
_TriggerStartExclusivityValidator(),
)
def get_graph_validator() -> GraphValidator:
"""Construct the validator composed of default rules."""
return GraphValidator(_DEFAULT_RULES)

Some files were not shown because too many files have changed in this diff Show More