Compare commits

...

2 Commits

Author SHA1 Message Date
-LAN-
64ef2ca103 Update api/core/variables/variables.py
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2026-01-09 16:26:07 +08:00
-LAN-
07eac2a982 refactor(variables): rename Variable to VariableBase and VariableUnion to Variable
Updated core variable definitions and propagated the base/union naming change through workflow runtime, services, and tests.
Tests not run (not requested).

Next steps:
1. `uv run --project api --dev dev/pytest/pytest_unit_tests.sh`
2. `make lint`
3. `make type-check`
2026-01-09 16:26:07 +08:00
22 changed files with 124 additions and 125 deletions

View File

@@ -24,7 +24,7 @@ from core.app.layers.conversation_variable_persist_layer import ConversationVari
from core.db.session_factory import session_factory
from core.moderation.base import ModerationError
from core.moderation.input_moderation import InputModeration
from core.variables.variables import VariableUnion
from core.variables.variables import Variable
from core.workflow.enums import WorkflowType
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
from core.workflow.graph_engine.layers.base import GraphEngineLayer
@@ -149,8 +149,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
system_variables=system_inputs,
user_inputs=inputs,
environment_variables=self._workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
# Based on the definition of `Variable`,
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
conversation_variables=conversation_variables,
)
@@ -318,7 +318,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
trace_manager=app_generate_entity.trace_manager,
)
def _initialize_conversation_variables(self) -> list[VariableUnion]:
def _initialize_conversation_variables(self) -> list[Variable]:
"""
Initialize conversation variables for the current conversation.
@@ -343,7 +343,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
conversation_variables = [var.to_variable() for var in existing_variables]
session.commit()
return cast(list[VariableUnion], conversation_variables)
return cast(list[Variable], conversation_variables)
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
"""

View File

