mirror of
https://github.com/langgenius/dify.git
synced 2026-01-11 00:41:55 +00:00
Compare commits
2 Commits
detached
...
83b5-renam
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
64ef2ca103 | ||
|
|
07eac2a982 |
@@ -24,7 +24,7 @@ from core.app.layers.conversation_variable_persist_layer import ConversationVari
|
||||
from core.db.session_factory import session_factory
|
||||
from core.moderation.base import ModerationError
|
||||
from core.moderation.input_moderation import InputModeration
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.enums import WorkflowType
|
||||
from core.workflow.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
@@ -149,8 +149,8 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=self._workflow.environment_variables,
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
# Based on the definition of `Variable`,
|
||||
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
|
||||
@@ -318,7 +318,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
trace_manager=app_generate_entity.trace_manager,
|
||||
)
|
||||
|
||||
def _initialize_conversation_variables(self) -> list[VariableUnion]:
|
||||
def _initialize_conversation_variables(self) -> list[Variable]:
|
||||
"""
|
||||
Initialize conversation variables for the current conversation.
|
||||
|
||||
@@ -343,7 +343,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
conversation_variables = [var.to_variable() for var in existing_variables]
|
||||
|
||||
session.commit()
|
||||
return cast(list[VariableUnion], conversation_variables)
|
||||
return cast(list[Variable], conversation_variables)
|
||||
|
||||
def _load_existing_conversation_variables(self, session: Session) -> list[ConversationVariable]:
|
||||
"""
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import logging
|
||||
|
||||
from core.variables import Variable
|
||||
from core.variables import VariableBase
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.conversation_variable_updater import ConversationVariableUpdater
|
||||
from core.workflow.enums import NodeType
|
||||
@@ -44,7 +44,7 @@ class ConversationVariablePersistenceLayer(GraphEngineLayer):
|
||||
if selector[0] != CONVERSATION_VARIABLE_NODE_ID:
|
||||
continue
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if not isinstance(variable, Variable):
|
||||
if not isinstance(variable, VariableBase):
|
||||
logger.warning(
|
||||
"Conversation variable not found in variable pool. selector=%s",
|
||||
selector,
|
||||
|
||||
@@ -30,6 +30,7 @@ from .variables import (
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
Variable,
|
||||
VariableBase,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -62,4 +63,5 @@ __all__ = [
|
||||
"StringSegment",
|
||||
"StringVariable",
|
||||
"Variable",
|
||||
"VariableBase",
|
||||
]
|
||||
|
||||
@@ -232,7 +232,7 @@ def get_segment_discriminator(v: Any) -> SegmentType | None:
|
||||
# - All variants in `SegmentUnion` must inherit from the `Segment` class.
|
||||
# - The union must include all non-abstract subclasses of `Segment`, except:
|
||||
# - `SegmentGroup`, which is not added to the variable pool.
|
||||
# - `Variable` and its subclasses, which are handled by `VariableUnion`.
|
||||
# - `VariableBase` and its subclasses, which are handled by `Variable`.
|
||||
SegmentUnion: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[NoneSegment, Tag(SegmentType.NONE)]
|
||||
|
||||
@@ -27,7 +27,7 @@ from .segments import (
|
||||
from .types import SegmentType
|
||||
|
||||
|
||||
class Variable(Segment):
|
||||
class VariableBase(Segment):
|
||||
"""
|
||||
A variable is a segment that has a name.
|
||||
|
||||
@@ -45,23 +45,23 @@ class Variable(Segment):
|
||||
selector: Sequence[str] = Field(default_factory=list)
|
||||
|
||||
|
||||
class StringVariable(StringSegment, Variable):
|
||||
class StringVariable(StringSegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
class FloatVariable(FloatSegment, Variable):
|
||||
class FloatVariable(FloatSegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
class IntegerVariable(IntegerSegment, Variable):
|
||||
class IntegerVariable(IntegerSegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
class ObjectVariable(ObjectSegment, Variable):
|
||||
class ObjectVariable(ObjectSegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayVariable(ArraySegment, Variable):
|
||||
class ArrayVariable(ArraySegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
@@ -89,16 +89,16 @@ class SecretVariable(StringVariable):
|
||||
return encrypter.obfuscated_token(self.value)
|
||||
|
||||
|
||||
class NoneVariable(NoneSegment, Variable):
|
||||
class NoneVariable(NoneSegment, VariableBase):
|
||||
value_type: SegmentType = SegmentType.NONE
|
||||
value: None = None
|
||||
|
||||
|
||||
class FileVariable(FileSegment, Variable):
|
||||
class FileVariable(FileSegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
class BooleanVariable(BooleanSegment, Variable):
|
||||
class BooleanVariable(BooleanSegment, VariableBase):
|
||||
pass
|
||||
|
||||
|
||||
@@ -139,13 +139,13 @@ class RAGPipelineVariableInput(BaseModel):
|
||||
value: Any
|
||||
|
||||
|
||||
# The `VariableUnion`` type is used to enable serialization and deserialization with Pydantic.
|
||||
# Use `Variable` for type hinting when serialization is not required.
|
||||
# The `Variable` type is used to enable serialization and deserialization with Pydantic.
|
||||
# Use `VariableBase` for type hinting when serialization is not required.
|
||||
#
|
||||
# Note:
|
||||
# - All variants in `VariableUnion` must inherit from the `Variable` class.
|
||||
# - The union must include all non-abstract subclasses of `Segment`, except:
|
||||
VariableUnion: TypeAlias = Annotated[
|
||||
# - All variants in `Variable` must inherit from the `VariableBase` class.
|
||||
# - The union must include all non-abstract subclasses of `VariableBase`.
|
||||
Variable: TypeAlias = Annotated[
|
||||
(
|
||||
Annotated[NoneVariable, Tag(SegmentType.NONE)]
|
||||
| Annotated[StringVariable, Tag(SegmentType.STRING)]
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import abc
|
||||
from typing import Protocol
|
||||
|
||||
from core.variables import Variable
|
||||
from core.variables import VariableBase
|
||||
|
||||
|
||||
class ConversationVariableUpdater(Protocol):
|
||||
@@ -20,12 +20,12 @@ class ConversationVariableUpdater(Protocol):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, conversation_id: str, variable: "Variable"):
|
||||
def update(self, conversation_id: str, variable: "VariableBase"):
|
||||
"""
|
||||
Updates the value of the specified conversation variable in the underlying storage.
|
||||
|
||||
:param conversation_id: The ID of the conversation to update. Typically references `ConversationVariable.id`.
|
||||
:param variable: The `Variable` instance containing the updated value.
|
||||
:param variable: The `VariableBase` instance containing the updated value.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.variables.variables import Variable
|
||||
|
||||
|
||||
class CommandType(StrEnum):
|
||||
@@ -46,7 +46,7 @@ class PauseCommand(GraphEngineCommand):
|
||||
class VariableUpdate(BaseModel):
|
||||
"""Represents a single variable update instruction."""
|
||||
|
||||
value: VariableUnion = Field(description="New variable value")
|
||||
value: Variable = Field(description="New variable value")
|
||||
|
||||
|
||||
class UpdateVariablesCommand(GraphEngineCommand):
|
||||
|
||||
@@ -11,7 +11,7 @@ from typing_extensions import TypeIs
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.variables import IntegerVariable, NoneSegment
|
||||
from core.variables.segments import ArrayAnySegment, ArraySegment
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import (
|
||||
NodeExecutionType,
|
||||
@@ -240,7 +240,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
datetime,
|
||||
list[GraphNodeEventBase],
|
||||
object | None,
|
||||
dict[str, VariableUnion],
|
||||
dict[str, Variable],
|
||||
LLMUsage,
|
||||
]
|
||||
],
|
||||
@@ -308,7 +308,7 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
item: object,
|
||||
flask_app: Flask,
|
||||
context_vars: contextvars.Context,
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, VariableUnion], LLMUsage]:
|
||||
) -> tuple[datetime, list[GraphNodeEventBase], object | None, dict[str, Variable], LLMUsage]:
|
||||
"""Execute a single iteration in parallel mode and return results."""
|
||||
with preserve_flask_contexts(flask_app=flask_app, context_vars=context_vars):
|
||||
iter_start_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
@@ -515,11 +515,11 @@ class IterationNode(LLMUsageTrackingMixin, Node[IterationNodeData]):
|
||||
|
||||
return variable_mapping
|
||||
|
||||
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, VariableUnion]:
|
||||
def _extract_conversation_variable_snapshot(self, *, variable_pool: VariablePool) -> dict[str, Variable]:
|
||||
conversation_variables = variable_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
|
||||
return {name: variable.model_copy(deep=True) for name, variable in conversation_variables.items()}
|
||||
|
||||
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, VariableUnion]) -> None:
|
||||
def _sync_conversation_variables_from_snapshot(self, snapshot: dict[str, Variable]) -> None:
|
||||
parent_pool = self.graph_runtime_state.variable_pool
|
||||
parent_conversations = parent_pool.variable_dictionary.get(CONVERSATION_VARIABLE_NODE_ID, {})
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables import SegmentType, VariableBase
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
@@ -64,7 +64,7 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
assigned_variable_selector = self.node_data.assigned_variable_selector
|
||||
# Should be String, Number, Object, ArrayString, ArrayNumber, ArrayObject
|
||||
original_variable = self.graph_runtime_state.variable_pool.get(assigned_variable_selector)
|
||||
if not isinstance(original_variable, Variable):
|
||||
if not isinstance(original_variable, VariableBase):
|
||||
raise VariableOperatorNodeError("assigned variable not found")
|
||||
|
||||
match self.node_data.write_mode:
|
||||
|
||||
@@ -2,7 +2,7 @@ import json
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.variables import SegmentType, Variable
|
||||
from core.variables import SegmentType, VariableBase
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
@@ -118,7 +118,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
# ==================== Validation Part
|
||||
|
||||
# Check if variable exists
|
||||
if not isinstance(variable, Variable):
|
||||
if not isinstance(variable, VariableBase):
|
||||
raise VariableNotFoundError(variable_selector=item.variable_selector)
|
||||
|
||||
# Check if operation is supported
|
||||
@@ -192,7 +192,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
|
||||
for selector in updated_variable_selectors:
|
||||
variable = self.graph_runtime_state.variable_pool.get(selector)
|
||||
if not isinstance(variable, Variable):
|
||||
if not isinstance(variable, VariableBase):
|
||||
raise VariableNotFoundError(variable_selector=selector)
|
||||
process_data[variable.name] = variable.value
|
||||
|
||||
@@ -213,7 +213,7 @@ class VariableAssignerNode(Node[VariableAssignerNodeData]):
|
||||
def _handle_item(
|
||||
self,
|
||||
*,
|
||||
variable: Variable,
|
||||
variable: VariableBase,
|
||||
operation: Operation,
|
||||
value: Any,
|
||||
):
|
||||
|
||||
@@ -9,10 +9,10 @@ from typing import Annotated, Any, Union, cast
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.file import File, FileAttribute, file_manager
|
||||
from core.variables import Segment, SegmentGroup, Variable
|
||||
from core.variables import Segment, SegmentGroup, VariableBase
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.variables.segments import FileSegment, ObjectSegment
|
||||
from core.variables.variables import RAGPipelineVariableInput, VariableUnion
|
||||
from core.variables.variables import RAGPipelineVariableInput, Variable
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
ENVIRONMENT_VARIABLE_NODE_ID,
|
||||
@@ -32,7 +32,7 @@ class VariablePool(BaseModel):
|
||||
# The first element of the selector is the node id, it's the first-level key in the dictionary.
|
||||
# Other elements of the selector are the keys in the second-level dictionary. To get the key, we hash the
|
||||
# elements of the selector except the first one.
|
||||
variable_dictionary: defaultdict[str, Annotated[dict[str, VariableUnion], Field(default_factory=dict)]] = Field(
|
||||
variable_dictionary: defaultdict[str, Annotated[dict[str, Variable], Field(default_factory=dict)]] = Field(
|
||||
description="Variables mapping",
|
||||
default=defaultdict(dict),
|
||||
)
|
||||
@@ -46,13 +46,13 @@ class VariablePool(BaseModel):
|
||||
description="System variables",
|
||||
default_factory=SystemVariable.empty,
|
||||
)
|
||||
environment_variables: Sequence[VariableUnion] = Field(
|
||||
environment_variables: Sequence[Variable] = Field(
|
||||
description="Environment variables.",
|
||||
default_factory=list[VariableUnion],
|
||||
default_factory=list[Variable],
|
||||
)
|
||||
conversation_variables: Sequence[VariableUnion] = Field(
|
||||
conversation_variables: Sequence[Variable] = Field(
|
||||
description="Conversation variables.",
|
||||
default_factory=list[VariableUnion],
|
||||
default_factory=list[Variable],
|
||||
)
|
||||
rag_pipeline_variables: list[RAGPipelineVariableInput] = Field(
|
||||
description="RAG pipeline variables.",
|
||||
@@ -105,7 +105,7 @@ class VariablePool(BaseModel):
|
||||
f"got {len(selector)} elements"
|
||||
)
|
||||
|
||||
if isinstance(value, Variable):
|
||||
if isinstance(value, VariableBase):
|
||||
variable = value
|
||||
elif isinstance(value, Segment):
|
||||
variable = variable_factory.segment_to_variable(segment=value, selector=selector)
|
||||
@@ -114,9 +114,9 @@ class VariablePool(BaseModel):
|
||||
variable = variable_factory.segment_to_variable(segment=segment, selector=selector)
|
||||
|
||||
node_id, name = self._selector_to_keys(selector)
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
self.variable_dictionary[node_id][name] = cast(VariableUnion, variable)
|
||||
# Based on the definition of `Variable`,
|
||||
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
|
||||
self.variable_dictionary[node_id][name] = cast(Variable, variable)
|
||||
|
||||
@classmethod
|
||||
def _selector_to_keys(cls, selector: Sequence[str]) -> tuple[str, str]:
|
||||
|
||||
@@ -2,7 +2,7 @@ import abc
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.variables import Variable
|
||||
from core.variables import VariableBase
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
@@ -26,7 +26,7 @@ class VariableLoader(Protocol):
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
|
||||
"""Load variables based on the provided selectors. If the selectors are empty,
|
||||
this method should return an empty list.
|
||||
|
||||
@@ -36,7 +36,7 @@ class VariableLoader(Protocol):
|
||||
:param: selectors: a list of string list, each inner list should have at least two elements:
|
||||
- the first element is the node ID,
|
||||
- the second element is the variable name.
|
||||
:return: a list of Variable objects that match the provided selectors.
|
||||
:return: a list of VariableBase objects that match the provided selectors.
|
||||
"""
|
||||
pass
|
||||
|
||||
@@ -46,7 +46,7 @@ class _DummyVariableLoader(VariableLoader):
|
||||
Serves as a placeholder when no variable loading is needed.
|
||||
"""
|
||||
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
|
||||
return []
|
||||
|
||||
|
||||
|
||||
@@ -38,7 +38,7 @@ from core.variables.variables import (
|
||||
ObjectVariable,
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
Variable,
|
||||
VariableBase,
|
||||
)
|
||||
from core.workflow.constants import (
|
||||
CONVERSATION_VARIABLE_NODE_ID,
|
||||
@@ -72,25 +72,25 @@ SEGMENT_TO_VARIABLE_MAP = {
|
||||
}
|
||||
|
||||
|
||||
def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
def build_conversation_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
|
||||
if not mapping.get("name"):
|
||||
raise VariableError("missing name")
|
||||
return _build_variable_from_mapping(mapping=mapping, selector=[CONVERSATION_VARIABLE_NODE_ID, mapping["name"]])
|
||||
|
||||
|
||||
def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
def build_environment_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
|
||||
if not mapping.get("name"):
|
||||
raise VariableError("missing name")
|
||||
return _build_variable_from_mapping(mapping=mapping, selector=[ENVIRONMENT_VARIABLE_NODE_ID, mapping["name"]])
|
||||
|
||||
|
||||
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> Variable:
|
||||
def build_pipeline_variable_from_mapping(mapping: Mapping[str, Any], /) -> VariableBase:
|
||||
if not mapping.get("variable"):
|
||||
raise VariableError("missing variable")
|
||||
return mapping["variable"]
|
||||
|
||||
|
||||
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> Variable:
|
||||
def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequence[str]) -> VariableBase:
|
||||
"""
|
||||
This factory function is used to create the environment variable or the conversation variable,
|
||||
not support the File type.
|
||||
@@ -100,7 +100,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
|
||||
if (value := mapping.get("value")) is None:
|
||||
raise VariableError("missing value")
|
||||
|
||||
result: Variable
|
||||
result: VariableBase
|
||||
match value_type:
|
||||
case SegmentType.STRING:
|
||||
result = StringVariable.model_validate(mapping)
|
||||
@@ -134,7 +134,7 @@ def _build_variable_from_mapping(*, mapping: Mapping[str, Any], selector: Sequen
|
||||
raise VariableError(f"variable size {result.size} exceeds limit {dify_config.MAX_VARIABLE_SIZE}")
|
||||
if not result.selector:
|
||||
result = result.model_copy(update={"selector": selector})
|
||||
return cast(Variable, result)
|
||||
return cast(VariableBase, result)
|
||||
|
||||
|
||||
def build_segment(value: Any, /) -> Segment:
|
||||
@@ -285,8 +285,8 @@ def segment_to_variable(
|
||||
id: str | None = None,
|
||||
name: str | None = None,
|
||||
description: str = "",
|
||||
) -> Variable:
|
||||
if isinstance(segment, Variable):
|
||||
) -> VariableBase:
|
||||
if isinstance(segment, VariableBase):
|
||||
return segment
|
||||
name = name or selector[-1]
|
||||
id = id or str(uuid4())
|
||||
@@ -297,7 +297,7 @@ def segment_to_variable(
|
||||
|
||||
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
|
||||
return cast(
|
||||
Variable,
|
||||
VariableBase,
|
||||
variable_class(
|
||||
id=id,
|
||||
name=name,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from flask_restx import fields
|
||||
|
||||
from core.helper import encrypter
|
||||
from core.variables import SecretVariable, SegmentType, Variable
|
||||
from core.variables import SecretVariable, SegmentType, VariableBase
|
||||
from fields.member_fields import simple_account_fields
|
||||
from libs.helper import TimestampField
|
||||
|
||||
@@ -21,7 +21,7 @@ class EnvironmentVariableField(fields.Raw):
|
||||
"value_type": value.value_type.value,
|
||||
"description": value.description,
|
||||
}
|
||||
if isinstance(value, Variable):
|
||||
if isinstance(value, VariableBase):
|
||||
return {
|
||||
"id": value.id,
|
||||
"name": value.name,
|
||||
|
||||
@@ -1,11 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
from typing import TYPE_CHECKING, Any, Optional, Union, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import sqlalchemy as sa
|
||||
@@ -46,7 +44,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from constants import DEFAULT_FILE_NUMBER_LIMITS, HIDDEN_VALUE
|
||||
from core.helper import encrypter
|
||||
from core.variables import SecretVariable, Segment, SegmentType, Variable
|
||||
from core.variables import SecretVariable, Segment, SegmentType, VariableBase
|
||||
from factories import variable_factory
|
||||
from libs import helper
|
||||
|
||||
@@ -69,7 +67,7 @@ class WorkflowType(StrEnum):
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> WorkflowType:
|
||||
def value_of(cls, value: str) -> "WorkflowType":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@@ -82,7 +80,7 @@ class WorkflowType(StrEnum):
|
||||
raise ValueError(f"invalid workflow type value {value}")
|
||||
|
||||
@classmethod
|
||||
def from_app_mode(cls, app_mode: Union[str, AppMode]) -> WorkflowType:
|
||||
def from_app_mode(cls, app_mode: Union[str, "AppMode"]) -> "WorkflowType":
|
||||
"""
|
||||
Get workflow type from app mode.
|
||||
|
||||
@@ -178,12 +176,12 @@ class Workflow(Base): # bug
|
||||
graph: str,
|
||||
features: str,
|
||||
created_by: str,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
environment_variables: Sequence[VariableBase],
|
||||
conversation_variables: Sequence[VariableBase],
|
||||
rag_pipeline_variables: list[dict],
|
||||
marked_name: str = "",
|
||||
marked_comment: str = "",
|
||||
) -> Workflow:
|
||||
) -> "Workflow":
|
||||
workflow = Workflow()
|
||||
workflow.id = str(uuid4())
|
||||
workflow.tenant_id = tenant_id
|
||||
@@ -447,7 +445,7 @@ class Workflow(Base): # bug
|
||||
|
||||
# decrypt secret variables value
|
||||
def decrypt_func(
|
||||
var: Variable,
|
||||
var: VariableBase,
|
||||
) -> StringVariable | IntegerVariable | FloatVariable | SecretVariable:
|
||||
if isinstance(var, SecretVariable):
|
||||
return var.model_copy(update={"value": encrypter.decrypt_token(tenant_id=tenant_id, token=var.value)})
|
||||
@@ -463,7 +461,7 @@ class Workflow(Base): # bug
|
||||
return decrypted_results
|
||||
|
||||
@environment_variables.setter
|
||||
def environment_variables(self, value: Sequence[Variable]):
|
||||
def environment_variables(self, value: Sequence[VariableBase]):
|
||||
if not value:
|
||||
self._environment_variables = "{}"
|
||||
return
|
||||
@@ -487,7 +485,7 @@ class Workflow(Base): # bug
|
||||
value[i] = origin_variables_dictionary[variable.id].model_copy(update={"name": variable.name})
|
||||
|
||||
# encrypt secret variables value
|
||||
def encrypt_func(var: Variable) -> Variable:
|
||||
def encrypt_func(var: VariableBase) -> VariableBase:
|
||||
if isinstance(var, SecretVariable):
|
||||
return var.model_copy(update={"value": encrypter.encrypt_token(tenant_id=tenant_id, token=var.value)})
|
||||
else:
|
||||
@@ -517,7 +515,7 @@ class Workflow(Base): # bug
|
||||
return result
|
||||
|
||||
@property
|
||||
def conversation_variables(self) -> Sequence[Variable]:
|
||||
def conversation_variables(self) -> Sequence[VariableBase]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._conversation_variables is None:
|
||||
self._conversation_variables = "{}"
|
||||
@@ -527,7 +525,7 @@ class Workflow(Base): # bug
|
||||
return results
|
||||
|
||||
@conversation_variables.setter
|
||||
def conversation_variables(self, value: Sequence[Variable]):
|
||||
def conversation_variables(self, value: Sequence[VariableBase]):
|
||||
self._conversation_variables = json.dumps(
|
||||
{var.name: var.model_dump() for var in value},
|
||||
ensure_ascii=False,
|
||||
@@ -621,7 +619,7 @@ class WorkflowRun(Base):
|
||||
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
exceptions_count: Mapped[int] = mapped_column(sa.Integer, server_default=sa.text("0"), nullable=True)
|
||||
|
||||
pause: Mapped[WorkflowPause | None] = orm.relationship(
|
||||
pause: Mapped[Optional["WorkflowPause"]] = orm.relationship(
|
||||
"WorkflowPause",
|
||||
primaryjoin="WorkflowRun.id == foreign(WorkflowPause.workflow_run_id)",
|
||||
uselist=False,
|
||||
@@ -691,7 +689,7 @@ class WorkflowRun(Base):
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> WorkflowRun:
|
||||
def from_dict(cls, data: dict[str, Any]) -> "WorkflowRun":
|
||||
return cls(
|
||||
id=data.get("id"),
|
||||
tenant_id=data.get("tenant_id"),
|
||||
@@ -843,7 +841,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
created_by: Mapped[str] = mapped_column(StringUUID)
|
||||
finished_at: Mapped[datetime | None] = mapped_column(DateTime)
|
||||
|
||||
offload_data: Mapped[list[WorkflowNodeExecutionOffload]] = orm.relationship(
|
||||
offload_data: Mapped[list["WorkflowNodeExecutionOffload"]] = orm.relationship(
|
||||
"WorkflowNodeExecutionOffload",
|
||||
primaryjoin="WorkflowNodeExecutionModel.id == foreign(WorkflowNodeExecutionOffload.node_execution_id)",
|
||||
uselist=True,
|
||||
@@ -853,13 +851,13 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
|
||||
@staticmethod
|
||||
def preload_offload_data(
|
||||
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
|
||||
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
|
||||
):
|
||||
return query.options(orm.selectinload(WorkflowNodeExecutionModel.offload_data))
|
||||
|
||||
@staticmethod
|
||||
def preload_offload_data_and_files(
|
||||
query: Select[tuple[WorkflowNodeExecutionModel]] | orm.Query[WorkflowNodeExecutionModel],
|
||||
query: Select[tuple["WorkflowNodeExecutionModel"]] | orm.Query["WorkflowNodeExecutionModel"],
|
||||
):
|
||||
return query.options(
|
||||
orm.selectinload(WorkflowNodeExecutionModel.offload_data).options(
|
||||
@@ -934,7 +932,7 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
)
|
||||
return extras
|
||||
|
||||
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> WorkflowNodeExecutionOffload | None:
|
||||
def _get_offload_by_type(self, type_: ExecutionOffLoadType) -> Optional["WorkflowNodeExecutionOffload"]:
|
||||
return next(iter([i for i in self.offload_data if i.type_ == type_]), None)
|
||||
|
||||
@property
|
||||
@@ -1048,7 +1046,7 @@ class WorkflowNodeExecutionOffload(Base):
|
||||
back_populates="offload_data",
|
||||
)
|
||||
|
||||
file: Mapped[UploadFile | None] = orm.relationship(
|
||||
file: Mapped[Optional["UploadFile"]] = orm.relationship(
|
||||
foreign_keys=[file_id],
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
@@ -1066,7 +1064,7 @@ class WorkflowAppLogCreatedFrom(StrEnum):
|
||||
INSTALLED_APP = "installed-app"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> WorkflowAppLogCreatedFrom:
|
||||
def value_of(cls, value: str) -> "WorkflowAppLogCreatedFrom":
|
||||
"""
|
||||
Get value of given mode.
|
||||
|
||||
@@ -1183,7 +1181,7 @@ class ConversationVariable(TypeBase):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_variable(cls, *, app_id: str, conversation_id: str, variable: Variable) -> ConversationVariable:
|
||||
def from_variable(cls, *, app_id: str, conversation_id: str, variable: VariableBase) -> "ConversationVariable":
|
||||
obj = cls(
|
||||
id=variable.id,
|
||||
app_id=app_id,
|
||||
@@ -1192,7 +1190,7 @@ class ConversationVariable(TypeBase):
|
||||
)
|
||||
return obj
|
||||
|
||||
def to_variable(self) -> Variable:
|
||||
def to_variable(self) -> VariableBase:
|
||||
mapping = json.loads(self.data)
|
||||
return variable_factory.build_conversation_variable_from_mapping(mapping)
|
||||
|
||||
@@ -1336,7 +1334,7 @@ class WorkflowDraftVariable(Base):
|
||||
)
|
||||
|
||||
# Relationship to WorkflowDraftVariableFile
|
||||
variable_file: Mapped[WorkflowDraftVariableFile | None] = orm.relationship(
|
||||
variable_file: Mapped[Optional["WorkflowDraftVariableFile"]] = orm.relationship(
|
||||
foreign_keys=[file_id],
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
@@ -1506,7 +1504,7 @@ class WorkflowDraftVariable(Base):
|
||||
node_execution_id: str | None,
|
||||
description: str = "",
|
||||
file_id: str | None = None,
|
||||
) -> WorkflowDraftVariable:
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = WorkflowDraftVariable()
|
||||
variable.id = str(uuid4())
|
||||
variable.created_at = naive_utc_now()
|
||||
@@ -1529,7 +1527,7 @@ class WorkflowDraftVariable(Base):
|
||||
name: str,
|
||||
value: Segment,
|
||||
description: str = "",
|
||||
) -> WorkflowDraftVariable:
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(
|
||||
app_id=app_id,
|
||||
node_id=CONVERSATION_VARIABLE_NODE_ID,
|
||||
@@ -1550,7 +1548,7 @@ class WorkflowDraftVariable(Base):
|
||||
value: Segment,
|
||||
node_execution_id: str,
|
||||
editable: bool = False,
|
||||
) -> WorkflowDraftVariable:
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(
|
||||
app_id=app_id,
|
||||
node_id=SYSTEM_VARIABLE_NODE_ID,
|
||||
@@ -1573,7 +1571,7 @@ class WorkflowDraftVariable(Base):
|
||||
visible: bool = True,
|
||||
editable: bool = True,
|
||||
file_id: str | None = None,
|
||||
) -> WorkflowDraftVariable:
|
||||
) -> "WorkflowDraftVariable":
|
||||
variable = cls._new(
|
||||
app_id=app_id,
|
||||
node_id=node_id,
|
||||
@@ -1669,7 +1667,7 @@ class WorkflowDraftVariableFile(Base):
|
||||
)
|
||||
|
||||
# Relationship to UploadFile
|
||||
upload_file: Mapped[UploadFile] = orm.relationship(
|
||||
upload_file: Mapped["UploadFile"] = orm.relationship(
|
||||
foreign_keys=[upload_file_id],
|
||||
lazy="raise",
|
||||
uselist=False,
|
||||
@@ -1736,7 +1734,7 @@ class WorkflowPause(DefaultFieldsMixin, Base):
|
||||
state_object_key: Mapped[str] = mapped_column(String(length=255), nullable=False)
|
||||
|
||||
# Relationship to WorkflowRun
|
||||
workflow_run: Mapped[WorkflowRun] = orm.relationship(
|
||||
workflow_run: Mapped["WorkflowRun"] = orm.relationship(
|
||||
foreign_keys=[workflow_run_id],
|
||||
# require explicit preloading.
|
||||
lazy="raise",
|
||||
@@ -1792,7 +1790,7 @@ class WorkflowPauseReason(DefaultFieldsMixin, Base):
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_entity(cls, pause_reason: PauseReason) -> WorkflowPauseReason:
|
||||
def from_entity(cls, pause_reason: PauseReason) -> "WorkflowPauseReason":
|
||||
if isinstance(pause_reason, HumanInputRequired):
|
||||
return cls(
|
||||
type_=PauseReasonType.HUMAN_INPUT_REQUIRED, form_id=pause_reason.form_id, node_id=pause_reason.node_id
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.variables.variables import Variable
|
||||
from core.variables.variables import VariableBase
|
||||
from models import ConversationVariable
|
||||
|
||||
|
||||
@@ -13,7 +13,7 @@ class ConversationVariableUpdater:
|
||||
def __init__(self, session_maker: sessionmaker[Session]) -> None:
|
||||
self._session_maker: sessionmaker[Session] = session_maker
|
||||
|
||||
def update(self, conversation_id: str, variable: Variable) -> None:
|
||||
def update(self, conversation_id: str, variable: VariableBase) -> None:
|
||||
stmt = select(ConversationVariable).where(
|
||||
ConversationVariable.id == variable.id, ConversationVariable.conversation_id == conversation_id
|
||||
)
|
||||
|
||||
@@ -36,7 +36,7 @@ from core.rag.entities.event import (
|
||||
)
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.variables.variables import Variable
|
||||
from core.variables.variables import VariableBase
|
||||
from core.workflow.entities.workflow_node_execution import (
|
||||
WorkflowNodeExecution,
|
||||
WorkflowNodeExecutionStatus,
|
||||
@@ -270,8 +270,8 @@ class RagPipelineService:
|
||||
graph: dict,
|
||||
unique_hash: str | None,
|
||||
account: Account,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
environment_variables: Sequence[VariableBase],
|
||||
conversation_variables: Sequence[VariableBase],
|
||||
rag_pipeline_variables: list,
|
||||
) -> Workflow:
|
||||
"""
|
||||
|
||||
@@ -15,7 +15,7 @@ from sqlalchemy.sql.expression import and_, or_
|
||||
from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from core.variables import Segment, StringSegment, Variable
|
||||
from core.variables import Segment, StringSegment, VariableBase
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.variables.segments import (
|
||||
ArrayFileSegment,
|
||||
@@ -77,14 +77,14 @@ class DraftVarLoader(VariableLoader):
|
||||
# Application ID for which variables are being loaded.
|
||||
_app_id: str
|
||||
_tenant_id: str
|
||||
_fallback_variables: Sequence[Variable]
|
||||
_fallback_variables: Sequence[VariableBase]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
engine: Engine,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
fallback_variables: Sequence[Variable] | None = None,
|
||||
fallback_variables: Sequence[VariableBase] | None = None,
|
||||
):
|
||||
self._engine = engine
|
||||
self._app_id = app_id
|
||||
@@ -94,12 +94,12 @@ class DraftVarLoader(VariableLoader):
|
||||
def _selector_to_tuple(self, selector: Sequence[str]) -> tuple[str, str]:
|
||||
return (selector[0], selector[1])
|
||||
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[Variable]:
|
||||
def load_variables(self, selectors: list[list[str]]) -> list[VariableBase]:
|
||||
if not selectors:
|
||||
return []
|
||||
|
||||
# Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding Variable instance.
|
||||
variable_by_selector: dict[tuple[str, str], Variable] = {}
|
||||
# Map each selector (as a tuple via `_selector_to_tuple`) to its corresponding variable instance.
|
||||
variable_by_selector: dict[tuple[str, str], VariableBase] = {}
|
||||
|
||||
with Session(bind=self._engine, expire_on_commit=False) as session:
|
||||
srv = WorkflowDraftVariableService(session)
|
||||
@@ -145,7 +145,7 @@ class DraftVarLoader(VariableLoader):
|
||||
|
||||
return list(variable_by_selector.values())
|
||||
|
||||
def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], Variable]:
|
||||
def _load_offloaded_variable(self, draft_var: WorkflowDraftVariable) -> tuple[tuple[str, str], VariableBase]:
|
||||
# This logic is closely tied to `WorkflowDraftVaribleService._try_offload_large_variable`
|
||||
# and must remain synchronized with it.
|
||||
# Ideally, these should be co-located for better maintainability.
|
||||
|
||||
@@ -13,8 +13,8 @@ from core.app.apps.advanced_chat.app_config_manager import AdvancedChatAppConfig
|
||||
from core.app.apps.workflow.app_config_manager import WorkflowAppConfigManager
|
||||
from core.file import File
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.variables import Variable
|
||||
from core.variables.variables import VariableUnion
|
||||
from core.variables import VariableBase
|
||||
from core.variables.variables import Variable
|
||||
from core.workflow.entities import WorkflowNodeExecution
|
||||
from core.workflow.enums import ErrorStrategy, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.errors import WorkflowNodeRunFailedError
|
||||
@@ -198,8 +198,8 @@ class WorkflowService:
|
||||
features: dict,
|
||||
unique_hash: str | None,
|
||||
account: Account,
|
||||
environment_variables: Sequence[Variable],
|
||||
conversation_variables: Sequence[Variable],
|
||||
environment_variables: Sequence[VariableBase],
|
||||
conversation_variables: Sequence[VariableBase],
|
||||
) -> Workflow:
|
||||
"""
|
||||
Sync draft workflow
|
||||
@@ -1044,7 +1044,7 @@ def _setup_variable_pool(
|
||||
workflow: Workflow,
|
||||
node_type: NodeType,
|
||||
conversation_id: str,
|
||||
conversation_variables: list[Variable],
|
||||
conversation_variables: list[VariableBase],
|
||||
):
|
||||
# Only inject system variables for START node type.
|
||||
if node_type == NodeType.START or node_type.is_trigger_node:
|
||||
@@ -1070,9 +1070,9 @@ def _setup_variable_pool(
|
||||
system_variables=system_variable,
|
||||
user_inputs=user_inputs,
|
||||
environment_variables=workflow.environment_variables,
|
||||
# Based on the definition of `VariableUnion`,
|
||||
# `list[Variable]` can be safely used as `list[VariableUnion]` since they are compatible.
|
||||
conversation_variables=cast(list[VariableUnion], conversation_variables), #
|
||||
# Based on the definition of `Variable`,
|
||||
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
|
||||
conversation_variables=cast(list[Variable], conversation_variables), #
|
||||
)
|
||||
|
||||
return variable_pool
|
||||
|
||||
@@ -35,7 +35,6 @@ from core.variables.variables import (
|
||||
SecretVariable,
|
||||
StringVariable,
|
||||
Variable,
|
||||
VariableUnion,
|
||||
)
|
||||
from core.workflow.runtime import VariablePool
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
@@ -96,7 +95,7 @@ class _Segments(BaseModel):
|
||||
|
||||
|
||||
class _Variables(BaseModel):
|
||||
variables: list[VariableUnion]
|
||||
variables: list[Variable]
|
||||
|
||||
|
||||
def create_test_file(
|
||||
@@ -194,7 +193,7 @@ class TestSegmentDumpAndLoad:
|
||||
# Create one instance of each variable type
|
||||
test_file = create_test_file()
|
||||
|
||||
all_variables: list[VariableUnion] = [
|
||||
all_variables: list[Variable] = [
|
||||
NoneVariable(name="none_var"),
|
||||
StringVariable(value="test string", name="string_var"),
|
||||
IntegerVariable(value=42, name="int_var"),
|
||||
|
||||
@@ -11,7 +11,7 @@ from core.variables import (
|
||||
SegmentType,
|
||||
StringVariable,
|
||||
)
|
||||
from core.variables.variables import Variable
|
||||
from core.variables.variables import VariableBase
|
||||
|
||||
|
||||
def test_frozen_variables():
|
||||
@@ -76,7 +76,7 @@ def test_object_variable_to_object():
|
||||
|
||||
|
||||
def test_variable_to_object():
|
||||
var: Variable = StringVariable(name="text", value="text")
|
||||
var: VariableBase = StringVariable(name="text", value="text")
|
||||
assert var.to_object() == "text"
|
||||
var = IntegerVariable(name="integer", value=42)
|
||||
assert var.to_object() == 42
|
||||
|
||||
@@ -24,7 +24,7 @@ from core.variables.variables import (
|
||||
IntegerVariable,
|
||||
ObjectVariable,
|
||||
StringVariable,
|
||||
VariableUnion,
|
||||
Variable,
|
||||
)
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, ENVIRONMENT_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.runtime import VariablePool
|
||||
@@ -160,7 +160,7 @@ class TestVariablePoolSerialization:
|
||||
)
|
||||
|
||||
# Create environment variables with all types including ArrayFileVariable
|
||||
env_vars: list[VariableUnion] = [
|
||||
env_vars: list[Variable] = [
|
||||
StringVariable(
|
||||
id="env_string_id",
|
||||
name="env_string",
|
||||
@@ -182,7 +182,7 @@ class TestVariablePoolSerialization:
|
||||
]
|
||||
|
||||
# Create conversation variables with complex data
|
||||
conv_vars: list[VariableUnion] = [
|
||||
conv_vars: list[Variable] = [
|
||||
StringVariable(
|
||||
id="conv_string_id",
|
||||
name="conv_string",
|
||||
|
||||
Reference in New Issue
Block a user