mirror of
https://github.com/langgenius/dify.git
synced 2026-02-24 18:05:11 +00:00
refactor(api): replace heartbeat mechanism with AutoRenewRedisLock for database migration
- Removed the manual heartbeat function for renewing the Redis lock during database migrations. - Integrated AutoRenewRedisLock to handle lock renewal automatically, simplifying the upgrade_db command. - Updated unit tests to reflect changes in lock handling and error management during migrations. (cherry picked from commit 8814256eb5fa20b29e554264f3b659b027bc4c9a)
This commit is contained in:
@@ -3,15 +3,13 @@ import datetime
|
||||
import json
|
||||
import logging
|
||||
import secrets
|
||||
import threading
|
||||
import time
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
import sqlalchemy as sa
|
||||
from flask import current_app
|
||||
from pydantic import TypeAdapter
|
||||
from redis.exceptions import LockNotOwnedError, RedisError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.exc import SQLAlchemyError
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
@@ -32,6 +30,7 @@ from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.opendal_storage import OpenDALStorage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.auto_renew_redis_lock import AutoRenewRedisLock
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, password_pattern, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
@@ -56,36 +55,9 @@ from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.lock import Lock
|
||||
|
||||
DB_UPGRADE_LOCK_TTL_SECONDS = 60
|
||||
|
||||
|
||||
def _heartbeat_db_upgrade_lock(lock: "Lock", stop_event: threading.Event, ttl_seconds: float) -> None:
|
||||
"""
|
||||
Keep the DB upgrade lock alive while migrations are running.
|
||||
|
||||
We intentionally keep the base TTL small (e.g. 60s) so that if the process is killed and can't
|
||||
release the lock, the lock will naturally expire soon. While the process is alive, this
|
||||
heartbeat periodically resets the TTL via `lock.reacquire()`.
|
||||
"""
|
||||
|
||||
interval_seconds = max(0.1, ttl_seconds / 3)
|
||||
while not stop_event.wait(interval_seconds):
|
||||
try:
|
||||
lock.reacquire()
|
||||
except LockNotOwnedError:
|
||||
# Another process took over / TTL expired; continuing to retry won't help.
|
||||
logger.warning("DB migration lock is no longer owned during heartbeat; stop renewing.")
|
||||
return
|
||||
except RedisError:
|
||||
# Best-effort: keep trying while the process is alive.
|
||||
logger.warning("Failed to renew DB migration lock due to Redis error; will retry.", exc_info=True)
|
||||
except Exception:
|
||||
logger.warning("Unexpected error while renewing DB migration lock; will retry.", exc_info=True)
|
||||
|
||||
|
||||
@click.command("reset-password", help="Reset the account password.")
|
||||
@click.option("--email", prompt=True, help="Account email to reset password for")
|
||||
@click.option("--new-password", prompt=True, help="New password")
|
||||
@@ -758,21 +730,14 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
|
||||
@click.command("upgrade-db", help="Upgrade the database")
|
||||
def upgrade_db():
|
||||
click.echo("Preparing database migration...")
|
||||
# Use a short base TTL + heartbeat renewal, so a crashed process doesn't block migrations for long.
|
||||
# thread_local=False is required because heartbeat runs in a separate thread.
|
||||
lock = redis_client.lock(
|
||||
lock = AutoRenewRedisLock(
|
||||
redis_client=redis_client,
|
||||
name="db_upgrade_lock",
|
||||
timeout=DB_UPGRADE_LOCK_TTL_SECONDS,
|
||||
thread_local=False,
|
||||
ttl_seconds=DB_UPGRADE_LOCK_TTL_SECONDS,
|
||||
logger=logger,
|
||||
log_context="db_migration",
|
||||
)
|
||||
if lock.acquire(blocking=False):
|
||||
stop_event = threading.Event()
|
||||
heartbeat_thread = threading.Thread(
|
||||
target=_heartbeat_db_upgrade_lock,
|
||||
args=(lock, stop_event, float(DB_UPGRADE_LOCK_TTL_SECONDS)),
|
||||
daemon=True,
|
||||
)
|
||||
heartbeat_thread.start()
|
||||
migration_succeeded = False
|
||||
try:
|
||||
click.echo(click.style("Starting database migration.", fg="green"))
|
||||
@@ -790,23 +755,8 @@ def upgrade_db():
|
||||
click.echo(click.style(f"Database migration failed: {e}", fg="red"))
|
||||
raise SystemExit(1)
|
||||
finally:
|
||||
stop_event.set()
|
||||
heartbeat_thread.join(timeout=5)
|
||||
# Lock release errors should never mask the real migration failure.
|
||||
try:
|
||||
lock.release()
|
||||
except LockNotOwnedError:
|
||||
status = "successful" if migration_succeeded else "failed"
|
||||
logger.warning(
|
||||
"DB migration lock not owned on release after %s migration (likely expired); ignoring.", status
|
||||
)
|
||||
except RedisError:
|
||||
status = "successful" if migration_succeeded else "failed"
|
||||
logger.warning(
|
||||
"Failed to release DB migration lock due to Redis error after %s migration; ignoring.",
|
||||
status,
|
||||
exc_info=True,
|
||||
)
|
||||
status = "successful" if migration_succeeded else "failed"
|
||||
lock.release_safely(status=status)
|
||||
else:
|
||||
click.echo("Database migration skipped")
|
||||
|
||||
|
||||
198
api/libs/auto_renew_redis_lock.py
Normal file
198
api/libs/auto_renew_redis_lock.py
Normal file
@@ -0,0 +1,198 @@
|
||||
"""
|
||||
Auto-renewing Redis distributed lock (redis-py Lock).
|
||||
|
||||
Why this exists:
|
||||
- A fixed, long lock TTL can leave a stale lock for a long time if the process is killed
|
||||
before releasing it.
|
||||
- A fixed, short lock TTL can expire during long critical sections (e.g. DB migrations),
|
||||
allowing another instance to acquire the same lock concurrently.
|
||||
|
||||
This wrapper keeps a short base TTL and renews it in a daemon thread using `Lock.reacquire()`
|
||||
while the process is alive. If the process is terminated, the renewal stops and the lock
|
||||
expires soon.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import threading
|
||||
from typing import Any
|
||||
|
||||
from redis.exceptions import LockNotOwnedError, RedisError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AutoRenewRedisLock:
|
||||
"""
|
||||
Redis lock wrapper that automatically renews TTL while held.
|
||||
|
||||
Notes:
|
||||
- We force `thread_local=False` when creating the underlying redis-py lock, because the
|
||||
lock token must be accessible from the heartbeat thread for `reacquire()` to work.
|
||||
- `release_safely()` is best-effort: it never raises, so it won't mask the caller's
|
||||
primary error/exit code.
|
||||
"""
|
||||
|
||||
_redis_client: Any
|
||||
_name: str
|
||||
_ttl_seconds: float
|
||||
_renew_interval_seconds: float
|
||||
_log_context: str | None
|
||||
_logger: logging.Logger
|
||||
|
||||
_lock: Any
|
||||
_stop_event: threading.Event | None
|
||||
_thread: threading.Thread | None
|
||||
_acquired: bool
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Any,
|
||||
name: str,
|
||||
ttl_seconds: float = 60,
|
||||
renew_interval_seconds: float | None = None,
|
||||
*,
|
||||
logger: logging.Logger | None = None,
|
||||
log_context: str | None = None,
|
||||
) -> None:
|
||||
self._redis_client = redis_client
|
||||
self._name = name
|
||||
self._ttl_seconds = float(ttl_seconds)
|
||||
self._renew_interval_seconds = (
|
||||
float(renew_interval_seconds) if renew_interval_seconds is not None else max(0.1, self._ttl_seconds / 3)
|
||||
)
|
||||
self._logger = logger or logging.getLogger(__name__)
|
||||
self._log_context = log_context
|
||||
|
||||
self._lock = None
|
||||
self._stop_event = None
|
||||
self._thread = None
|
||||
self._acquired = False
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
def acquire(self, *args: Any, **kwargs: Any) -> bool:
|
||||
"""
|
||||
Acquire the lock and start auto-renew heartbeat on success.
|
||||
|
||||
Accepts the same args/kwargs as redis-py `Lock.acquire()`.
|
||||
"""
|
||||
self._lock = self._redis_client.lock(
|
||||
name=self._name,
|
||||
timeout=self._ttl_seconds,
|
||||
thread_local=False,
|
||||
)
|
||||
acquired = bool(self._lock.acquire(*args, **kwargs))
|
||||
self._acquired = acquired
|
||||
if acquired:
|
||||
self._start_heartbeat()
|
||||
return acquired
|
||||
|
||||
def owned(self) -> bool:
|
||||
if self._lock is None:
|
||||
return False
|
||||
try:
|
||||
return bool(self._lock.owned())
|
||||
except Exception:
|
||||
# Ownership checks are best-effort and must not break callers.
|
||||
return False
|
||||
|
||||
def _start_heartbeat(self) -> None:
|
||||
if self._lock is None:
|
||||
return
|
||||
if self._stop_event is not None:
|
||||
return
|
||||
|
||||
self._stop_event = threading.Event()
|
||||
self._thread = threading.Thread(
|
||||
target=self._heartbeat_loop,
|
||||
args=(self._lock, self._stop_event),
|
||||
daemon=True,
|
||||
name=f"AutoRenewRedisLock({self._name})",
|
||||
)
|
||||
self._thread.start()
|
||||
|
||||
def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None:
|
||||
while not stop_event.wait(self._renew_interval_seconds):
|
||||
try:
|
||||
lock.reacquire()
|
||||
except LockNotOwnedError:
|
||||
self._logger.warning(
|
||||
"Auto-renew lock is no longer owned during heartbeat%s; stop renewing.",
|
||||
f" ({self._log_context})" if self._log_context else "",
|
||||
exc_info=True,
|
||||
)
|
||||
return
|
||||
except RedisError:
|
||||
self._logger.warning(
|
||||
"Failed to renew auto-renew lock due to Redis error%s; will retry.",
|
||||
f" ({self._log_context})" if self._log_context else "",
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception:
|
||||
self._logger.warning(
|
||||
"Unexpected error while renewing auto-renew lock%s; will retry.",
|
||||
f" ({self._log_context})" if self._log_context else "",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
def release_safely(self, *, status: str | None = None) -> None:
|
||||
"""
|
||||
Stop heartbeat and release lock. Never raises.
|
||||
|
||||
Args:
|
||||
status: Optional caller-provided status (e.g. 'successful'/'failed') to add context to logs.
|
||||
"""
|
||||
lock = self._lock
|
||||
if lock is None:
|
||||
return
|
||||
|
||||
self._stop_heartbeat()
|
||||
|
||||
# Lock release errors should never mask the real error/exit code.
|
||||
try:
|
||||
lock.release()
|
||||
except LockNotOwnedError:
|
||||
self._logger.warning(
|
||||
"Auto-renew lock not owned on release%s%s; ignoring.",
|
||||
f" after {status} operation" if status else "",
|
||||
f" ({self._log_context})" if self._log_context else "",
|
||||
exc_info=True,
|
||||
)
|
||||
except RedisError:
|
||||
self._logger.warning(
|
||||
"Failed to release auto-renew lock due to Redis error%s%s; ignoring.",
|
||||
f" after {status} operation" if status else "",
|
||||
f" ({self._log_context})" if self._log_context else "",
|
||||
exc_info=True,
|
||||
)
|
||||
except Exception:
|
||||
self._logger.warning(
|
||||
"Unexpected error while releasing auto-renew lock%s%s; ignoring.",
|
||||
f" after {status} operation" if status else "",
|
||||
f" ({self._log_context})" if self._log_context else "",
|
||||
exc_info=True,
|
||||
)
|
||||
finally:
|
||||
self._acquired = False
|
||||
|
||||
def _stop_heartbeat(self) -> None:
|
||||
if self._stop_event is None:
|
||||
return
|
||||
self._stop_event.set()
|
||||
if self._thread is not None:
|
||||
# Best-effort join: if Redis calls are blocked, the daemon thread may remain alive.
|
||||
join_timeout_seconds = max(0.5, min(5.0, self._renew_interval_seconds * 2))
|
||||
self._thread.join(timeout=join_timeout_seconds)
|
||||
if self._thread.is_alive():
|
||||
self._logger.warning(
|
||||
"Auto-renew lock heartbeat thread did not stop within %.2fs%s; ignoring.",
|
||||
join_timeout_seconds,
|
||||
f" ({self._log_context})" if self._log_context else "",
|
||||
)
|
||||
self._stop_event = None
|
||||
self._thread = None
|
||||
|
||||
@@ -0,0 +1,39 @@
|
||||
"""
|
||||
Integration tests for AutoRenewRedisLock using real Redis via TestContainers.
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.auto_renew_redis_lock import AutoRenewRedisLock
|
||||
|
||||
|
||||
@pytest.mark.usefixtures("flask_app_with_containers")
|
||||
def test_auto_renew_redis_lock_renews_ttl_and_releases():
|
||||
lock_name = f"test:auto_renew_lock:{uuid.uuid4().hex}"
|
||||
|
||||
# Keep base TTL very small, and renew frequently so the test is stable even on slower CI.
|
||||
lock = AutoRenewRedisLock(
|
||||
redis_client=redis_client,
|
||||
name=lock_name,
|
||||
ttl_seconds=1.0,
|
||||
renew_interval_seconds=0.2,
|
||||
log_context="test_auto_renew_redis_lock",
|
||||
)
|
||||
|
||||
acquired = lock.acquire(blocking=True, blocking_timeout=5)
|
||||
assert acquired is True
|
||||
|
||||
# Wait beyond the base TTL; key should still exist due to renewal.
|
||||
time.sleep(1.5)
|
||||
ttl = redis_client.ttl(lock_name)
|
||||
assert ttl > 0
|
||||
|
||||
lock.release_safely(status="successful")
|
||||
|
||||
# After release, the key should not exist.
|
||||
assert redis_client.exists(lock_name) == 0
|
||||
|
||||
@@ -4,8 +4,9 @@ import types
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import commands
|
||||
from libs.auto_renew_redis_lock import LockNotOwnedError, RedisError
|
||||
|
||||
HEARTBEAT_WAIT_TIMEOUT_SECONDS = 1.0
|
||||
HEARTBEAT_WAIT_TIMEOUT_SECONDS = 5.0
|
||||
|
||||
|
||||
def _install_fake_flask_migrate(monkeypatch, upgrade_impl) -> None:
|
||||
@@ -45,7 +46,7 @@ def test_upgrade_db_failure_not_masked_by_lock_release(monkeypatch, capsys):
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.side_effect = commands.LockNotOwnedError("simulated")
|
||||
lock.release.side_effect = LockNotOwnedError("simulated")
|
||||
commands.redis_client.lock.return_value = lock
|
||||
|
||||
def _upgrade():
|
||||
@@ -69,7 +70,7 @@ def test_upgrade_db_success_ignores_lock_not_owned_on_release(monkeypatch, capsy
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
lock.release.side_effect = commands.LockNotOwnedError("simulated")
|
||||
lock.release.side_effect = LockNotOwnedError("simulated")
|
||||
commands.redis_client.lock.return_value = lock
|
||||
|
||||
_install_fake_flask_migrate(monkeypatch, lambda: None)
|
||||
@@ -129,7 +130,7 @@ def test_upgrade_db_ignores_reacquire_errors(monkeypatch, capsys):
|
||||
|
||||
def _reacquire():
|
||||
attempted.set()
|
||||
raise commands.RedisError("simulated")
|
||||
raise RedisError("simulated")
|
||||
|
||||
lock.reacquire.side_effect = _reacquire
|
||||
|
||||
|
||||
@@ -0,0 +1,125 @@
|
||||
"""Unit tests for enterprise service integrations.
|
||||
|
||||
This module covers the enterprise-only default workspace auto-join behavior:
|
||||
- Enterprise mode disabled: no external calls
|
||||
- Successful join / skipped join: no errors
|
||||
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from services.enterprise.enterprise_service import (
|
||||
DefaultWorkspaceJoinResult,
|
||||
EnterpriseService,
|
||||
try_join_default_workspace,
|
||||
)
|
||||
|
||||
|
||||
class TestJoinDefaultWorkspace:
|
||||
def test_join_default_workspace_success(self):
|
||||
account_id = "11111111-1111-1111-1111-111111111111"
|
||||
response = {"workspace_id": "22222222-2222-2222-2222-222222222222", "joined": True, "message": "ok"}
|
||||
|
||||
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
|
||||
mock_send_request.return_value = response
|
||||
|
||||
result = EnterpriseService.join_default_workspace(account_id=account_id)
|
||||
|
||||
assert isinstance(result, DefaultWorkspaceJoinResult)
|
||||
assert result.workspace_id == response["workspace_id"]
|
||||
assert result.joined is True
|
||||
assert result.message == "ok"
|
||||
|
||||
mock_send_request.assert_called_once_with(
|
||||
"POST",
|
||||
"/default-workspace/members",
|
||||
json={"account_id": account_id},
|
||||
)
|
||||
|
||||
def test_join_default_workspace_invalid_response_format_raises(self):
|
||||
account_id = "11111111-1111-1111-1111-111111111111"
|
||||
|
||||
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
|
||||
mock_send_request.return_value = "not-a-dict"
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid response format"):
|
||||
EnterpriseService.join_default_workspace(account_id=account_id)
|
||||
|
||||
def test_join_default_workspace_invalid_account_id_raises(self):
|
||||
with pytest.raises(ValueError):
|
||||
EnterpriseService.join_default_workspace(account_id="not-a-uuid")
|
||||
|
||||
|
||||
class TestTryJoinDefaultWorkspace:
|
||||
def test_try_join_default_workspace_enterprise_disabled_noop(self):
|
||||
with (
|
||||
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
|
||||
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = False
|
||||
|
||||
try_join_default_workspace("11111111-1111-1111-1111-111111111111")
|
||||
|
||||
mock_join.assert_not_called()
|
||||
|
||||
def test_try_join_default_workspace_successful_join_does_not_raise(self):
|
||||
account_id = "11111111-1111-1111-1111-111111111111"
|
||||
|
||||
with (
|
||||
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
|
||||
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_join.return_value = DefaultWorkspaceJoinResult(
|
||||
workspace_id="22222222-2222-2222-2222-222222222222",
|
||||
joined=True,
|
||||
message="ok",
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
try_join_default_workspace(account_id)
|
||||
|
||||
mock_join.assert_called_once_with(account_id=account_id)
|
||||
|
||||
def test_try_join_default_workspace_skipped_join_does_not_raise(self):
|
||||
account_id = "11111111-1111-1111-1111-111111111111"
|
||||
|
||||
with (
|
||||
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
|
||||
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_join.return_value = DefaultWorkspaceJoinResult(
|
||||
workspace_id="",
|
||||
joined=False,
|
||||
message="no default workspace configured",
|
||||
)
|
||||
|
||||
# Should not raise
|
||||
try_join_default_workspace(account_id)
|
||||
|
||||
mock_join.assert_called_once_with(account_id=account_id)
|
||||
|
||||
def test_try_join_default_workspace_api_failure_soft_fails(self):
|
||||
account_id = "11111111-1111-1111-1111-111111111111"
|
||||
|
||||
with (
|
||||
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
|
||||
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_join.side_effect = Exception("network failure")
|
||||
|
||||
# Should not raise
|
||||
try_join_default_workspace(account_id)
|
||||
|
||||
mock_join.assert_called_once_with(account_id=account_id)
|
||||
|
||||
def test_try_join_default_workspace_invalid_account_id_soft_fails(self):
|
||||
with patch("services.enterprise.enterprise_service.dify_config") as mock_config:
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
|
||||
# Should not raise even though UUID parsing fails inside join_default_workspace
|
||||
try_join_default_workspace("not-a-uuid")
|
||||
Reference in New Issue
Block a user