feat: Human Input Node (#32060)

The frontend and backend implementation for the human input node.

Co-authored-by: twwu <twwu@dify.ai>
Co-authored-by: JzoNg <jzongcode@gmail.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
QuantumGhost
2026-02-09 14:57:23 +08:00
committed by GitHub
parent 56e3a55023
commit a1fc280102
474 changed files with 32667 additions and 2050 deletions

View File

@@ -0,0 +1,65 @@
from unittest.mock import MagicMock
import services.app_generate_service as app_generate_service_module
from models.model import AppMode
from services.app_generate_service import AppGenerateService
class _DummyRateLimit:
def __init__(self, client_id: str, max_active_requests: int) -> None:
self.client_id = client_id
self.max_active_requests = max_active_requests
@staticmethod
def gen_request_key() -> str:
return "dummy-request-id"
def enter(self, request_id: str | None = None) -> str:
return request_id or "dummy-request-id"
def exit(self, request_id: str) -> None:
return None
def generate(self, generator, request_id: str):
return generator
def test_workflow_blocking_injects_pause_state_config(mocker, monkeypatch):
monkeypatch.setattr(app_generate_service_module.dify_config, "BILLING_ENABLED", False)
mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit)
workflow = MagicMock()
workflow.id = "workflow-id"
workflow.created_by = "owner-id"
mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow)
generator_spy = mocker.patch(
"services.app_generate_service.WorkflowAppGenerator.generate",
return_value={"result": "ok"},
)
app_model = MagicMock()
app_model.mode = AppMode.WORKFLOW
app_model.id = "app-id"
app_model.tenant_id = "tenant-id"
app_model.max_active_requests = 0
app_model.is_agent = False
user = MagicMock()
user.id = "user-id"
result = AppGenerateService.generate(
app_model=app_model,
user=user,
args={"inputs": {"k": "v"}},
invoke_from=MagicMock(),
streaming=False,
)
assert result == {"result": "ok"}
call_kwargs = generator_spy.call_args.kwargs
pause_state_config = call_kwargs.get("pause_state_config")
assert pause_state_config is not None
assert pause_state_config.state_owner_user_id == "owner-id"

View File

