refactor(api): replace heartbeat mechanism with AutoRenewRedisLock for database migration

- Removed the manual heartbeat function for renewing the Redis lock during database migrations.
- Integrated AutoRenewRedisLock to handle lock renewal automatically, simplifying the upgrade_db command.
- Updated unit tests to reflect changes in lock handling and error management during migrations.

(cherry picked from commit 8814256eb5fa20b29e554264f3b659b027bc4c9a)
This commit is contained in:
L1nSn0w
2026-02-14 12:03:58 +08:00
parent 8d4bd5636b
commit 94603b5408
5 changed files with 376 additions and 63 deletions

View File

@@ -3,15 +3,13 @@ import datetime
import json
import logging
import secrets
import threading
import time
from typing import TYPE_CHECKING, Any
from typing import Any
import click
import sqlalchemy as sa
from flask import current_app
from pydantic import TypeAdapter
from redis.exceptions import LockNotOwnedError, RedisError
from sqlalchemy import select
from sqlalchemy.exc import SQLAlchemyError
from sqlalchemy.orm import sessionmaker
@@ -32,6 +30,7 @@ from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from extensions.storage.opendal_storage import OpenDALStorage
from extensions.storage.storage_type import StorageType
from libs.auto_renew_redis_lock import AutoRenewRedisLock
from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password
from libs.rsa import generate_key_pair
@@ -56,36 +55,9 @@ from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
logger = logging.getLogger(__name__)
if TYPE_CHECKING:
from redis.lock import Lock
DB_UPGRADE_LOCK_TTL_SECONDS = 60
def _heartbeat_db_upgrade_lock(lock: "Lock", stop_event: threading.Event, ttl_seconds: float) -> None:
"""
Keep the DB upgrade lock alive while migrations are running.
We intentionally keep the base TTL small (e.g. 60s) so that if the process is killed and can't
release the lock, the lock will naturally expire soon. While the process is alive, this
heartbeat periodically resets the TTL via `lock.reacquire()`.
"""
interval_seconds = max(0.1, ttl_seconds / 3)
while not stop_event.wait(interval_seconds):
try:
lock.reacquire()
except LockNotOwnedError:
# Another process took over / TTL expired; continuing to retry won't help.
logger.warning("DB migration lock is no longer owned during heartbeat; stop renewing.")
return
except RedisError:
# Best-effort: keep trying while the process is alive.
logger.warning("Failed to renew DB migration lock due to Redis error; will retry.", exc_info=True)
except Exception:
logger.warning("Unexpected error while renewing DB migration lock; will retry.", exc_info=True)
@click.command("reset-password", help="Reset the account password.")
@click.option("--email", prompt=True, help="Account email to reset password for")
@click.option("--new-password", prompt=True, help="New password")
@@ -758,21 +730,14 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
@click.command("upgrade-db", help="Upgrade the database")
def upgrade_db():
click.echo("Preparing database migration...")
# Use a short base TTL + heartbeat renewal, so a crashed process doesn't block migrations for long.
# thread_local=False is required because heartbeat runs in a separate thread.
lock = redis_client.lock(
lock = AutoRenewRedisLock(
redis_client=redis_client,
name="db_upgrade_lock",
timeout=DB_UPGRADE_LOCK_TTL_SECONDS,
thread_local=False,
ttl_seconds=DB_UPGRADE_LOCK_TTL_SECONDS,
logger=logger,
log_context="db_migration",
)
if lock.acquire(blocking=False):
stop_event = threading.Event()
heartbeat_thread = threading.Thread(
target=_heartbeat_db_upgrade_lock,
args=(lock, stop_event, float(DB_UPGRADE_LOCK_TTL_SECONDS)),
daemon=True,
)
heartbeat_thread.start()
migration_succeeded = False
try:
click.echo(click.style("Starting database migration.", fg="green"))
@@ -790,23 +755,8 @@ def upgrade_db():
click.echo(click.style(f"Database migration failed: {e}", fg="red"))
raise SystemExit(1)
finally:
stop_event.set()
heartbeat_thread.join(timeout=5)
# Lock release errors should never mask the real migration failure.
try:
lock.release()
except LockNotOwnedError:
status = "successful" if migration_succeeded else "failed"
logger.warning(
"DB migration lock not owned on release after %s migration (likely expired); ignoring.", status
)
except RedisError:
status = "successful" if migration_succeeded else "failed"
logger.warning(
"Failed to release DB migration lock due to Redis error after %s migration; ignoring.",
status,
exc_info=True,
)
status = "successful" if migration_succeeded else "failed"
lock.release_safely(status=status)
else:
click.echo("Database migration skipped")

View File

@@ -0,0 +1,198 @@
"""
Auto-renewing Redis distributed lock (redis-py Lock).
Why this exists:
- A fixed, long lock TTL can leave a stale lock for a long time if the process is killed
before releasing it.
- A fixed, short lock TTL can expire during long critical sections (e.g. DB migrations),
allowing another instance to acquire the same lock concurrently.
This wrapper keeps a short base TTL and renews it in a daemon thread using `Lock.reacquire()`
while the process is alive. If the process is terminated, the renewal stops and the lock
expires soon.
"""
from __future__ import annotations
import logging
import threading
from typing import Any
from redis.exceptions import LockNotOwnedError, RedisError
logger = logging.getLogger(__name__)
class AutoRenewRedisLock:
"""
Redis lock wrapper that automatically renews TTL while held.
Notes:
- We force `thread_local=False` when creating the underlying redis-py lock, because the
lock token must be accessible from the heartbeat thread for `reacquire()` to work.
- `release_safely()` is best-effort: it never raises, so it won't mask the caller's
primary error/exit code.
"""
_redis_client: Any
_name: str
_ttl_seconds: float
_renew_interval_seconds: float
_log_context: str | None
_logger: logging.Logger
_lock: Any
_stop_event: threading.Event | None
_thread: threading.Thread | None
_acquired: bool
def __init__(
self,
redis_client: Any,
name: str,
ttl_seconds: float = 60,
renew_interval_seconds: float | None = None,
*,
logger: logging.Logger | None = None,
log_context: str | None = None,
) -> None:
self._redis_client = redis_client
self._name = name
self._ttl_seconds = float(ttl_seconds)
self._renew_interval_seconds = (
float(renew_interval_seconds) if renew_interval_seconds is not None else max(0.1, self._ttl_seconds / 3)
)
self._logger = logger or logging.getLogger(__name__)
self._log_context = log_context
self._lock = None
self._stop_event = None
self._thread = None
self._acquired = False
@property
def name(self) -> str:
return self._name
def acquire(self, *args: Any, **kwargs: Any) -> bool:
"""
Acquire the lock and start auto-renew heartbeat on success.
Accepts the same args/kwargs as redis-py `Lock.acquire()`.
"""
self._lock = self._redis_client.lock(
name=self._name,
timeout=self._ttl_seconds,
thread_local=False,
)
acquired = bool(self._lock.acquire(*args, **kwargs))
self._acquired = acquired
if acquired:
self._start_heartbeat()
return acquired
def owned(self) -> bool:
if self._lock is None:
return False
try:
return bool(self._lock.owned())
except Exception:
# Ownership checks are best-effort and must not break callers.
return False
def _start_heartbeat(self) -> None:
if self._lock is None:
return
if self._stop_event is not None:
return
self._stop_event = threading.Event()
self._thread = threading.Thread(
target=self._heartbeat_loop,
args=(self._lock, self._stop_event),
daemon=True,
name=f"AutoRenewRedisLock({self._name})",
)
self._thread.start()
def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None:
while not stop_event.wait(self._renew_interval_seconds):
try:
lock.reacquire()
except LockNotOwnedError:
self._logger.warning(
"Auto-renew lock is no longer owned during heartbeat%s; stop renewing.",
f" ({self._log_context})" if self._log_context else "",
exc_info=True,
)
return
except RedisError:
self._logger.warning(
"Failed to renew auto-renew lock due to Redis error%s; will retry.",
f" ({self._log_context})" if self._log_context else "",
exc_info=True,
)
except Exception:
self._logger.warning(
"Unexpected error while renewing auto-renew lock%s; will retry.",
f" ({self._log_context})" if self._log_context else "",
exc_info=True,
)
def release_safely(self, *, status: str | None = None) -> None:
"""
Stop heartbeat and release lock. Never raises.
Args:
status: Optional caller-provided status (e.g. 'successful'/'failed') to add context to logs.
"""
lock = self._lock
if lock is None:
return
self._stop_heartbeat()
# Lock release errors should never mask the real error/exit code.
try:
lock.release()
except LockNotOwnedError:
self._logger.warning(
"Auto-renew lock not owned on release%s%s; ignoring.",
f" after {status} operation" if status else "",
f" ({self._log_context})" if self._log_context else "",
exc_info=True,
)
except RedisError:
self._logger.warning(
"Failed to release auto-renew lock due to Redis error%s%s; ignoring.",
f" after {status} operation" if status else "",
f" ({self._log_context})" if self._log_context else "",
exc_info=True,
)
except Exception:
self._logger.warning(
"Unexpected error while releasing auto-renew lock%s%s; ignoring.",
f" after {status} operation" if status else "",
f" ({self._log_context})" if self._log_context else "",
exc_info=True,
)
finally:
self._acquired = False
def _stop_heartbeat(self) -> None:
if self._stop_event is None:
return
self._stop_event.set()
if self._thread is not None:
# Best-effort join: if Redis calls are blocked, the daemon thread may remain alive.
join_timeout_seconds = max(0.5, min(5.0, self._renew_interval_seconds * 2))
self._thread.join(timeout=join_timeout_seconds)
if self._thread.is_alive():
self._logger.warning(
"Auto-renew lock heartbeat thread did not stop within %.2fs%s; ignoring.",
join_timeout_seconds,
f" ({self._log_context})" if self._log_context else "",
)
self._stop_event = None
self._thread = None

View File

@@ -0,0 +1,39 @@
"""
Integration tests for AutoRenewRedisLock using real Redis via TestContainers.
"""
import time
import uuid
import pytest
from extensions.ext_redis import redis_client
from libs.auto_renew_redis_lock import AutoRenewRedisLock
@pytest.mark.usefixtures("flask_app_with_containers")
def test_auto_renew_redis_lock_renews_ttl_and_releases():
lock_name = f"test:auto_renew_lock:{uuid.uuid4().hex}"
# Keep base TTL very small, and renew frequently so the test is stable even on slower CI.
lock = AutoRenewRedisLock(
redis_client=redis_client,
name=lock_name,
ttl_seconds=1.0,
renew_interval_seconds=0.2,
log_context="test_auto_renew_redis_lock",
)
acquired = lock.acquire(blocking=True, blocking_timeout=5)
assert acquired is True
# Wait beyond the base TTL; key should still exist due to renewal.
time.sleep(1.5)
ttl = redis_client.ttl(lock_name)
assert ttl > 0
lock.release_safely(status="successful")
# After release, the key should not exist.
assert redis_client.exists(lock_name) == 0

View File

@@ -4,8 +4,9 @@ import types
from unittest.mock import MagicMock
import commands
from libs.auto_renew_redis_lock import LockNotOwnedError, RedisError
HEARTBEAT_WAIT_TIMEOUT_SECONDS = 1.0
HEARTBEAT_WAIT_TIMEOUT_SECONDS = 5.0
def _install_fake_flask_migrate(monkeypatch, upgrade_impl) -> None:
@@ -45,7 +46,7 @@ def test_upgrade_db_failure_not_masked_by_lock_release(monkeypatch, capsys):
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = commands.LockNotOwnedError("simulated")
lock.release.side_effect = LockNotOwnedError("simulated")
commands.redis_client.lock.return_value = lock
def _upgrade():
@@ -69,7 +70,7 @@ def test_upgrade_db_success_ignores_lock_not_owned_on_release(monkeypatch, capsy
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = commands.LockNotOwnedError("simulated")
lock.release.side_effect = LockNotOwnedError("simulated")
commands.redis_client.lock.return_value = lock
_install_fake_flask_migrate(monkeypatch, lambda: None)
@@ -129,7 +130,7 @@ def test_upgrade_db_ignores_reacquire_errors(monkeypatch, capsys):
def _reacquire():
attempted.set()
raise commands.RedisError("simulated")
raise RedisError("simulated")
lock.reacquire.side_effect = _reacquire

View File

@@ -0,0 +1,125 @@
"""Unit tests for enterprise service integrations.
This module covers the enterprise-only default workspace auto-join behavior:
- Enterprise mode disabled: no external calls
- Successful join / skipped join: no errors
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
"""
from unittest.mock import patch
import pytest
from services.enterprise.enterprise_service import (
DefaultWorkspaceJoinResult,
EnterpriseService,
try_join_default_workspace,
)
class TestJoinDefaultWorkspace:
def test_join_default_workspace_success(self):
account_id = "11111111-1111-1111-1111-111111111111"
response = {"workspace_id": "22222222-2222-2222-2222-222222222222", "joined": True, "message": "ok"}
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
mock_send_request.return_value = response
result = EnterpriseService.join_default_workspace(account_id=account_id)
assert isinstance(result, DefaultWorkspaceJoinResult)
assert result.workspace_id == response["workspace_id"]
assert result.joined is True
assert result.message == "ok"
mock_send_request.assert_called_once_with(
"POST",
"/default-workspace/members",
json={"account_id": account_id},
)
def test_join_default_workspace_invalid_response_format_raises(self):
account_id = "11111111-1111-1111-1111-111111111111"
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
mock_send_request.return_value = "not-a-dict"
with pytest.raises(ValueError, match="Invalid response format"):
EnterpriseService.join_default_workspace(account_id=account_id)
def test_join_default_workspace_invalid_account_id_raises(self):
with pytest.raises(ValueError):
EnterpriseService.join_default_workspace(account_id="not-a-uuid")
class TestTryJoinDefaultWorkspace:
def test_try_join_default_workspace_enterprise_disabled_noop(self):
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = False
try_join_default_workspace("11111111-1111-1111-1111-111111111111")
mock_join.assert_not_called()
def test_try_join_default_workspace_successful_join_does_not_raise(self):
account_id = "11111111-1111-1111-1111-111111111111"
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = True
mock_join.return_value = DefaultWorkspaceJoinResult(
workspace_id="22222222-2222-2222-2222-222222222222",
joined=True,
message="ok",
)
# Should not raise
try_join_default_workspace(account_id)
mock_join.assert_called_once_with(account_id=account_id)
def test_try_join_default_workspace_skipped_join_does_not_raise(self):
account_id = "11111111-1111-1111-1111-111111111111"
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = True
mock_join.return_value = DefaultWorkspaceJoinResult(
workspace_id="",
joined=False,
message="no default workspace configured",
)
# Should not raise
try_join_default_workspace(account_id)
mock_join.assert_called_once_with(account_id=account_id)
def test_try_join_default_workspace_api_failure_soft_fails(self):
account_id = "11111111-1111-1111-1111-111111111111"
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = True
mock_join.side_effect = Exception("network failure")
# Should not raise
try_join_default_workspace(account_id)
mock_join.assert_called_once_with(account_id=account_id)
def test_try_join_default_workspace_invalid_account_id_soft_fails(self):
with patch("services.enterprise.enterprise_service.dify_config") as mock_config:
mock_config.ENTERPRISE_ENABLED = True
# Should not raise even though UUID parsing fails inside join_default_workspace
try_join_default_workspace("not-a-uuid")