Compare commits

...

8 Commits

17 changed files with 426 additions and 626 deletions

View File

@@ -133,7 +133,6 @@ class EducationAutocompleteQuery(BaseModel):
class ChangeEmailSendPayload(BaseModel):
email: EmailStr
language: str | None = None
phase: str | None = None
token: str | None = None
@@ -547,13 +546,17 @@ class ChangeEmailSendEmailApi(Resource):
account = None
user_email = None
email_for_sending = args.email.lower()
if args.phase is not None and args.phase == "new_email":
if args.token is None:
raise InvalidTokenError()
send_phase = AccountService.CHANGE_EMAIL_PHASE_OLD
if args.token is not None:
send_phase = AccountService.CHANGE_EMAIL_PHASE_NEW
reset_data = AccountService.get_change_email_data(args.token)
if reset_data is None:
raise InvalidTokenError()
reset_token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if reset_token_phase != AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED:
raise InvalidTokenError()
user_email = reset_data.get("email", "")
if user_email.lower() != current_user.email.lower():
@@ -573,7 +576,7 @@ class ChangeEmailSendEmailApi(Resource):
email=email_for_sending,
old_email=user_email,
language=language,
phase=args.phase,
phase=send_phase,
)
return {"result": "success", "data": token}
@@ -608,12 +611,26 @@ class ChangeEmailCheckApi(Resource):
AccountService.add_change_email_error_rate_limit(user_email)
raise EmailCodeError()
phase_transitions: dict[str, str] = {
AccountService.CHANGE_EMAIL_PHASE_OLD: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
AccountService.CHANGE_EMAIL_PHASE_NEW: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
}
token_phase = token_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if not isinstance(token_phase, str):
raise InvalidTokenError()
refreshed_phase = phase_transitions.get(token_phase)
if refreshed_phase is None:
raise InvalidTokenError()
# Verified, revoke the first token
AccountService.revoke_change_email_token(args.token)
# Refresh token data by generating a new token
_, new_token = AccountService.generate_change_email_token(
user_email, code=args.code, old_email=token_data.get("old_email"), additional_data={}
user_email,
code=args.code,
old_email=token_data.get("old_email"),
additional_data={AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: refreshed_phase},
)
AccountService.reset_change_email_error_rate_limit(user_email)
@@ -643,13 +660,22 @@ class ChangeEmailResetApi(Resource):
if not reset_data:
raise InvalidTokenError()
AccountService.revoke_change_email_token(args.token)
token_phase = reset_data.get(AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY)
if token_phase != AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED:
raise InvalidTokenError()
token_email = reset_data.get("email")
normalized_token_email = token_email.lower() if isinstance(token_email, str) else token_email
if normalized_token_email != normalized_new_email:
raise InvalidTokenError()
old_email = reset_data.get("old_email", "")
current_user, _ = current_account_with_tenant()
if current_user.email.lower() != old_email.lower():
raise AccountNotFound()
AccountService.revoke_change_email_token(args.token)
updated_account = AccountService.update_account_email(current_user, email=normalized_new_email)
AccountService.send_change_email_completed_notify_email(

View File

@@ -9,7 +9,6 @@ from __future__ import annotations
import logging
import queue
import threading
from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
@@ -77,13 +76,10 @@ class GraphEngine:
config: GraphEngineConfig = _DEFAULT_CONFIG,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
# stop event
self._stop_event = threading.Event()
# Bind runtime state to current workflow context
self._graph = graph
self._graph_runtime_state = graph_runtime_state
self._graph_runtime_state.stop_event = self._stop_event
self._graph_runtime_state.configure(graph=cast("GraphProtocol", graph))
self._command_channel = command_channel
self._config = config
@@ -163,7 +159,6 @@ class GraphEngine:
layers=self._layers,
execution_context=execution_context,
config=self._config,
stop_event=self._stop_event,
)
# === Orchestration ===
@@ -194,7 +189,6 @@ class GraphEngine:
event_handler=self._event_handler_registry,
execution_coordinator=self._execution_coordinator,
event_emitter=self._event_manager,
stop_event=self._stop_event,
)
# === Validation ===
@@ -314,7 +308,6 @@ class GraphEngine:
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
self._stop_event.clear()
paused_nodes: list[str] = []
deferred_nodes: list[str] = []
if resume:
@@ -348,7 +341,6 @@ class GraphEngine:
def _stop_execution(self) -> None:
"""Stop execution subsystems."""
self._stop_event.set()
self._dispatcher.stop()
self._worker_pool.stop()
# Don't mark complete here as the dispatcher already does it