@@ -508,9 +508,12 @@ class TestConversationServiceMessageCreation:
within conversations.
"""
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_by_first_id_without_first_id(self, mock_get_conversation, mock_db_session):
def test_pagination_by_first_id_without_first_id(
self, mock_get_conversation, mock_db_session, mock_create_extra_repo
):
"""
Test message pagination without specifying first_id.
@@ -540,6 +543,9 @@ class TestConversationServiceMessageCreation:
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.all.return_value = messages # Final .all() returns the messages
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act - Call the pagination method without first_id
result = MessageService.pagination_by_first_id(
@@ -556,9 +562,10 @@ class TestConversationServiceMessageCreation:
# Verify conversation was looked up with correct parameters
mock_get_conversation.assert_called_once_with(app_model=app_model, user=user, conversation_id=conversation.id)
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session):
def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
"""
Test message pagination with first_id specified.
@@ -590,6 +597,9 @@ class TestConversationServiceMessageCreation:
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.first.return_value = first_message # First message returned
mock_query.all.return_value = messages # Remaining messages returned
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act - Call the pagination method with first_id
result = MessageService.pagination_by_first_id(
@@ -684,9 +694,10 @@ class TestConversationServiceMessageCreation:
assert result.data == []
assert result.has_more is False
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session):
def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
"""
Test that has_more flag is correctly set when there are more messages.
@@ -716,6 +727,9 @@ class TestConversationServiceMessageCreation:
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.all.return_value = messages # Final .all() returns the messages
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act
result = MessageService.pagination_by_first_id(
@@ -730,9 +744,10 @@ class TestConversationServiceMessageCreation:
assert len(result.data) == limit # Extra message should be removed
assert result.has_more is True # Flag should be set
@patch("services.message_service._create_execution_extra_content_repository")
@patch("services.message_service.db.session")
@patch("services.message_service.ConversationService.get_conversation")
def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session):
def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
"""
Test message pagination with ascending order.
@@ -761,6 +776,9 @@ class TestConversationServiceMessageCreation:
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
mock_query.all.return_value = messages # Final .all() returns the messages
mock_repository = MagicMock()
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
mock_create_extra_repo.return_value = mock_repository
# Act
result = MessageService.pagination_by_first_id(

View File

@@ -0,0 +1,104 @@
from dataclasses import dataclass
import pytest
from enums.cloud_plan import CloudPlan
from services import feature_service as feature_service_module
from services.feature_service import FeatureModel, FeatureService
@dataclass(frozen=True)
class HumanInputEmailDeliveryCase:
name: str
enterprise_enabled: bool
billing_enabled: bool
tenant_id: str | None
billing_feature_enabled: bool
plan: str
expected: bool
CASES = [
HumanInputEmailDeliveryCase(
name="enterprise_enabled",
enterprise_enabled=True,
billing_enabled=True,
tenant_id=None,
billing_feature_enabled=False,
plan=CloudPlan.SANDBOX,
expected=True,
),
HumanInputEmailDeliveryCase(
name="billing_disabled",
enterprise_enabled=False,
billing_enabled=False,
tenant_id=None,
billing_feature_enabled=False,
plan=CloudPlan.SANDBOX,
expected=True,
),
HumanInputEmailDeliveryCase(
name="billing_enabled_requires_tenant",
enterprise_enabled=False,
billing_enabled=True,
tenant_id=None,
billing_feature_enabled=True,
plan=CloudPlan.PROFESSIONAL,
expected=False,
),
HumanInputEmailDeliveryCase(
name="billing_feature_off",
enterprise_enabled=False,
billing_enabled=True,
tenant_id="tenant-1",
billing_feature_enabled=False,
plan=CloudPlan.PROFESSIONAL,
expected=False,
),
HumanInputEmailDeliveryCase(
name="professional_plan",
enterprise_enabled=False,
billing_enabled=True,
tenant_id="tenant-1",
billing_feature_enabled=True,
plan=CloudPlan.PROFESSIONAL,
expected=True,
),
HumanInputEmailDeliveryCase(
name="team_plan",
enterprise_enabled=False,
billing_enabled=True,
tenant_id="tenant-1",
billing_feature_enabled=True,
plan=CloudPlan.TEAM,
expected=True,
),
HumanInputEmailDeliveryCase(
name="sandbox_plan",
enterprise_enabled=False,
billing_enabled=True,
tenant_id="tenant-1",
billing_feature_enabled=True,
plan=CloudPlan.SANDBOX,
expected=False,
),
]
@pytest.mark.parametrize("case", CASES, ids=lambda case: case.name)
def test_resolve_human_input_email_delivery_enabled_matrix(
monkeypatch: pytest.MonkeyPatch,
case: HumanInputEmailDeliveryCase,
):
monkeypatch.setattr(feature_service_module.dify_config, "ENTERPRISE_ENABLED", case.enterprise_enabled)
monkeypatch.setattr(feature_service_module.dify_config, "BILLING_ENABLED", case.billing_enabled)
features = FeatureModel()
features.billing.enabled = case.billing_feature_enabled
features.billing.subscription.plan = case.plan
result = FeatureService._resolve_human_input_email_delivery_enabled(
features=features,
tenant_id=case.tenant_id,
)
assert result is case.expected

View File

@@ -0,0 +1,97 @@
from types import SimpleNamespace
import pytest
from core.workflow.nodes.human_input.entities import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
)
from core.workflow.runtime import VariablePool
from services import human_input_delivery_test_service as service_module
from services.human_input_delivery_test_service import (
DeliveryTestContext,
DeliveryTestError,
EmailDeliveryTestHandler,
)
def _make_email_method() -> EmailDeliveryMethod:
return EmailDeliveryMethod(
config=EmailDeliveryConfig(
recipients=EmailRecipients(
whole_workspace=False,
items=[ExternalRecipient(email="tester@example.com")],
),
subject="Test subject",
body="Test body",
)
)
def test_email_delivery_test_handler_rejects_when_feature_disabled(monkeypatch: pytest.MonkeyPatch):
monkeypatch.setattr(
service_module.FeatureService,
"get_features",
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False),
)
handler = EmailDeliveryTestHandler(session_factory=object())
context = DeliveryTestContext(
tenant_id="tenant-1",
app_id="app-1",
node_id="node-1",
node_title="Human Input",
rendered_content="content",
)
method = _make_email_method()
with pytest.raises(DeliveryTestError, match="Email delivery is not available"):
handler.send_test(context=context, method=method)
def test_email_delivery_test_handler_replaces_body_variables(monkeypatch: pytest.MonkeyPatch):
class DummyMail:
def __init__(self):
self.sent: list[dict[str, str]] = []
def is_inited(self) -> bool:
return True
def send(self, *, to: str, subject: str, html: str):
self.sent.append({"to": to, "subject": subject, "html": html})
mail = DummyMail()
monkeypatch.setattr(service_module, "mail", mail)
monkeypatch.setattr(service_module, "render_email_template", lambda template, _substitutions: template)
monkeypatch.setattr(
service_module.FeatureService,
"get_features",
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True),
)
handler = EmailDeliveryTestHandler(session_factory=object())
handler._resolve_recipients = lambda **_kwargs: ["tester@example.com"] # type: ignore[assignment]
method = EmailDeliveryMethod(
config=EmailDeliveryConfig(
recipients=EmailRecipients(whole_workspace=False, items=[ExternalRecipient(email="tester@example.com")]),
subject="Subject",
body="Value {{#node1.value#}}",
)
)
variable_pool = VariablePool()
variable_pool.add(["node1", "value"], "OK")
context = DeliveryTestContext(
tenant_id="tenant-1",
app_id="app-1",
node_id="node-1",
node_title="Human Input",
rendered_content="content",
variable_pool=variable_pool,
)
handler.send_test(context=context, method=method)
assert mail.sent[0]["html"] == "Value OK"

View File

@@ -0,0 +1,290 @@
import dataclasses
from datetime import datetime, timedelta
from unittest.mock import MagicMock
import pytest
import services.human_input_service as human_input_service_module
from core.repositories.human_input_repository import (
HumanInputFormRecord,
HumanInputFormSubmissionRepository,
)
from core.workflow.nodes.human_input.entities import (
FormDefinition,
FormInput,
UserAction,
)
from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus
from models.human_input import RecipientType
from services.human_input_service import Form, FormExpiredError, HumanInputService, InvalidFormDataError
from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE
@pytest.fixture
def mock_session_factory():
session = MagicMock()
session_cm = MagicMock()
session_cm.__enter__.return_value = session
session_cm.__exit__.return_value = None
factory = MagicMock()
factory.return_value = session_cm
return factory, session
@pytest.fixture
def sample_form_record():
return HumanInputFormRecord(
form_id="form-id",
workflow_run_id="workflow-run-id",
node_id="node-id",
tenant_id="tenant-id",
app_id="app-id",
form_kind=HumanInputFormKind.RUNTIME,
definition=FormDefinition(
form_content="hello",
inputs=[],
user_actions=[UserAction(id="submit", title="Submit")],
rendered_content="<p>hello</p>",
expiration_time=datetime.utcnow() + timedelta(hours=1),
),
rendered_content="<p>hello</p>",
created_at=datetime.utcnow(),
expiration_time=datetime.utcnow() + timedelta(hours=1),
status=HumanInputFormStatus.WAITING,
selected_action_id=None,
submitted_data=None,
submitted_at=None,
submission_user_id=None,
submission_end_user_id=None,
completed_by_recipient_id=None,
recipient_id="recipient-id",
recipient_type=RecipientType.STANDALONE_WEB_APP,
access_token="token",
)
def test_enqueue_resume_dispatches_task_for_workflow(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
workflow_run = MagicMock()
workflow_run.app_id = "app-id"
workflow_run_repo = MagicMock()
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
mocker.patch(
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
return_value=workflow_run_repo,
)
app = MagicMock()
app.mode = "workflow"
session.execute.return_value.scalar_one_or_none.return_value = app
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
service.enqueue_resume("workflow-run-id")
resume_task.apply_async.assert_called_once()
call_kwargs = resume_task.apply_async.call_args.kwargs
assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
def test_ensure_form_active_respects_global_timeout(monkeypatch, sample_form_record, mock_session_factory):
session_factory, _ = mock_session_factory
service = HumanInputService(session_factory)
expired_record = dataclasses.replace(
sample_form_record,
created_at=datetime.utcnow() - timedelta(hours=2),
expiration_time=datetime.utcnow() + timedelta(hours=2),
)
monkeypatch.setattr(human_input_service_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 3600)
with pytest.raises(FormExpiredError):
service.ensure_form_active(Form(expired_record))
def test_enqueue_resume_dispatches_task_for_advanced_chat(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
workflow_run = MagicMock()
workflow_run.app_id = "app-id"
workflow_run_repo = MagicMock()
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
mocker.patch(
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
return_value=workflow_run_repo,
)
app = MagicMock()
app.mode = "advanced-chat"
session.execute.return_value.scalar_one_or_none.return_value = app
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
service.enqueue_resume("workflow-run-id")
resume_task.apply_async.assert_called_once()
call_kwargs = resume_task.apply_async.call_args.kwargs
assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
def test_enqueue_resume_skips_unsupported_app_mode(mocker, mock_session_factory):
session_factory, session = mock_session_factory
service = HumanInputService(session_factory)
workflow_run = MagicMock()
workflow_run.app_id = "app-id"
workflow_run_repo = MagicMock()
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
mocker.patch(
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
return_value=workflow_run_repo,
)
app = MagicMock()
app.mode = "completion"
session.execute.return_value.scalar_one_or_none.return_value = app
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
service.enqueue_resume("workflow-run-id")
resume_task.apply_async.assert_not_called()
def test_get_form_definition_by_token_for_console_uses_repository(sample_form_record, mock_session_factory):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
console_record = dataclasses.replace(sample_form_record, recipient_type=RecipientType.CONSOLE)
repo.get_by_token.return_value = console_record
service = HumanInputService(session_factory, form_repository=repo)
form = service.get_form_definition_by_token_for_console("token")
repo.get_by_token.assert_called_once_with("token")
assert form is not None
assert form.get_definition() == console_record.definition
def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, mock_session_factory, mocker):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
repo.get_by_token.return_value = sample_form_record
repo.mark_submitted.return_value = sample_form_record
service = HumanInputService(session_factory, form_repository=repo)
enqueue_spy = mocker.patch.object(service, "enqueue_resume")
service.submit_form_by_token(
recipient_type=RecipientType.STANDALONE_WEB_APP,
form_token="token",
selected_action_id="submit",
form_data={"field": "value"},
submission_end_user_id="end-user-id",
)
repo.get_by_token.assert_called_once_with("token")
repo.mark_submitted.assert_called_once()
call_kwargs = repo.mark_submitted.call_args.kwargs
assert call_kwargs["form_id"] == sample_form_record.form_id
assert call_kwargs["recipient_id"] == sample_form_record.recipient_id
assert call_kwargs["selected_action_id"] == "submit"
assert call_kwargs["form_data"] == {"field": "value"}
assert call_kwargs["submission_end_user_id"] == "end-user-id"
enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id)
def test_submit_form_by_token_skips_enqueue_for_delivery_test(sample_form_record, mock_session_factory, mocker):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
test_record = dataclasses.replace(
sample_form_record,
form_kind=HumanInputFormKind.DELIVERY_TEST,
workflow_run_id=None,
)
repo.get_by_token.return_value = test_record
repo.mark_submitted.return_value = test_record
service = HumanInputService(session_factory, form_repository=repo)
enqueue_spy = mocker.patch.object(service, "enqueue_resume")
service.submit_form_by_token(
recipient_type=RecipientType.STANDALONE_WEB_APP,
form_token="token",
selected_action_id="submit",
form_data={"field": "value"},
)
enqueue_spy.assert_not_called()
def test_submit_form_by_token_passes_submission_user_id(sample_form_record, mock_session_factory, mocker):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
repo.get_by_token.return_value = sample_form_record
repo.mark_submitted.return_value = sample_form_record
service = HumanInputService(session_factory, form_repository=repo)
enqueue_spy = mocker.patch.object(service, "enqueue_resume")
service.submit_form_by_token(
recipient_type=RecipientType.STANDALONE_WEB_APP,
form_token="token",
selected_action_id="submit",
form_data={"field": "value"},
submission_user_id="account-id",
)
call_kwargs = repo.mark_submitted.call_args.kwargs
assert call_kwargs["submission_user_id"] == "account-id"
assert call_kwargs["submission_end_user_id"] is None
enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id)
def test_submit_form_by_token_invalid_action(sample_form_record, mock_session_factory):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
repo.get_by_token.return_value = dataclasses.replace(sample_form_record)
service = HumanInputService(session_factory, form_repository=repo)
with pytest.raises(InvalidFormDataError) as exc_info:
service.submit_form_by_token(
recipient_type=RecipientType.STANDALONE_WEB_APP,
form_token="token",
selected_action_id="invalid",
form_data={},
)
assert "Invalid action" in str(exc_info.value)
repo.mark_submitted.assert_not_called()
def test_submit_form_by_token_missing_inputs(sample_form_record, mock_session_factory):
session_factory, _ = mock_session_factory
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
definition_with_input = FormDefinition(
form_content="hello",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content")],
user_actions=sample_form_record.definition.user_actions,
rendered_content="<p>hello</p>",
expiration_time=sample_form_record.expiration_time,
)
form_with_input = dataclasses.replace(sample_form_record, definition=definition_with_input)
repo.get_by_token.return_value = form_with_input
service = HumanInputService(session_factory, form_repository=repo)
with pytest.raises(InvalidFormDataError) as exc_info:
service.submit_form_by_token(
recipient_type=RecipientType.STANDALONE_WEB_APP,
form_token="token",
selected_action_id="submit",
form_data={},
)
assert "Missing required inputs" in str(exc_info.value)
repo.mark_submitted.assert_not_called()

View File

@@ -0,0 +1,61 @@
from __future__ import annotations
import pytest
from core.entities.execution_extra_content import HumanInputContent, HumanInputFormSubmissionData
from services import message_service
class _FakeMessage:
def __init__(self, message_id: str):
self.id = message_id
self.extra_contents = None
def set_extra_contents(self, contents):
self.extra_contents = contents
def test_attach_message_extra_contents_assigns_serialized_payload(monkeypatch: pytest.MonkeyPatch) -> None:
messages = [_FakeMessage("msg-1"), _FakeMessage("msg-2")]
repo = type(
"Repo",
(),
{
"get_by_message_ids": lambda _self, message_ids: [
[
HumanInputContent(
workflow_run_id="workflow-run-1",
submitted=True,
form_submission_data=HumanInputFormSubmissionData(
node_id="node-1",
node_title="Approval",
rendered_content="Rendered",
action_id="approve",
action_text="Approve",
),
)
],
[],
]
},
)()
monkeypatch.setattr(message_service, "_create_execution_extra_content_repository", lambda: repo)
message_service.attach_message_extra_contents(messages)
assert messages[0].extra_contents == [
{
"type": "human_input",
"workflow_run_id": "workflow-run-1",
"submitted": True,
"form_submission_data": {
"node_id": "node-1",
"node_title": "Approval",
"rendered_content": "Rendered",
"action_id": "approve",
"action_text": "Approve",
},
}
]
assert messages[1].extra_contents == []

View File

@@ -35,7 +35,6 @@ class TestDataFactory:
app_id: str = "app-789",
workflow_id: str = "workflow-101",
status: str | WorkflowExecutionStatus = "paused",
pause_id: str | None = None,
**kwargs,
) -> MagicMock:
"""Create a mock WorkflowRun object."""
@@ -45,7 +44,6 @@ class TestDataFactory:
mock_run.app_id = app_id
mock_run.workflow_id = workflow_id
mock_run.status = status
mock_run.pause_id = pause_id
for key, value in kwargs.items():
setattr(mock_run, key, value)

View File

@@ -0,0 +1,162 @@
import json
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from models.model import App
from models.tools import WorkflowToolProvider
from services.tools import workflow_tools_manage_service
class DummyWorkflow:
def __init__(self, graph_dict: dict, version: str = "1.0.0") -> None:
self._graph_dict = graph_dict
self.version = version
@property
def graph_dict(self) -> dict:
return self._graph_dict
class FakeQuery:
def __init__(self, result):
self._result = result
def where(self, *args, **kwargs):
return self
def first(self):
return self._result
class DummySession:
def __init__(self) -> None:
self.added: list[object] = []
def __enter__(self) -> "DummySession":
return self
def __exit__(self, exc_type, exc, tb) -> bool:
return False
def add(self, obj) -> None:
self.added.append(obj)
def begin(self):
return DummyBegin(self)
class DummyBegin:
def __init__(self, session: DummySession) -> None:
self._session = session
def __enter__(self) -> DummySession:
return self._session
def __exit__(self, exc_type, exc, tb) -> bool:
return False
class DummySessionContext:
def __init__(self, session: DummySession) -> None:
self._session = session
def __enter__(self) -> DummySession:
return self._session
def __exit__(self, exc_type, exc, tb) -> bool:
return False
class DummySessionFactory:
def __init__(self, session: DummySession) -> None:
self._session = session
def create_session(self) -> DummySessionContext:
return DummySessionContext(self._session)
def _build_fake_session(app) -> SimpleNamespace:
def query(model):
if model is WorkflowToolProvider:
return FakeQuery(None)
if model is App:
return FakeQuery(app)
return FakeQuery(None)
return SimpleNamespace(query=query)
def _build_parameters() -> list[WorkflowToolParameterConfiguration]:
return [
WorkflowToolParameterConfiguration(name="input", description="input", form=ToolParameter.ToolParameterForm.LLM),
]
def test_create_workflow_tool_rejects_human_input_nodes(monkeypatch):
workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "human-input"}}]})
app = SimpleNamespace(workflow=workflow)
fake_session = _build_fake_session(app)
monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session)
mock_from_db = MagicMock()
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db)
mock_invalidate = MagicMock()
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool(
user_id="user-id",
tenant_id="tenant-id",
workflow_app_id="app-id",
name="tool_name",
label="Tool",
icon={"type": "emoji", "emoji": "tool"},
description="desc",
parameters=_build_parameters(),
)
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
mock_from_db.assert_not_called()
mock_invalidate.assert_not_called()
def test_create_workflow_tool_success(monkeypatch):
workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "start"}}]})
app = SimpleNamespace(workflow=workflow)
fake_db = MagicMock()
fake_session = _build_fake_session(app)
fake_db.session = fake_session
monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db)
dummy_session = DummySession()
monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session)
mock_from_db = MagicMock()
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db)
icon = {"type": "emoji", "emoji": "tool"}
result = workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool(
user_id="user-id",
tenant_id="tenant-id",
workflow_app_id="app-id",
name="tool_name",
label="Tool",
icon=icon,
description="desc",
parameters=_build_parameters(),
)
assert result == {"result": "success"}
assert len(dummy_session.added) == 1
created_provider = dummy_session.added[0]
assert created_provider.name == "tool_name"
assert created_provider.label == "Tool"
assert created_provider.icon == json.dumps(icon)
assert created_provider.version == workflow.version
mock_from_db.assert_called_once()

View File

@@ -0,0 +1,226 @@
from __future__ import annotations
import json
import queue
from collections.abc import Sequence
from dataclasses import dataclass
from datetime import UTC, datetime
from threading import Event
import pytest
from core.app.app_config.entities import WorkflowUIBasedAppConfig
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import CreatorUserRole
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
from repositories.entities.workflow_pause import WorkflowPauseEntity
from services.workflow_event_snapshot_service import (
BufferState,
MessageContext,
_build_snapshot_events,
_resolve_task_id,
)
@dataclass(frozen=True)
class _FakePauseEntity(WorkflowPauseEntity):
pause_id: str
workflow_run_id: str
paused_at_value: datetime
pause_reasons: Sequence[HumanInputRequired]
@property
def id(self) -> str:
return self.pause_id
@property
def workflow_execution_id(self) -> str:
return self.workflow_run_id
def get_state(self) -> bytes:
raise AssertionError("state is not required for snapshot tests")
@property
def resumed_at(self) -> datetime | None:
return None
@property
def paused_at(self) -> datetime:
return self.paused_at_value
def get_pause_reasons(self) -> Sequence[HumanInputRequired]:
return self.pause_reasons
def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun:
return WorkflowRun(
id="run-1",
tenant_id="tenant-1",
app_id="app-1",
workflow_id="workflow-1",
type="workflow",
triggered_from="app-run",
version="v1",
graph=None,
inputs=json.dumps({"input": "value"}),
status=status,
outputs=json.dumps({}),
error=None,
elapsed_time=0.0,
total_tokens=0,
total_steps=0,
created_by_role=CreatorUserRole.END_USER,
created_by="user-1",
created_at=datetime(2024, 1, 1, tzinfo=UTC),
)
def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot:
created_at = datetime(2024, 1, 1, tzinfo=UTC)
finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC)
return WorkflowNodeExecutionSnapshot(
execution_id="exec-1",
node_id="node-1",
node_type="human-input",
title="Human Input",
index=1,
status=status.value,
elapsed_time=0.5,
created_at=created_at,
finished_at=finished_at,
iteration_id=None,
loop_id=None,
)
def _build_resumption_context(task_id: str) -> WorkflowResumptionContext:
app_config = WorkflowUIBasedAppConfig(
tenant_id="tenant-1",
app_id="app-1",
app_mode=AppMode.WORKFLOW,
workflow_id="workflow-1",
)
generate_entity = WorkflowAppGenerateEntity(
task_id=task_id,
app_config=app_config,
inputs={},
files=[],
user_id="user-1",
stream=True,
invoke_from=InvokeFrom.EXPLORE,
call_depth=0,
workflow_execution_id="run-1",
)
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
runtime_state.register_paused_node("node-1")
runtime_state.outputs = {"result": "value"}
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
return WorkflowResumptionContext(
generate_entity=wrapper,
serialized_graph_runtime_state=runtime_state.dumps(),
)
def test_build_snapshot_events_includes_pause_event() -> None:
workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED)
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED)
resumption_context = _build_resumption_context("task-ctx")
pause_entity = _FakePauseEntity(
pause_id="pause-1",
workflow_run_id="run-1",
paused_at_value=datetime(2024, 1, 1, tzinfo=UTC),
pause_reasons=[
HumanInputRequired(
form_id="form-1",
form_content="content",
node_id="node-1",
node_title="Human Input",
)
],
)
events = _build_snapshot_events(
workflow_run=workflow_run,
node_snapshots=[snapshot],
task_id="task-ctx",
message_context=None,
pause_entity=pause_entity,
resumption_context=resumption_context,
)
assert [event["event"] for event in events] == [
"workflow_started",
"node_started",
"node_finished",
"workflow_paused",
]
assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value
pause_data = events[-1]["data"]
assert pause_data["paused_nodes"] == ["node-1"]
assert pause_data["outputs"] == {"result": "value"}
assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value
assert pause_data["created_at"] == int(workflow_run.created_at.timestamp())
assert pause_data["elapsed_time"] == workflow_run.elapsed_time
assert pause_data["total_tokens"] == workflow_run.total_tokens
assert pause_data["total_steps"] == workflow_run.total_steps
def test_build_snapshot_events_applies_message_context() -> None:
workflow_run = _build_workflow_run(WorkflowExecutionStatus.RUNNING)
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.SUCCEEDED)
message_context = MessageContext(
conversation_id="conv-1",
message_id="msg-1",
created_at=1700000000,
answer="snapshot message",
)
events = _build_snapshot_events(
workflow_run=workflow_run,
node_snapshots=[snapshot],
task_id="task-1",
message_context=message_context,
pause_entity=None,
resumption_context=None,
)
assert [event["event"] for event in events] == [
"workflow_started",
"message_replace",
"node_started",
"node_finished",
]
assert events[1]["answer"] == "snapshot message"
for event in events:
assert event["conversation_id"] == "conv-1"
assert event["message_id"] == "msg-1"
assert event["created_at"] == 1700000000
@pytest.mark.parametrize(
("context_task_id", "buffered_task_id", "expected"),
[
("task-ctx", "task-buffer", "task-ctx"),
(None, "task-buffer", "task-buffer"),
(None, None, "run-1"),
],
)
def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -> None:
resumption_context = _build_resumption_context(context_task_id) if context_task_id else None
buffer_state = BufferState(
queue=queue.Queue(),
stop_event=Event(),
done_event=Event(),
task_id_ready=Event(),
task_id_hint=buffered_task_id,
)
if buffered_task_id:
buffer_state.task_id_ready.set()
task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0)
assert task_id == expected

View File

@@ -0,0 +1,184 @@
import uuid
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from sqlalchemy.orm import sessionmaker
from core.workflow.enums import NodeType
from core.workflow.nodes.human_input.entities import (
EmailDeliveryConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
HumanInputNodeData,
MemberRecipient,
)
from services import workflow_service as workflow_service_module
from services.workflow_service import WorkflowService
def _make_service() -> WorkflowService:
return WorkflowService(session_maker=sessionmaker())
def _build_node_config(delivery_methods):
node_data = HumanInputNodeData(
title="Human Input",
delivery_methods=delivery_methods,
form_content="Test content",
inputs=[],
user_actions=[],
).model_dump(mode="json")
node_data["type"] = NodeType.HUMAN_INPUT.value
return {"id": "node-1", "data": node_data}
def _make_email_method(enabled: bool = True, debug_mode: bool = False) -> EmailDeliveryMethod:
return EmailDeliveryMethod(
id=uuid.uuid4(),
enabled=enabled,
config=EmailDeliveryConfig(
recipients=EmailRecipients(
whole_workspace=False,
items=[ExternalRecipient(email="tester@example.com")],
),
subject="Test subject",
body="Test body",
debug_mode=debug_mode,
),
)
def test_human_input_delivery_requires_draft_workflow():
service = _make_service()
service.get_draft_workflow = MagicMock(return_value=None) # type: ignore[method-assign]
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
account = SimpleNamespace(id="account-1")
with pytest.raises(ValueError, match="Workflow not initialized"):
service.test_human_input_delivery(
app_model=app_model,
account=account,
node_id="node-1",
delivery_method_id="delivery-1",
)
def test_human_input_delivery_allows_disabled_method(monkeypatch: pytest.MonkeyPatch):
service = _make_service()
delivery_method = _make_email_method(enabled=False)
node_config = _build_node_config([delivery_method])
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = node_config
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined]
node_stub = MagicMock()
node_stub._render_form_content_before_submission.return_value = "rendered"
node_stub._resolve_default_values.return_value = {}
service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined]
service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined]
return_value=("form-1", {})
)
test_service_instance = MagicMock()
monkeypatch.setattr(
workflow_service_module,
"HumanInputDeliveryTestService",
MagicMock(return_value=test_service_instance),
)
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
account = SimpleNamespace(id="account-1")
service.test_human_input_delivery(
app_model=app_model,
account=account,
node_id="node-1",
delivery_method_id=str(delivery_method.id),
)
test_service_instance.send_test.assert_called_once()
def test_human_input_delivery_dispatches_to_test_service(monkeypatch: pytest.MonkeyPatch):
service = _make_service()
delivery_method = _make_email_method(enabled=True)
node_config = _build_node_config([delivery_method])
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = node_config
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined]
node_stub = MagicMock()
node_stub._render_form_content_before_submission.return_value = "rendered"
node_stub._resolve_default_values.return_value = {}
service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined]
service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined]
return_value=("form-1", {})
)
test_service_instance = MagicMock()
monkeypatch.setattr(
workflow_service_module,
"HumanInputDeliveryTestService",
MagicMock(return_value=test_service_instance),
)
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
account = SimpleNamespace(id="account-1")
service.test_human_input_delivery(
app_model=app_model,
account=account,
node_id="node-1",
delivery_method_id=str(delivery_method.id),
inputs={"#node-1.output#": "value"},
)
pool_args = service._build_human_input_variable_pool.call_args.kwargs
assert pool_args["manual_inputs"] == {"#node-1.output#": "value"}
test_service_instance.send_test.assert_called_once()
def test_human_input_delivery_debug_mode_overrides_recipients(monkeypatch: pytest.MonkeyPatch):
service = _make_service()
delivery_method = _make_email_method(enabled=True, debug_mode=True)
node_config = _build_node_config([delivery_method])
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = node_config
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined]
node_stub = MagicMock()
node_stub._render_form_content_before_submission.return_value = "rendered"
node_stub._resolve_default_values.return_value = {}
service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined]
service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined]
return_value=("form-1", {})
)
test_service_instance = MagicMock()
monkeypatch.setattr(
workflow_service_module,
"HumanInputDeliveryTestService",
MagicMock(return_value=test_service_instance),
)
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
account = SimpleNamespace(id="account-1")
service.test_human_input_delivery(
app_model=app_model,
account=account,
node_id="node-1",
delivery_method_id=str(delivery_method.id),
)
test_service_instance.send_test.assert_called_once()
sent_method = test_service_instance.send_test.call_args.kwargs["method"]
assert isinstance(sent_method, EmailDeliveryMethod)
assert sent_method.config.debug_mode is True
assert sent_method.config.recipients.whole_workspace is False
assert len(sent_method.config.recipients.items) == 1
recipient = sent_method.config.recipients.items[0]
assert isinstance(recipient, MemberRecipient)
assert recipient.user_id == account.id

View File

@@ -5,6 +5,7 @@ from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from core.workflow.enums import WorkflowNodeExecutionStatus
from models.workflow import WorkflowNodeExecutionModel
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
@@ -52,6 +53,9 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
call_args = mock_session.scalar.call_args[0][0]
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
compiled = call_args.compile()
assert WorkflowNodeExecutionStatus.PAUSED in compiled.params.values()
def test_get_node_last_execution_not_found(self, repository):
"""Test getting the last execution for a node when it doesn't exist."""
# Arrange
@@ -71,28 +75,6 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
assert result is None
mock_session.scalar.assert_called_once()
def test_get_executions_by_workflow_run(self, repository, mock_execution):
"""Test getting all executions for a workflow run."""
# Arrange
mock_session = MagicMock(spec=Session)
repository._session_maker.return_value.__enter__.return_value = mock_session
executions = [mock_execution]
mock_session.execute.return_value.scalars.return_value.all.return_value = executions
# Act
result = repository.get_executions_by_workflow_run(
tenant_id="tenant-123",
app_id="app-456",
workflow_run_id="run-101",
)
# Assert
assert result == executions
mock_session.execute.assert_called_once()
# Verify the query was constructed correctly
call_args = mock_session.execute.call_args[0][0]
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
def test_get_executions_by_workflow_run_empty(self, repository):
"""Test getting executions for a workflow run when none exist."""
# Arrange