@@ -1,6 +1,6 @@
import logging
from core.variables import Variable
from core.variables import VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
from core.workflow.enums import NodeType
@@ -44,7 +44,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
continue
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
if not isinstance(variable, VariableBase):
logger.warning(
"Conversation variable not found in variable pool. selector=%s",
selector,

View File

@@ -30,6 +30,7 @@ from .variables import (
SecretVariable,
StringVariable,
Variable,
VariableBase,
)
__all__ = [
@@ -62,4 +63,5 @@ __all__ = [
"StringSegment",
"StringVariable",
"Variable",
"VariableBase",
]

View File

@@ -232,7 +232,7 @@ def get_segment_discriminator(v: Any) -> SegmentType | None:
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
# - `SegmentGroup`, which is not added to the variable pool.
# - `Variable` and its subclasses, which are handled by `VariableUnion`.
# - `VariableBase` and its subclasses, which are handled by `Variable`.
SegmentUnion: TypeAlias = Annotated[
(
Annotated[NoneSegment, Tag(SegmentType.NONE)]

View File

@@ -27,7 +27,7 @@ from .segments import (
from .types import SegmentType
class Variable(Segment):
class VariableBase(Segment):
"""
A variable is a segment that has a name.
@@ -45,23 +45,23 @@ class Variable(Segment):
selector: Sequence[str] = Field(default_factory=list)
class StringVariable(StringSegment, Variable):
class StringVariable(StringSegment, VariableBase):
pass
class FloatVariable(FloatSegment, Variable):
class FloatVariable(FloatSegment, VariableBase):
pass
class IntegerVariable(IntegerSegment, Variable):
class IntegerVariable(IntegerSegment, VariableBase):
pass
class ObjectVariable(ObjectSegment, Variable):
class ObjectVariable(ObjectSegment, VariableBase):
pass
class ArrayVariable(ArraySegment, Variable):
class ArrayVariable(ArraySegment, VariableBase):
pass
@@ -89,16 +89,16 @@ class SecretVariable(StringVariable):
return encrypter.obfuscated_token(self.value)
class NoneVariable(NoneSegment, Variable):
class NoneVariable(NoneSegment, VariableBase):
value_type: SegmentType = SegmentType.NONE
value: None = None
class FileVariable(FileSegment, Variable):
class FileVariable(FileSegment, VariableBase):
pass
class BooleanVariable(BooleanSegment, Variable):
class BooleanVariable(BooleanSegment, VariableBase):
pass
@@ -139,13 +139,13 @@ class RAGPipelineVariableInput(BaseModel):
value: Any
# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
# Use `Variable` for type hinting when serialization is not required.
# The `Variable` type is used to enable serialization and deserialization with Pydantic.
# Use `VariableBase` for type hinting when serialization is not required.
#
# Note:
# - All variants in `VariableUnion` must inherit from the `Variable` class.
# - The union must include all non-abstract subclasses of `Segment`, except:
VariableUnion: TypeAlias = Annotated[
# - All variants in `Variable` must inherit from the `VariableBase` class.
# - The union must include all non-abstract subclasses of `VariableBase`.
Variable: TypeAlias = Annotated[
(
Annotated[NoneVariable, Tag(SegmentType.NONE)]
| Annotated[StringVariable, Tag(SegmentType.STRING)]

View File

@@ -1,7 +1,7 @@
import abc
from typing import Protocol
from core.variables import Variable
from core.variables import VariableBase
class ConversationVariableUpdater(Protocol):
@@ -20,12 +20,12 @@ class ConversationVariableUpdater(Protocol):
"""
@abc.abstractmethod
def update(self, conversation_id: str, variable: "Variable"):
def update(self, conversation_id: str, variable: "VariableBase"):
"""
Updates the value of the specified conversation variable in the underlying storage.
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
:param variable: The `Variable` instance containing the updated value.
:param variable: The `VariableBase` instance containing the updated value.
"""
pass

View File

@@ -11,7 +11,7 @@ from typing import Any
from pydantic import BaseModel, Field
from core.variables.variables import VariableUnion
from core.variables.variables import Variable
class CommandType(StrEnum):
@@ -46,7 +46,7 @@ class PauseCommand(GraphEngineCommand):
class VariableUpdate(BaseModel):
"""Represents a single variable update instruction."""
value: VariableUnion = Field(description="New variable value")
value: Variable = Field(description="New variable value")
class UpdateVariablesCommand(GraphEngineCommand):

View File

@@ -11,7 +11,7 @@ from typing_extensions import TypeIs
from core.model_runtime.entities.llm_entities import LLMUsage
from core.variables import IntegerVariable, NoneSegment
from core.variables.segments import ArrayAnySegment, ArraySegment
from core.variables.variables import VariableUnion
from core.variables.variables import Variable
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import (
NodeExecutionType,
@@ -240,7 +240,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
datetime,
list[GraphNodeEventBase],
object | None,
dict[str, VariableUnion],
dict[str, Variable],
LLMUsage,
]
],
@@ -308,7 +308,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
item: object,
flask_app: Flask,
context_vars: contextvars.Context,
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]:
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
"""Execute a single iteration in parallel mode and return results."""
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
@@ -515,11 +515,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
return variable_mapping
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]:
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]:
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None:
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None:
parent_pool = self.graph_runtime_state.variable_pool
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})

View File

@@ -1,7 +1,7 @@
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, Variable
from core.variables import SegmentType, VariableBase
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.entities import GraphInitParams
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@@ -64,7 +64,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
assigned_variable_selector = self.node_data.assigned_variable_selector
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
if not isinstance(original_variable, Variable):
if not isinstance(original_variable, VariableBase):
raise VariableOperatorNodeError("assigned variable not found")
match self.node_data.write_mode:

View File

@@ -2,7 +2,7 @@ import json
from collections.abc import Mapping, MutableMapping, Sequence
from typing import TYPE_CHECKING, Any
from core.variables import SegmentType, Variable
from core.variables import SegmentType, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@@ -118,7 +118,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
# ==================== Validation Part
# Check if variable exists
if not isinstance(variable, Variable):
if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=item.variable_selector)
# Check if operation is supported
@@ -192,7 +192,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
for selector in updated_variable_selectors:
variable = self.graph_runtime_state.variable_pool.get(selector)
if not isinstance(variable, Variable):
if not isinstance(variable, VariableBase):
raise VariableNotFoundError(variable_selector=selector)
process_data[variable.name] = variable.value
@@ -213,7 +213,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
def _handle_item(
self,
*,
variable: Variable,
variable: VariableBase,
operation: Operation,
value: Any,
):

View File

