mirror of
https://github.com/langgenius/dify.git
synced 2026-02-24 18:05:11 +00:00
feat: Human Input Node (#32060)
The frontend and backend implementation for the human input node. Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
@@ -10,6 +10,7 @@ tenant_id, app_id, triggered_from, etc., which are not part of the core domain m
|
||||
"""
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Protocol
|
||||
|
||||
@@ -19,6 +20,27 @@ from core.workflow.repositories.workflow_node_execution_repository import Workfl
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WorkflowNodeExecutionSnapshot:
|
||||
"""
|
||||
Minimal snapshot of workflow node execution for stream recovery.
|
||||
|
||||
Only includes fields required by snapshot events.
|
||||
"""
|
||||
|
||||
execution_id: str # Unique execution identifier (node_execution_id or row id).
|
||||
node_id: str # Workflow graph node id.
|
||||
node_type: str # Workflow graph node type (e.g. "human-input").
|
||||
title: str # Human-friendly node title.
|
||||
index: int # Execution order index within the workflow run.
|
||||
status: str # Execution status (running/succeeded/failed/paused).
|
||||
elapsed_time: float # Execution elapsed time in seconds.
|
||||
created_at: datetime # Execution created timestamp.
|
||||
finished_at: datetime | None # Execution finished timestamp.
|
||||
iteration_id: str | None = None # Iteration id from execution metadata, if any.
|
||||
loop_id: str | None = None # Loop id from execution metadata, if any.
|
||||
|
||||
|
||||
class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Protocol):
|
||||
"""
|
||||
Protocol for service-layer operations on WorkflowNodeExecutionModel.
|
||||
@@ -79,6 +101,8 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
workflow_id: The workflow identifier
|
||||
triggered_from: The workflow trigger source
|
||||
workflow_run_id: The workflow run identifier
|
||||
|
||||
Returns:
|
||||
@@ -86,6 +110,27 @@ class DifyAPIWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository, Pr
|
||||
"""
|
||||
...
|
||||
|
||||
def get_execution_snapshots_by_workflow_run(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
triggered_from: str,
|
||||
workflow_run_id: str,
|
||||
) -> Sequence[WorkflowNodeExecutionSnapshot]:
|
||||
"""
|
||||
Get minimal snapshots for node executions in a workflow run.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant identifier
|
||||
app_id: The application identifier
|
||||
workflow_run_id: The workflow run identifier
|
||||
|
||||
Returns:
|
||||
A sequence of WorkflowNodeExecutionSnapshot ordered by creation time
|
||||
"""
|
||||
...
|
||||
|
||||
def get_execution_by_id(
|
||||
self,
|
||||
execution_id: str,
|
||||
|
||||
@@ -432,6 +432,13 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
||||
# while creating pause.
|
||||
...
|
||||
|
||||
def get_workflow_pause(self, workflow_run_id: str) -> WorkflowPauseEntity | None:
|
||||
"""Retrieve the current pause for a workflow execution.
|
||||
|
||||
If there is no current pause, this method would return `None`.
|
||||
"""
|
||||
...
|
||||
|
||||
def resume_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
@@ -627,3 +634,19 @@ class APIWorkflowRunRepository(WorkflowExecutionRepository, Protocol):
|
||||
[{"date": "2024-01-01", "interactions": 2.5}, ...]
|
||||
"""
|
||||
...
|
||||
|
||||
def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None:
|
||||
"""
|
||||
Get a specific workflow run by its id and the associated tenant id.
|
||||
|
||||
This function does not apply application isolation. It should only be used when
|
||||
the application identifier is not available.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant identifier for multi-tenant isolation
|
||||
run_id: Workflow run identifier
|
||||
|
||||
Returns:
|
||||
WorkflowRun object if found, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -63,6 +63,12 @@ class WorkflowPauseEntity(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def paused_at(self) -> datetime:
|
||||
"""`paused_at` returns the creation time of the pause."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_pause_reasons(self) -> Sequence[PauseReason]:
|
||||
"""
|
||||
@@ -70,7 +76,5 @@ class WorkflowPauseEntity(ABC):
|
||||
|
||||
Returns a sequence of `PauseReason` objects describing the specific nodes and
|
||||
reasons for which the workflow execution was paused.
|
||||
This information is related to, but distinct from, the `PauseReason` type
|
||||
defined in `api/core/workflow/entities/pause_reason.py`.
|
||||
"""
|
||||
...
|
||||
|
||||
13
api/repositories/execution_extra_content_repository.py
Normal file
13
api/repositories/execution_extra_content_repository.py
Normal file
@@ -0,0 +1,13 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol
|
||||
|
||||
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
|
||||
|
||||
|
||||
class ExecutionExtraContentRepository(Protocol):
|
||||
def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]: ...
|
||||
|
||||
|
||||
__all__ = ["ExecutionExtraContentRepository"]
|
||||
@@ -5,6 +5,7 @@ This module provides a concrete implementation of the service repository protoco
|
||||
using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
@@ -13,11 +14,12 @@ from sqlalchemy import asc, delete, desc, func, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from models.workflow import (
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowNodeExecutionOffload,
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionModel, WorkflowNodeExecutionOffload
|
||||
from repositories.api_workflow_node_execution_repository import (
|
||||
DifyAPIWorkflowNodeExecutionRepository,
|
||||
WorkflowNodeExecutionSnapshot,
|
||||
)
|
||||
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
|
||||
@@ -79,6 +81,7 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
WorkflowNodeExecutionModel.app_id == app_id,
|
||||
WorkflowNodeExecutionModel.workflow_id == workflow_id,
|
||||
WorkflowNodeExecutionModel.node_id == node_id,
|
||||
WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED,
|
||||
)
|
||||
.order_by(desc(WorkflowNodeExecutionModel.created_at))
|
||||
.limit(1)
|
||||
@@ -117,6 +120,80 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
with self._session_maker() as session:
|
||||
return session.execute(stmt).scalars().all()
|
||||
|
||||
def get_execution_snapshots_by_workflow_run(
|
||||
self,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
workflow_id: str,
|
||||
triggered_from: str,
|
||||
workflow_run_id: str,
|
||||
) -> Sequence[WorkflowNodeExecutionSnapshot]:
|
||||
stmt = (
|
||||
select(
|
||||
WorkflowNodeExecutionModel.id,
|
||||
WorkflowNodeExecutionModel.node_execution_id,
|
||||
WorkflowNodeExecutionModel.node_id,
|
||||
WorkflowNodeExecutionModel.node_type,
|
||||
WorkflowNodeExecutionModel.title,
|
||||
WorkflowNodeExecutionModel.index,
|
||||
WorkflowNodeExecutionModel.status,
|
||||
WorkflowNodeExecutionModel.elapsed_time,
|
||||
WorkflowNodeExecutionModel.created_at,
|
||||
WorkflowNodeExecutionModel.finished_at,
|
||||
WorkflowNodeExecutionModel.execution_metadata,
|
||||
)
|
||||
.where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.app_id == app_id,
|
||||
WorkflowNodeExecutionModel.workflow_id == workflow_id,
|
||||
WorkflowNodeExecutionModel.triggered_from == triggered_from,
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
)
|
||||
.order_by(
|
||||
asc(WorkflowNodeExecutionModel.created_at),
|
||||
asc(WorkflowNodeExecutionModel.index),
|
||||
)
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
|
||||
return [self._row_to_snapshot(row) for row in rows]
|
||||
|
||||
@staticmethod
|
||||
def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot:
|
||||
metadata: dict[str, object] = {}
|
||||
execution_metadata = getattr(row, "execution_metadata", None)
|
||||
if execution_metadata:
|
||||
try:
|
||||
metadata = json.loads(execution_metadata)
|
||||
except json.JSONDecodeError:
|
||||
metadata = {}
|
||||
iteration_id = metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID.value)
|
||||
loop_id = metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID.value)
|
||||
execution_id = getattr(row, "node_execution_id", None) or row.id
|
||||
elapsed_time = getattr(row, "elapsed_time", None)
|
||||
created_at = row.created_at
|
||||
finished_at = getattr(row, "finished_at", None)
|
||||
if elapsed_time is None:
|
||||
if finished_at is not None and created_at is not None:
|
||||
elapsed_time = (finished_at - created_at).total_seconds()
|
||||
else:
|
||||
elapsed_time = 0.0
|
||||
return WorkflowNodeExecutionSnapshot(
|
||||
execution_id=str(execution_id),
|
||||
node_id=row.node_id,
|
||||
node_type=row.node_type,
|
||||
title=row.title,
|
||||
index=row.index,
|
||||
status=row.status,
|
||||
elapsed_time=float(elapsed_time),
|
||||
created_at=created_at,
|
||||
finished_at=finished_at,
|
||||
iteration_id=str(iteration_id) if iteration_id else None,
|
||||
loop_id=str(loop_id) if loop_id else None,
|
||||
)
|
||||
|
||||
def get_execution_by_id(
|
||||
self,
|
||||
execution_id: str,
|
||||
|
||||
@@ -19,6 +19,7 @@ Implementation Notes:
|
||||
- Maintains data consistency with proper transaction handling
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import uuid
|
||||
from collections.abc import Callable, Sequence
|
||||
@@ -27,12 +28,14 @@ from decimal import Decimal
|
||||
from typing import Any, cast
|
||||
|
||||
import sqlalchemy as sa
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import and_, delete, func, null, or_, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, SchedulingPause
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired, PauseReason, PauseReasonType, SchedulingPause
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowType
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition
|
||||
from extensions.ext_storage import storage
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import convert_datetime_to_date
|
||||
@@ -40,6 +43,7 @@ from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.time_parser import get_time_threshold
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.human_input import HumanInputForm, HumanInputFormRecipient, RecipientType
|
||||
from models.workflow import WorkflowAppLog, WorkflowArchiveLog, WorkflowPause, WorkflowPauseReason, WorkflowRun
|
||||
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
@@ -57,6 +61,67 @@ class _WorkflowRunError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
def _select_recipient_token(
|
||||
recipients: Sequence[HumanInputFormRecipient],
|
||||
recipient_type: RecipientType,
|
||||
) -> str | None:
|
||||
for recipient in recipients:
|
||||
if recipient.recipient_type == recipient_type and recipient.access_token:
|
||||
return recipient.access_token
|
||||
return None
|
||||
|
||||
|
||||
def _build_human_input_required_reason(
|
||||
reason_model: WorkflowPauseReason,
|
||||
form_model: HumanInputForm | None,
|
||||
recipients: Sequence[HumanInputFormRecipient],
|
||||
) -> HumanInputRequired:
|
||||
form_content = ""
|
||||
inputs = []
|
||||
actions = []
|
||||
display_in_ui = False
|
||||
resolved_default_values: dict[str, Any] = {}
|
||||
node_title = "Human Input"
|
||||
form_id = reason_model.form_id
|
||||
node_id = reason_model.node_id
|
||||
if form_model is not None:
|
||||
form_id = form_model.id
|
||||
node_id = form_model.node_id or node_id
|
||||
try:
|
||||
definition_payload = json.loads(form_model.form_definition)
|
||||
if "expiration_time" not in definition_payload:
|
||||
definition_payload["expiration_time"] = form_model.expiration_time
|
||||
definition = FormDefinition.model_validate(definition_payload)
|
||||
except ValidationError:
|
||||
definition = None
|
||||
|
||||
if definition is not None:
|
||||
form_content = definition.form_content
|
||||
inputs = list(definition.inputs)
|
||||
actions = list(definition.user_actions)
|
||||
display_in_ui = bool(definition.display_in_ui)
|
||||
resolved_default_values = dict(definition.default_values)
|
||||
node_title = definition.node_title or node_title
|
||||
|
||||
form_token = (
|
||||
_select_recipient_token(recipients, RecipientType.BACKSTAGE)
|
||||
or _select_recipient_token(recipients, RecipientType.CONSOLE)
|
||||
or _select_recipient_token(recipients, RecipientType.STANDALONE_WEB_APP)
|
||||
)
|
||||
|
||||
return HumanInputRequired(
|
||||
form_id=form_id,
|
||||
form_content=form_content,
|
||||
inputs=inputs,
|
||||
actions=actions,
|
||||
display_in_ui=display_in_ui,
|
||||
node_id=node_id,
|
||||
node_title=node_title,
|
||||
form_token=form_token,
|
||||
resolved_default_values=resolved_default_values,
|
||||
)
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
"""
|
||||
SQLAlchemy implementation of APIWorkflowRunRepository.
|
||||
@@ -676,9 +741,11 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
raise ValueError(f"WorkflowRun not found: {workflow_run_id}")
|
||||
|
||||
# Check if workflow is in RUNNING status
|
||||
if workflow_run.status != WorkflowExecutionStatus.RUNNING:
|
||||
# TODO(QuantumGhost): It seems that the persistence of `WorkflowRun.status`
|
||||
# happens before the execution of GraphLayer
|
||||
if workflow_run.status not in {WorkflowExecutionStatus.RUNNING, WorkflowExecutionStatus.PAUSED}:
|
||||
raise _WorkflowRunError(
|
||||
f"Only WorkflowRun with RUNNING status can be paused, "
|
||||
f"Only WorkflowRun with RUNNING or PAUSED status can be paused, "
|
||||
f"workflow_run_id={workflow_run_id}, current_status={workflow_run.status}"
|
||||
)
|
||||
#
|
||||
@@ -729,13 +796,48 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
|
||||
logger.info("Created workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
||||
|
||||
return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reason_models)
|
||||
return _PrivateWorkflowPauseEntity(
|
||||
pause_model=pause_model,
|
||||
reason_models=pause_reason_models,
|
||||
pause_reasons=pause_reasons,
|
||||
)
|
||||
|
||||
def _get_reasons_by_pause_id(self, session: Session, pause_id: str):
|
||||
reason_stmt = select(WorkflowPauseReason).where(WorkflowPauseReason.pause_id == pause_id)
|
||||
pause_reason_models = session.scalars(reason_stmt).all()
|
||||
return pause_reason_models
|
||||
|
||||
def _hydrate_pause_reasons(
|
||||
self,
|
||||
session: Session,
|
||||
pause_reason_models: Sequence[WorkflowPauseReason],
|
||||
) -> list[PauseReason]:
|
||||
form_ids = [
|
||||
reason.form_id
|
||||
for reason in pause_reason_models
|
||||
if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED and reason.form_id
|
||||
]
|
||||
form_models: dict[str, HumanInputForm] = {}
|
||||
recipient_models_by_form: dict[str, list[HumanInputFormRecipient]] = {}
|
||||
if form_ids:
|
||||
form_stmt = select(HumanInputForm).where(HumanInputForm.id.in_(form_ids))
|
||||
for form in session.scalars(form_stmt).all():
|
||||
form_models[form.id] = form
|
||||
|
||||
recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
||||
for recipient in session.scalars(recipient_stmt).all():
|
||||
recipient_models_by_form.setdefault(recipient.form_id, []).append(recipient)
|
||||
|
||||
pause_reasons: list[PauseReason] = []
|
||||
for reason in pause_reason_models:
|
||||
if reason.type_ == PauseReasonType.HUMAN_INPUT_REQUIRED:
|
||||
form_model = form_models.get(reason.form_id)
|
||||
recipients = recipient_models_by_form.get(reason.form_id, [])
|
||||
pause_reasons.append(_build_human_input_required_reason(reason, form_model, recipients))
|
||||
else:
|
||||
pause_reasons.append(reason.to_entity())
|
||||
return pause_reasons
|
||||
|
||||
def get_workflow_pause(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
@@ -767,14 +869,12 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
if pause_model is None:
|
||||
return None
|
||||
pause_reason_models = self._get_reasons_by_pause_id(session, pause_model.id)
|
||||
|
||||
human_input_form: list[Any] = []
|
||||
# TODO(QuantumGhost): query human_input_forms model and rebuild PauseReason
|
||||
pause_reasons = self._hydrate_pause_reasons(session, pause_reason_models)
|
||||
|
||||
return _PrivateWorkflowPauseEntity(
|
||||
pause_model=pause_model,
|
||||
reason_models=pause_reason_models,
|
||||
human_input_form=human_input_form,
|
||||
pause_reasons=pause_reasons,
|
||||
)
|
||||
|
||||
def resume_workflow_pause(
|
||||
@@ -828,10 +928,10 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
raise _WorkflowRunError(f"Cannot resume an already resumed pause, pause_id={pause_model.id}")
|
||||
|
||||
pause_reasons = self._get_reasons_by_pause_id(session, pause_model.id)
|
||||
hydrated_pause_reasons = self._hydrate_pause_reasons(session, pause_reasons)
|
||||
|
||||
# Mark as resumed
|
||||
pause_model.resumed_at = naive_utc_now()
|
||||
workflow_run.pause_id = None # type: ignore
|
||||
workflow_run.status = WorkflowExecutionStatus.RUNNING
|
||||
|
||||
session.add(pause_model)
|
||||
@@ -839,7 +939,11 @@ class DifyAPISQLAlchemyWorkflowRunRepository(APIWorkflowRunRepository):
|
||||
|
||||
logger.info("Resumed workflow pause %s for workflow run %s", pause_model.id, workflow_run_id)
|
||||
|
||||
return _PrivateWorkflowPauseEntity(pause_model=pause_model, reason_models=pause_reasons)
|
||||
return _PrivateWorkflowPauseEntity(
|
||||
pause_model=pause_model,
|
||||
reason_models=pause_reasons,
|
||||
pause_reasons=hydrated_pause_reasons,
|
||||
)
|
||||
|
||||
def delete_workflow_pause(
|
||||
self,
|
||||
@@ -1165,6 +1269,15 @@ GROUP BY
|
||||
|
||||
return cast(list[AverageInteractionStats], response_data)
|
||||
|
||||
def get_workflow_run_by_id_and_tenant_id(self, tenant_id: str, run_id: str) -> WorkflowRun | None:
|
||||
"""Get a specific workflow run by its id and the associated tenant id."""
|
||||
with self._session_maker() as session:
|
||||
stmt = select(WorkflowRun).where(
|
||||
WorkflowRun.tenant_id == tenant_id,
|
||||
WorkflowRun.id == run_id,
|
||||
)
|
||||
return session.scalar(stmt)
|
||||
|
||||
|
||||
class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
||||
"""
|
||||
@@ -1179,10 +1292,12 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
||||
*,
|
||||
pause_model: WorkflowPause,
|
||||
reason_models: Sequence[WorkflowPauseReason],
|
||||
pause_reasons: Sequence[PauseReason] | None = None,
|
||||
human_input_form: Sequence = (),
|
||||
) -> None:
|
||||
self._pause_model = pause_model
|
||||
self._reason_models = reason_models
|
||||
self._pause_reasons = pause_reasons
|
||||
self._cached_state: bytes | None = None
|
||||
self._human_input_form = human_input_form
|
||||
|
||||
@@ -1219,4 +1334,10 @@ class _PrivateWorkflowPauseEntity(WorkflowPauseEntity):
|
||||
return self._pause_model.resumed_at
|
||||
|
||||
def get_pause_reasons(self) -> Sequence[PauseReason]:
|
||||
if self._pause_reasons is not None:
|
||||
return list(self._pause_reasons)
|
||||
return [reason.to_entity() for reason in self._reason_models]
|
||||
|
||||
@property
|
||||
def paused_at(self) -> datetime:
|
||||
return self._pause_model.created_at
|
||||
|
||||
@@ -0,0 +1,200 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Sequence
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.entities.execution_extra_content import (
|
||||
ExecutionExtraContentDomainModel,
|
||||
HumanInputFormDefinition,
|
||||
HumanInputFormSubmissionData,
|
||||
)
|
||||
from core.entities.execution_extra_content import (
|
||||
HumanInputContent as HumanInputContentDomainModel,
|
||||
)
|
||||
from core.workflow.nodes.human_input.entities import FormDefinition
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormStatus
|
||||
from core.workflow.nodes.human_input.human_input_node import HumanInputNode
|
||||
from models.execution_extra_content import (
|
||||
ExecutionExtraContent as ExecutionExtraContentModel,
|
||||
)
|
||||
from models.execution_extra_content import (
|
||||
HumanInputContent as HumanInputContentModel,
|
||||
)
|
||||
from models.human_input import HumanInputFormRecipient, RecipientType
|
||||
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P<field_name>[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}")
|
||||
|
||||
|
||||
def _extract_output_field_names(form_content: str) -> list[str]:
|
||||
if not form_content:
|
||||
return []
|
||||
return [match.group("field_name") for match in _OUTPUT_VARIABLE_PATTERN.finditer(form_content)]
|
||||
|
||||
|
||||
class SQLAlchemyExecutionExtraContentRepository(ExecutionExtraContentRepository):
|
||||
def __init__(self, session_maker: sessionmaker[Session]):
|
||||
self._session_maker = session_maker
|
||||
|
||||
def get_by_message_ids(self, message_ids: Sequence[str]) -> list[list[ExecutionExtraContentDomainModel]]:
|
||||
if not message_ids:
|
||||
return []
|
||||
|
||||
grouped_contents: dict[str, list[ExecutionExtraContentDomainModel]] = {
|
||||
message_id: [] for message_id in message_ids
|
||||
}
|
||||
|
||||
stmt = (
|
||||
select(ExecutionExtraContentModel)
|
||||
.where(ExecutionExtraContentModel.message_id.in_(message_ids))
|
||||
.options(selectinload(HumanInputContentModel.form))
|
||||
.order_by(ExecutionExtraContentModel.created_at.asc())
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
results = session.scalars(stmt).all()
|
||||
|
||||
form_ids = {
|
||||
content.form_id
|
||||
for content in results
|
||||
if isinstance(content, HumanInputContentModel) and content.form_id is not None
|
||||
}
|
||||
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]] = defaultdict(list)
|
||||
if form_ids:
|
||||
recipient_stmt = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id.in_(form_ids))
|
||||
recipients = session.scalars(recipient_stmt).all()
|
||||
for recipient in recipients:
|
||||
recipients_by_form_id[recipient.form_id].append(recipient)
|
||||
else:
|
||||
recipients_by_form_id = {}
|
||||
|
||||
for content in results:
|
||||
message_id = content.message_id
|
||||
if not message_id or message_id not in grouped_contents:
|
||||
continue
|
||||
|
||||
domain_model = self._map_model_to_domain(content, recipients_by_form_id)
|
||||
if domain_model is None:
|
||||
continue
|
||||
|
||||
grouped_contents[message_id].append(domain_model)
|
||||
|
||||
return [grouped_contents[message_id] for message_id in message_ids]
|
||||
|
||||
def _map_model_to_domain(
|
||||
self,
|
||||
model: ExecutionExtraContentModel,
|
||||
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]],
|
||||
) -> ExecutionExtraContentDomainModel | None:
|
||||
if isinstance(model, HumanInputContentModel):
|
||||
return self._map_human_input_content(model, recipients_by_form_id)
|
||||
|
||||
logger.debug("Unsupported execution extra content type encountered: %s", model.type)
|
||||
return None
|
||||
|
||||
def _map_human_input_content(
|
||||
self,
|
||||
model: HumanInputContentModel,
|
||||
recipients_by_form_id: dict[str, list[HumanInputFormRecipient]],
|
||||
) -> HumanInputContentDomainModel | None:
|
||||
form = model.form
|
||||
if form is None:
|
||||
logger.warning("HumanInputContent(id=%s) has no associated form loaded", model.id)
|
||||
return None
|
||||
|
||||
try:
|
||||
definition_payload = json.loads(form.form_definition)
|
||||
if "expiration_time" not in definition_payload:
|
||||
definition_payload["expiration_time"] = form.expiration_time
|
||||
form_definition = FormDefinition.model_validate(definition_payload)
|
||||
except ValueError:
|
||||
logger.warning("Failed to load form definition for HumanInputContent(id=%s)", model.id)
|
||||
return None
|
||||
node_title = form_definition.node_title or form.node_id
|
||||
display_in_ui = bool(form_definition.display_in_ui)
|
||||
|
||||
submitted = form.submitted_at is not None or form.status == HumanInputFormStatus.SUBMITTED
|
||||
if not submitted:
|
||||
form_token = self._resolve_form_token(recipients_by_form_id.get(form.id, []))
|
||||
return HumanInputContentDomainModel(
|
||||
workflow_run_id=model.workflow_run_id,
|
||||
submitted=False,
|
||||
form_definition=HumanInputFormDefinition(
|
||||
form_id=form.id,
|
||||
node_id=form.node_id,
|
||||
node_title=node_title,
|
||||
form_content=form.rendered_content,
|
||||
inputs=form_definition.inputs,
|
||||
actions=form_definition.user_actions,
|
||||
display_in_ui=display_in_ui,
|
||||
form_token=form_token,
|
||||
resolved_default_values=form_definition.default_values,
|
||||
expiration_time=int(form.expiration_time.timestamp()),
|
||||
),
|
||||
)
|
||||
|
||||
selected_action_id = form.selected_action_id
|
||||
if not selected_action_id:
|
||||
logger.warning("HumanInputContent(id=%s) form has no selected action", model.id)
|
||||
return None
|
||||
|
||||
action_text = next(
|
||||
(action.title for action in form_definition.user_actions if action.id == selected_action_id),
|
||||
selected_action_id,
|
||||
)
|
||||
|
||||
submitted_data: dict[str, Any] = {}
|
||||
if form.submitted_data:
|
||||
try:
|
||||
submitted_data = json.loads(form.submitted_data)
|
||||
except ValueError:
|
||||
logger.warning("Failed to load submitted data for HumanInputContent(id=%s)", model.id)
|
||||
return None
|
||||
|
||||
rendered_content = HumanInputNode.render_form_content_with_outputs(
|
||||
form.rendered_content,
|
||||
submitted_data,
|
||||
_extract_output_field_names(form_definition.form_content),
|
||||
)
|
||||
|
||||
return HumanInputContentDomainModel(
|
||||
workflow_run_id=model.workflow_run_id,
|
||||
submitted=True,
|
||||
form_submission_data=HumanInputFormSubmissionData(
|
||||
node_id=form.node_id,
|
||||
node_title=node_title,
|
||||
rendered_content=rendered_content,
|
||||
action_id=selected_action_id,
|
||||
action_text=action_text,
|
||||
),
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _resolve_form_token(recipients: Sequence[HumanInputFormRecipient]) -> str | None:
|
||||
console_recipient = next(
|
||||
(recipient for recipient in recipients if recipient.recipient_type == RecipientType.CONSOLE),
|
||||
None,
|
||||
)
|
||||
if console_recipient and console_recipient.access_token:
|
||||
return console_recipient.access_token
|
||||
|
||||
web_app_recipient = next(
|
||||
(recipient for recipient in recipients if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP),
|
||||
None,
|
||||
)
|
||||
if web_app_recipient and web_app_recipient.access_token:
|
||||
return web_app_recipient.access_token
|
||||
|
||||
return None
|
||||
|
||||
|
||||
__all__ = ["SQLAlchemyExecutionExtraContentRepository"]
|
||||
@@ -92,6 +92,16 @@ class SQLAlchemyWorkflowTriggerLogRepository(WorkflowTriggerLogRepository):
|
||||
|
||||
return list(self.session.scalars(query).all())
|
||||
|
||||
def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None:
|
||||
"""Get the trigger log associated with a workflow run."""
|
||||
query = (
|
||||
select(WorkflowTriggerLog)
|
||||
.where(WorkflowTriggerLog.workflow_run_id == workflow_run_id)
|
||||
.order_by(WorkflowTriggerLog.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
return self.session.scalar(query)
|
||||
|
||||
def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
|
||||
"""
|
||||
Delete trigger logs associated with the given workflow run ids.
|
||||
|
||||
@@ -110,6 +110,18 @@ class WorkflowTriggerLogRepository(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def get_by_workflow_run_id(self, workflow_run_id: str) -> WorkflowTriggerLog | None:
|
||||
"""
|
||||
Retrieve a trigger log associated with a specific workflow run.
|
||||
|
||||
Args:
|
||||
workflow_run_id: Identifier of the workflow run
|
||||
|
||||
Returns:
|
||||
The matching WorkflowTriggerLog if present, None otherwise
|
||||
"""
|
||||
...
|
||||
|
||||
def delete_by_run_ids(self, run_ids: Sequence[str]) -> int:
|
||||
"""
|
||||
Delete trigger logs for workflow run IDs.
|
||||
|
||||
Reference in New Issue
Block a user