View File

@@ -1,9 +1,15 @@
from contextlib import nullcontext
from types import SimpleNamespace
from unittest.mock import MagicMock
import pytest
from core.workflow.enums import NodeType
from core.workflow.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction
from core.workflow.nodes.human_input.enums import FormInputType
from models.model import App
from models.workflow import Workflow
from services import workflow_service as workflow_service_module
from services.workflow_service import WorkflowService
@@ -161,3 +167,120 @@ class TestWorkflowService:
assert workflows == []
assert has_more is False
mock_session.scalars.assert_called_once()
def test_submit_human_input_form_preview_uses_rendered_content(
self, workflow_service: WorkflowService, monkeypatch: pytest.MonkeyPatch
) -> None:
service = workflow_service
node_data = HumanInputNodeData(
title="Human Input",
form_content="<p>{{#$output.name#}}</p>",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
user_actions=[UserAction(id="approve", title="Approve")],
)
node = MagicMock()
node.node_data = node_data
node.render_form_content_before_submission.return_value = "<p>preview</p>"
node.render_form_content_with_outputs.return_value = "<p>rendered</p>"
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign]
service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign]
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
workflow.get_enclosing_node_type_and_id.return_value = None
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
saved_outputs: dict[str, object] = {}
class DummySession:
def __init__(self, *args, **kwargs):
self.commit = MagicMock()
def __enter__(self):
return self
def __exit__(self, exc_type, exc, tb):
return False
def begin(self):
return nullcontext()
class DummySaver:
def __init__(self, *args, **kwargs):
pass
def save(self, outputs, process_data):
saved_outputs.update(outputs)
monkeypatch.setattr(workflow_service_module, "Session", DummySession)
monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver)
monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock()))
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
account = SimpleNamespace(id="account-1")
result = service.submit_human_input_form_preview(
app_model=app_model,
account=account,
node_id="node-1",
form_inputs={"name": "Ada", "extra": "ignored"},
inputs={"#node-0.result#": "LLM output"},
action="approve",
)
service._build_human_input_variable_pool.assert_called_once_with(
app_model=app_model,
workflow=workflow,
node_config={"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}},
manual_inputs={"#node-0.result#": "LLM output"},
)
node.render_form_content_with_outputs.assert_called_once()
called_args = node.render_form_content_with_outputs.call_args.args
assert called_args[0] == "<p>preview</p>"
assert called_args[2] == node_data.outputs_field_names()
rendered_outputs = called_args[1]
assert rendered_outputs["name"] == "Ada"
assert rendered_outputs["extra"] == "ignored"
assert "extra" in saved_outputs
assert "extra" in result
assert saved_outputs["name"] == "Ada"
assert result["name"] == "Ada"
assert result["__action_id"] == "approve"
assert "__rendered_content" in result
def test_submit_human_input_form_preview_missing_inputs_message(self, workflow_service: WorkflowService) -> None:
service = workflow_service
node_data = HumanInputNodeData(
title="Human Input",
form_content="<p>{{#$output.name#}}</p>",
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
user_actions=[UserAction(id="approve", title="Approve")],
)
node = MagicMock()
node.node_data = node_data
node._render_form_content_before_submission.return_value = "<p>preview</p>"
node._render_form_content_with_outputs.return_value = "<p>rendered</p>"
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign]
service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign]
workflow = MagicMock()
workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
account = SimpleNamespace(id="account-1")
with pytest.raises(ValueError) as exc_info:
service.submit_human_input_form_preview(
app_model=app_model,
account=account,
node_id="node-1",
form_inputs={},
inputs={},
action="approve",
)
assert "Missing required inputs" in str(exc_info.value)