mirror of
https://github.com/langgenius/dify.git
synced 2026-03-01 12:55:13 +00:00
Compare commits
8 Commits
copilot/fi
...
verify-ema
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9c5c935ed5 | ||
|
|
559f8263b7 | ||
|
|
59c5638342 | ||
|
|
897ffb6b35 | ||
|
|
d367a6b1e1 | ||
|
|
daa9d38788 | ||
|
|
ffe77fecdf | ||
|
|
b462a96fa0 |
@@ -133,7 +133,6 @@ class EducationAutocompleteQuery(BaseModel):
|
||||
class ChangeEmailSendPayload(BaseModel):
|
||||
email: EmailStr
|
||||
language: str | None = None
|
||||
phase: str | None = None
|
||||
token: str | None = None
|
||||
|
||||
|
||||
@@ -547,13 +546,17 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
account = None
|
||||
user_email = None
|
||||
email_for_sending = args.email.lower()
|
||||
if args.phase is not None and args.phase == "new_email":
|
||||
if args.token is None:
|
||||
raise InvalidTokenError()
|
||||
send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD
|
||||
if args.token is not None:
|
||||
send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW
|
||||
|
||||
reset_data = AccountService.get_change_email_data(args.token)
|
||||
if reset_data is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
reset_token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if reset_token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED:
|
||||
raise InvalidTokenError()
|
||||
user_email = reset_data.get("email", "")
|
||||
|
||||
if user_email.lower() != current_user.email.lower():
|
||||
@@ -573,7 +576,7 @@ class ChangeEmailSendEmailApi(Resource):
|
||||
email=email_for_sending,
|
||||
old_email=user_email,
|
||||
language=language,
|
||||
phase=args.phase,
|
||||
phase=send_phase,
|
||||
)
|
||||
return {"result": "success", "data": token}
|
||||
|
||||
@@ -608,12 +611,26 @@ class ChangeEmailCheckApi(Resource):
|
||||
AccountService.add_change_email_error_rate_limit(user_email)
|
||||
raise EmailCodeError()
|
||||
|
||||
phase_transitions: dict[str, str] = {
|
||||
AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
|
||||
AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if not isinstance(token_phase, str):
|
||||
raise InvalidTokenError()
|
||||
refreshed_phase = phase_transitions.get(token_phase)
|
||||
if refreshed_phase is None:
|
||||
raise InvalidTokenError()
|
||||
|
||||
# Verified, revoke the first token
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
# Refresh token data by generating a new token
|
||||
_, new_token = AccountService.generate_change_email_token(
|
||||
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
|
||||
user_email,
|
||||
code=args.code,
|
||||
old_email=token_data.get("old_email"),
|
||||
additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase},
|
||||
)
|
||||
|
||||
AccountService.reset_change_email_error_rate_limit(user_email)
|
||||
@@ -643,13 +660,22 @@ class ChangeEmailResetApi(Resource):
|
||||
if not reset_data:
|
||||
raise InvalidTokenError()
|
||||
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
|
||||
if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED:
|
||||
raise InvalidTokenError()
|
||||
|
||||
token_email = reset_data.get("email")
|
||||
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
|
||||
if normalized_token_email != normalized_new_email:
|
||||
raise InvalidTokenError()
|
||||
|
||||
old_email = reset_data.get("old_email", "")
|
||||
current_user, _ = current_account_with_tenant()
|
||||
if current_user.email.lower() != old_email.lower():
|
||||
raise AccountNotFound()
|
||||
|
||||
AccountService.revoke_change_email_token(args.token)
|
||||
|
||||
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
|
||||
|
||||
AccountService.send_change_email_completed_notify_email(
|
||||
|
||||
@@ -9,7 +9,6 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import queue
|
||||
import threading
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
@@ -77,13 +76,10 @@ class GraphEngine:
|
||||
config: GraphEngineConfig = _DEFAULT_CONFIG,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
# stop event
|
||||
self._stop_event = threading.Event()
|
||||
|
||||
# Bind runtime state to current workflow context
|
||||
self._graph = graph
|
||||
self._graph_runtime_state = graph_runtime_state
|
||||
self._graph_runtime_state.stop_event = self._stop_event
|
||||
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
|
||||
self._command_channel = command_channel
|
||||
self._config = config
|
||||
@@ -163,7 +159,6 @@ class GraphEngine:
|
||||
layers=self._layers,
|
||||
execution_context=execution_context,
|
||||
config=self._config,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
# === Orchestration ===
|
||||
@@ -194,7 +189,6 @@ class GraphEngine:
|
||||
event_handler=self._event_handler_registry,
|
||||
execution_coordinator=self._execution_coordinator,
|
||||
event_emitter=self._event_manager,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
# === Validation ===
|
||||
@@ -314,7 +308,6 @@ class GraphEngine:
|
||||
|
||||
def _start_execution(self, *, resume: bool = False) -> None:
|
||||
"""Start execution subsystems."""
|
||||
self._stop_event.clear()
|
||||
paused_nodes: list[str] = []
|
||||
deferred_nodes: list[str] = []
|
||||
if resume:
|
||||
@@ -348,7 +341,6 @@ class GraphEngine:
|
||||
|
||||
def _stop_execution(self) -> None:
|
||||
"""Stop execution subsystems."""
|
||||
self._stop_event.set()
|
||||
self._dispatcher.stop()
|
||||
self._worker_pool.stop()
|
||||
# Don't mark complete here as the dispatcher already does it
|
||||
|
||||
@@ -44,7 +44,6 @@ class Dispatcher:
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
event_handler: "EventHandler",
|
||||
execution_coordinator: ExecutionCoordinator,
|
||||
stop_event: threading.Event,
|
||||
event_emitter: EventManager | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
@@ -62,7 +61,7 @@ class Dispatcher:
|
||||
self._event_emitter = event_emitter
|
||||
|
||||
self._thread: threading.Thread | None = None
|
||||
self._stop_event = stop_event
|
||||
self._stop_event = threading.Event()
|
||||
self._start_time: float | None = None
|
||||
|
||||
def start(self) -> None:
|
||||
@@ -70,12 +69,14 @@ class Dispatcher:
|
||||
if self._thread and self._thread.is_alive():
|
||||
return
|
||||
|
||||
self._stop_event.clear()
|
||||
self._start_time = time.time()
|
||||
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Stop the dispatcher thread."""
|
||||
self._stop_event.set()
|
||||
if self._thread and self._thread.is_alive():
|
||||
self._thread.join(timeout=2.0)
|
||||
|
||||
|
||||
@@ -42,7 +42,6 @@ class Worker(threading.Thread):
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: Sequence[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
worker_id: int = 0,
|
||||
execution_context: IExecutionContext | None = None,
|
||||
) -> None:
|
||||
@@ -63,16 +62,13 @@ class Worker(threading.Thread):
|
||||
self._graph = graph
|
||||
self._worker_id = worker_id
|
||||
self._execution_context = execution_context
|
||||
self._stop_event = stop_event
|
||||
self._stop_event = threading.Event()
|
||||
self._layers = layers if layers is not None else []
|
||||
self._last_task_time = time.time()
|
||||
|
||||
def stop(self) -> None:
|
||||
"""Worker is controlled via shared stop_event from GraphEngine.
|
||||
|
||||
This method is a no-op retained for backward compatibility.
|
||||
"""
|
||||
pass
|
||||
"""Signal the worker to stop processing."""
|
||||
self._stop_event.set()
|
||||
|
||||
@property
|
||||
def is_idle(self) -> bool:
|
||||
|
||||
@@ -37,7 +37,6 @@ class WorkerPool:
|
||||
event_queue: queue.Queue[GraphNodeEventBase],
|
||||
graph: Graph,
|
||||
layers: list[GraphEngineLayer],
|
||||
stop_event: threading.Event,
|
||||
config: GraphEngineConfig,
|
||||
execution_context: IExecutionContext | None = None,
|
||||
) -> None:
|
||||
@@ -64,7 +63,6 @@ class WorkerPool:
|
||||
self._worker_counter = 0
|
||||
self._lock = threading.RLock()
|
||||
self._running = False
|
||||
self._stop_event = stop_event
|
||||
|
||||
# No longer tracking worker states with callbacks to avoid lock contention
|
||||
|
||||
@@ -135,7 +133,6 @@ class WorkerPool:
|
||||
layers=self._layers,
|
||||
worker_id=worker_id,
|
||||
execution_context=self._execution_context,
|
||||
stop_event=self._stop_event,
|
||||
)
|
||||
|
||||
worker.start()
|
||||
|
||||
@@ -302,10 +302,6 @@ class Node(Generic[NodeDataT]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def _should_stop(self) -> bool:
|
||||
"""Check if execution should be stopped."""
|
||||
return self.graph_runtime_state.stop_event.is_set()
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
@@ -374,21 +370,6 @@ class Node(Generic[NodeDataT]):
|
||||
yield event
|
||||
else:
|
||||
yield event
|
||||
|
||||
if self._should_stop():
|
||||
error_message = "Execution cancelled"
|
||||
yield NodeRunFailedEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
start_at=self._start_at,
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=error_message,
|
||||
),
|
||||
error=error_message,
|
||||
)
|
||||
return
|
||||
except Exception as e:
|
||||
logger.exception("Node %s failed to run", self._node_id)
|
||||
result = NodeRunResult(
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import threading
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
@@ -219,8 +218,6 @@ class GraphRuntimeState:
|
||||
self._pending_graph_node_states: dict[str, NodeState] | None = None
|
||||
self._pending_graph_edge_states: dict[str, NodeState] | None = None
|
||||
|
||||
self.stop_event: threading.Event = threading.Event()
|
||||
|
||||
if graph is not None:
|
||||
self.attach_graph(graph)
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ import logging
|
||||
import secrets
|
||||
import uuid
|
||||
from datetime import UTC, datetime, timedelta
|
||||
from enum import StrEnum
|
||||
from hashlib import sha256
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -80,12 +81,25 @@ class TokenPair(BaseModel):
|
||||
csrf_token: str
|
||||
|
||||
|
||||
class ChangeEmailPhase(StrEnum):
|
||||
OLD = "old_email"
|
||||
OLD_VERIFIED = "old_email_verified"
|
||||
NEW = "new_email"
|
||||
NEW_VERIFIED = "new_email_verified"
|
||||
|
||||
|
||||
REFRESH_TOKEN_PREFIX = "refresh_token:"
|
||||
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
|
||||
REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
|
||||
|
||||
|
||||
class AccountService:
|
||||
CHANGE_EMAIL_TOKEN_PHASE_KEY = "email_change_phase"
|
||||
CHANGE_EMAIL_PHASE_OLD = ChangeEmailPhase.OLD
|
||||
CHANGE_EMAIL_PHASE_OLD_VERIFIED = ChangeEmailPhase.OLD_VERIFIED
|
||||
CHANGE_EMAIL_PHASE_NEW = ChangeEmailPhase.NEW
|
||||
CHANGE_EMAIL_PHASE_NEW_VERIFIED = ChangeEmailPhase.NEW_VERIFIED
|
||||
|
||||
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
|
||||
email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1)
|
||||
email_code_login_rate_limiter = RateLimiter(
|
||||
@@ -541,13 +555,20 @@ class AccountService:
|
||||
raise ValueError("Email must be provided.")
|
||||
if not phase:
|
||||
raise ValueError("phase must be provided.")
|
||||
if phase not in (cls.CHANGE_EMAIL_PHASE_OLD, cls.CHANGE_EMAIL_PHASE_NEW):
|
||||
raise ValueError("phase must be one of old_email or new_email.")
|
||||
|
||||
if cls.change_email_rate_limiter.is_rate_limited(account_email):
|
||||
from controllers.console.auth.error import EmailChangeRateLimitExceededError
|
||||
|
||||
raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60))
|
||||
|
||||
code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
|
||||
code, token = cls.generate_change_email_token(
|
||||
account_email,
|
||||
account,
|
||||
old_email=old_email,
|
||||
additional_data={cls.CHANGE_EMAIL_TOKEN_PHASE_KEY: phase},
|
||||
)
|
||||
|
||||
send_change_mail_task.delay(
|
||||
language=language,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
|
||||
@@ -125,7 +126,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
|
||||
data_source_info = document.data_source_info_dict
|
||||
data_source_info["last_edited_time"] = last_edited_time
|
||||
document.data_source_info = data_source_info
|
||||
document.data_source_info = json.dumps(data_source_info)
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
|
||||
@@ -12,8 +12,6 @@ from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
from psycopg2.extensions import register_adapter
|
||||
from psycopg2.extras import Json
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
@@ -21,12 +19,6 @@ from models.dataset import Dataset, Document, DocumentSegment
|
||||
from tasks.document_indexing_sync_task import document_indexing_sync_task
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _register_dict_adapter_for_psycopg2():
|
||||
"""Align test DB adapter behavior with dict payloads used in task update flow."""
|
||||
register_adapter(dict, Json)
|
||||
|
||||
|
||||
class DocumentIndexingSyncTaskTestDataFactory:
|
||||
"""Create real DB entities for document indexing sync integration tests."""
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
from flask import Flask, g
|
||||
|
||||
from controllers.console.auth.error import InvalidTokenError
|
||||
from controllers.console.workspace.account import (
|
||||
AccountDeleteUpdateFeedbackApi,
|
||||
ChangeEmailCheckApi,
|
||||
@@ -52,7 +53,7 @@ class TestChangeEmailSend:
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_normalize_new_email_phase(
|
||||
def test_should_infer_new_email_phase_from_token(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
@@ -68,13 +69,16 @@ class TestChangeEmailSend:
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_get_change_data.return_value = {"email": "current@example.com"}
|
||||
mock_get_change_data.return_value = {
|
||||
"email": "current@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
|
||||
}
|
||||
mock_send_email.return_value = "token-abc"
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
|
||||
json={"email": "New@Example.com", "language": "en-US", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailSendEmailApi().post()
|
||||
@@ -91,6 +95,107 @@ class TestChangeEmailSend:
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.db")
|
||||
@patch("controllers.console.workspace.account.Session")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_account_by_email_with_case_fallback")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_ignore_client_phase_and_use_old_phase_when_token_missing(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_account_by_email,
|
||||
mock_session_cls,
|
||||
mock_account_db,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_current_account.return_value = (_build_account("current@example.com", "current"), None)
|
||||
existing_account = _build_account("old@example.com", "acc-old")
|
||||
mock_get_account_by_email.return_value = existing_account
|
||||
mock_send_email.return_value = "token-legacy"
|
||||
mock_session = MagicMock()
|
||||
mock_session_cm = MagicMock()
|
||||
mock_session_cm.__enter__.return_value = mock_session
|
||||
mock_session_cm.__exit__.return_value = None
|
||||
mock_session_cls.return_value = mock_session_cm
|
||||
mock_account_db.engine = MagicMock()
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "old@example.com", "language": "en-US", "phase": "new_email"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailSendEmailApi().post()
|
||||
|
||||
assert response == {"result": "success", "data": "token-legacy"}
|
||||
mock_get_account_by_email.assert_called_once_with("old@example.com", session=mock_session)
|
||||
mock_send_email.assert_called_once_with(
|
||||
account=existing_account,
|
||||
email="old@example.com",
|
||||
old_email="old@example.com",
|
||||
language="en-US",
|
||||
phase=AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
)
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
|
||||
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_unverified_old_email_token_for_new_email_phase(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_extract_ip,
|
||||
mock_is_ip_limit,
|
||||
mock_send_email,
|
||||
mock_get_change_data,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("current@example.com", "acc1")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_get_change_data.return_value = {
|
||||
"email": "current@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email",
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "language": "en-US", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailSendEmailApi().post()
|
||||
|
||||
mock_send_email.assert_not_called()
|
||||
mock_extract_ip.assert_called_once()
|
||||
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestChangeEmailValidity:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@@ -122,7 +227,12 @@ class TestChangeEmailValidity:
|
||||
mock_account = _build_account("user@example.com", "acc2")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"}
|
||||
mock_get_data.return_value = {
|
||||
"email": "user@example.com",
|
||||
"code": "1234",
|
||||
"old_email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
}
|
||||
mock_generate_token.return_value = (None, "new-token")
|
||||
|
||||
with app.test_request_context(
|
||||
@@ -138,11 +248,76 @@ class TestChangeEmailValidity:
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-123")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"user@example.com", code="1234", old_email="old@example.com", additional_data={}
|
||||
"user@example.com",
|
||||
code="1234",
|
||||
old_email="old@example.com",
|
||||
additional_data={
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED
|
||||
},
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("user@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_refresh_new_email_phase_to_verified(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_rate_limit,
|
||||
mock_get_data,
|
||||
mock_add_rate,
|
||||
mock_revoke_token,
|
||||
mock_generate_token,
|
||||
mock_reset_rate,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
mock_account = _build_account("old@example.com", "acc2")
|
||||
mock_current_account.return_value = (mock_account, None)
|
||||
mock_is_rate_limit.return_value = False
|
||||
mock_get_data.return_value = {
|
||||
"email": "new@example.com",
|
||||
"code": "5678",
|
||||
"old_email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW,
|
||||
}
|
||||
mock_generate_token.return_value = (None, "new-phase-token")
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/validity",
|
||||
method="POST",
|
||||
json={"email": "New@Example.com", "code": "5678", "token": "token-456"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
response = ChangeEmailCheckApi().post()
|
||||
|
||||
assert response == {"is_valid": True, "email": "new@example.com", "token": "new-phase-token"}
|
||||
mock_is_rate_limit.assert_called_once_with("new@example.com")
|
||||
mock_add_rate.assert_not_called()
|
||||
mock_revoke_token.assert_called_once_with("token-456")
|
||||
mock_generate_token.assert_called_once_with(
|
||||
"new@example.com",
|
||||
code="5678",
|
||||
old_email="old@example.com",
|
||||
additional_data={
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED
|
||||
},
|
||||
)
|
||||
mock_reset_rate.assert_called_once_with("new@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestChangeEmailReset:
|
||||
@patch("controllers.console.wraps.db")
|
||||
@@ -175,7 +350,11 @@ class TestChangeEmailReset:
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {"old_email": "OLD@example.com"}
|
||||
mock_get_data.return_value = {
|
||||
"old_email": "OLD@example.com",
|
||||
"email": "new@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
mock_account_after_update = _build_account("new@example.com", "acc3-updated")
|
||||
mock_update_account.return_value = mock_account_after_update
|
||||
|
||||
@@ -194,6 +373,106 @@ class TestChangeEmailReset:
|
||||
mock_send_notify.assert_called_once_with(email="new@example.com")
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.update_account_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_old_phase_token_for_reset(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {
|
||||
"old_email": "OLD@example.com",
|
||||
"email": "old@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/reset",
|
||||
method="POST",
|
||||
json={"new_email": "new@example.com", "token": "token-123"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailResetApi().post()
|
||||
|
||||
mock_revoke_token.assert_not_called()
|
||||
mock_update_account.assert_not_called()
|
||||
mock_send_notify.assert_not_called()
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
@patch("controllers.console.wraps.db")
|
||||
@patch("controllers.console.workspace.account.current_account_with_tenant")
|
||||
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.update_account_email")
|
||||
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
|
||||
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
|
||||
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
|
||||
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
|
||||
@patch("libs.login.check_csrf_token", return_value=None)
|
||||
@patch("controllers.console.wraps.FeatureService.get_system_features")
|
||||
def test_should_reject_mismatched_new_email_for_verified_token(
|
||||
self,
|
||||
mock_features,
|
||||
mock_csrf,
|
||||
mock_is_freeze,
|
||||
mock_check_unique,
|
||||
mock_get_data,
|
||||
mock_revoke_token,
|
||||
mock_update_account,
|
||||
mock_send_notify,
|
||||
mock_current_account,
|
||||
mock_db,
|
||||
app,
|
||||
):
|
||||
_mock_wraps_db(mock_db)
|
||||
mock_features.return_value = SimpleNamespace(enable_change_email=True)
|
||||
current_user = _build_account("old@example.com", "acc3")
|
||||
mock_current_account.return_value = (current_user, None)
|
||||
mock_is_freeze.return_value = False
|
||||
mock_check_unique.return_value = True
|
||||
mock_get_data.return_value = {
|
||||
"old_email": "OLD@example.com",
|
||||
"email": "another@example.com",
|
||||
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
|
||||
}
|
||||
|
||||
with app.test_request_context(
|
||||
"/account/change-email/reset",
|
||||
method="POST",
|
||||
json={"new_email": "new@example.com", "token": "token-789"},
|
||||
):
|
||||
_set_logged_in_user(_build_account("tester@example.com", "tester"))
|
||||
with pytest.raises(InvalidTokenError):
|
||||
ChangeEmailResetApi().post()
|
||||
|
||||
mock_revoke_token.assert_not_called()
|
||||
mock_update_account.assert_not_called()
|
||||
mock_send_notify.assert_not_called()
|
||||
mock_csrf.assert_called_once()
|
||||
|
||||
|
||||
class TestAccountDeletionFeedback:
|
||||
@patch("controllers.console.wraps.db")
|
||||
|
||||
@@ -3,7 +3,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from unittest import mock
|
||||
|
||||
from core.workflow.entities.pause_reason import SchedulingPause
|
||||
@@ -37,7 +36,6 @@ def test_dispatcher_should_consume_remains_events_after_pause():
|
||||
event_queue=event_queue,
|
||||
event_handler=event_handler,
|
||||
execution_coordinator=execution_coordinator,
|
||||
stop_event=threading.Event(),
|
||||
)
|
||||
dispatcher._dispatcher_loop()
|
||||
assert event_queue.empty()
|
||||
@@ -98,7 +96,6 @@ def _run_dispatcher_for_event(event) -> int:
|
||||
event_queue=event_queue,
|
||||
event_handler=event_handler,
|
||||
execution_coordinator=coordinator,
|
||||
stop_event=threading.Event(),
|
||||
)
|
||||
|
||||
dispatcher._dispatcher_loop()
|
||||
@@ -184,7 +181,6 @@ def test_dispatcher_drain_event_queue():
|
||||
event_queue=event_queue,
|
||||
event_handler=event_handler,
|
||||
execution_coordinator=coordinator,
|
||||
stop_event=threading.Event(),
|
||||
)
|
||||
|
||||
dispatcher._dispatcher_loop()
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
import queue
|
||||
import threading
|
||||
from datetime import datetime
|
||||
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
@@ -65,7 +64,6 @@ def test_dispatcher_drains_events_when_paused() -> None:
|
||||
event_handler=handler,
|
||||
execution_coordinator=coordinator,
|
||||
event_emitter=None,
|
||||
stop_event=threading.Event(),
|
||||
)
|
||||
|
||||
dispatcher._dispatcher_loop()
|
||||
|
||||
@@ -1,550 +0,0 @@
|
||||
"""
|
||||
Unit tests for stop_event functionality in GraphEngine.
|
||||
|
||||
Tests the unified stop_event management by GraphEngine and its propagation
|
||||
to WorkerPool, Worker, Dispatcher, and Nodes.
|
||||
"""
|
||||
|
||||
import threading
|
||||
import time
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
|
||||
from core.workflow.graph_engine.command_channels import InMemoryChannel
|
||||
from core.workflow.graph_events import (
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunStartedEvent,
|
||||
)
|
||||
from core.workflow.nodes.answer.answer_node import AnswerNode
|
||||
from core.workflow.nodes.start.start_node import StartNode
|
||||
from core.workflow.runtime import GraphRuntimeState, VariablePool
|
||||
from models.enums import UserFrom
|
||||
|
||||
|
||||
class TestStopEventPropagation:
|
||||
"""Test suite for stop_event propagation through GraphEngine components."""
|
||||
|
||||
def test_graph_engine_creates_stop_event(self):
|
||||
"""Test that GraphEngine creates a stop_event on initialization."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
# Verify stop_event was created
|
||||
assert engine._stop_event is not None
|
||||
assert isinstance(engine._stop_event, threading.Event)
|
||||
|
||||
# Verify it was set in graph_runtime_state
|
||||
assert runtime_state.stop_event is not None
|
||||
assert runtime_state.stop_event is engine._stop_event
|
||||
|
||||
def test_stop_event_cleared_on_start(self):
|
||||
"""Test that stop_event is cleared when execution starts."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start" # Set proper id
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
mock_graph.nodes["start"] = start_node
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
# Set the stop_event before running
|
||||
engine._stop_event.set()
|
||||
assert engine._stop_event.is_set()
|
||||
|
||||
# Run the engine (should clear the stop_event)
|
||||
events = list(engine.run())
|
||||
|
||||
# After running, stop_event should be set again (by _stop_execution)
|
||||
# But during start it was cleared
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
|
||||
|
||||
def test_stop_event_set_on_stop(self):
|
||||
"""Test that stop_event is set when execution stops."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start" # Set proper id
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
mock_graph.nodes["start"] = start_node
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
# Initially not set
|
||||
assert not engine._stop_event.is_set()
|
||||
|
||||
# Run the engine
|
||||
list(engine.run())
|
||||
|
||||
# After execution completes, stop_event should be set
|
||||
assert engine._stop_event.is_set()
|
||||
|
||||
def test_stop_event_passed_to_worker_pool(self):
|
||||
"""Test that stop_event is passed to WorkerPool."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
# Verify WorkerPool has the stop_event
|
||||
assert engine._worker_pool._stop_event is not None
|
||||
assert engine._worker_pool._stop_event is engine._stop_event
|
||||
|
||||
def test_stop_event_passed_to_dispatcher(self):
|
||||
"""Test that stop_event is passed to Dispatcher."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
# Verify Dispatcher has the stop_event
|
||||
assert engine._dispatcher._stop_event is not None
|
||||
assert engine._dispatcher._stop_event is engine._stop_event
|
||||
|
||||
|
||||
class TestNodeStopCheck:
|
||||
"""Test suite for Node._should_stop() functionality."""
|
||||
|
||||
def test_node_should_stop_checks_runtime_state(self):
|
||||
"""Test that Node._should_stop() checks GraphRuntimeState.stop_event."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
answer_node = AnswerNode(
|
||||
id="answer",
|
||||
config={"id": "answer", "data": {"title": "answer", "answer": "{{#start.result#}}"}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
# Initially stop_event is not set
|
||||
assert not answer_node._should_stop()
|
||||
|
||||
# Set the stop_event
|
||||
runtime_state.stop_event.set()
|
||||
|
||||
# Now _should_stop should return True
|
||||
assert answer_node._should_stop()
|
||||
|
||||
def test_node_run_checks_stop_event_between_yields(self):
|
||||
"""Test that Node.run() checks stop_event between yielding events."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
# Create a simple node
|
||||
answer_node = AnswerNode(
|
||||
id="answer",
|
||||
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
# Set stop_event BEFORE running the node
|
||||
runtime_state.stop_event.set()
|
||||
|
||||
# Run the node - should yield start event then detect stop
|
||||
# The node should check stop_event before processing
|
||||
assert answer_node._should_stop(), "stop_event should be set"
|
||||
|
||||
# Run and collect events
|
||||
events = list(answer_node.run())
|
||||
|
||||
# Since stop_event is set at the start, we should get:
|
||||
# 1. NodeRunStartedEvent (always yielded first)
|
||||
# 2. Either NodeRunFailedEvent (if detected early) or NodeRunSucceededEvent (if too fast)
|
||||
assert len(events) >= 2
|
||||
assert isinstance(events[0], NodeRunStartedEvent)
|
||||
|
||||
# Note: AnswerNode is very simple and might complete before stop check
|
||||
# The important thing is that _should_stop() returns True when stop_event is set
|
||||
assert answer_node._should_stop()
|
||||
|
||||
|
||||
class TestStopEventIntegration:
|
||||
"""Integration tests for stop_event in workflow execution."""
|
||||
|
||||
def test_simple_workflow_respects_stop_event(self):
|
||||
"""Test that a simple workflow respects stop_event."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start"
|
||||
|
||||
# Create start and answer nodes
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
answer_node = AnswerNode(
|
||||
id="answer",
|
||||
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
|
||||
mock_graph.nodes["start"] = start_node
|
||||
mock_graph.nodes["answer"] = answer_node
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
# Set stop_event before running
|
||||
runtime_state.stop_event.set()
|
||||
|
||||
# Run the engine
|
||||
events = list(engine.run())
|
||||
|
||||
# Should get started event but not succeeded (due to stop)
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
# The workflow should still complete (start node runs quickly)
|
||||
# but answer node might be cancelled depending on timing
|
||||
|
||||
def test_stop_event_with_concurrent_nodes(self):
|
||||
"""Test stop_event behavior with multiple concurrent nodes."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
# Create multiple nodes
|
||||
for i in range(3):
|
||||
answer_node = AnswerNode(
|
||||
id=f"answer_{i}",
|
||||
config={"id": f"answer_{i}", "data": {"title": f"answer_{i}", "answer": f"test{i}"}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
mock_graph.nodes[f"answer_{i}"] = answer_node
|
||||
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
# All nodes should share the same stop_event
|
||||
for node in mock_graph.nodes.values():
|
||||
assert node.graph_runtime_state.stop_event is runtime_state.stop_event
|
||||
assert node.graph_runtime_state.stop_event is engine._stop_event
|
||||
|
||||
|
||||
class TestStopEventTimeoutBehavior:
|
||||
"""Test stop_event behavior with join timeouts."""
|
||||
|
||||
@patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread", autospec=True)
|
||||
def test_dispatcher_uses_shorter_timeout(self, mock_thread_cls: MagicMock):
|
||||
"""Test that Dispatcher uses 2s timeout instead of 10s."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
dispatcher = engine._dispatcher
|
||||
dispatcher.start() # This will create and start the mocked thread
|
||||
|
||||
mock_thread_instance = mock_thread_cls.return_value
|
||||
mock_thread_instance.is_alive.return_value = True
|
||||
|
||||
dispatcher.stop()
|
||||
|
||||
mock_thread_instance.join.assert_called_once_with(timeout=2.0)
|
||||
|
||||
@patch("core.workflow.graph_engine.worker_management.worker_pool.Worker", autospec=True)
|
||||
def test_worker_pool_uses_shorter_timeout(self, mock_worker_cls: MagicMock):
|
||||
"""Test that WorkerPool uses 2s timeout instead of 10s."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
worker_pool = engine._worker_pool
|
||||
worker_pool.start(initial_count=1) # Start with one worker
|
||||
|
||||
mock_worker_instance = mock_worker_cls.return_value
|
||||
mock_worker_instance.is_alive.return_value = True
|
||||
|
||||
worker_pool.stop()
|
||||
|
||||
mock_worker_instance.join.assert_called_once_with(timeout=2.0)
|
||||
|
||||
|
||||
class TestStopEventResumeBehavior:
|
||||
"""Test stop_event behavior during workflow resume."""
|
||||
|
||||
def test_stop_event_cleared_on_resume(self):
|
||||
"""Test that stop_event is cleared when resuming a paused workflow."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
mock_graph.root_node.id = "start" # Set proper id
|
||||
|
||||
start_node = StartNode(
|
||||
id="start",
|
||||
config={"id": "start", "data": {"title": "start", "variables": []}},
|
||||
graph_init_params=GraphInitParams(
|
||||
tenant_id="test_tenant",
|
||||
app_id="test_app",
|
||||
workflow_id="test_workflow",
|
||||
graph_config={},
|
||||
user_id="test_user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=0,
|
||||
),
|
||||
graph_runtime_state=runtime_state,
|
||||
)
|
||||
mock_graph.nodes["start"] = start_node
|
||||
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
|
||||
mock_graph.get_incoming_edges = MagicMock(return_value=[])
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
# Simulate a previous execution that set stop_event
|
||||
engine._stop_event.set()
|
||||
assert engine._stop_event.is_set()
|
||||
|
||||
# Run the engine (should clear stop_event in _start_execution)
|
||||
events = list(engine.run())
|
||||
|
||||
# Execution should complete successfully
|
||||
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
|
||||
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
|
||||
|
||||
|
||||
class TestWorkerStopBehavior:
|
||||
"""Test Worker behavior with shared stop_event."""
|
||||
|
||||
def test_worker_uses_shared_stop_event(self):
|
||||
"""Test that Worker uses shared stop_event from GraphEngine."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
mock_graph = MagicMock(spec=Graph)
|
||||
mock_graph.nodes = {}
|
||||
mock_graph.edges = {}
|
||||
mock_graph.root_node = MagicMock()
|
||||
|
||||
engine = GraphEngine(
|
||||
workflow_id="test_workflow",
|
||||
graph=mock_graph,
|
||||
graph_runtime_state=runtime_state,
|
||||
command_channel=InMemoryChannel(),
|
||||
config=GraphEngineConfig(),
|
||||
)
|
||||
|
||||
# Get the worker pool and check workers
|
||||
worker_pool = engine._worker_pool
|
||||
|
||||
# Start the worker pool to create workers
|
||||
worker_pool.start()
|
||||
|
||||
# Check that at least one worker was created
|
||||
assert len(worker_pool._workers) > 0
|
||||
|
||||
# Verify workers use the shared stop_event
|
||||
for worker in worker_pool._workers:
|
||||
assert worker._stop_event is engine._stop_event
|
||||
|
||||
# Clean up
|
||||
worker_pool.stop()
|
||||
|
||||
def test_worker_stop_is_noop(self):
|
||||
"""Test that Worker.stop() is now a no-op."""
|
||||
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
|
||||
|
||||
# Create a mock worker
|
||||
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
|
||||
from core.workflow.graph_engine.worker import Worker
|
||||
|
||||
ready_queue = InMemoryReadyQueue()
|
||||
event_queue = MagicMock()
|
||||
|
||||
# Create a proper mock graph with real dict
|
||||
mock_graph = Mock(spec=Graph)
|
||||
mock_graph.nodes = {} # Use real dict
|
||||
|
||||
stop_event = threading.Event()
|
||||
|
||||
worker = Worker(
|
||||
ready_queue=ready_queue,
|
||||
event_queue=event_queue,
|
||||
graph=mock_graph,
|
||||
layers=[],
|
||||
stop_event=stop_event,
|
||||
)
|
||||
|
||||
# Calling stop() should do nothing (no-op)
|
||||
# and should NOT set the stop_event
|
||||
worker.stop()
|
||||
assert not stop_event.is_set()
|
||||
@@ -5,6 +5,7 @@ These tests intentionally stay in unit scope because they validate call argument
|
||||
for external collaborators rather than SQL-backed state transitions.
|
||||
"""
|
||||
|
||||
import json
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
@@ -196,3 +197,78 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
|
||||
provider="notion_datasource",
|
||||
plugin_id="langgenius/notion_datasource",
|
||||
)
|
||||
|
||||
|
||||
class TestDataSourceInfoSerialization:
|
||||
"""Regression test: data_source_info must be written as a JSON string, not a raw dict.
|
||||
|
||||
See https://github.com/langgenius/dify/issues/32705
|
||||
psycopg2 raises ``ProgrammingError: can't adapt type 'dict'`` when a Python
|
||||
dict is passed directly to a text/LongText column.
|
||||
"""
|
||||
|
||||
def test_data_source_info_serialized_as_json_string(
|
||||
self,
|
||||
mock_document,
|
||||
mock_dataset,
|
||||
dataset_id,
|
||||
document_id,
|
||||
):
|
||||
"""data_source_info must be serialized with json.dumps before DB write."""
|
||||
with (
|
||||
patch("tasks.document_indexing_sync_task.session_factory") as mock_session_factory,
|
||||
patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class,
|
||||
patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class,
|
||||
patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_ipf,
|
||||
patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_runner_class,
|
||||
):
|
||||
# External collaborators
|
||||
mock_service = MagicMock()
|
||||
mock_service.get_datasource_credentials.return_value = {"integration_secret": "token"}
|
||||
mock_service_class.return_value = mock_service
|
||||
|
||||
mock_extractor = MagicMock()
|
||||
# Return a *different* timestamp so the task enters the sync/update branch
|
||||
mock_extractor.get_notion_last_edited_time.return_value = "2024-02-01T00:00:00Z"
|
||||
mock_extractor_class.return_value = mock_extractor
|
||||
|
||||
mock_ip = MagicMock()
|
||||
mock_ipf.return_value.init_index_processor.return_value = mock_ip
|
||||
|
||||
mock_runner = MagicMock()
|
||||
mock_runner_class.return_value = mock_runner
|
||||
|
||||
# DB session mock — shared across all ``session_factory.create_session()`` calls
|
||||
session = MagicMock()
|
||||
session.scalars.return_value.all.return_value = []
|
||||
# .where() path: session 1 reads document + dataset, session 2 reads dataset
|
||||
session.query.return_value.where.return_value.first.side_effect = [
|
||||
mock_document,
|
||||
mock_dataset,
|
||||
mock_dataset,
|
||||
]
|
||||
# .filter_by() path: session 3 (update), session 4 (indexing)
|
||||
session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
mock_document,
|
||||
mock_document,
|
||||
]
|
||||
|
||||
begin_cm = MagicMock()
|
||||
begin_cm.__enter__.return_value = session
|
||||
begin_cm.__exit__.return_value = False
|
||||
session.begin.return_value = begin_cm
|
||||
|
||||
session_cm = MagicMock()
|
||||
session_cm.__enter__.return_value = session
|
||||
session_cm.__exit__.return_value = False
|
||||
mock_session_factory.create_session.return_value = session_cm
|
||||
|
||||
# Act
|
||||
document_indexing_sync_task(dataset_id, document_id)
|
||||
|
||||
# Assert: data_source_info must be a JSON *string*, not a dict
|
||||
assert isinstance(mock_document.data_source_info, str), (
|
||||
f"data_source_info should be a JSON string, got {type(mock_document.data_source_info).__name__}"
|
||||
)
|
||||
parsed = json.loads(mock_document.data_source_info)
|
||||
assert parsed["last_edited_time"] == "2024-02-01T00:00:00Z"
|
||||
|
||||
@@ -58,11 +58,10 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
}, 1000)
|
||||
}
|
||||
|
||||
const sendEmail = async (email: string, isOrigin: boolean, token?: string) => {
|
||||
const sendEmail = async (email: string, token?: string) => {
|
||||
try {
|
||||
const res = await sendVerifyCode({
|
||||
email,
|
||||
phase: isOrigin ? 'old_email' : 'new_email',
|
||||
token,
|
||||
})
|
||||
startCount()
|
||||
@@ -106,7 +105,6 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
const sendCodeToOriginEmail = async () => {
|
||||
await sendEmail(
|
||||
email,
|
||||
true,
|
||||
)
|
||||
setStep(STEP.verifyOrigin)
|
||||
}
|
||||
@@ -162,7 +160,6 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
}
|
||||
await sendEmail(
|
||||
mail,
|
||||
false,
|
||||
stepToken,
|
||||
)
|
||||
setStep(STEP.verifyNew)
|
||||
|
||||
@@ -372,7 +372,7 @@ export const submitDeleteAccountFeedback = (body: { feedback: string, email: str
|
||||
export const getDocDownloadUrl = (doc_name: string): Promise<{ url: string }> =>
|
||||
get<{ url: string }>('/compliance/download', { params: { doc_name } }, { silent: true })
|
||||
|
||||
export const sendVerifyCode = (body: { email: string, phase: string, token?: string }): Promise<CommonResponse & { data: string }> =>
|
||||
export const sendVerifyCode = (body: { email: string, token?: string }): Promise<CommonResponse & { data: string }> =>
|
||||
post<CommonResponse & { data: string }>('/account/change-email', { body })
|
||||
|
||||
export const verifyEmail = (body: { email: string, code: string, token: string }): Promise<CommonResponse & { is_valid: boolean, email: string, token: string }> =>
|
||||
|
||||
Reference in New Issue
Block a user