View File

@@ -44,7 +44,6 @@ class Dispatcher:
event_queue: queue.Queue[GraphNodeEventBase],
event_handler: "EventHandler",
execution_coordinator: ExecutionCoordinator,
stop_event: threading.Event,
event_emitter: EventManager | None = None,
) -> None:
"""
@@ -62,7 +61,7 @@ class Dispatcher:
self._event_emitter = event_emitter
self._thread: threading.Thread | None = None
self._stop_event = stop_event
self._stop_event = threading.Event()
self._start_time: float | None = None
def start(self) -> None:
@@ -70,12 +69,14 @@ class Dispatcher:
if self._thread and self._thread.is_alive():
return
self._stop_event.clear()
self._start_time = time.time()
self._thread = threading.Thread(target=self._dispatcher_loop, name="GraphDispatcher", daemon=True)
self._thread.start()
def stop(self) -> None:
"""Stop the dispatcher thread."""
self._stop_event.set()
if self._thread and self._thread.is_alive():
self._thread.join(timeout=2.0)

View File

@@ -42,7 +42,6 @@ class Worker(threading.Thread):
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: Sequence[GraphEngineLayer],
stop_event: threading.Event,
worker_id: int = 0,
execution_context: IExecutionContext | None = None,
) -> None:
@@ -63,16 +62,13 @@ class Worker(threading.Thread):
self._graph = graph
self._worker_id = worker_id
self._execution_context = execution_context
self._stop_event = stop_event
self._stop_event = threading.Event()
self._layers = layers if layers is not None else []
self._last_task_time = time.time()
def stop(self) -> None:
"""Worker is controlled via shared stop_event from GraphEngine.
This method is a no-op retained for backward compatibility.
"""
pass
"""Signal the worker to stop processing."""
self._stop_event.set()
@property
def is_idle(self) -> bool:

View File

@@ -37,7 +37,6 @@ class WorkerPool:
event_queue: queue.Queue[GraphNodeEventBase],
graph: Graph,
layers: list[GraphEngineLayer],
stop_event: threading.Event,
config: GraphEngineConfig,
execution_context: IExecutionContext | None = None,
) -> None:
@@ -64,7 +63,6 @@ class WorkerPool:
self._worker_counter = 0
self._lock = threading.RLock()
self._running = False
self._stop_event = stop_event
# No longer tracking worker states with callbacks to avoid lock contention
@@ -135,7 +133,6 @@ class WorkerPool:
layers=self._layers,
worker_id=worker_id,
execution_context=self._execution_context,
stop_event=self._stop_event,
)
worker.start()

View File

@@ -302,10 +302,6 @@ class Node(Generic[NodeDataT]):
"""
raise NotImplementedError
def _should_stop(self) -> bool:
"""Check if execution should be stopped."""
return self.graph_runtime_state.stop_event.is_set()
def run(self) -> Generator[GraphNodeEventBase, None, None]:
execution_id = self.ensure_execution_id()
self._start_at = naive_utc_now()
@@ -374,21 +370,6 @@ class Node(Generic[NodeDataT]):
yield event
else:
yield event
if self._should_stop():
error_message = "Execution cancelled"
yield NodeRunFailedEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
start_at=self._start_at,
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.FAILED,
error=error_message,
),
error=error_message,
)
return
except Exception as e:
logger.exception("Node %s failed to run", self._node_id)
result = NodeRunResult(

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import importlib
import json
import threading
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
@@ -219,8 +218,6 @@ class GraphRuntimeState:
self._pending_graph_node_states: dict[str, NodeState] | None = None
self._pending_graph_edge_states: dict[str, NodeState] | None = None
self.stop_event: threading.Event = threading.Event()
if graph is not None:
self.attach_graph(graph)

View File

@@ -4,6 +4,7 @@ import logging
import secrets
import uuid
from datetime import UTC, datetime, timedelta
from enum import StrEnum
from hashlib import sha256
from typing import Any, cast
@@ -80,12 +81,25 @@ class TokenPair(BaseModel):
csrf_token: str
class ChangeEmailPhase(StrEnum):
OLD = "old_email"
OLD_VERIFIED = "old_email_verified"
NEW = "new_email"
NEW_VERIFIED = "new_email_verified"
REFRESH_TOKEN_PREFIX = "refresh_token:"
ACCOUNT_REFRESH_TOKEN_PREFIX = "account_refresh_token:"
REFRESH_TOKEN_EXPIRY = timedelta(days=dify_config.REFRESH_TOKEN_EXPIRE_DAYS)
class AccountService:
CHANGE_EMAIL_TOKEN_PHASE_KEY = "email_change_phase"
CHANGE_EMAIL_PHASE_OLD = ChangeEmailPhase.OLD
CHANGE_EMAIL_PHASE_OLD_VERIFIED = ChangeEmailPhase.OLD_VERIFIED
CHANGE_EMAIL_PHASE_NEW = ChangeEmailPhase.NEW
CHANGE_EMAIL_PHASE_NEW_VERIFIED = ChangeEmailPhase.NEW_VERIFIED
reset_password_rate_limiter = RateLimiter(prefix="reset_password_rate_limit", max_attempts=1, time_window=60 * 1)
email_register_rate_limiter = RateLimiter(prefix="email_register_rate_limit", max_attempts=1, time_window=60 * 1)
email_code_login_rate_limiter = RateLimiter(
@@ -541,13 +555,20 @@ class AccountService:
raise ValueError("Email must be provided.")
if not phase:
raise ValueError("phase must be provided.")
if phase not in (cls.CHANGE_EMAIL_PHASE_OLD, cls.CHANGE_EMAIL_PHASE_NEW):
raise ValueError("phase must be one of old_email or new_email.")
if cls.change_email_rate_limiter.is_rate_limited(account_email):
from controllers.console.auth.error import EmailChangeRateLimitExceededError
raise EmailChangeRateLimitExceededError(int(cls.change_email_rate_limiter.time_window / 60))
code, token = cls.generate_change_email_token(account_email, account, old_email=old_email)
code, token = cls.generate_change_email_token(
account_email,
account,
old_email=old_email,
additional_data={cls.CHANGE_EMAIL_TOKEN_PHASE_KEY: phase},
)
send_change_mail_task.delay(
language=language,

View File

@@ -1,3 +1,4 @@
import json
import logging
import time
@@ -125,7 +126,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
data_source_info = document.data_source_info_dict
data_source_info["last_edited_time"] = last_edited_time
document.data_source_info = data_source_info
document.data_source_info = json.dumps(data_source_info)
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()

View File

@@ -12,8 +12,6 @@ from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from psycopg2.extensions import register_adapter
from psycopg2.extras import Json
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
@@ -21,12 +19,6 @@ from models.dataset import Dataset, Document, DocumentSegment
from tasks.document_indexing_sync_task import document_indexing_sync_task
@pytest.fixture(autouse=True)
def _register_dict_adapter_for_psycopg2():
"""Align test DB adapter behavior with dict payloads used in task update flow."""
register_adapter(dict, Json)
class DocumentIndexingSyncTaskTestDataFactory:
"""Create real DB entities for document indexing sync integration tests."""

View File

@@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch
import pytest
from flask import Flask, g
from controllers.console.auth.error import InvalidTokenError
from controllers.console.workspace.account import (
AccountDeleteUpdateFeedbackApi,
ChangeEmailCheckApi,
@@ -52,7 +53,7 @@ class TestChangeEmailSend:
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_normalize_new_email_phase(
def test_should_infer_new_email_phase_from_token(
self,
mock_features,
mock_csrf,
@@ -68,13 +69,16 @@ class TestChangeEmailSend:
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
mock_get_change_data.return_value = {"email": "current@example.com"}
mock_get_change_data.return_value = {
"email": "current@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED,
}
mock_send_email.return_value = "token-abc"
with app.test_request_context(
"/account/change-email",
method="POST",
json={"email": "New@Example.com", "language": "en-US", "phase": "new_email", "token": "token-123"},
json={"email": "New@Example.com", "language": "en-US", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
response = ChangeEmailSendEmailApi().post()
@@ -91,6 +95,107 @@ class TestChangeEmailSend:
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.db")
@patch("controllers.console.workspace.account.Session")
@patch("controllers.console.workspace.account.AccountService.get_account_by_email_with_case_fallback")
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_ignore_client_phase_and_use_old_phase_when_token_missing(
self,
mock_features,
mock_csrf,
mock_extract_ip,
mock_is_ip_limit,
mock_send_email,
mock_get_account_by_email,
mock_session_cls,
mock_account_db,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_current_account.return_value = (_build_account("current@example.com", "current"), None)
existing_account = _build_account("old@example.com", "acc-old")
mock_get_account_by_email.return_value = existing_account
mock_send_email.return_value = "token-legacy"
mock_session = MagicMock()
mock_session_cm = MagicMock()
mock_session_cm.__enter__.return_value = mock_session
mock_session_cm.__exit__.return_value = None
mock_session_cls.return_value = mock_session_cm
mock_account_db.engine = MagicMock()
with app.test_request_context(
"/account/change-email",
method="POST",
json={"email": "old@example.com", "language": "en-US", "phase": "new_email"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
response = ChangeEmailSendEmailApi().post()
assert response == {"result": "success", "data": "token-legacy"}
mock_get_account_by_email.assert_called_once_with("old@example.com", session=mock_session)
mock_send_email.assert_called_once_with(
account=existing_account,
email="old@example.com",
old_email="old@example.com",
language="en-US",
phase=AccountService.CHANGE_EMAIL_PHASE_OLD,
)
mock_extract_ip.assert_called_once()
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.send_change_email_email")
@patch("controllers.console.workspace.account.AccountService.is_email_send_ip_limit", return_value=False)
@patch("controllers.console.workspace.account.extract_remote_ip", return_value="127.0.0.1")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_unverified_old_email_token_for_new_email_phase(
self,
mock_features,
mock_csrf,
mock_extract_ip,
mock_is_ip_limit,
mock_send_email,
mock_get_change_data,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("current@example.com", "acc1")
mock_current_account.return_value = (mock_account, None)
mock_get_change_data.return_value = {
"email": "current@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
}
with app.test_request_context(
"/account/change-email",
method="POST",
json={"email": "New@Example.com", "language": "en-US", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
with pytest.raises(InvalidTokenError):
ChangeEmailSendEmailApi().post()
mock_send_email.assert_not_called()
mock_extract_ip.assert_called_once()
mock_is_ip_limit.assert_called_once_with("127.0.0.1")
mock_csrf.assert_called_once()
class TestChangeEmailValidity:
@patch("controllers.console.wraps.db")
@@ -122,7 +227,12 @@ class TestChangeEmailValidity:
mock_account = _build_account("user@example.com", "acc2")
mock_current_account.return_value = (mock_account, None)
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {"email": "user@example.com", "code": "1234", "old_email": "old@example.com"}
mock_get_data.return_value = {
"email": "user@example.com",
"code": "1234",
"old_email": "old@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
}
mock_generate_token.return_value = (None, "new-token")
with app.test_request_context(
@@ -138,11 +248,76 @@ class TestChangeEmailValidity:
mock_add_rate.assert_not_called()
mock_revoke_token.assert_called_once_with("token-123")
mock_generate_token.assert_called_once_with(
"user@example.com", code="1234", old_email="old@example.com", additional_data={}
"user@example.com",
code="1234",
old_email="old@example.com",
additional_data={
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD_VERIFIED
},
)
mock_reset_rate.assert_called_once_with("user@example.com")
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.reset_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.generate_change_email_token")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.add_change_email_error_rate_limit")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.is_change_email_error_rate_limit")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_refresh_new_email_phase_to_verified(
self,
mock_features,
mock_csrf,
mock_is_rate_limit,
mock_get_data,
mock_add_rate,
mock_revoke_token,
mock_generate_token,
mock_reset_rate,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
mock_account = _build_account("old@example.com", "acc2")
mock_current_account.return_value = (mock_account, None)
mock_is_rate_limit.return_value = False
mock_get_data.return_value = {
"email": "new@example.com",
"code": "5678",
"old_email": "old@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW,
}
mock_generate_token.return_value = (None, "new-phase-token")
with app.test_request_context(
"/account/change-email/validity",
method="POST",
json={"email": "New@Example.com", "code": "5678", "token": "token-456"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
response = ChangeEmailCheckApi().post()
assert response == {"is_valid": True, "email": "new@example.com", "token": "new-phase-token"}
mock_is_rate_limit.assert_called_once_with("new@example.com")
mock_add_rate.assert_not_called()
mock_revoke_token.assert_called_once_with("token-456")
mock_generate_token.assert_called_once_with(
"new@example.com",
code="5678",
old_email="old@example.com",
additional_data={
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED
},
)
mock_reset_rate.assert_called_once_with("new@example.com")
mock_csrf.assert_called_once()
class TestChangeEmailReset:
@patch("controllers.console.wraps.db")
@@ -175,7 +350,11 @@ class TestChangeEmailReset:
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
mock_get_data.return_value = {"old_email": "OLD@example.com"}
mock_get_data.return_value = {
"old_email": "OLD@example.com",
"email": "new@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
}
mock_account_after_update = _build_account("new@example.com", "acc3-updated")
mock_update_account.return_value = mock_account_after_update
@@ -194,6 +373,106 @@ class TestChangeEmailReset:
mock_send_notify.assert_called_once_with(email="new@example.com")
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
@patch("controllers.console.workspace.account.AccountService.update_account_email")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_old_phase_token_for_reset(
self,
mock_features,
mock_csrf,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
mock_get_data.return_value = {
"old_email": "OLD@example.com",
"email": "old@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_OLD,
}
with app.test_request_context(
"/account/change-email/reset",
method="POST",
json={"new_email": "new@example.com", "token": "token-123"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
with pytest.raises(InvalidTokenError):
ChangeEmailResetApi().post()
mock_revoke_token.assert_not_called()
mock_update_account.assert_not_called()
mock_send_notify.assert_not_called()
mock_csrf.assert_called_once()
@patch("controllers.console.wraps.db")
@patch("controllers.console.workspace.account.current_account_with_tenant")
@patch("controllers.console.workspace.account.AccountService.send_change_email_completed_notify_email")
@patch("controllers.console.workspace.account.AccountService.update_account_email")
@patch("controllers.console.workspace.account.AccountService.revoke_change_email_token")
@patch("controllers.console.workspace.account.AccountService.get_change_email_data")
@patch("controllers.console.workspace.account.AccountService.check_email_unique")
@patch("controllers.console.workspace.account.AccountService.is_account_in_freeze")
@patch("libs.login.check_csrf_token", return_value=None)
@patch("controllers.console.wraps.FeatureService.get_system_features")
def test_should_reject_mismatched_new_email_for_verified_token(
self,
mock_features,
mock_csrf,
mock_is_freeze,
mock_check_unique,
mock_get_data,
mock_revoke_token,
mock_update_account,
mock_send_notify,
mock_current_account,
mock_db,
app,
):
_mock_wraps_db(mock_db)
mock_features.return_value = SimpleNamespace(enable_change_email=True)
current_user = _build_account("old@example.com", "acc3")
mock_current_account.return_value = (current_user, None)
mock_is_freeze.return_value = False
mock_check_unique.return_value = True
mock_get_data.return_value = {
"old_email": "OLD@example.com",
"email": "another@example.com",
AccountService.CHANGE_EMAIL_TOKEN_PHASE_KEY: AccountService.CHANGE_EMAIL_PHASE_NEW_VERIFIED,
}
with app.test_request_context(
"/account/change-email/reset",
method="POST",
json={"new_email": "new@example.com", "token": "token-789"},
):
_set_logged_in_user(_build_account("tester@example.com", "tester"))
with pytest.raises(InvalidTokenError):
ChangeEmailResetApi().post()
mock_revoke_token.assert_not_called()
mock_update_account.assert_not_called()
mock_send_notify.assert_not_called()
mock_csrf.assert_called_once()
class TestAccountDeletionFeedback:
@patch("controllers.console.wraps.db")

View File

@@ -3,7 +3,6 @@
from __future__ import annotations
import queue
import threading
from unittest import mock
from core.workflow.entities.pause_reason import SchedulingPause
@@ -37,7 +36,6 @@ def test_dispatcher_should_consume_remains_events_after_pause():
event_queue=event_queue,
event_handler=event_handler,
execution_coordinator=execution_coordinator,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()
assert event_queue.empty()
@@ -98,7 +96,6 @@ def _run_dispatcher_for_event(event) -> int:
event_queue=event_queue,
event_handler=event_handler,
execution_coordinator=coordinator,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()
@@ -184,7 +181,6 @@ def test_dispatcher_drain_event_queue():
event_queue=event_queue,
event_handler=event_handler,
execution_coordinator=coordinator,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()

View File

@@ -1,5 +1,4 @@
import queue
import threading
from datetime import datetime
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
@@ -65,7 +64,6 @@ def test_dispatcher_drains_events_when_paused() -> None:
event_handler=handler,
execution_coordinator=coordinator,
event_emitter=None,
stop_event=threading.Event(),
)
dispatcher._dispatcher_loop()

View File

@@ -1,550 +0,0 @@
"""
Unit tests for stop_event functionality in GraphEngine.
Tests the unified stop_event management by GraphEngine and its propagation
to WorkerPool, Worker, Dispatcher, and Nodes.
"""
import threading
import time
from unittest.mock import MagicMock, Mock, patch
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.graph import Graph
from core.workflow.graph_engine import GraphEngine, GraphEngineConfig
from core.workflow.graph_engine.command_channels import InMemoryChannel
from core.workflow.graph_events import (
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunStartedEvent,
)
from core.workflow.nodes.answer.answer_node import AnswerNode
from core.workflow.nodes.start.start_node import StartNode
from core.workflow.runtime import GraphRuntimeState, VariablePool
from models.enums import UserFrom
class TestStopEventPropagation:
"""Test suite for stop_event propagation through GraphEngine components."""
def test_graph_engine_creates_stop_event(self):
"""Test that GraphEngine creates a stop_event on initialization."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Verify stop_event was created
assert engine._stop_event is not None
assert isinstance(engine._stop_event, threading.Event)
# Verify it was set in graph_runtime_state
assert runtime_state.stop_event is not None
assert runtime_state.stop_event is engine._stop_event
def test_stop_event_cleared_on_start(self):
"""Test that stop_event is cleared when execution starts."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start" # Set proper id
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Set the stop_event before running
engine._stop_event.set()
assert engine._stop_event.is_set()
# Run the engine (should clear the stop_event)
events = list(engine.run())
# After running, stop_event should be set again (by _stop_execution)
# But during start it was cleared
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
def test_stop_event_set_on_stop(self):
"""Test that stop_event is set when execution stops."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start" # Set proper id
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Initially not set
assert not engine._stop_event.is_set()
# Run the engine
list(engine.run())
# After execution completes, stop_event should be set
assert engine._stop_event.is_set()
def test_stop_event_passed_to_worker_pool(self):
"""Test that stop_event is passed to WorkerPool."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Verify WorkerPool has the stop_event
assert engine._worker_pool._stop_event is not None
assert engine._worker_pool._stop_event is engine._stop_event
def test_stop_event_passed_to_dispatcher(self):
"""Test that stop_event is passed to Dispatcher."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Verify Dispatcher has the stop_event
assert engine._dispatcher._stop_event is not None
assert engine._dispatcher._stop_event is engine._stop_event
class TestNodeStopCheck:
"""Test suite for Node._should_stop() functionality."""
def test_node_should_stop_checks_runtime_state(self):
"""Test that Node._should_stop() checks GraphRuntimeState.stop_event."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
answer_node = AnswerNode(
id="answer",
config={"id": "answer", "data": {"title": "answer", "answer": "{{#start.result#}}"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
# Initially stop_event is not set
assert not answer_node._should_stop()
# Set the stop_event
runtime_state.stop_event.set()
# Now _should_stop should return True
assert answer_node._should_stop()
def test_node_run_checks_stop_event_between_yields(self):
"""Test that Node.run() checks stop_event between yielding events."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
# Create a simple node
answer_node = AnswerNode(
id="answer",
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
# Set stop_event BEFORE running the node
runtime_state.stop_event.set()
# Run the node - should yield start event then detect stop
# The node should check stop_event before processing
assert answer_node._should_stop(), "stop_event should be set"
# Run and collect events
events = list(answer_node.run())
# Since stop_event is set at the start, we should get:
# 1. NodeRunStartedEvent (always yielded first)
# 2. Either NodeRunFailedEvent (if detected early) or NodeRunSucceededEvent (if too fast)
assert len(events) >= 2
assert isinstance(events[0], NodeRunStartedEvent)
# Note: AnswerNode is very simple and might complete before stop check
# The important thing is that _should_stop() returns True when stop_event is set
assert answer_node._should_stop()
class TestStopEventIntegration:
"""Integration tests for stop_event in workflow execution."""
def test_simple_workflow_respects_stop_event(self):
"""Test that a simple workflow respects stop_event."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start"
# Create start and answer nodes
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
answer_node = AnswerNode(
id="answer",
config={"id": "answer", "data": {"title": "answer", "answer": "hello"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.nodes["answer"] = answer_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Set stop_event before running
runtime_state.stop_event.set()
# Run the engine
events = list(engine.run())
# Should get started event but not succeeded (due to stop)
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
# The workflow should still complete (start node runs quickly)
# but answer node might be cancelled depending on timing
def test_stop_event_with_concurrent_nodes(self):
"""Test stop_event behavior with multiple concurrent nodes."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
# Create multiple nodes
for i in range(3):
answer_node = AnswerNode(
id=f"answer_{i}",
config={"id": f"answer_{i}", "data": {"title": f"answer_{i}", "answer": f"test{i}"}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes[f"answer_{i}"] = answer_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# All nodes should share the same stop_event
for node in mock_graph.nodes.values():
assert node.graph_runtime_state.stop_event is runtime_state.stop_event
assert node.graph_runtime_state.stop_event is engine._stop_event
class TestStopEventTimeoutBehavior:
"""Test stop_event behavior with join timeouts."""
@patch("core.workflow.graph_engine.orchestration.dispatcher.threading.Thread", autospec=True)
def test_dispatcher_uses_shorter_timeout(self, mock_thread_cls: MagicMock):
"""Test that Dispatcher uses 2s timeout instead of 10s."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
dispatcher = engine._dispatcher
dispatcher.start() # This will create and start the mocked thread
mock_thread_instance = mock_thread_cls.return_value
mock_thread_instance.is_alive.return_value = True
dispatcher.stop()
mock_thread_instance.join.assert_called_once_with(timeout=2.0)
@patch("core.workflow.graph_engine.worker_management.worker_pool.Worker", autospec=True)
def test_worker_pool_uses_shorter_timeout(self, mock_worker_cls: MagicMock):
"""Test that WorkerPool uses 2s timeout instead of 10s."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
worker_pool = engine._worker_pool
worker_pool.start(initial_count=1) # Start with one worker
mock_worker_instance = mock_worker_cls.return_value
mock_worker_instance.is_alive.return_value = True
worker_pool.stop()
mock_worker_instance.join.assert_called_once_with(timeout=2.0)
class TestStopEventResumeBehavior:
"""Test stop_event behavior during workflow resume."""
def test_stop_event_cleared_on_resume(self):
"""Test that stop_event is cleared when resuming a paused workflow."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
mock_graph.root_node.id = "start" # Set proper id
start_node = StartNode(
id="start",
config={"id": "start", "data": {"title": "start", "variables": []}},
graph_init_params=GraphInitParams(
tenant_id="test_tenant",
app_id="test_app",
workflow_id="test_workflow",
graph_config={},
user_id="test_user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=0,
),
graph_runtime_state=runtime_state,
)
mock_graph.nodes["start"] = start_node
mock_graph.get_outgoing_edges = MagicMock(return_value=[])
mock_graph.get_incoming_edges = MagicMock(return_value=[])
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Simulate a previous execution that set stop_event
engine._stop_event.set()
assert engine._stop_event.is_set()
# Run the engine (should clear stop_event in _start_execution)
events = list(engine.run())
# Execution should complete successfully
assert any(isinstance(e, GraphRunStartedEvent) for e in events)
assert any(isinstance(e, GraphRunSucceededEvent) for e in events)
class TestWorkerStopBehavior:
"""Test Worker behavior with shared stop_event."""
def test_worker_uses_shared_stop_event(self):
"""Test that Worker uses shared stop_event from GraphEngine."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
mock_graph = MagicMock(spec=Graph)
mock_graph.nodes = {}
mock_graph.edges = {}
mock_graph.root_node = MagicMock()
engine = GraphEngine(
workflow_id="test_workflow",
graph=mock_graph,
graph_runtime_state=runtime_state,
command_channel=InMemoryChannel(),
config=GraphEngineConfig(),
)
# Get the worker pool and check workers
worker_pool = engine._worker_pool
# Start the worker pool to create workers
worker_pool.start()
# Check that at least one worker was created
assert len(worker_pool._workers) > 0
# Verify workers use the shared stop_event
for worker in worker_pool._workers:
assert worker._stop_event is engine._stop_event
# Clean up
worker_pool.stop()
def test_worker_stop_is_noop(self):
"""Test that Worker.stop() is now a no-op."""
runtime_state = GraphRuntimeState(variable_pool=VariablePool(), start_at=time.perf_counter())
# Create a mock worker
from core.workflow.graph_engine.ready_queue import InMemoryReadyQueue
from core.workflow.graph_engine.worker import Worker
ready_queue = InMemoryReadyQueue()
event_queue = MagicMock()
# Create a proper mock graph with real dict
mock_graph = Mock(spec=Graph)
mock_graph.nodes = {} # Use real dict
stop_event = threading.Event()
worker = Worker(
ready_queue=ready_queue,
event_queue=event_queue,
graph=mock_graph,
layers=[],
stop_event=stop_event,
)
# Calling stop() should do nothing (no-op)
# and should NOT set the stop_event
worker.stop()
assert not stop_event.is_set()

View File

@@ -5,6 +5,7 @@ These tests intentionally stay in unit scope because they validate call argument
for external collaborators rather than SQL-backed state transitions.
"""
import json
import uuid
from unittest.mock import MagicMock, Mock, patch
@@ -196,3 +197,78 @@ class TestDocumentIndexingSyncTaskCollaboratorParams:
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
class TestDataSourceInfoSerialization:
"""Regression test: data_source_info must be written as a JSON string, not a raw dict.
See https://github.com/langgenius/dify/issues/32705
psycopg2 raises ``ProgrammingError: can't adapt type 'dict'`` when a Python
dict is passed directly to a text/LongText column.
"""
def test_data_source_info_serialized_as_json_string(
self,
mock_document,
mock_dataset,
dataset_id,
document_id,
):
"""data_source_info must be serialized with json.dumps before DB write."""
with (
patch("tasks.document_indexing_sync_task.session_factory") as mock_session_factory,
patch("tasks.document_indexing_sync_task.DatasourceProviderService") as mock_service_class,
patch("tasks.document_indexing_sync_task.NotionExtractor") as mock_extractor_class,
patch("tasks.document_indexing_sync_task.IndexProcessorFactory") as mock_ipf,
patch("tasks.document_indexing_sync_task.IndexingRunner") as mock_runner_class,
):
# External collaborators
mock_service = MagicMock()
mock_service.get_datasource_credentials.return_value = {"integration_secret": "token"}
mock_service_class.return_value = mock_service
mock_extractor = MagicMock()
# Return a *different* timestamp so the task enters the sync/update branch
mock_extractor.get_notion_last_edited_time.return_value = "2024-02-01T00:00:00Z"
mock_extractor_class.return_value = mock_extractor
mock_ip = MagicMock()
mock_ipf.return_value.init_index_processor.return_value = mock_ip
mock_runner = MagicMock()
mock_runner_class.return_value = mock_runner
# DB session mock — shared across all ``session_factory.create_session()`` calls
session = MagicMock()
session.scalars.return_value.all.return_value = []
# .where() path: session 1 reads document + dataset, session 2 reads dataset
session.query.return_value.where.return_value.first.side_effect = [
mock_document,
mock_dataset,
mock_dataset,
]
# .filter_by() path: session 3 (update), session 4 (indexing)
session.query.return_value.filter_by.return_value.first.side_effect = [
mock_document,
mock_document,
]
begin_cm = MagicMock()
begin_cm.__enter__.return_value = session
begin_cm.__exit__.return_value = False
session.begin.return_value = begin_cm
session_cm = MagicMock()
session_cm.__enter__.return_value = session
session_cm.__exit__.return_value = False
mock_session_factory.create_session.return_value = session_cm
# Act
document_indexing_sync_task(dataset_id, document_id)
# Assert: data_source_info must be a JSON *string*, not a dict
assert isinstance(mock_document.data_source_info, str), (
f"data_source_info should be a JSON string, got {type(mock_document.data_source_info).__name__}"
)
parsed = json.loads(mock_document.data_source_info)
assert parsed["last_edited_time"] == "2024-02-01T00:00:00Z"

View File

@@ -58,11 +58,10 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
}, 1000)
}
const sendEmail = async (email: string, isOrigin: boolean, token?: string) => {
const sendEmail = async (email: string, token?: string) => {
try {
const res = await sendVerifyCode({
email,
phase: isOrigin ? 'old_email' : 'new_email',
token,
})
startCount()
@@ -106,7 +105,6 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
const sendCodeToOriginEmail = async () => {
await sendEmail(
email,
true,
)
setStep(STEP.verifyOrigin)
}
@@ -162,7 +160,6 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
}
await sendEmail(
mail,
false,
stepToken,
)
setStep(STEP.verifyNew)

View File

@@ -372,7 +372,7 @@ export const submitDeleteAccountFeedback = (body: { feedback: string, email: str
export const getDocDownloadUrl = (doc_name: string): Promise<{ url: string }> =>
get<{ url: string }>('/compliance/download', { params: { doc_name } }, { silent: true })
export const sendVerifyCode = (body: { email: string, phase: string, token?: string }): Promise<CommonResponse & { data: string }> =>
export const sendVerifyCode = (body: { email: string, token?: string }): Promise<CommonResponse & { data: string }> =>
post<CommonResponse & { data: string }>('/account/change-email', { body })
export const verifyEmail = (body: { email: string, code: string, token: string }): Promise<CommonResponse & { is_valid: boolean, email: string, token: string }> =>