refactor(api): replace heartbeat mechanism with AutoRenewRedisLock for database migration

- Removed the manual heartbeat function for renewing the Redis lock during database migrations.
- Integrated AutoRenewRedisLock to handle lock renewal automatically, simplifying the upgrade_db command.
- Updated unit tests to reflect changes in lock handling and error management during migrations.

(cherry picked from commit 8814256eb5fa20b29e554264f3b659b027bc4c9a)
This commit is contained in:
L1nSn0w
2026-02-14 12:03:58 +08:00
parent 8d4bd5636b
commit 94603b5408
5 changed files with 376 additions and 63 deletions

View File

@@ -0,0 +1,39 @@
"""
Integration tests for AutoRenewRedisLock using real Redis via TestContainers.
"""
import time
import uuid
import pytest
from extensions.ext_redis import redis_client
from libs.auto_renew_redis_lock import AutoRenewRedisLock
@pytest.mark.usefixtures("flask_app_with_containers")
def test_auto_renew_redis_lock_renews_ttl_and_releases():
lock_name = f"test:auto_renew_lock:{uuid.uuid4().hex}"
# Keep base TTL very small, and renew frequently so the test is stable even on slower CI.
lock = AutoRenewRedisLock(
redis_client=redis_client,
name=lock_name,
ttl_seconds=1.0,
renew_interval_seconds=0.2,
log_context="test_auto_renew_redis_lock",
)
acquired = lock.acquire(blocking=True, blocking_timeout=5)
assert acquired is True
# Wait beyond the base TTL; key should still exist due to renewal.
time.sleep(1.5)
ttl = redis_client.ttl(lock_name)
assert ttl > 0
lock.release_safely(status="successful")
# After release, the key should not exist.
assert redis_client.exists(lock_name) == 0

View File

@@ -4,8 +4,9 @@ import types
from unittest.mock import MagicMock
import commands
from libs.auto_renew_redis_lock import LockNotOwnedError, RedisError
HEARTBEAT_WAIT_TIMEOUT_SECONDS = 1.0
HEARTBEAT_WAIT_TIMEOUT_SECONDS = 5.0
def _install_fake_flask_migrate(monkeypatch, upgrade_impl) -> None:
@@ -45,7 +46,7 @@ def test_upgrade_db_failure_not_masked_by_lock_release(monkeypatch, capsys):
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = commands.LockNotOwnedError("simulated")
lock.release.side_effect = LockNotOwnedError("simulated")
commands.redis_client.lock.return_value = lock
def _upgrade():
@@ -69,7 +70,7 @@ def test_upgrade_db_success_ignores_lock_not_owned_on_release(monkeypatch, capsy
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = commands.LockNotOwnedError("simulated")
lock.release.side_effect = LockNotOwnedError("simulated")
commands.redis_client.lock.return_value = lock
_install_fake_flask_migrate(monkeypatch, lambda: None)
@@ -129,7 +130,7 @@ def test_upgrade_db_ignores_reacquire_errors(monkeypatch, capsys):
def _reacquire():
attempted.set()
raise commands.RedisError("simulated")
raise RedisError("simulated")
lock.reacquire.side_effect = _reacquire

View File

@@ -0,0 +1,125 @@
"""Unit tests for enterprise service integrations.
This module covers the enterprise-only default workspace auto-join behavior:
- Enterprise mode disabled: no external calls
- Successful join / skipped join: no errors
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
"""
from unittest.mock import patch
import pytest
from services.enterprise.enterprise_service import (
DefaultWorkspaceJoinResult,
EnterpriseService,
try_join_default_workspace,
)
class TestJoinDefaultWorkspace:
def test_join_default_workspace_success(self):
account_id = "11111111-1111-1111-1111-111111111111"
response = {"workspace_id": "22222222-2222-2222-2222-222222222222", "joined": True, "message": "ok"}
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
mock_send_request.return_value = response
result = EnterpriseService.join_default_workspace(account_id=account_id)
assert isinstance(result, DefaultWorkspaceJoinResult)
assert result.workspace_id == response["workspace_id"]
assert result.joined is True
assert result.message == "ok"
mock_send_request.assert_called_once_with(
"POST",
"/default-workspace/members",
json={"account_id": account_id},
)
def test_join_default_workspace_invalid_response_format_raises(self):
account_id = "11111111-1111-1111-1111-111111111111"
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
mock_send_request.return_value = "not-a-dict"
with pytest.raises(ValueError, match="Invalid response format"):
EnterpriseService.join_default_workspace(account_id=account_id)
def test_join_default_workspace_invalid_account_id_raises(self):
with pytest.raises(ValueError):
EnterpriseService.join_default_workspace(account_id="not-a-uuid")
class TestTryJoinDefaultWorkspace:
def test_try_join_default_workspace_enterprise_disabled_noop(self):
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = False
try_join_default_workspace("11111111-1111-1111-1111-111111111111")
mock_join.assert_not_called()
def test_try_join_default_workspace_successful_join_does_not_raise(self):
account_id = "11111111-1111-1111-1111-111111111111"
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = True
mock_join.return_value = DefaultWorkspaceJoinResult(
workspace_id="22222222-2222-2222-2222-222222222222",
joined=True,
message="ok",
)
# Should not raise
try_join_default_workspace(account_id)
mock_join.assert_called_once_with(account_id=account_id)
def test_try_join_default_workspace_skipped_join_does_not_raise(self):
account_id = "11111111-1111-1111-1111-111111111111"
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = True
mock_join.return_value = DefaultWorkspaceJoinResult(
workspace_id="",
joined=False,
message="no default workspace configured",
)
# Should not raise
try_join_default_workspace(account_id)
mock_join.assert_called_once_with(account_id=account_id)
def test_try_join_default_workspace_api_failure_soft_fails(self):
account_id = "11111111-1111-1111-1111-111111111111"
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = True
mock_join.side_effect = Exception("network failure")
# Should not raise
try_join_default_workspace(account_id)
mock_join.assert_called_once_with(account_id=account_id)
def test_try_join_default_workspace_invalid_account_id_soft_fails(self):
with patch("services.enterprise.enterprise_service.dify_config") as mock_config:
mock_config.ENTERPRISE_ENABLED = True
# Should not raise even though UUID parsing fails inside join_default_workspace
try_join_default_workspace("not-a-uuid")