diff --git a/api/commands.py b/api/commands.py index d40cf58e2a..f7af5a5df2 100644 --- a/api/commands.py +++ b/api/commands.py @@ -3,15 +3,13 @@ import datetime import json import logging import secrets -import threading import time -from typing import TYPE_CHECKING, Any +from typing import Any import click import sqlalchemy as sa from flask import current_app from pydantic import TypeAdapter -from redis.exceptions import LockNotOwnedError, RedisError from sqlalchemy import select from sqlalchemy.exc import SQLAlchemyError from sqlalchemy.orm import sessionmaker @@ -32,6 +30,7 @@ from extensions.ext_redis import redis_client from extensions.ext_storage import storage from extensions.storage.opendal_storage import OpenDALStorage from extensions.storage.storage_type import StorageType +from libs.auto_renew_redis_lock import AutoRenewRedisLock from libs.helper import email as email_validate from libs.password import hash_password, password_pattern, valid_password from libs.rsa import generate_key_pair @@ -56,36 +55,9 @@ from tasks.remove_app_and_related_data_task import delete_draft_variables_batch logger = logging.getLogger(__name__) -if TYPE_CHECKING: - from redis.lock import Lock - DB_UPGRADE_LOCK_TTL_SECONDS = 60 -def _heartbeat_db_upgrade_lock(lock: "Lock", stop_event: threading.Event, ttl_seconds: float) -> None: - """ - Keep the DB upgrade lock alive while migrations are running. - - We intentionally keep the base TTL small (e.g. 60s) so that if the process is killed and can't - release the lock, the lock will naturally expire soon. While the process is alive, this - heartbeat periodically resets the TTL via `lock.reacquire()`. - """ - - interval_seconds = max(0.1, ttl_seconds / 3) - while not stop_event.wait(interval_seconds): - try: - lock.reacquire() - except LockNotOwnedError: - # Another process took over / TTL expired; continuing to retry won't help. - logger.warning("DB migration lock is no longer owned during heartbeat; stop renewing.") - return - except RedisError: - # Best-effort: keep trying while the process is alive. - logger.warning("Failed to renew DB migration lock due to Redis error; will retry.", exc_info=True) - except Exception: - logger.warning("Unexpected error while renewing DB migration lock; will retry.", exc_info=True) - - @click.command("reset-password", help="Reset the account password.") @click.option("--email", prompt=True, help="Account email to reset password for") @click.option("--new-password", prompt=True, help="New password") @@ -758,21 +730,14 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No @click.command("upgrade-db", help="Upgrade the database") def upgrade_db(): click.echo("Preparing database migration...") - # Use a short base TTL + heartbeat renewal, so a crashed process doesn't block migrations for long. - # thread_local=False is required because heartbeat runs in a separate thread. - lock = redis_client.lock( + lock = AutoRenewRedisLock( + redis_client=redis_client, name="db_upgrade_lock", - timeout=DB_UPGRADE_LOCK_TTL_SECONDS, - thread_local=False, + ttl_seconds=DB_UPGRADE_LOCK_TTL_SECONDS, + logger=logger, + log_context="db_migration", ) if lock.acquire(blocking=False): - stop_event = threading.Event() - heartbeat_thread = threading.Thread( - target=_heartbeat_db_upgrade_lock, - args=(lock, stop_event, float(DB_UPGRADE_LOCK_TTL_SECONDS)), - daemon=True, - ) - heartbeat_thread.start() migration_succeeded = False try: click.echo(click.style("Starting database migration.", fg="green")) @@ -790,23 +755,8 @@ def upgrade_db(): click.echo(click.style(f"Database migration failed: {e}", fg="red")) raise SystemExit(1) finally: - stop_event.set() - heartbeat_thread.join(timeout=5) - # Lock release errors should never mask the real migration failure. - try: - lock.release() - except LockNotOwnedError: - status = "successful" if migration_succeeded else "failed" - logger.warning( - "DB migration lock not owned on release after %s migration (likely expired); ignoring.", status - ) - except RedisError: - status = "successful" if migration_succeeded else "failed" - logger.warning( - "Failed to release DB migration lock due to Redis error after %s migration; ignoring.", - status, - exc_info=True, - ) + status = "successful" if migration_succeeded else "failed" + lock.release_safely(status=status) else: click.echo("Database migration skipped") diff --git a/api/libs/auto_renew_redis_lock.py b/api/libs/auto_renew_redis_lock.py new file mode 100644 index 0000000000..2d45c6bf26 --- /dev/null +++ b/api/libs/auto_renew_redis_lock.py @@ -0,0 +1,198 @@ +""" +Auto-renewing Redis distributed lock (redis-py Lock). + +Why this exists: +- A fixed, long lock TTL can leave a stale lock for a long time if the process is killed + before releasing it. +- A fixed, short lock TTL can expire during long critical sections (e.g. DB migrations), + allowing another instance to acquire the same lock concurrently. + +This wrapper keeps a short base TTL and renews it in a daemon thread using `Lock.reacquire()` +while the process is alive. If the process is terminated, the renewal stops and the lock +expires soon. +""" + +from __future__ import annotations + +import logging +import threading +from typing import Any + +from redis.exceptions import LockNotOwnedError, RedisError + +logger = logging.getLogger(__name__) + + +class AutoRenewRedisLock: + """ + Redis lock wrapper that automatically renews TTL while held. + + Notes: + - We force `thread_local=False` when creating the underlying redis-py lock, because the + lock token must be accessible from the heartbeat thread for `reacquire()` to work. + - `release_safely()` is best-effort: it never raises, so it won't mask the caller's + primary error/exit code. + """ + + _redis_client: Any + _name: str + _ttl_seconds: float + _renew_interval_seconds: float + _log_context: str | None + _logger: logging.Logger + + _lock: Any + _stop_event: threading.Event | None + _thread: threading.Thread | None + _acquired: bool + + def __init__( + self, + redis_client: Any, + name: str, + ttl_seconds: float = 60, + renew_interval_seconds: float | None = None, + *, + logger: logging.Logger | None = None, + log_context: str | None = None, + ) -> None: + self._redis_client = redis_client + self._name = name + self._ttl_seconds = float(ttl_seconds) + self._renew_interval_seconds = ( + float(renew_interval_seconds) if renew_interval_seconds is not None else max(0.1, self._ttl_seconds / 3) + ) + self._logger = logger or logging.getLogger(__name__) + self._log_context = log_context + + self._lock = None + self._stop_event = None + self._thread = None + self._acquired = False + + @property + def name(self) -> str: + return self._name + + def acquire(self, *args: Any, **kwargs: Any) -> bool: + """ + Acquire the lock and start auto-renew heartbeat on success. + + Accepts the same args/kwargs as redis-py `Lock.acquire()`. + """ + self._lock = self._redis_client.lock( + name=self._name, + timeout=self._ttl_seconds, + thread_local=False, + ) + acquired = bool(self._lock.acquire(*args, **kwargs)) + self._acquired = acquired + if acquired: + self._start_heartbeat() + return acquired + + def owned(self) -> bool: + if self._lock is None: + return False + try: + return bool(self._lock.owned()) + except Exception: + # Ownership checks are best-effort and must not break callers. + return False + + def _start_heartbeat(self) -> None: + if self._lock is None: + return + if self._stop_event is not None: + return + + self._stop_event = threading.Event() + self._thread = threading.Thread( + target=self._heartbeat_loop, + args=(self._lock, self._stop_event), + daemon=True, + name=f"AutoRenewRedisLock({self._name})", + ) + self._thread.start() + + def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None: + while not stop_event.wait(self._renew_interval_seconds): + try: + lock.reacquire() + except LockNotOwnedError: + self._logger.warning( + "Auto-renew lock is no longer owned during heartbeat%s; stop renewing.", + f" ({self._log_context})" if self._log_context else "", + exc_info=True, + ) + return + except RedisError: + self._logger.warning( + "Failed to renew auto-renew lock due to Redis error%s; will retry.", + f" ({self._log_context})" if self._log_context else "", + exc_info=True, + ) + except Exception: + self._logger.warning( + "Unexpected error while renewing auto-renew lock%s; will retry.", + f" ({self._log_context})" if self._log_context else "", + exc_info=True, + ) + + def release_safely(self, *, status: str | None = None) -> None: + """ + Stop heartbeat and release lock. Never raises. + + Args: + status: Optional caller-provided status (e.g. 'successful'/'failed') to add context to logs. + """ + lock = self._lock + if lock is None: + return + + self._stop_heartbeat() + + # Lock release errors should never mask the real error/exit code. + try: + lock.release() + except LockNotOwnedError: + self._logger.warning( + "Auto-renew lock not owned on release%s%s; ignoring.", + f" after {status} operation" if status else "", + f" ({self._log_context})" if self._log_context else "", + exc_info=True, + ) + except RedisError: + self._logger.warning( + "Failed to release auto-renew lock due to Redis error%s%s; ignoring.", + f" after {status} operation" if status else "", + f" ({self._log_context})" if self._log_context else "", + exc_info=True, + ) + except Exception: + self._logger.warning( + "Unexpected error while releasing auto-renew lock%s%s; ignoring.", + f" after {status} operation" if status else "", + f" ({self._log_context})" if self._log_context else "", + exc_info=True, + ) + finally: + self._acquired = False + + def _stop_heartbeat(self) -> None: + if self._stop_event is None: + return + self._stop_event.set() + if self._thread is not None: + # Best-effort join: if Redis calls are blocked, the daemon thread may remain alive. + join_timeout_seconds = max(0.5, min(5.0, self._renew_interval_seconds * 2)) + self._thread.join(timeout=join_timeout_seconds) + if self._thread.is_alive(): + self._logger.warning( + "Auto-renew lock heartbeat thread did not stop within %.2fs%s; ignoring.", + join_timeout_seconds, + f" ({self._log_context})" if self._log_context else "", + ) + self._stop_event = None + self._thread = None + diff --git a/api/tests/test_containers_integration_tests/libs/test_auto_renew_redis_lock_integration.py b/api/tests/test_containers_integration_tests/libs/test_auto_renew_redis_lock_integration.py new file mode 100644 index 0000000000..072ba27d73 --- /dev/null +++ b/api/tests/test_containers_integration_tests/libs/test_auto_renew_redis_lock_integration.py @@ -0,0 +1,39 @@ +""" +Integration tests for AutoRenewRedisLock using real Redis via TestContainers. +""" + +import time +import uuid + +import pytest + +from extensions.ext_redis import redis_client +from libs.auto_renew_redis_lock import AutoRenewRedisLock + + +@pytest.mark.usefixtures("flask_app_with_containers") +def test_auto_renew_redis_lock_renews_ttl_and_releases(): + lock_name = f"test:auto_renew_lock:{uuid.uuid4().hex}" + + # Keep base TTL very small, and renew frequently so the test is stable even on slower CI. + lock = AutoRenewRedisLock( + redis_client=redis_client, + name=lock_name, + ttl_seconds=1.0, + renew_interval_seconds=0.2, + log_context="test_auto_renew_redis_lock", + ) + + acquired = lock.acquire(blocking=True, blocking_timeout=5) + assert acquired is True + + # Wait beyond the base TTL; key should still exist due to renewal. + time.sleep(1.5) + ttl = redis_client.ttl(lock_name) + assert ttl > 0 + + lock.release_safely(status="successful") + + # After release, the key should not exist. + assert redis_client.exists(lock_name) == 0 + diff --git a/api/tests/unit_tests/commands/test_upgrade_db.py b/api/tests/unit_tests/commands/test_upgrade_db.py index d884477143..c4c333f457 100644 --- a/api/tests/unit_tests/commands/test_upgrade_db.py +++ b/api/tests/unit_tests/commands/test_upgrade_db.py @@ -4,8 +4,9 @@ import types from unittest.mock import MagicMock import commands +from libs.auto_renew_redis_lock import LockNotOwnedError, RedisError -HEARTBEAT_WAIT_TIMEOUT_SECONDS = 1.0 +HEARTBEAT_WAIT_TIMEOUT_SECONDS = 5.0 def _install_fake_flask_migrate(monkeypatch, upgrade_impl) -> None: @@ -45,7 +46,7 @@ def test_upgrade_db_failure_not_masked_by_lock_release(monkeypatch, capsys): lock = MagicMock() lock.acquire.return_value = True - lock.release.side_effect = commands.LockNotOwnedError("simulated") + lock.release.side_effect = LockNotOwnedError("simulated") commands.redis_client.lock.return_value = lock def _upgrade(): @@ -69,7 +70,7 @@ def test_upgrade_db_success_ignores_lock_not_owned_on_release(monkeypatch, capsy lock = MagicMock() lock.acquire.return_value = True - lock.release.side_effect = commands.LockNotOwnedError("simulated") + lock.release.side_effect = LockNotOwnedError("simulated") commands.redis_client.lock.return_value = lock _install_fake_flask_migrate(monkeypatch, lambda: None) @@ -129,7 +130,7 @@ def test_upgrade_db_ignores_reacquire_errors(monkeypatch, capsys): def _reacquire(): attempted.set() - raise commands.RedisError("simulated") + raise RedisError("simulated") lock.reacquire.side_effect = _reacquire diff --git a/api/tests/unit_tests/services/enterprise/test_enterprise_service.py b/api/tests/unit_tests/services/enterprise/test_enterprise_service.py new file mode 100644 index 0000000000..b4201aa061 --- /dev/null +++ b/api/tests/unit_tests/services/enterprise/test_enterprise_service.py @@ -0,0 +1,125 @@ +"""Unit tests for enterprise service integrations. + +This module covers the enterprise-only default workspace auto-join behavior: +- Enterprise mode disabled: no external calls +- Successful join / skipped join: no errors +- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise +""" + +from unittest.mock import patch + +import pytest + +from services.enterprise.enterprise_service import ( + DefaultWorkspaceJoinResult, + EnterpriseService, + try_join_default_workspace, +) + + +class TestJoinDefaultWorkspace: + def test_join_default_workspace_success(self): + account_id = "11111111-1111-1111-1111-111111111111" + response = {"workspace_id": "22222222-2222-2222-2222-222222222222", "joined": True, "message": "ok"} + + with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request: + mock_send_request.return_value = response + + result = EnterpriseService.join_default_workspace(account_id=account_id) + + assert isinstance(result, DefaultWorkspaceJoinResult) + assert result.workspace_id == response["workspace_id"] + assert result.joined is True + assert result.message == "ok" + + mock_send_request.assert_called_once_with( + "POST", + "/default-workspace/members", + json={"account_id": account_id}, + ) + + def test_join_default_workspace_invalid_response_format_raises(self): + account_id = "11111111-1111-1111-1111-111111111111" + + with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request: + mock_send_request.return_value = "not-a-dict" + + with pytest.raises(ValueError, match="Invalid response format"): + EnterpriseService.join_default_workspace(account_id=account_id) + + def test_join_default_workspace_invalid_account_id_raises(self): + with pytest.raises(ValueError): + EnterpriseService.join_default_workspace(account_id="not-a-uuid") + + +class TestTryJoinDefaultWorkspace: + def test_try_join_default_workspace_enterprise_disabled_noop(self): + with ( + patch("services.enterprise.enterprise_service.dify_config") as mock_config, + patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join, + ): + mock_config.ENTERPRISE_ENABLED = False + + try_join_default_workspace("11111111-1111-1111-1111-111111111111") + + mock_join.assert_not_called() + + def test_try_join_default_workspace_successful_join_does_not_raise(self): + account_id = "11111111-1111-1111-1111-111111111111" + + with ( + patch("services.enterprise.enterprise_service.dify_config") as mock_config, + patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_join.return_value = DefaultWorkspaceJoinResult( + workspace_id="22222222-2222-2222-2222-222222222222", + joined=True, + message="ok", + ) + + # Should not raise + try_join_default_workspace(account_id) + + mock_join.assert_called_once_with(account_id=account_id) + + def test_try_join_default_workspace_skipped_join_does_not_raise(self): + account_id = "11111111-1111-1111-1111-111111111111" + + with ( + patch("services.enterprise.enterprise_service.dify_config") as mock_config, + patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_join.return_value = DefaultWorkspaceJoinResult( + workspace_id="", + joined=False, + message="no default workspace configured", + ) + + # Should not raise + try_join_default_workspace(account_id) + + mock_join.assert_called_once_with(account_id=account_id) + + def test_try_join_default_workspace_api_failure_soft_fails(self): + account_id = "11111111-1111-1111-1111-111111111111" + + with ( + patch("services.enterprise.enterprise_service.dify_config") as mock_config, + patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join, + ): + mock_config.ENTERPRISE_ENABLED = True + mock_join.side_effect = Exception("network failure") + + # Should not raise + try_join_default_workspace(account_id) + + mock_join.assert_called_once_with(account_id=account_id) + + def test_try_join_default_workspace_invalid_account_id_soft_fails(self): + with patch("services.enterprise.enterprise_service.dify_config") as mock_config: + mock_config.ENTERPRISE_ENABLED = True + + # Should not raise even though UUID parsing fails inside join_default_workspace + try_join_default_workspace("not-a-uuid")