@@ -9,10 +9,10 @@ from typing import Annotated, Any, Union, cast
from pydantic import BaseModel, Field
from core.file import File, FileAttribute, file_manager
from core.variables import Segment, SegmentGroup, Variable
from core.variables import Segment, SegmentGroup, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import FileSegment, ObjectSegment
from core.variables.variables import RAGPipelineVariableInput, VariableUnion
from core.variables.variables import RAGPipelineVariableInput, Variable
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
ENVIRONMENT_VARIABLE_NODE_ID,
@@ -32,7 +32,7 @@ class VariablePool(BaseModel):
# The first element of the selector is the node id, it's the first-level key in the dictionary.
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
# elements of the selector except the first one.
variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field(
variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field(
description="Variables mapping",
default=defaultdict(dict),
)
@@ -46,13 +46,13 @@ class VariablePool(BaseModel):
description="System variables",
default_factory=SystemVariable.empty,
)
environment_variables: Sequence[VariableUnion] = Field(
environment_variables: Sequence[Variable] = Field(
description="Environment variables.",
default_factory=list[VariableUnion],
default_factory=list[Variable],
)
conversation_variables: Sequence[VariableUnion] = Field(
conversation_variables: Sequence[Variable] = Field(
description="Conversation variables.",
default_factory=list[VariableUnion],
default_factory=list[Variable],
)
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
description="RAG pipeline variables.",
@@ -105,7 +105,7 @@ class VariablePool(BaseModel):
f"got {len(selector)} elements"
)
if isinstance(value, Variable):
if isinstance(value, VariableBase):
variable = value
elif isinstance(value, Segment):
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
@@ -114,9 +114,9 @@ class VariablePool(BaseModel):
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
node_id, name = self._selector_to_keys(selector)
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
self.variable_dictionary[node_id][name] = cast(VariableUnion, variable)
# Based on the definition of `Variable`,
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
self.variable_dictionary[node_id][name] = cast(Variable, variable)
@classmethod
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:

View File

@@ -2,7 +2,7 @@ import abc
from collections.abc import Mapping, Sequence
from typing import Any, Protocol
from core.variables import Variable
from core.variables import VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.runtime import VariablePool
@@ -26,7 +26,7 @@ class VariableLoader(Protocol):
"""
@abc.abstractmethod
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
"""Load variables based on the provided selectors. If the selectors are empty,
this method should return an empty list.
@@ -36,7 +36,7 @@ class VariableLoader(Protocol):
:param: selectors: a list of string list, each inner list should have at least two elements:
- the first element is the node ID,
- the second element is the variable name.
:return: a list of Variable objects that match the provided selectors.
:return: a list of VariableBase objects that match the provided selectors.
"""
pass
@@ -46,7 +46,7 @@ class _DummyVariableLoader(VariableLoader):
Serves as a placeholder when no variable loading is needed.
"""
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
return []

View File

@@ -38,7 +38,7 @@ from core.variables.variables import (
ObjectVariable,
SecretVariable,
StringVariable,
Variable,
VariableBase,
)
from core.workflow.constants import (
CONVERSATION_VARIABLE_NODE_ID,
@@ -72,25 +72,25 @@ SEGMENT_TO_VARIABLE_MAP = {
}
def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("name"):
raise VariableError("missing name")
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
if not mapping.get("variable"):
raise VariableError("missing variable")
return mapping["variable"]
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> VariableBase:
"""
This factory function is used to create the environment variable or the conversation variable,
not support the File type.
@@ -100,7 +100,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
if (value := mapping.get("value")) is None:
raise VariableError("missing value")
result: Variable
result: VariableBase
match value_type:
case SegmentType.STRING:
result = StringVariable.model_validate(mapping)
@@ -134,7 +134,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
if not result.selector:
result = result.model_copy(update={"selector": selector})
return cast(Variable, result)
return cast(VariableBase, result)
def build_segment(value: Any, /) -> Segment:
@@ -285,8 +285,8 @@ def segment_to_variable(
id: str | None = None,
name: str | None = None,
description: str = "",
) -> Variable:
if isinstance(segment, Variable):
) -> VariableBase:
if isinstance(segment, VariableBase):
return segment
name = name or selector[-1]
id = id or str(uuid4())
@@ -297,7 +297,7 @@ def segment_to_variable(
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
return cast(
Variable,
VariableBase,
variable_class(
id=id,
name=name,

View File

@@ -1,7 +1,7 @@
from flask_restx import fields
from core.helper import encrypter
from core.variables import SecretVariable, SegmentType, Variable
from core.variables import SecretVariable, SegmentType, VariableBase
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
@@ -21,7 +21,7 @@ class EnvironmentVariableField(fields.Raw):
"value_type": value.value_type.value,
"description": value.description,
}
if isinstance(value, Variable):
if isinstance(value, VariableBase):
return {
"id": value.id,
"name": value.name,

View File

@@ -1,11 +1,9 @@
from __future__ import annotations
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from datetime import datetime
from enum import StrEnum
from typing import TYPE_CHECKING, Any, Union, cast
from typing import TYPE_CHECKING, Any, Optional, Union, cast
from uuid import uuid4
import sqlalchemy as sa
@@ -46,7 +44,7 @@ if TYPE_CHECKING:
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
from core.helper import encrypter
from core.variables import SecretVariable, Segment, SegmentType, Variable
from core.variables import SecretVariable, Segment, SegmentType, VariableBase
from factories import variable_factory
from libs import helper
@@ -69,7 +67,7 @@ class WorkflowType(StrEnum):
RAG_PIPELINE = "rag-pipeline"
@classmethod
def value_of(cls, value: str) -> WorkflowType:
def value_of(cls, value: str) -> "WorkflowType":
"""
Get value of given mode.
@@ -82,7 +80,7 @@ class WorkflowType(StrEnum):
raise ValueError(f"invalid workflow type value {value}")
@classmethod
def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType:
def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
"""
Get workflow type from app mode.
@@ -178,12 +176,12 @@ class Workflow(Base): # bug
graph: str,
features: str,
created_by: str,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list[dict],
marked_name: str = "",
marked_comment: str = "",
) -> Workflow:
) -> "Workflow":
workflow = Workflow()
workflow.id = str(uuid4())
workflow.tenant_id = tenant_id
@@ -447,7 +445,7 @@ class Workflow(Base): # bug
# decrypt secret variables value
def decrypt_func(
var: Variable,
var: VariableBase,
) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
@@ -463,7 +461,7 @@ class Workflow(Base): # bug
return decrypted_results
@environment_variables.setter
def environment_variables(self, value: Sequence[Variable]):
def environment_variables(self, value: Sequence[VariableBase]):
if not value:
self._environment_variables = "{}"
return
@@ -487,7 +485,7 @@ class Workflow(Base): # bug
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})
# encrypt secret variables value
def encrypt_func(var: Variable) -> Variable:
def encrypt_func(var: VariableBase) -> VariableBase:
if isinstance(var, SecretVariable):
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
else:
@@ -517,7 +515,7 @@ class Workflow(Base): # bug
return result
@property
def conversation_variables(self) -> Sequence[Variable]:
def conversation_variables(self) -> Sequence[VariableBase]:
# TODO: find some way to init `self._conversation_variables` when instance created.
if self._conversation_variables is None:
self._conversation_variables = "{}"
@@ -527,7 +525,7 @@ class Workflow(Base): # bug
return results
@conversation_variables.setter
def conversation_variables(self, value: Sequence[Variable]):
def conversation_variables(self, value: Sequence[VariableBase]):
self._conversation_variables = json.dumps(
{var.name: var.model_dump() for var in value},
ensure_ascii=False,
@@ -621,7 +619,7 @@ class WorkflowRun(Base):
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
pause: Mapped[WorkflowPause | None] = orm.relationship(
pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
"WorkflowPause",
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
uselist=False,
@@ -691,7 +689,7 @@ class WorkflowRun(Base):
}
@classmethod
def from_dict(cls, data: dict[str, Any]) -> WorkflowRun:
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
return cls(
id=data.get("id"),
tenant_id=data.get("tenant_id"),
@@ -843,7 +841,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
created_by: Mapped[str] = mapped_column(StringUUID)
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship(
offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship(
"WorkflowNodeExecutionOffload",
primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)",
uselist=True,
@@ -853,13 +851,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
@staticmethod
def preload_offload_data(
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
):
return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data))
@staticmethod
def preload_offload_data_and_files(
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
):
return query.options(
orm.selectinload(WorkflowNodeExecutionModel.offload_data).options(
@@ -934,7 +932,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
)
return extras
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None:
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
@property
@@ -1048,7 +1046,7 @@ class WorkflowNodeExecutionOffload(Base):
back_populates="offload_data",
)
file: Mapped[UploadFile | None] = orm.relationship(
file: Mapped[Optional["UploadFile"]] = orm.relationship(
foreign_keys=[file_id],
lazy="raise",
uselist=False,
@@ -1066,7 +1064,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
INSTALLED_APP = "installed-app"
@classmethod
def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom:
def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
"""
Get value of given mode.
@@ -1183,7 +1181,7 @@ class ConversationVariable(TypeBase):
)
@classmethod
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable:
def from_variable(cls, *, app_id: str, conversation_id: str, variable: VariableBase) -> "ConversationVariable":
obj = cls(
id=variable.id,
app_id=app_id,
@@ -1192,7 +1190,7 @@ class ConversationVariable(TypeBase):
)
return obj
def to_variable(self) -> Variable:
def to_variable(self) -> VariableBase:
mapping = json.loads(self.data)
return variable_factory.build_conversation_variable_from_mapping(mapping)
@@ -1336,7 +1334,7 @@ class WorkflowDraftVariable(Base):
)
# Relationship to WorkflowDraftVariableFile
variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship(
variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship(
foreign_keys=[file_id],
lazy="raise",
uselist=False,
@@ -1506,7 +1504,7 @@ class WorkflowDraftVariable(Base):
node_execution_id: str | None,
description: str = "",
file_id: str | None = None,
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = WorkflowDraftVariable()
variable.id = str(uuid4())
variable.created_at = naive_utc_now()
@@ -1529,7 +1527,7 @@ class WorkflowDraftVariable(Base):
name: str,
value: Segment,
description: str = "",
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=CONVERSATION_VARIABLE_NODE_ID,
@@ -1550,7 +1548,7 @@ class WorkflowDraftVariable(Base):
value: Segment,
node_execution_id: str,
editable: bool = False,
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=SYSTEM_VARIABLE_NODE_ID,
@@ -1573,7 +1571,7 @@ class WorkflowDraftVariable(Base):
visible: bool = True,
editable: bool = True,
file_id: str | None = None,
) -> WorkflowDraftVariable:
) -> "WorkflowDraftVariable":
variable = cls._new(
app_id=app_id,
node_id=node_id,
@@ -1669,7 +1667,7 @@ class WorkflowDraftVariableFile(Base):
)
# Relationship to UploadFile
upload_file: Mapped[UploadFile] = orm.relationship(
upload_file: Mapped["UploadFile"] = orm.relationship(
foreign_keys=[upload_file_id],
lazy="raise",
uselist=False,
@@ -1736,7 +1734,7 @@ class WorkflowPause(DefaultFieldsMixin, Base):
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
# Relationship to WorkflowRun
workflow_run: Mapped[WorkflowRun] = orm.relationship(
workflow_run: Mapped["WorkflowRun"] = orm.relationship(
foreign_keys=[workflow_run_id],
# require explicit preloading.
lazy="raise",
@@ -1792,7 +1790,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
)
@classmethod
def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason:
def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
if isinstance(pause_reason, HumanInputRequired):
return cls(
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id

View File

@@ -1,7 +1,7 @@
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.variables.variables import Variable
from core.variables.variables import VariableBase
from models import ConversationVariable
@@ -13,7 +13,7 @@ class ConversationVariableUpdater:
def __init__(self, session_maker: sessionmaker[Session]) -> None:
self._session_maker: sessionmaker[Session] = session_maker
def update(self, conversation_id: str, variable: Variable) -> None:
def update(self, conversation_id: str, variable: VariableBase) -> None:
stmt = select(ConversationVariable).where(
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
)

View File

@@ -36,7 +36,7 @@ from core.rag.entities.event import (
)
from core.repositories.factory import DifyCoreRepositoryFactory
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from core.variables.variables import Variable
from core.variables.variables import VariableBase
from core.workflow.entities.workflow_node_execution import (
WorkflowNodeExecution,
WorkflowNodeExecutionStatus,
@@ -270,8 +270,8 @@ class RagPipelineService:
graph: dict,
unique_hash: str | None,
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[VariableBase],
rag_pipeline_variables: list,
) -> Workflow:
"""

View File

@@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import and_, or_
from configs import dify_config
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from core.variables import Segment, StringSegment, Variable
from core.variables import Segment, StringSegment, VariableBase
from core.variables.consts import SELECTORS_LENGTH
from core.variables.segments import (
ArrayFileSegment,
@@ -77,14 +77,14 @@ class DraftVarLoader(VariableLoader):
# Application ID for which variables are being loaded.
_app_id: str
_tenant_id: str
_fallback_variables: Sequence[Variable]
_fallback_variables: Sequence[VariableBase]
def __init__(
self,
engine: Engine,
app_id: str,
tenant_id: str,
fallback_variables: Sequence[Variable] | None = None,
fallback_variables: Sequence[VariableBase] | None = None,
):
self._engine = engine
self._app_id = app_id
@@ -94,12 +94,12 @@ class DraftVarLoader(VariableLoader):
def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]:
return (selector[0], selector[1])
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
if not selectors:
return []
# Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance.
variable_by_selector: dict[tuple[str, str], Variable] = {}
# Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding variable instance.
variable_by_selector: dict[tuple[str, str], VariableBase] = {}
with Session(bind=self._engine, expire_on_commit=False) as session:
srv = WorkflowDraftVariableService(session)
@@ -145,7 +145,7 @@ class DraftVarLoader(VariableLoader):
return list(variable_by_selector.values())
def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], Variable]:
def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], VariableBase]:
# This logic is closely tied to `WorkflowDraftVaribleService._try_offload_large_variable`
# and must remain synchronized with it.
# Ideally, these should be co-located for better maintainability.

View File

@@ -13,8 +13,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
from core.file import File
from core.repositories import DifyCoreRepositoryFactory
from core.variables import Variable
from core.variables.variables import VariableUnion
from core.variables import VariableBase
from core.variables.variables import Variable
from core.workflow.entities import WorkflowNodeExecution
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.errors import WorkflowNodeRunFailedError
@@ -198,8 +198,8 @@ class WorkflowService:
features: dict,
unique_hash: str | None,
account: Account,
environment_variables: Sequence[Variable],
conversation_variables: Sequence[Variable],
environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[VariableBase],
) -> Workflow:
"""
Sync draft workflow
@@ -1044,7 +1044,7 @@ def _setup_variable_pool(
workflow: Workflow,
node_type: NodeType,
conversation_id: str,
conversation_variables: list[Variable],
conversation_variables: list[VariableBase],
):
# Only inject system variables for START node type.
if node_type == NodeType.START or node_type.is_trigger_node:
@@ -1070,9 +1070,9 @@ def _setup_variable_pool(
system_variables=system_variable,
user_inputs=user_inputs,
environment_variables=workflow.environment_variables,
# Based on the definition of `VariableUnion`,
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
conversation_variables=cast(list[VariableUnion], conversation_variables), #
# Based on the definition of `Variable`,
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
conversation_variables=cast(list[Variable], conversation_variables), #
)
return variable_pool

View File

@@ -35,7 +35,6 @@ from core.variables.variables import (
SecretVariable,
StringVariable,
Variable,
VariableUnion,
)
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
@@ -96,7 +95,7 @@ class _Segments(BaseModel):
class _Variables(BaseModel):
variables: list[VariableUnion]
variables: list[Variable]
def create_test_file(
@@ -194,7 +193,7 @@ class TestSegmentDumpAndLoad:
# Create one instance of each variable type
test_file = create_test_file()
all_variables: list[VariableUnion] = [
all_variables: list[Variable] = [
NoneVariable(name="none_var"),
StringVariable(value="test string", name="string_var"),
IntegerVariable(value=42, name="int_var"),

View File

@@ -11,7 +11,7 @@ from core.variables import (
SegmentType,
StringVariable,
)
from core.variables.variables import Variable
from core.variables.variables import VariableBase
def test_frozen_variables():
@@ -76,7 +76,7 @@ def test_object_variable_to_object():
def test_variable_to_object():
var: Variable = StringVariable(name="text", value="text")
var: VariableBase = StringVariable(name="text", value="text")
assert var.to_object() == "text"
var = IntegerVariable(name="integer", value=42)
assert var.to_object() == 42

View File

@@ -24,7 +24,7 @@ from core.variables.variables import (
IntegerVariable,
ObjectVariable,
StringVariable,
VariableUnion,
Variable,
)
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
from core.workflow.runtime import VariablePool
@@ -160,7 +160,7 @@ class TestVariablePoolSerialization:
)
# Create environment variables with all types including ArrayFileVariable
env_vars: list[VariableUnion] = [
env_vars: list[Variable] = [
StringVariable(
id="env_string_id",
name="env_string",
@@ -182,7 +182,7 @@ class TestVariablePoolSerialization:
]
# Create conversation variables with complex data
conv_vars: list[VariableUnion] = [
conv_vars: list[Variable] = [
StringVariable(
id="conv_string_id",
name="conv_string",