Compare commits

..

2 Commits

Author SHA1 Message Date
L1nSn0w
d6abc7f52c feat(api): enhance account registration process with improved error handling
Implemented better error handling during account addition to the default workspace for enterprise users, ensuring smoother user registration experience even when workspace joining fails.
2026-02-13 17:47:19 +08:00
L1nSn0w
330857dbb2 feat(api): implement best-effort account addition to default workspace for enterprise users
Added functionality to attempt adding accounts to the default workspace during account registration and creation processes. This includes a new method in the enterprise service to handle the workspace joining logic, ensuring it does not block user registration on failure.
2026-02-13 17:47:19 +08:00
5 changed files with 201 additions and 210 deletions

View File

@@ -3,15 +3,13 @@ import datetime
import json
import logging
import secrets
import threading
import time
from typing import TYPE_CHECKING, Any
from typing import Any
import click
import sqlalchemy as sa
from flask import current_app
from pydantic import TypeAdapter
from redis.exceptions import LockNotOwnedError, RedisError
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
@@ -56,35 +54,6 @@ from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from redis.lock import Lock
DB_UPGRADE_LOCK_TTL_SECONDS = 60
def _heartbeat_db_upgrade_lock(lock: "Lock", stop_event: threading.Event, ttl_seconds: float) -> None:
"""
Keep the DB upgrade lock alive while migrations are running.
We intentionally keep the base TTL small (e.g. 60s) so that if the process is killed and can't
release the lock, the lock will naturally expire soon. While the process is alive, this
heartbeat periodically resets the TTL via `lock.reacquire()`.
"""
interval_seconds = max(0.1, ttl_seconds / 3)
while not stop_event.wait(interval_seconds):
try:
lock.reacquire()
except LockNotOwnedError:
# Another process took over / TTL expired; continuing to retry won't help.
logger.warning("DB migration lock is no longer owned during heartbeat; stop renewing.")
return
except RedisError:
# Best-effort: keep trying while the process is alive.
logger.warning("Failed to renew DB migration lock due to Redis error; will retry.", exc_info=True)
except Exception:
logger.warning("Unexpected error while renewing DB migration lock; will retry.", exc_info=True)
@click.command("reset-password", help="Reset the account password.")
@click.option("--email", prompt=True, help="Account email to reset password for")
@@ -758,22 +727,8 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
@click.command("upgrade-db", help="Upgrade the database")
def upgrade_db():
click.echo("Preparing database migration...")
# Use a short base TTL + heartbeat renewal, so a crashed process doesn't block migrations for long.
# thread_local=False is required because heartbeat runs in a separate thread.
lock = redis_client.lock(
name="db_upgrade_lock",
timeout=DB_UPGRADE_LOCK_TTL_SECONDS,
thread_local=False,
)
lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
if lock.acquire(blocking=False):
stop_event = threading.Event()
heartbeat_thread = threading.Thread(
target=_heartbeat_db_upgrade_lock,
args=(lock, stop_event, float(DB_UPGRADE_LOCK_TTL_SECONDS)),
daemon=True,
)
heartbeat_thread.start()
migration_succeeded = False
try:
click.echo(click.style("Starting database migration.", fg="green"))
@@ -782,7 +737,6 @@ def upgrade_db():
flask_migrate.upgrade()
migration_succeeded = True
click.echo(click.style("Database migration successful!", fg="green"))
except Exception as e:
@@ -790,23 +744,7 @@ def upgrade_db():
click.echo(click.style(f"Database migration failed: {e}", fg="red"))
raise SystemExit(1)
finally:
stop_event.set()
heartbeat_thread.join(timeout=5)
# Lock release errors should never mask the real migration failure.
try:
lock.release()
except LockNotOwnedError:
status = "successful" if migration_succeeded else "failed"
logger.warning(
"DB migration lock not owned on release after %s migration (likely expired); ignoring.", status
)
except RedisError:
status = "successful" if migration_succeeded else "failed"
logger.warning(
"Failed to release DB migration lock due to Redis error after %s migration; ignoring.",
status,
exc_info=True,
)
lock.release()
else:
click.echo("Database migration skipped")

View File

@@ -289,6 +289,11 @@ class AccountService:
TenantService.create_owner_tenant_if_not_exist(account=account)
# Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace).
from services.enterprise.enterprise_service import try_join_default_workspace
try_join_default_workspace(str(account.id))
return account
@staticmethod
@@ -1407,6 +1412,11 @@ class RegisterService:
tenant_was_created.send(tenant)
db.session.commit()
# Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace).
from services.enterprise.enterprise_service import try_join_default_workspace
try_join_default_workspace(str(account.id))
except WorkSpaceNotAllowedCreateError:
db.session.rollback()
logger.exception("Register failed")

