mirror of
https://github.com/langgenius/dify.git
synced 2026-02-25 02:35:12 +00:00
feat: Human Input Node (#32060)
The frontend and backend implementation for the human input node. Co-authored-by: twwu <twwu@dify.ai> Co-authored-by: JzoNg <jzongcode@gmail.com> Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com> Co-authored-by: zhsama <torvalds@linux.do>
This commit is contained in:
65
api/tests/unit_tests/services/test_app_generate_service.py
Normal file
65
api/tests/unit_tests/services/test_app_generate_service.py
Normal file
@@ -0,0 +1,65 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import services.app_generate_service as app_generate_service_module
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
|
||||
class _DummyRateLimit:
|
||||
def __init__(self, client_id: str, max_active_requests: int) -> None:
|
||||
self.client_id = client_id
|
||||
self.max_active_requests = max_active_requests
|
||||
|
||||
@staticmethod
|
||||
def gen_request_key() -> str:
|
||||
return "dummy-request-id"
|
||||
|
||||
def enter(self, request_id: str | None = None) -> str:
|
||||
return request_id or "dummy-request-id"
|
||||
|
||||
def exit(self, request_id: str) -> None:
|
||||
return None
|
||||
|
||||
def generate(self, generator, request_id: str):
|
||||
return generator
|
||||
|
||||
|
||||
def test_workflow_blocking_injects_pause_state_config(mocker, monkeypatch):
|
||||
monkeypatch.setattr(app_generate_service_module.dify_config, "BILLING_ENABLED", False)
|
||||
mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit)
|
||||
|
||||
workflow = MagicMock()
|
||||
workflow.id = "workflow-id"
|
||||
workflow.created_by = "owner-id"
|
||||
|
||||
mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow)
|
||||
|
||||
generator_spy = mocker.patch(
|
||||
"services.app_generate_service.WorkflowAppGenerator.generate",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
|
||||
app_model = MagicMock()
|
||||
app_model.mode = AppMode.WORKFLOW
|
||||
app_model.id = "app-id"
|
||||
app_model.tenant_id = "tenant-id"
|
||||
app_model.max_active_requests = 0
|
||||
app_model.is_agent = False
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "user-id"
|
||||
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args={"inputs": {"k": "v"}},
|
||||
invoke_from=MagicMock(),
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
assert result == {"result": "ok"}
|
||||
|
||||
call_kwargs = generator_spy.call_args.kwargs
|
||||
pause_state_config = call_kwargs.get("pause_state_config")
|
||||
assert pause_state_config is not None
|
||||
assert pause_state_config.state_owner_user_id == "owner-id"
|
||||
@@ -508,9 +508,12 @@ class TestConversationServiceMessageCreation:
|
||||
within conversations.
|
||||
"""
|
||||
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db.session")
|
||||
@patch("services.message_service.ConversationService.get_conversation")
|
||||
def test_pagination_by_first_id_without_first_id(self, mock_get_conversation, mock_db_session):
|
||||
def test_pagination_by_first_id_without_first_id(
|
||||
self, mock_get_conversation, mock_db_session, mock_create_extra_repo
|
||||
):
|
||||
"""
|
||||
Test message pagination without specifying first_id.
|
||||
|
||||
@@ -540,6 +543,9 @@ class TestConversationServiceMessageCreation:
|
||||
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
|
||||
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
|
||||
mock_query.all.return_value = messages # Final .all() returns the messages
|
||||
mock_repository = MagicMock()
|
||||
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
|
||||
mock_create_extra_repo.return_value = mock_repository
|
||||
|
||||
# Act - Call the pagination method without first_id
|
||||
result = MessageService.pagination_by_first_id(
|
||||
@@ -556,9 +562,10 @@ class TestConversationServiceMessageCreation:
|
||||
# Verify conversation was looked up with correct parameters
|
||||
mock_get_conversation.assert_called_once_with(app_model=app_model, user=user, conversation_id=conversation.id)
|
||||
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db.session")
|
||||
@patch("services.message_service.ConversationService.get_conversation")
|
||||
def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session):
|
||||
def test_pagination_by_first_id_with_first_id(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
|
||||
"""
|
||||
Test message pagination with first_id specified.
|
||||
|
||||
@@ -590,6 +597,9 @@ class TestConversationServiceMessageCreation:
|
||||
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
|
||||
mock_query.first.return_value = first_message # First message returned
|
||||
mock_query.all.return_value = messages # Remaining messages returned
|
||||
mock_repository = MagicMock()
|
||||
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
|
||||
mock_create_extra_repo.return_value = mock_repository
|
||||
|
||||
# Act - Call the pagination method with first_id
|
||||
result = MessageService.pagination_by_first_id(
|
||||
@@ -684,9 +694,10 @@ class TestConversationServiceMessageCreation:
|
||||
assert result.data == []
|
||||
assert result.has_more is False
|
||||
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db.session")
|
||||
@patch("services.message_service.ConversationService.get_conversation")
|
||||
def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session):
|
||||
def test_pagination_with_has_more_flag(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
|
||||
"""
|
||||
Test that has_more flag is correctly set when there are more messages.
|
||||
|
||||
@@ -716,6 +727,9 @@ class TestConversationServiceMessageCreation:
|
||||
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
|
||||
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
|
||||
mock_query.all.return_value = messages # Final .all() returns the messages
|
||||
mock_repository = MagicMock()
|
||||
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
|
||||
mock_create_extra_repo.return_value = mock_repository
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
@@ -730,9 +744,10 @@ class TestConversationServiceMessageCreation:
|
||||
assert len(result.data) == limit # Extra message should be removed
|
||||
assert result.has_more is True # Flag should be set
|
||||
|
||||
@patch("services.message_service._create_execution_extra_content_repository")
|
||||
@patch("services.message_service.db.session")
|
||||
@patch("services.message_service.ConversationService.get_conversation")
|
||||
def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session):
|
||||
def test_pagination_with_ascending_order(self, mock_get_conversation, mock_db_session, mock_create_extra_repo):
|
||||
"""
|
||||
Test message pagination with ascending order.
|
||||
|
||||
@@ -761,6 +776,9 @@ class TestConversationServiceMessageCreation:
|
||||
mock_query.order_by.return_value = mock_query # ORDER BY returns self for chaining
|
||||
mock_query.limit.return_value = mock_query # LIMIT returns self for chaining
|
||||
mock_query.all.return_value = messages # Final .all() returns the messages
|
||||
mock_repository = MagicMock()
|
||||
mock_repository.get_by_message_ids.return_value = [[] for _ in messages]
|
||||
mock_create_extra_repo.return_value = mock_repository
|
||||
|
||||
# Act
|
||||
result = MessageService.pagination_by_first_id(
|
||||
|
||||
@@ -0,0 +1,104 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services import feature_service as feature_service_module
|
||||
from services.feature_service import FeatureModel, FeatureService
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class HumanInputEmailDeliveryCase:
|
||||
name: str
|
||||
enterprise_enabled: bool
|
||||
billing_enabled: bool
|
||||
tenant_id: str | None
|
||||
billing_feature_enabled: bool
|
||||
plan: str
|
||||
expected: bool
|
||||
|
||||
|
||||
CASES = [
|
||||
HumanInputEmailDeliveryCase(
|
||||
name="enterprise_enabled",
|
||||
enterprise_enabled=True,
|
||||
billing_enabled=True,
|
||||
tenant_id=None,
|
||||
billing_feature_enabled=False,
|
||||
plan=CloudPlan.SANDBOX,
|
||||
expected=True,
|
||||
),
|
||||
HumanInputEmailDeliveryCase(
|
||||
name="billing_disabled",
|
||||
enterprise_enabled=False,
|
||||
billing_enabled=False,
|
||||
tenant_id=None,
|
||||
billing_feature_enabled=False,
|
||||
plan=CloudPlan.SANDBOX,
|
||||
expected=True,
|
||||
),
|
||||
HumanInputEmailDeliveryCase(
|
||||
name="billing_enabled_requires_tenant",
|
||||
enterprise_enabled=False,
|
||||
billing_enabled=True,
|
||||
tenant_id=None,
|
||||
billing_feature_enabled=True,
|
||||
plan=CloudPlan.PROFESSIONAL,
|
||||
expected=False,
|
||||
),
|
||||
HumanInputEmailDeliveryCase(
|
||||
name="billing_feature_off",
|
||||
enterprise_enabled=False,
|
||||
billing_enabled=True,
|
||||
tenant_id="tenant-1",
|
||||
billing_feature_enabled=False,
|
||||
plan=CloudPlan.PROFESSIONAL,
|
||||
expected=False,
|
||||
),
|
||||
HumanInputEmailDeliveryCase(
|
||||
name="professional_plan",
|
||||
enterprise_enabled=False,
|
||||
billing_enabled=True,
|
||||
tenant_id="tenant-1",
|
||||
billing_feature_enabled=True,
|
||||
plan=CloudPlan.PROFESSIONAL,
|
||||
expected=True,
|
||||
),
|
||||
HumanInputEmailDeliveryCase(
|
||||
name="team_plan",
|
||||
enterprise_enabled=False,
|
||||
billing_enabled=True,
|
||||
tenant_id="tenant-1",
|
||||
billing_feature_enabled=True,
|
||||
plan=CloudPlan.TEAM,
|
||||
expected=True,
|
||||
),
|
||||
HumanInputEmailDeliveryCase(
|
||||
name="sandbox_plan",
|
||||
enterprise_enabled=False,
|
||||
billing_enabled=True,
|
||||
tenant_id="tenant-1",
|
||||
billing_feature_enabled=True,
|
||||
plan=CloudPlan.SANDBOX,
|
||||
expected=False,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.parametrize("case", CASES, ids=lambda case: case.name)
|
||||
def test_resolve_human_input_email_delivery_enabled_matrix(
|
||||
monkeypatch: pytest.MonkeyPatch,
|
||||
case: HumanInputEmailDeliveryCase,
|
||||
):
|
||||
monkeypatch.setattr(feature_service_module.dify_config, "ENTERPRISE_ENABLED", case.enterprise_enabled)
|
||||
monkeypatch.setattr(feature_service_module.dify_config, "BILLING_ENABLED", case.billing_enabled)
|
||||
features = FeatureModel()
|
||||
features.billing.enabled = case.billing_feature_enabled
|
||||
features.billing.subscription.plan = case.plan
|
||||
|
||||
result = FeatureService._resolve_human_input_email_delivery_enabled(
|
||||
features=features,
|
||||
tenant_id=case.tenant_id,
|
||||
)
|
||||
|
||||
assert result is case.expected
|
||||
@@ -0,0 +1,97 @@
|
||||
from types import SimpleNamespace
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
)
|
||||
from core.workflow.runtime import VariablePool
|
||||
from services import human_input_delivery_test_service as service_module
|
||||
from services.human_input_delivery_test_service import (
|
||||
DeliveryTestContext,
|
||||
DeliveryTestError,
|
||||
EmailDeliveryTestHandler,
|
||||
)
|
||||
|
||||
|
||||
def _make_email_method() -> EmailDeliveryMethod:
|
||||
return EmailDeliveryMethod(
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[ExternalRecipient(email="tester@example.com")],
|
||||
),
|
||||
subject="Test subject",
|
||||
body="Test body",
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_email_delivery_test_handler_rejects_when_feature_disabled(monkeypatch: pytest.MonkeyPatch):
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_features",
|
||||
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=False),
|
||||
)
|
||||
|
||||
handler = EmailDeliveryTestHandler(session_factory=object())
|
||||
context = DeliveryTestContext(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
rendered_content="content",
|
||||
)
|
||||
method = _make_email_method()
|
||||
|
||||
with pytest.raises(DeliveryTestError, match="Email delivery is not available"):
|
||||
handler.send_test(context=context, method=method)
|
||||
|
||||
|
||||
def test_email_delivery_test_handler_replaces_body_variables(monkeypatch: pytest.MonkeyPatch):
|
||||
class DummyMail:
|
||||
def __init__(self):
|
||||
self.sent: list[dict[str, str]] = []
|
||||
|
||||
def is_inited(self) -> bool:
|
||||
return True
|
||||
|
||||
def send(self, *, to: str, subject: str, html: str):
|
||||
self.sent.append({"to": to, "subject": subject, "html": html})
|
||||
|
||||
mail = DummyMail()
|
||||
monkeypatch.setattr(service_module, "mail", mail)
|
||||
monkeypatch.setattr(service_module, "render_email_template", lambda template, _substitutions: template)
|
||||
monkeypatch.setattr(
|
||||
service_module.FeatureService,
|
||||
"get_features",
|
||||
lambda _tenant_id: SimpleNamespace(human_input_email_delivery_enabled=True),
|
||||
)
|
||||
|
||||
handler = EmailDeliveryTestHandler(session_factory=object())
|
||||
handler._resolve_recipients = lambda **_kwargs: ["tester@example.com"] # type: ignore[assignment]
|
||||
|
||||
method = EmailDeliveryMethod(
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(whole_workspace=False, items=[ExternalRecipient(email="tester@example.com")]),
|
||||
subject="Subject",
|
||||
body="Value {{#node1.value#}}",
|
||||
)
|
||||
)
|
||||
variable_pool = VariablePool()
|
||||
variable_pool.add(["node1", "value"], "OK")
|
||||
context = DeliveryTestContext(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
rendered_content="content",
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
handler.send_test(context=context, method=method)
|
||||
|
||||
assert mail.sent[0]["html"] == "Value OK"
|
||||
290
api/tests/unit_tests/services/test_human_input_service.py
Normal file
290
api/tests/unit_tests/services/test_human_input_service.py
Normal file
@@ -0,0 +1,290 @@
|
||||
import dataclasses
|
||||
from datetime import datetime, timedelta
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
import services.human_input_service as human_input_service_module
|
||||
from core.repositories.human_input_repository import (
|
||||
HumanInputFormRecord,
|
||||
HumanInputFormSubmissionRepository,
|
||||
)
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
FormDefinition,
|
||||
FormInput,
|
||||
UserAction,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import FormInputType, HumanInputFormKind, HumanInputFormStatus
|
||||
from models.human_input import RecipientType
|
||||
from services.human_input_service import Form, FormExpiredError, HumanInputService, InvalidFormDataError
|
||||
from tasks.app_generate.workflow_execute_task import WORKFLOW_BASED_APP_EXECUTION_QUEUE
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_session_factory():
|
||||
session = MagicMock()
|
||||
session_cm = MagicMock()
|
||||
session_cm.__enter__.return_value = session
|
||||
session_cm.__exit__.return_value = None
|
||||
|
||||
factory = MagicMock()
|
||||
factory.return_value = session_cm
|
||||
return factory, session
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_form_record():
|
||||
return HumanInputFormRecord(
|
||||
form_id="form-id",
|
||||
workflow_run_id="workflow-run-id",
|
||||
node_id="node-id",
|
||||
tenant_id="tenant-id",
|
||||
app_id="app-id",
|
||||
form_kind=HumanInputFormKind.RUNTIME,
|
||||
definition=FormDefinition(
|
||||
form_content="hello",
|
||||
inputs=[],
|
||||
user_actions=[UserAction(id="submit", title="Submit")],
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=datetime.utcnow() + timedelta(hours=1),
|
||||
),
|
||||
rendered_content="<p>hello</p>",
|
||||
created_at=datetime.utcnow(),
|
||||
expiration_time=datetime.utcnow() + timedelta(hours=1),
|
||||
status=HumanInputFormStatus.WAITING,
|
||||
selected_action_id=None,
|
||||
submitted_data=None,
|
||||
submitted_at=None,
|
||||
submission_user_id=None,
|
||||
submission_end_user_id=None,
|
||||
completed_by_recipient_id=None,
|
||||
recipient_id="recipient-id",
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
access_token="token",
|
||||
)
|
||||
|
||||
|
||||
def test_enqueue_resume_dispatches_task_for_workflow(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.app_id = "app-id"
|
||||
|
||||
workflow_run_repo = MagicMock()
|
||||
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
|
||||
mocker.patch(
|
||||
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
|
||||
return_value=workflow_run_repo,
|
||||
)
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = "workflow"
|
||||
session.execute.return_value.scalar_one_or_none.return_value = app
|
||||
|
||||
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
|
||||
|
||||
service.enqueue_resume("workflow-run-id")
|
||||
|
||||
resume_task.apply_async.assert_called_once()
|
||||
call_kwargs = resume_task.apply_async.call_args.kwargs
|
||||
assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE
|
||||
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
|
||||
|
||||
|
||||
def test_ensure_form_active_respects_global_timeout(monkeypatch, sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
expired_record = dataclasses.replace(
|
||||
sample_form_record,
|
||||
created_at=datetime.utcnow() - timedelta(hours=2),
|
||||
expiration_time=datetime.utcnow() + timedelta(hours=2),
|
||||
)
|
||||
monkeypatch.setattr(human_input_service_module.dify_config, "HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS", 3600)
|
||||
|
||||
with pytest.raises(FormExpiredError):
|
||||
service.ensure_form_active(Form(expired_record))
|
||||
|
||||
|
||||
def test_enqueue_resume_dispatches_task_for_advanced_chat(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.app_id = "app-id"
|
||||
|
||||
workflow_run_repo = MagicMock()
|
||||
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
|
||||
mocker.patch(
|
||||
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
|
||||
return_value=workflow_run_repo,
|
||||
)
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = "advanced-chat"
|
||||
session.execute.return_value.scalar_one_or_none.return_value = app
|
||||
|
||||
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
|
||||
|
||||
service.enqueue_resume("workflow-run-id")
|
||||
|
||||
resume_task.apply_async.assert_called_once()
|
||||
call_kwargs = resume_task.apply_async.call_args.kwargs
|
||||
assert call_kwargs["queue"] == WORKFLOW_BASED_APP_EXECUTION_QUEUE
|
||||
assert call_kwargs["kwargs"]["payload"]["workflow_run_id"] == "workflow-run-id"
|
||||
|
||||
|
||||
def test_enqueue_resume_skips_unsupported_app_mode(mocker, mock_session_factory):
|
||||
session_factory, session = mock_session_factory
|
||||
service = HumanInputService(session_factory)
|
||||
|
||||
workflow_run = MagicMock()
|
||||
workflow_run.app_id = "app-id"
|
||||
|
||||
workflow_run_repo = MagicMock()
|
||||
workflow_run_repo.get_workflow_run_by_id_without_tenant.return_value = workflow_run
|
||||
mocker.patch(
|
||||
"services.human_input_service.DifyAPIRepositoryFactory.create_api_workflow_run_repository",
|
||||
return_value=workflow_run_repo,
|
||||
)
|
||||
|
||||
app = MagicMock()
|
||||
app.mode = "completion"
|
||||
session.execute.return_value.scalar_one_or_none.return_value = app
|
||||
|
||||
resume_task = mocker.patch("services.human_input_service.resume_app_execution")
|
||||
|
||||
service.enqueue_resume("workflow-run-id")
|
||||
|
||||
resume_task.apply_async.assert_not_called()
|
||||
|
||||
|
||||
def test_get_form_definition_by_token_for_console_uses_repository(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
console_record = dataclasses.replace(sample_form_record, recipient_type=RecipientType.CONSOLE)
|
||||
repo.get_by_token.return_value = console_record
|
||||
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
form = service.get_form_definition_by_token_for_console("token")
|
||||
|
||||
repo.get_by_token.assert_called_once_with("token")
|
||||
assert form is not None
|
||||
assert form.get_definition() == console_record.definition
|
||||
|
||||
|
||||
def test_submit_form_by_token_calls_repository_and_enqueue(sample_form_record, mock_session_factory, mocker):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_token.return_value = sample_form_record
|
||||
repo.mark_submitted.return_value = sample_form_record
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
enqueue_spy = mocker.patch.object(service, "enqueue_resume")
|
||||
|
||||
service.submit_form_by_token(
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
form_token="token",
|
||||
selected_action_id="submit",
|
||||
form_data={"field": "value"},
|
||||
submission_end_user_id="end-user-id",
|
||||
)
|
||||
|
||||
repo.get_by_token.assert_called_once_with("token")
|
||||
repo.mark_submitted.assert_called_once()
|
||||
call_kwargs = repo.mark_submitted.call_args.kwargs
|
||||
assert call_kwargs["form_id"] == sample_form_record.form_id
|
||||
assert call_kwargs["recipient_id"] == sample_form_record.recipient_id
|
||||
assert call_kwargs["selected_action_id"] == "submit"
|
||||
assert call_kwargs["form_data"] == {"field": "value"}
|
||||
assert call_kwargs["submission_end_user_id"] == "end-user-id"
|
||||
enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id)
|
||||
|
||||
|
||||
def test_submit_form_by_token_skips_enqueue_for_delivery_test(sample_form_record, mock_session_factory, mocker):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
test_record = dataclasses.replace(
|
||||
sample_form_record,
|
||||
form_kind=HumanInputFormKind.DELIVERY_TEST,
|
||||
workflow_run_id=None,
|
||||
)
|
||||
repo.get_by_token.return_value = test_record
|
||||
repo.mark_submitted.return_value = test_record
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
enqueue_spy = mocker.patch.object(service, "enqueue_resume")
|
||||
|
||||
service.submit_form_by_token(
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
form_token="token",
|
||||
selected_action_id="submit",
|
||||
form_data={"field": "value"},
|
||||
)
|
||||
|
||||
enqueue_spy.assert_not_called()
|
||||
|
||||
|
||||
def test_submit_form_by_token_passes_submission_user_id(sample_form_record, mock_session_factory, mocker):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_token.return_value = sample_form_record
|
||||
repo.mark_submitted.return_value = sample_form_record
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
enqueue_spy = mocker.patch.object(service, "enqueue_resume")
|
||||
|
||||
service.submit_form_by_token(
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
form_token="token",
|
||||
selected_action_id="submit",
|
||||
form_data={"field": "value"},
|
||||
submission_user_id="account-id",
|
||||
)
|
||||
|
||||
call_kwargs = repo.mark_submitted.call_args.kwargs
|
||||
assert call_kwargs["submission_user_id"] == "account-id"
|
||||
assert call_kwargs["submission_end_user_id"] is None
|
||||
enqueue_spy.assert_called_once_with(sample_form_record.workflow_run_id)
|
||||
|
||||
|
||||
def test_submit_form_by_token_invalid_action(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
repo.get_by_token.return_value = dataclasses.replace(sample_form_record)
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
|
||||
with pytest.raises(InvalidFormDataError) as exc_info:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
form_token="token",
|
||||
selected_action_id="invalid",
|
||||
form_data={},
|
||||
)
|
||||
|
||||
assert "Invalid action" in str(exc_info.value)
|
||||
repo.mark_submitted.assert_not_called()
|
||||
|
||||
|
||||
def test_submit_form_by_token_missing_inputs(sample_form_record, mock_session_factory):
|
||||
session_factory, _ = mock_session_factory
|
||||
repo = MagicMock(spec=HumanInputFormSubmissionRepository)
|
||||
|
||||
definition_with_input = FormDefinition(
|
||||
form_content="hello",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="content")],
|
||||
user_actions=sample_form_record.definition.user_actions,
|
||||
rendered_content="<p>hello</p>",
|
||||
expiration_time=sample_form_record.expiration_time,
|
||||
)
|
||||
form_with_input = dataclasses.replace(sample_form_record, definition=definition_with_input)
|
||||
repo.get_by_token.return_value = form_with_input
|
||||
service = HumanInputService(session_factory, form_repository=repo)
|
||||
|
||||
with pytest.raises(InvalidFormDataError) as exc_info:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
form_token="token",
|
||||
selected_action_id="submit",
|
||||
form_data={},
|
||||
)
|
||||
|
||||
assert "Missing required inputs" in str(exc_info.value)
|
||||
repo.mark_submitted.assert_not_called()
|
||||
@@ -0,0 +1,61 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
|
||||
from core.entities.execution_extra_content import HumanInputContent, HumanInputFormSubmissionData
|
||||
from services import message_service
|
||||
|
||||
|
||||
class _FakeMessage:
|
||||
def __init__(self, message_id: str):
|
||||
self.id = message_id
|
||||
self.extra_contents = None
|
||||
|
||||
def set_extra_contents(self, contents):
|
||||
self.extra_contents = contents
|
||||
|
||||
|
||||
def test_attach_message_extra_contents_assigns_serialized_payload(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
messages = [_FakeMessage("msg-1"), _FakeMessage("msg-2")]
|
||||
repo = type(
|
||||
"Repo",
|
||||
(),
|
||||
{
|
||||
"get_by_message_ids": lambda _self, message_ids: [
|
||||
[
|
||||
HumanInputContent(
|
||||
workflow_run_id="workflow-run-1",
|
||||
submitted=True,
|
||||
form_submission_data=HumanInputFormSubmissionData(
|
||||
node_id="node-1",
|
||||
node_title="Approval",
|
||||
rendered_content="Rendered",
|
||||
action_id="approve",
|
||||
action_text="Approve",
|
||||
),
|
||||
)
|
||||
],
|
||||
[],
|
||||
]
|
||||
},
|
||||
)()
|
||||
|
||||
monkeypatch.setattr(message_service, "_create_execution_extra_content_repository", lambda: repo)
|
||||
|
||||
message_service.attach_message_extra_contents(messages)
|
||||
|
||||
assert messages[0].extra_contents == [
|
||||
{
|
||||
"type": "human_input",
|
||||
"workflow_run_id": "workflow-run-1",
|
||||
"submitted": True,
|
||||
"form_submission_data": {
|
||||
"node_id": "node-1",
|
||||
"node_title": "Approval",
|
||||
"rendered_content": "Rendered",
|
||||
"action_id": "approve",
|
||||
"action_text": "Approve",
|
||||
},
|
||||
}
|
||||
]
|
||||
assert messages[1].extra_contents == []
|
||||
@@ -35,7 +35,6 @@ class TestDataFactory:
|
||||
app_id: str = "app-789",
|
||||
workflow_id: str = "workflow-101",
|
||||
status: str | WorkflowExecutionStatus = "paused",
|
||||
pause_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> MagicMock:
|
||||
"""Create a mock WorkflowRun object."""
|
||||
@@ -45,7 +44,6 @@ class TestDataFactory:
|
||||
mock_run.app_id = app_id
|
||||
mock_run.workflow_id = workflow_id
|
||||
mock_run.status = status
|
||||
mock_run.pause_id = pause_id
|
||||
|
||||
for key, value in kwargs.items():
|
||||
setattr(mock_run, key, value)
|
||||
|
||||
@@ -0,0 +1,162 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.tools.entities.tool_entities import ToolParameter, WorkflowToolParameterConfiguration
|
||||
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
|
||||
from models.model import App
|
||||
from models.tools import WorkflowToolProvider
|
||||
from services.tools import workflow_tools_manage_service
|
||||
|
||||
|
||||
class DummyWorkflow:
|
||||
def __init__(self, graph_dict: dict, version: str = "1.0.0") -> None:
|
||||
self._graph_dict = graph_dict
|
||||
self.version = version
|
||||
|
||||
@property
|
||||
def graph_dict(self) -> dict:
|
||||
return self._graph_dict
|
||||
|
||||
|
||||
class FakeQuery:
|
||||
def __init__(self, result):
|
||||
self._result = result
|
||||
|
||||
def where(self, *args, **kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return self._result
|
||||
|
||||
|
||||
class DummySession:
|
||||
def __init__(self) -> None:
|
||||
self.added: list[object] = []
|
||||
|
||||
def __enter__(self) -> "DummySession":
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> bool:
|
||||
return False
|
||||
|
||||
def add(self, obj) -> None:
|
||||
self.added.append(obj)
|
||||
|
||||
def begin(self):
|
||||
return DummyBegin(self)
|
||||
|
||||
|
||||
class DummyBegin:
|
||||
def __init__(self, session: DummySession) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> DummySession:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class DummySessionContext:
|
||||
def __init__(self, session: DummySession) -> None:
|
||||
self._session = session
|
||||
|
||||
def __enter__(self) -> DummySession:
|
||||
return self._session
|
||||
|
||||
def __exit__(self, exc_type, exc, tb) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
class DummySessionFactory:
|
||||
def __init__(self, session: DummySession) -> None:
|
||||
self._session = session
|
||||
|
||||
def create_session(self) -> DummySessionContext:
|
||||
return DummySessionContext(self._session)
|
||||
|
||||
|
||||
def _build_fake_session(app) -> SimpleNamespace:
|
||||
def query(model):
|
||||
if model is WorkflowToolProvider:
|
||||
return FakeQuery(None)
|
||||
if model is App:
|
||||
return FakeQuery(app)
|
||||
return FakeQuery(None)
|
||||
|
||||
return SimpleNamespace(query=query)
|
||||
|
||||
|
||||
def _build_parameters() -> list[WorkflowToolParameterConfiguration]:
|
||||
return [
|
||||
WorkflowToolParameterConfiguration(name="input", description="input", form=ToolParameter.ToolParameterForm.LLM),
|
||||
]
|
||||
|
||||
|
||||
def test_create_workflow_tool_rejects_human_input_nodes(monkeypatch):
|
||||
workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "human-input"}}]})
|
||||
app = SimpleNamespace(workflow=workflow)
|
||||
|
||||
fake_session = _build_fake_session(app)
|
||||
monkeypatch.setattr(workflow_tools_manage_service.db, "session", fake_session)
|
||||
|
||||
mock_from_db = MagicMock()
|
||||
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db)
|
||||
mock_invalidate = MagicMock()
|
||||
|
||||
with pytest.raises(WorkflowToolHumanInputNotSupportedError) as exc_info:
|
||||
workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool(
|
||||
user_id="user-id",
|
||||
tenant_id="tenant-id",
|
||||
workflow_app_id="app-id",
|
||||
name="tool_name",
|
||||
label="Tool",
|
||||
icon={"type": "emoji", "emoji": "tool"},
|
||||
description="desc",
|
||||
parameters=_build_parameters(),
|
||||
)
|
||||
|
||||
assert exc_info.value.error_code == "workflow_tool_human_input_not_supported"
|
||||
mock_from_db.assert_not_called()
|
||||
mock_invalidate.assert_not_called()
|
||||
|
||||
|
||||
def test_create_workflow_tool_success(monkeypatch):
|
||||
workflow = DummyWorkflow(graph_dict={"nodes": [{"id": "node_1", "data": {"type": "start"}}]})
|
||||
app = SimpleNamespace(workflow=workflow)
|
||||
|
||||
fake_db = MagicMock()
|
||||
fake_session = _build_fake_session(app)
|
||||
fake_db.session = fake_session
|
||||
monkeypatch.setattr(workflow_tools_manage_service, "db", fake_db)
|
||||
|
||||
dummy_session = DummySession()
|
||||
monkeypatch.setattr(workflow_tools_manage_service, "Session", lambda *_, **__: dummy_session)
|
||||
|
||||
mock_from_db = MagicMock()
|
||||
monkeypatch.setattr(workflow_tools_manage_service.WorkflowToolProviderController, "from_db", mock_from_db)
|
||||
|
||||
icon = {"type": "emoji", "emoji": "tool"}
|
||||
|
||||
result = workflow_tools_manage_service.WorkflowToolManageService.create_workflow_tool(
|
||||
user_id="user-id",
|
||||
tenant_id="tenant-id",
|
||||
workflow_app_id="app-id",
|
||||
name="tool_name",
|
||||
label="Tool",
|
||||
icon=icon,
|
||||
description="desc",
|
||||
parameters=_build_parameters(),
|
||||
)
|
||||
|
||||
assert result == {"result": "success"}
|
||||
assert len(dummy_session.added) == 1
|
||||
created_provider = dummy_session.added[0]
|
||||
assert created_provider.name == "tool_name"
|
||||
assert created_provider.label == "Tool"
|
||||
assert created_provider.icon == json.dumps(icon)
|
||||
assert created_provider.version == workflow.version
|
||||
mock_from_db.assert_called_once()
|
||||
@@ -0,0 +1,226 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import queue
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from threading import Event
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.app_config.entities import WorkflowUIBasedAppConfig
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.layers.pause_state_persist_layer import WorkflowResumptionContext, _WorkflowGenerateEntityWrapper
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionStatus
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.api_workflow_node_execution_repository import WorkflowNodeExecutionSnapshot
|
||||
from repositories.entities.workflow_pause import WorkflowPauseEntity
|
||||
from services.workflow_event_snapshot_service import (
|
||||
BufferState,
|
||||
MessageContext,
|
||||
_build_snapshot_events,
|
||||
_resolve_task_id,
|
||||
)
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class _FakePauseEntity(WorkflowPauseEntity):
|
||||
pause_id: str
|
||||
workflow_run_id: str
|
||||
paused_at_value: datetime
|
||||
pause_reasons: Sequence[HumanInputRequired]
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self.pause_id
|
||||
|
||||
@property
|
||||
def workflow_execution_id(self) -> str:
|
||||
return self.workflow_run_id
|
||||
|
||||
def get_state(self) -> bytes:
|
||||
raise AssertionError("state is not required for snapshot tests")
|
||||
|
||||
@property
|
||||
def resumed_at(self) -> datetime | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def paused_at(self) -> datetime:
|
||||
return self.paused_at_value
|
||||
|
||||
def get_pause_reasons(self) -> Sequence[HumanInputRequired]:
|
||||
return self.pause_reasons
|
||||
|
||||
|
||||
def _build_workflow_run(status: WorkflowExecutionStatus) -> WorkflowRun:
|
||||
return WorkflowRun(
|
||||
id="run-1",
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
workflow_id="workflow-1",
|
||||
type="workflow",
|
||||
triggered_from="app-run",
|
||||
version="v1",
|
||||
graph=None,
|
||||
inputs=json.dumps({"input": "value"}),
|
||||
status=status,
|
||||
outputs=json.dumps({}),
|
||||
error=None,
|
||||
elapsed_time=0.0,
|
||||
total_tokens=0,
|
||||
total_steps=0,
|
||||
created_by_role=CreatorUserRole.END_USER,
|
||||
created_by="user-1",
|
||||
created_at=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
)
|
||||
|
||||
|
||||
def _build_snapshot(status: WorkflowNodeExecutionStatus) -> WorkflowNodeExecutionSnapshot:
|
||||
created_at = datetime(2024, 1, 1, tzinfo=UTC)
|
||||
finished_at = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC)
|
||||
return WorkflowNodeExecutionSnapshot(
|
||||
execution_id="exec-1",
|
||||
node_id="node-1",
|
||||
node_type="human-input",
|
||||
title="Human Input",
|
||||
index=1,
|
||||
status=status.value,
|
||||
elapsed_time=0.5,
|
||||
created_at=created_at,
|
||||
finished_at=finished_at,
|
||||
iteration_id=None,
|
||||
loop_id=None,
|
||||
)
|
||||
|
||||
|
||||
def _build_resumption_context(task_id: str) -> WorkflowResumptionContext:
|
||||
app_config = WorkflowUIBasedAppConfig(
|
||||
tenant_id="tenant-1",
|
||||
app_id="app-1",
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
workflow_id="workflow-1",
|
||||
)
|
||||
generate_entity = WorkflowAppGenerateEntity(
|
||||
task_id=task_id,
|
||||
app_config=app_config,
|
||||
inputs={},
|
||||
files=[],
|
||||
user_id="user-1",
|
||||
stream=True,
|
||||
invoke_from=InvokeFrom.EXPLORE,
|
||||
call_depth=0,
|
||||
workflow_execution_id="run-1",
|
||||
)
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=0.0)
|
||||
runtime_state.register_paused_node("node-1")
|
||||
runtime_state.outputs = {"result": "value"}
|
||||
wrapper = _WorkflowGenerateEntityWrapper(entity=generate_entity)
|
||||
return WorkflowResumptionContext(
|
||||
generate_entity=wrapper,
|
||||
serialized_graph_runtime_state=runtime_state.dumps(),
|
||||
)
|
||||
|
||||
|
||||
def test_build_snapshot_events_includes_pause_event() -> None:
|
||||
workflow_run = _build_workflow_run(WorkflowExecutionStatus.PAUSED)
|
||||
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.PAUSED)
|
||||
resumption_context = _build_resumption_context("task-ctx")
|
||||
pause_entity = _FakePauseEntity(
|
||||
pause_id="pause-1",
|
||||
workflow_run_id="run-1",
|
||||
paused_at_value=datetime(2024, 1, 1, tzinfo=UTC),
|
||||
pause_reasons=[
|
||||
HumanInputRequired(
|
||||
form_id="form-1",
|
||||
form_content="content",
|
||||
node_id="node-1",
|
||||
node_title="Human Input",
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
events = _build_snapshot_events(
|
||||
workflow_run=workflow_run,
|
||||
node_snapshots=[snapshot],
|
||||
task_id="task-ctx",
|
||||
message_context=None,
|
||||
pause_entity=pause_entity,
|
||||
resumption_context=resumption_context,
|
||||
)
|
||||
|
||||
assert [event["event"] for event in events] == [
|
||||
"workflow_started",
|
||||
"node_started",
|
||||
"node_finished",
|
||||
"workflow_paused",
|
||||
]
|
||||
assert events[2]["data"]["status"] == WorkflowNodeExecutionStatus.PAUSED.value
|
||||
pause_data = events[-1]["data"]
|
||||
assert pause_data["paused_nodes"] == ["node-1"]
|
||||
assert pause_data["outputs"] == {"result": "value"}
|
||||
assert pause_data["status"] == WorkflowExecutionStatus.PAUSED.value
|
||||
assert pause_data["created_at"] == int(workflow_run.created_at.timestamp())
|
||||
assert pause_data["elapsed_time"] == workflow_run.elapsed_time
|
||||
assert pause_data["total_tokens"] == workflow_run.total_tokens
|
||||
assert pause_data["total_steps"] == workflow_run.total_steps
|
||||
|
||||
|
||||
def test_build_snapshot_events_applies_message_context() -> None:
|
||||
workflow_run = _build_workflow_run(WorkflowExecutionStatus.RUNNING)
|
||||
snapshot = _build_snapshot(WorkflowNodeExecutionStatus.SUCCEEDED)
|
||||
message_context = MessageContext(
|
||||
conversation_id="conv-1",
|
||||
message_id="msg-1",
|
||||
created_at=1700000000,
|
||||
answer="snapshot message",
|
||||
)
|
||||
|
||||
events = _build_snapshot_events(
|
||||
workflow_run=workflow_run,
|
||||
node_snapshots=[snapshot],
|
||||
task_id="task-1",
|
||||
message_context=message_context,
|
||||
pause_entity=None,
|
||||
resumption_context=None,
|
||||
)
|
||||
|
||||
assert [event["event"] for event in events] == [
|
||||
"workflow_started",
|
||||
"message_replace",
|
||||
"node_started",
|
||||
"node_finished",
|
||||
]
|
||||
assert events[1]["answer"] == "snapshot message"
|
||||
for event in events:
|
||||
assert event["conversation_id"] == "conv-1"
|
||||
assert event["message_id"] == "msg-1"
|
||||
assert event["created_at"] == 1700000000
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("context_task_id", "buffered_task_id", "expected"),
|
||||
[
|
||||
("task-ctx", "task-buffer", "task-ctx"),
|
||||
(None, "task-buffer", "task-buffer"),
|
||||
(None, None, "run-1"),
|
||||
],
|
||||
)
|
||||
def test_resolve_task_id_priority(context_task_id, buffered_task_id, expected) -> None:
|
||||
resumption_context = _build_resumption_context(context_task_id) if context_task_id else None
|
||||
buffer_state = BufferState(
|
||||
queue=queue.Queue(),
|
||||
stop_event=Event(),
|
||||
done_event=Event(),
|
||||
task_id_ready=Event(),
|
||||
task_id_hint=buffered_task_id,
|
||||
)
|
||||
if buffered_task_id:
|
||||
buffer_state.task_id_ready.set()
|
||||
task_id = _resolve_task_id(resumption_context, buffer_state, "run-1", wait_timeout=0.0)
|
||||
assert task_id == expected
|
||||
@@ -0,0 +1,184 @@
|
||||
import uuid
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
EmailDeliveryConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
)
|
||||
from services import workflow_service as workflow_service_module
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
def _make_service() -> WorkflowService:
|
||||
return WorkflowService(session_maker=sessionmaker())
|
||||
|
||||
|
||||
def _build_node_config(delivery_methods):
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
delivery_methods=delivery_methods,
|
||||
form_content="Test content",
|
||||
inputs=[],
|
||||
user_actions=[],
|
||||
).model_dump(mode="json")
|
||||
node_data["type"] = NodeType.HUMAN_INPUT.value
|
||||
return {"id": "node-1", "data": node_data}
|
||||
|
||||
|
||||
def _make_email_method(enabled: bool = True, debug_mode: bool = False) -> EmailDeliveryMethod:
|
||||
return EmailDeliveryMethod(
|
||||
id=uuid.uuid4(),
|
||||
enabled=enabled,
|
||||
config=EmailDeliveryConfig(
|
||||
recipients=EmailRecipients(
|
||||
whole_workspace=False,
|
||||
items=[ExternalRecipient(email="tester@example.com")],
|
||||
),
|
||||
subject="Test subject",
|
||||
body="Test body",
|
||||
debug_mode=debug_mode,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def test_human_input_delivery_requires_draft_workflow():
|
||||
service = _make_service()
|
||||
service.get_draft_workflow = MagicMock(return_value=None) # type: ignore[method-assign]
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
|
||||
account = SimpleNamespace(id="account-1")
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not initialized"):
|
||||
service.test_human_input_delivery(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-1",
|
||||
delivery_method_id="delivery-1",
|
||||
)
|
||||
|
||||
|
||||
def test_human_input_delivery_allows_disabled_method(monkeypatch: pytest.MonkeyPatch):
|
||||
service = _make_service()
|
||||
delivery_method = _make_email_method(enabled=False)
|
||||
node_config = _build_node_config([delivery_method])
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = node_config
|
||||
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
|
||||
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined]
|
||||
node_stub = MagicMock()
|
||||
node_stub._render_form_content_before_submission.return_value = "rendered"
|
||||
node_stub._resolve_default_values.return_value = {}
|
||||
service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined]
|
||||
service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined]
|
||||
return_value=("form-1", {})
|
||||
)
|
||||
|
||||
test_service_instance = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
workflow_service_module,
|
||||
"HumanInputDeliveryTestService",
|
||||
MagicMock(return_value=test_service_instance),
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
|
||||
account = SimpleNamespace(id="account-1")
|
||||
|
||||
service.test_human_input_delivery(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-1",
|
||||
delivery_method_id=str(delivery_method.id),
|
||||
)
|
||||
|
||||
test_service_instance.send_test.assert_called_once()
|
||||
|
||||
|
||||
def test_human_input_delivery_dispatches_to_test_service(monkeypatch: pytest.MonkeyPatch):
|
||||
service = _make_service()
|
||||
delivery_method = _make_email_method(enabled=True)
|
||||
node_config = _build_node_config([delivery_method])
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = node_config
|
||||
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
|
||||
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined]
|
||||
node_stub = MagicMock()
|
||||
node_stub._render_form_content_before_submission.return_value = "rendered"
|
||||
node_stub._resolve_default_values.return_value = {}
|
||||
service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined]
|
||||
service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined]
|
||||
return_value=("form-1", {})
|
||||
)
|
||||
|
||||
test_service_instance = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
workflow_service_module,
|
||||
"HumanInputDeliveryTestService",
|
||||
MagicMock(return_value=test_service_instance),
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
|
||||
account = SimpleNamespace(id="account-1")
|
||||
|
||||
service.test_human_input_delivery(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-1",
|
||||
delivery_method_id=str(delivery_method.id),
|
||||
inputs={"#node-1.output#": "value"},
|
||||
)
|
||||
|
||||
pool_args = service._build_human_input_variable_pool.call_args.kwargs
|
||||
assert pool_args["manual_inputs"] == {"#node-1.output#": "value"}
|
||||
test_service_instance.send_test.assert_called_once()
|
||||
|
||||
|
||||
def test_human_input_delivery_debug_mode_overrides_recipients(monkeypatch: pytest.MonkeyPatch):
|
||||
service = _make_service()
|
||||
delivery_method = _make_email_method(enabled=True, debug_mode=True)
|
||||
node_config = _build_node_config([delivery_method])
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = node_config
|
||||
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
|
||||
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[attr-defined]
|
||||
node_stub = MagicMock()
|
||||
node_stub._render_form_content_before_submission.return_value = "rendered"
|
||||
node_stub._resolve_default_values.return_value = {}
|
||||
service._build_human_input_node = MagicMock(return_value=node_stub) # type: ignore[attr-defined]
|
||||
service._create_human_input_delivery_test_form = MagicMock( # type: ignore[attr-defined]
|
||||
return_value=("form-1", {})
|
||||
)
|
||||
|
||||
test_service_instance = MagicMock()
|
||||
monkeypatch.setattr(
|
||||
workflow_service_module,
|
||||
"HumanInputDeliveryTestService",
|
||||
MagicMock(return_value=test_service_instance),
|
||||
)
|
||||
|
||||
app_model = SimpleNamespace(tenant_id="tenant-1", id="app-1")
|
||||
account = SimpleNamespace(id="account-1")
|
||||
|
||||
service.test_human_input_delivery(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-1",
|
||||
delivery_method_id=str(delivery_method.id),
|
||||
)
|
||||
|
||||
test_service_instance.send_test.assert_called_once()
|
||||
sent_method = test_service_instance.send_test.call_args.kwargs["method"]
|
||||
assert isinstance(sent_method, EmailDeliveryMethod)
|
||||
assert sent_method.config.debug_mode is True
|
||||
assert sent_method.config.recipients.whole_workspace is False
|
||||
assert len(sent_method.config.recipients.items) == 1
|
||||
recipient = sent_method.config.recipients.items[0]
|
||||
assert isinstance(recipient, MemberRecipient)
|
||||
assert recipient.user_id == account.id
|
||||
@@ -5,6 +5,7 @@ from uuid import uuid4
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
from repositories.sqlalchemy_api_workflow_node_execution_repository import (
|
||||
DifyAPISQLAlchemyWorkflowNodeExecutionRepository,
|
||||
@@ -52,6 +53,9 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
|
||||
call_args = mock_session.scalar.call_args[0][0]
|
||||
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
|
||||
|
||||
compiled = call_args.compile()
|
||||
assert WorkflowNodeExecutionStatus.PAUSED in compiled.params.values()
|
||||
|
||||
def test_get_node_last_execution_not_found(self, repository):
|
||||
"""Test getting the last execution for a node when it doesn't exist."""
|
||||
# Arrange
|
||||
@@ -71,28 +75,6 @@ class TestSQLAlchemyWorkflowNodeExecutionServiceRepository:
|
||||
assert result is None
|
||||
mock_session.scalar.assert_called_once()
|
||||
|
||||
def test_get_executions_by_workflow_run(self, repository, mock_execution):
|
||||
"""Test getting all executions for a workflow run."""
|
||||
# Arrange
|
||||
mock_session = MagicMock(spec=Session)
|
||||
repository._session_maker.return_value.__enter__.return_value = mock_session
|
||||
executions = [mock_execution]
|
||||
mock_session.execute.return_value.scalars.return_value.all.return_value = executions
|
||||
|
||||
# Act
|
||||
result = repository.get_executions_by_workflow_run(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-456",
|
||||
workflow_run_id="run-101",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == executions
|
||||
mock_session.execute.assert_called_once()
|
||||
# Verify the query was constructed correctly
|
||||
call_args = mock_session.execute.call_args[0][0]
|
||||
assert hasattr(call_args, "compile") # It's a SQLAlchemy statement
|
||||
|
||||
def test_get_executions_by_workflow_run_empty(self, repository):
|
||||
"""Test getting executions for a workflow run when none exist."""
|
||||
# Arrange
|
||||
|
||||
@@ -1,9 +1,15 @@
|
||||
from contextlib import nullcontext
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.human_input.entities import FormInput, HumanInputNodeData, UserAction
|
||||
from core.workflow.nodes.human_input.enums import FormInputType
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services import workflow_service as workflow_service_module
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
|
||||
@@ -161,3 +167,120 @@ class TestWorkflowService:
|
||||
assert workflows == []
|
||||
assert has_more is False
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
def test_submit_human_input_form_preview_uses_rendered_content(
|
||||
self, workflow_service: WorkflowService, monkeypatch: pytest.MonkeyPatch
|
||||
) -> None:
|
||||
service = workflow_service
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="<p>{{#$output.name#}}</p>",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
)
|
||||
node = MagicMock()
|
||||
node.node_data = node_data
|
||||
node.render_form_content_before_submission.return_value = "<p>preview</p>"
|
||||
node.render_form_content_with_outputs.return_value = "<p>rendered</p>"
|
||||
|
||||
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign]
|
||||
service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign]
|
||||
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
|
||||
workflow.get_enclosing_node_type_and_id.return_value = None
|
||||
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
|
||||
|
||||
saved_outputs: dict[str, object] = {}
|
||||
|
||||
class DummySession:
|
||||
def __init__(self, *args, **kwargs):
|
||||
self.commit = MagicMock()
|
||||
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc, tb):
|
||||
return False
|
||||
|
||||
def begin(self):
|
||||
return nullcontext()
|
||||
|
||||
class DummySaver:
|
||||
def __init__(self, *args, **kwargs):
|
||||
pass
|
||||
|
||||
def save(self, outputs, process_data):
|
||||
saved_outputs.update(outputs)
|
||||
|
||||
monkeypatch.setattr(workflow_service_module, "Session", DummySession)
|
||||
monkeypatch.setattr(workflow_service_module, "DraftVariableSaver", DummySaver)
|
||||
monkeypatch.setattr(workflow_service_module, "db", SimpleNamespace(engine=MagicMock()))
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
account = SimpleNamespace(id="account-1")
|
||||
|
||||
result = service.submit_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-1",
|
||||
form_inputs={"name": "Ada", "extra": "ignored"},
|
||||
inputs={"#node-0.result#": "LLM output"},
|
||||
action="approve",
|
||||
)
|
||||
|
||||
service._build_human_input_variable_pool.assert_called_once_with(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
node_config={"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}},
|
||||
manual_inputs={"#node-0.result#": "LLM output"},
|
||||
)
|
||||
|
||||
node.render_form_content_with_outputs.assert_called_once()
|
||||
called_args = node.render_form_content_with_outputs.call_args.args
|
||||
assert called_args[0] == "<p>preview</p>"
|
||||
assert called_args[2] == node_data.outputs_field_names()
|
||||
rendered_outputs = called_args[1]
|
||||
assert rendered_outputs["name"] == "Ada"
|
||||
assert rendered_outputs["extra"] == "ignored"
|
||||
assert "extra" in saved_outputs
|
||||
assert "extra" in result
|
||||
assert saved_outputs["name"] == "Ada"
|
||||
assert result["name"] == "Ada"
|
||||
assert result["__action_id"] == "approve"
|
||||
assert "__rendered_content" in result
|
||||
|
||||
def test_submit_human_input_form_preview_missing_inputs_message(self, workflow_service: WorkflowService) -> None:
|
||||
service = workflow_service
|
||||
node_data = HumanInputNodeData(
|
||||
title="Human Input",
|
||||
form_content="<p>{{#$output.name#}}</p>",
|
||||
inputs=[FormInput(type=FormInputType.TEXT_INPUT, output_variable_name="name")],
|
||||
user_actions=[UserAction(id="approve", title="Approve")],
|
||||
)
|
||||
node = MagicMock()
|
||||
node.node_data = node_data
|
||||
node._render_form_content_before_submission.return_value = "<p>preview</p>"
|
||||
node._render_form_content_with_outputs.return_value = "<p>rendered</p>"
|
||||
|
||||
service._build_human_input_variable_pool = MagicMock(return_value=MagicMock()) # type: ignore[method-assign]
|
||||
service._build_human_input_node = MagicMock(return_value=node) # type: ignore[method-assign]
|
||||
|
||||
workflow = MagicMock()
|
||||
workflow.get_node_config_by_id.return_value = {"id": "node-1", "data": {"type": NodeType.HUMAN_INPUT.value}}
|
||||
service.get_draft_workflow = MagicMock(return_value=workflow) # type: ignore[method-assign]
|
||||
|
||||
app_model = SimpleNamespace(id="app-1", tenant_id="tenant-1")
|
||||
account = SimpleNamespace(id="account-1")
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
service.submit_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=account,
|
||||
node_id="node-1",
|
||||
form_inputs={},
|
||||
inputs={},
|
||||
action="approve",
|
||||
)
|
||||
|
||||
assert "Missing required inputs" in str(exc_info.value)
|
||||
|
||||
Reference in New Issue
Block a user