View File

@@ -1,9 +1,14 @@
import logging
import uuid
from datetime import datetime
from pydantic import BaseModel, Field
from configs import dify_config
from services.enterprise.base import EnterpriseRequest
logger = logging.getLogger(__name__)
class WebAppSettings(BaseModel):
access_mode: str = Field(
@@ -30,6 +35,47 @@ class WorkspacePermission(BaseModel):
)
class DefaultWorkspaceJoinResult(BaseModel):
"""
Result of ensuring an account is a member of the enterprise default workspace.
- joined=True is idempotent (already a member also returns True)
- joined=False means enterprise default workspace is not configured or invalid/archived
"""
workspace_id: str = ""
joined: bool = False
message: str = ""
def try_join_default_workspace(account_id: str) -> None:
"""
Enterprise-only side-effect: ensure account is a member of the default workspace.
This is a best-effort integration. Failures must not block user registration.
"""
if not dify_config.ENTERPRISE_ENABLED:
return
try:
result = EnterpriseService.join_default_workspace(account_id=account_id)
if result.joined:
logger.info(
"Joined enterprise default workspace for account %s (workspace_id=%s)",
account_id,
result.workspace_id,
)
else:
logger.info(
"Skipped joining enterprise default workspace for account %s (message=%s)",
account_id,
result.message,
)
except Exception:
logger.warning("Failed to join enterprise default workspace for account %s", account_id, exc_info=True)
class EnterpriseService:
@classmethod
def get_info(cls):
@@ -39,6 +85,23 @@ class EnterpriseService:
def get_workspace_info(cls, tenant_id: str):
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
@classmethod
def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult:
"""
Call enterprise inner API to add an account to the default workspace.
NOTE: EnterpriseRequest.base_url is expected to already include the `/inner/api` prefix,
so the endpoint here is `/default-workspace/members`.
"""
# Ensure we are sending a UUID-shaped string (enterprise side validates too).
uuid.UUID(account_id)
data = EnterpriseRequest.send_request("POST", "/default-workspace/members", json={"account_id": account_id})
if not isinstance(data, dict):
raise ValueError("Invalid response format from enterprise default workspace API")
return DefaultWorkspaceJoinResult.model_validate(data)
@classmethod
def get_app_sso_settings_last_update_time(cls) -> datetime:
data = EnterpriseRequest.send_request("GET", "/sso/app/last-update-time")

View File

@@ -1,145 +0,0 @@
import sys
import threading
import types
from unittest.mock import MagicMock
import commands
HEARTBEAT_WAIT_TIMEOUT_SECONDS = 1.0
def _install_fake_flask_migrate(monkeypatch, upgrade_impl) -> None:
module = types.ModuleType("flask_migrate")
module.upgrade = upgrade_impl
monkeypatch.setitem(sys.modules, "flask_migrate", module)
def _invoke_upgrade_db() -> int:
try:
commands.upgrade_db.callback()
except SystemExit as e:
return int(e.code or 0)
return 0
def test_upgrade_db_skips_when_lock_not_acquired(monkeypatch, capsys):
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 1234)
lock = MagicMock()
lock.acquire.return_value = False
commands.redis_client.lock.return_value = lock
exit_code = _invoke_upgrade_db()
captured = capsys.readouterr()
assert exit_code == 0
assert "Database migration skipped" in captured.out
commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=1234, thread_local=False)
lock.acquire.assert_called_once_with(blocking=False)
lock.release.assert_not_called()
def test_upgrade_db_failure_not_masked_by_lock_release(monkeypatch, capsys):
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 321)
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = commands.LockNotOwnedError("simulated")
commands.redis_client.lock.return_value = lock
def _upgrade():
raise RuntimeError("boom")
_install_fake_flask_migrate(monkeypatch, _upgrade)
exit_code = _invoke_upgrade_db()
captured = capsys.readouterr()
assert exit_code == 1
assert "Database migration failed: boom" in captured.out
commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=321, thread_local=False)
lock.acquire.assert_called_once_with(blocking=False)
lock.release.assert_called_once()
def test_upgrade_db_success_ignores_lock_not_owned_on_release(monkeypatch, capsys):
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 999)
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = commands.LockNotOwnedError("simulated")
commands.redis_client.lock.return_value = lock
_install_fake_flask_migrate(monkeypatch, lambda: None)
exit_code = _invoke_upgrade_db()
captured = capsys.readouterr()
assert exit_code == 0
assert "Database migration successful!" in captured.out
commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=999, thread_local=False)
lock.acquire.assert_called_once_with(blocking=False)
lock.release.assert_called_once()
def test_upgrade_db_renews_lock_during_migration(monkeypatch, capsys):
"""
Ensure the lock is renewed while migrations are running, so the base TTL can stay short.
"""
# Use a small TTL so the heartbeat interval triggers quickly.
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3)
lock = MagicMock()
lock.acquire.return_value = True
commands.redis_client.lock.return_value = lock
renewed = threading.Event()
def _reacquire():
renewed.set()
return True
lock.reacquire.side_effect = _reacquire
def _upgrade():
assert renewed.wait(HEARTBEAT_WAIT_TIMEOUT_SECONDS)
_install_fake_flask_migrate(monkeypatch, _upgrade)
exit_code = _invoke_upgrade_db()
_ = capsys.readouterr()
assert exit_code == 0
assert lock.reacquire.call_count >= 1
def test_upgrade_db_ignores_reacquire_errors(monkeypatch, capsys):
# Use a small TTL so heartbeat runs during the upgrade call.
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3)
lock = MagicMock()
lock.acquire.return_value = True
commands.redis_client.lock.return_value = lock
attempted = threading.Event()
def _reacquire():
attempted.set()
raise commands.RedisError("simulated")
lock.reacquire.side_effect = _reacquire
def _upgrade():
assert attempted.wait(HEARTBEAT_WAIT_TIMEOUT_SECONDS)
_install_fake_flask_migrate(monkeypatch, _upgrade)
exit_code = _invoke_upgrade_db()
_ = capsys.readouterr()
assert exit_code == 0
assert lock.reacquire.call_count >= 1

View File

@@ -0,0 +1,125 @@
"""Unit tests for enterprise service integrations.
This module covers the enterprise-only default workspace auto-join behavior:
- Enterprise mode disabled: no external calls
- Successful join / skipped join: no errors
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
"""
from unittest.mock import patch
import pytest
from services.enterprise.enterprise_service import (
DefaultWorkspaceJoinResult,
EnterpriseService,
try_join_default_workspace,
)
class TestJoinDefaultWorkspace:
def test_join_default_workspace_success(self):
account_id = "11111111-1111-1111-1111-111111111111"
response = {"workspace_id": "22222222-2222-2222-2222-222222222222", "joined": True, "message": "ok"}
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
mock_send_request.return_value = response
result = EnterpriseService.join_default_workspace(account_id=account_id)
assert isinstance(result, DefaultWorkspaceJoinResult)
assert result.workspace_id == response["workspace_id"]
assert result.joined is True
assert result.message == "ok"
mock_send_request.assert_called_once_with(
"POST",
"/default-workspace/members",
json={"account_id": account_id},
)
def test_join_default_workspace_invalid_response_format_raises(self):
account_id = "11111111-1111-1111-1111-111111111111"
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
mock_send_request.return_value = "not-a-dict"
with pytest.raises(ValueError, match="Invalid response format"):
EnterpriseService.join_default_workspace(account_id=account_id)
def test_join_default_workspace_invalid_account_id_raises(self):
with pytest.raises(ValueError):
EnterpriseService.join_default_workspace(account_id="not-a-uuid")
class TestTryJoinDefaultWorkspace:
def test_try_join_default_workspace_enterprise_disabled_noop(self):
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = False
try_join_default_workspace("11111111-1111-1111-1111-111111111111")
mock_join.assert_not_called()
def test_try_join_default_workspace_successful_join_does_not_raise(self):
account_id = "11111111-1111-1111-1111-111111111111"
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = True
mock_join.return_value = DefaultWorkspaceJoinResult(
workspace_id="22222222-2222-2222-2222-222222222222",
joined=True,
message="ok",
)
# Should not raise
try_join_default_workspace(account_id)
mock_join.assert_called_once_with(account_id=account_id)
def test_try_join_default_workspace_skipped_join_does_not_raise(self):
account_id = "11111111-1111-1111-1111-111111111111"
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = True
mock_join.return_value = DefaultWorkspaceJoinResult(
workspace_id="",
joined=False,
message="no default workspace configured",
)
# Should not raise
try_join_default_workspace(account_id)
mock_join.assert_called_once_with(account_id=account_id)
def test_try_join_default_workspace_api_failure_soft_fails(self):
account_id = "11111111-1111-1111-1111-111111111111"
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = True
mock_join.side_effect = Exception("network failure")
# Should not raise
try_join_default_workspace(account_id)
mock_join.assert_called_once_with(account_id=account_id)
def test_try_join_default_workspace_invalid_account_id_soft_fails(self):
with patch("services.enterprise.enterprise_service.dify_config") as mock_config:
mock_config.ENTERPRISE_ENABLED = True
# Should not raise even though UUID parsing fails inside join_default_workspace
try_join_default_workspace("not-a-uuid")