Compare commits

...

8 Commits

Author SHA1 Message Date
L1nSn0w
5720017e4d refactor(api): replace async workspace join with synchronous method
Updated the account and registration services to use a synchronous method for joining the default workspace during account creation. This change simplifies the implementation by removing the asynchronous wrapper and related executor, while ensuring that the registration process remains efficient. Adjusted unit tests to reflect the updated method usage.
2026-02-14 17:33:50 +08:00
L1nSn0w
c21c6c3815 feat(api): add shutdown hook for workspace join executor
Implemented a shutdown hook to ensure proper cleanup of the module-level executor used for workspace joining. This includes a best-effort cleanup method that cancels queued tasks and waits for currently running tasks to complete, enhancing the reliability of the service during process termination. Updated error handling to log any issues encountered during shutdown.
2026-02-14 17:33:50 +08:00
L1nSn0w
0e55ef7336 feat(api): implement asynchronous workspace joining for enterprise accounts
Refactored the account and registration services to utilize an asynchronous method for joining the default workspace during account creation. This change enhances performance by allowing the registration process to proceed without waiting for the workspace joining operation to complete. Updated unit tests to cover the new asynchronous behavior and ensure proper logging of any exceptions that occur during the process.
2026-02-14 17:33:50 +08:00
L1nSn0w
eb5b747a06 feat(api): improve timeout handling in BaseRequest class
Updated the BaseRequest class to conditionally include the timeout parameter when making requests with httpx. This change preserves the library's default timeout behavior by only passing the timeout argument when it is explicitly set, enhancing request management and flexibility.
2026-02-14 17:33:50 +08:00
L1nSn0w
e643b83460 feat(api): conditionally join default workspace for enterprise accounts
Updated the account and registration services to conditionally attempt joining the default workspace based on the ENTERPRISE_ENABLED configuration. Enhanced the enterprise service to enforce required fields in the response payload, ensuring robust error handling. Added unit tests to verify the behavior for both enabled and disabled enterprise scenarios.
2026-02-14 17:33:50 +08:00
L1nSn0w
95d1913f2c feat(api): add timeout and error handling options to enterprise request
Enhanced the BaseRequest class to include optional timeout and raise_for_status parameters for improved request handling. Updated the EnterpriseService to utilize these new options during account addition to the default workspace, ensuring better control over request behavior. Additionally, modified unit tests to reflect these changes.
2026-02-14 17:33:50 +08:00
L1nSn0w
0318f2ec71 feat(api): enhance account registration process with improved error handling
Implemented better error handling during account addition to the default workspace for enterprise users, ensuring smoother user registration experience even when workspace joining fails.
2026-02-14 17:33:50 +08:00
L1nSn0w
68e3a1c990 feat(api): implement best-effort account addition to default workspace for enterprise users
Added functionality to attempt adding accounts to the default workspace during account registration and creation processes. This includes a new method in the enterprise service to handle the workspace joining logic, ensuring it does not block user registration on failure.
2026-02-14 17:33:50 +08:00
5 changed files with 364 additions and 2 deletions

View File

@@ -289,6 +289,12 @@ class AccountService:
TenantService.create_owner_tenant_if_not_exist(account=account)
# Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace).
if getattr(dify_config, "ENTERPRISE_ENABLED", False):
from services.enterprise.enterprise_service import try_join_default_workspace
try_join_default_workspace(str(account.id))
return account
@staticmethod
@@ -1407,6 +1413,12 @@ class RegisterService:
tenant_was_created.send(tenant)
db.session.commit()
# Enterprise-only: best-effort add the account to the default workspace (does not switch current workspace).
if getattr(dify_config, "ENTERPRISE_ENABLED", False):
from services.enterprise.enterprise_service import try_join_default_workspace
try_join_default_workspace(str(account.id))
except WorkSpaceNotAllowedCreateError:
db.session.rollback()
logger.exception("Register failed")

View File

@@ -39,6 +39,9 @@ class BaseRequest:
endpoint: str,
json: Any | None = None,
params: Mapping[str, Any] | None = None,
*,
timeout: float | httpx.Timeout | None = None,
raise_for_status: bool = False,
) -> Any:
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
url = f"{cls.base_url}{endpoint}"
@@ -53,7 +56,16 @@ class BaseRequest:
logger.debug("Failed to generate traceparent header", exc_info=True)
with httpx.Client(mounts=mounts) as client:
response = client.request(method, url, json=json, params=params, headers=headers)
# IMPORTANT:
# - In httpx, passing timeout=None disables timeouts (infinite) and overrides the library default.
# - To preserve httpx's default timeout behavior for existing call sites, only pass the kwarg when set.
request_kwargs: dict[str, Any] = {"json": json, "params": params, "headers": headers}
if timeout is not None:
request_kwargs["timeout"] = timeout
response = client.request(method, url, **request_kwargs)
if raise_for_status:
response.raise_for_status()
return response.json()

View File

@@ -1,9 +1,16 @@
import logging
import uuid
from datetime import datetime
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field
from configs import dify_config
from services.enterprise.base import EnterpriseRequest
logger = logging.getLogger(__name__)
DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0
class WebAppSettings(BaseModel):
access_mode: str = Field(
@@ -30,6 +37,52 @@ class WorkspacePermission(BaseModel):
)
class DefaultWorkspaceJoinResult(BaseModel):
"""
Result of ensuring an account is a member of the enterprise default workspace.
- joined=True is idempotent (already a member also returns True)
- joined=False means enterprise default workspace is not configured or invalid/archived
"""
# Only workspace_id can be empty when "no default workspace configured".
workspace_id: str = ""
# These fields are required to avoid silently treating error payloads as "skipped".
joined: bool
message: str
model_config = ConfigDict(extra="forbid")
def try_join_default_workspace(account_id: str) -> None:
"""
Enterprise-only side-effect: ensure account is a member of the default workspace.
This is a best-effort integration. Failures must not block user registration.
"""
if not dify_config.ENTERPRISE_ENABLED:
return
try:
result = EnterpriseService.join_default_workspace(account_id=account_id)
if result.joined:
logger.info(
"Joined enterprise default workspace for account %s (workspace_id=%s)",
account_id,
result.workspace_id,
)
else:
logger.info(
"Skipped joining enterprise default workspace for account %s (message=%s)",
account_id,
result.message,
)
except Exception:
logger.warning("Failed to join enterprise default workspace for account %s", account_id, exc_info=True)
class EnterpriseService:
@classmethod
def get_info(cls):
@@ -39,6 +92,34 @@ class EnterpriseService:
def get_workspace_info(cls, tenant_id: str):
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
@classmethod
def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult:
"""
Call enterprise inner API to add an account to the default workspace.
NOTE: EnterpriseRequest.base_url is expected to already include the `/inner/api` prefix,
so the endpoint here is `/default-workspace/members`.
"""
# Ensure we are sending a UUID-shaped string (enterprise side validates too).
try:
uuid.UUID(account_id)
except ValueError as e:
raise ValueError(f"account_id must be a valid UUID: {account_id}") from e
data = EnterpriseRequest.send_request(
"POST",
"/default-workspace/members",
json={"account_id": account_id},
timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS,
raise_for_status=True,
)
if not isinstance(data, dict):
raise ValueError("Invalid response format from enterprise default workspace API")
if "joined" not in data or "message" not in data:
raise ValueError("Invalid response payload from enterprise default workspace API")
return DefaultWorkspaceJoinResult.model_validate(data)
@classmethod
def get_app_sso_settings_last_update_time(cls) -> datetime:
data = EnterpriseRequest.send_request("GET", "/sso/app/last-update-time")

View File

@@ -0,0 +1,137 @@
"""Unit tests for enterprise service integrations.
This module covers the enterprise-only default workspace auto-join behavior:
- Enterprise mode disabled: no external calls
- Successful join / skipped join: no errors
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
"""
from unittest.mock import patch
import pytest
from services.enterprise.enterprise_service import (
DefaultWorkspaceJoinResult,
EnterpriseService,
try_join_default_workspace,
)
class TestJoinDefaultWorkspace:
def test_join_default_workspace_success(self):
account_id = "11111111-1111-1111-1111-111111111111"
response = {"workspace_id": "22222222-2222-2222-2222-222222222222", "joined": True, "message": "ok"}
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
mock_send_request.return_value = response
result = EnterpriseService.join_default_workspace(account_id=account_id)
assert isinstance(result, DefaultWorkspaceJoinResult)
assert result.workspace_id == response["workspace_id"]
assert result.joined is True
assert result.message == "ok"
mock_send_request.assert_called_once_with(
"POST",
"/default-workspace/members",
json={"account_id": account_id},
timeout=1.0,
raise_for_status=True,
)
def test_join_default_workspace_invalid_response_format_raises(self):
account_id = "11111111-1111-1111-1111-111111111111"
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
mock_send_request.return_value = "not-a-dict"
with pytest.raises(ValueError, match="Invalid response format"):
EnterpriseService.join_default_workspace(account_id=account_id)
def test_join_default_workspace_invalid_account_id_raises(self):
with pytest.raises(ValueError):
EnterpriseService.join_default_workspace(account_id="not-a-uuid")
def test_join_default_workspace_missing_required_fields_raises(self):
account_id = "11111111-1111-1111-1111-111111111111"
response = {"workspace_id": "", "message": "ok"} # missing "joined"
with patch("services.enterprise.enterprise_service.EnterpriseRequest.send_request") as mock_send_request:
mock_send_request.return_value = response
with pytest.raises(ValueError, match="Invalid response payload"):
EnterpriseService.join_default_workspace(account_id=account_id)
class TestTryJoinDefaultWorkspace:
def test_try_join_default_workspace_enterprise_disabled_noop(self):
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = False
try_join_default_workspace("11111111-1111-1111-1111-111111111111")
mock_join.assert_not_called()
def test_try_join_default_workspace_successful_join_does_not_raise(self):
account_id = "11111111-1111-1111-1111-111111111111"
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = True
mock_join.return_value = DefaultWorkspaceJoinResult(
workspace_id="22222222-2222-2222-2222-222222222222",
joined=True,
message="ok",
)
# Should not raise
try_join_default_workspace(account_id)
mock_join.assert_called_once_with(account_id=account_id)
def test_try_join_default_workspace_skipped_join_does_not_raise(self):
account_id = "11111111-1111-1111-1111-111111111111"
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = True
mock_join.return_value = DefaultWorkspaceJoinResult(
workspace_id="",
joined=False,
message="no default workspace configured",
)
# Should not raise
try_join_default_workspace(account_id)
mock_join.assert_called_once_with(account_id=account_id)
def test_try_join_default_workspace_api_failure_soft_fails(self):
account_id = "11111111-1111-1111-1111-111111111111"
with (
patch("services.enterprise.enterprise_service.dify_config") as mock_config,
patch("services.enterprise.enterprise_service.EnterpriseService.join_default_workspace") as mock_join,
):
mock_config.ENTERPRISE_ENABLED = True
mock_join.side_effect = Exception("network failure")
# Should not raise
try_join_default_workspace(account_id)
mock_join.assert_called_once_with(account_id=account_id)
def test_try_join_default_workspace_invalid_account_id_soft_fails(self):
with patch("services.enterprise.enterprise_service.dify_config") as mock_config:
mock_config.ENTERPRISE_ENABLED = True
# Should not raise even though UUID parsing fails inside join_default_workspace
try_join_default_workspace("not-a-uuid")

View File

@@ -1064,6 +1064,67 @@ class TestRegisterService:
# ==================== Registration Tests ====================
def test_create_account_and_tenant_calls_default_workspace_join_when_enterprise_enabled(
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
):
"""Enterprise-only side effect should be invoked when ENTERPRISE_ENABLED is True."""
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False)
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
account_id="11111111-1111-1111-1111-111111111111"
)
with (
patch("services.account_service.AccountService.create_account") as mock_create_account,
patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_workspace,
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
):
mock_create_account.return_value = mock_account
result = AccountService.create_account_and_tenant(
email="test@example.com",
name="Test User",
interface_language="en-US",
password=None,
)
assert result == mock_account
mock_create_workspace.assert_called_once_with(account=mock_account)
mock_join_default_workspace.assert_called_once_with(str(mock_account.id))
def test_create_account_and_tenant_does_not_call_default_workspace_join_when_enterprise_disabled(
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
):
"""Enterprise-only side effect should not be invoked when ENTERPRISE_ENABLED is False."""
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False, raising=False)
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
account_id="11111111-1111-1111-1111-111111111111"
)
with (
patch("services.account_service.AccountService.create_account") as mock_create_account,
patch("services.account_service.TenantService.create_owner_tenant_if_not_exist") as mock_create_workspace,
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
):
mock_create_account.return_value = mock_account
AccountService.create_account_and_tenant(
email="test@example.com",
name="Test User",
interface_language="en-US",
password=None,
)
mock_create_workspace.assert_called_once_with(account=mock_account)
mock_join_default_workspace.assert_not_called()
def test_register_success(self, mock_db_dependencies, mock_external_service_dependencies):
"""Test successful account registration."""
# Setup mocks
@@ -1115,6 +1176,65 @@ class TestRegisterService:
mock_event.send.assert_called_once_with(mock_tenant)
self._assert_database_operations_called(mock_db_dependencies["db"])
def test_register_calls_default_workspace_join_when_enterprise_enabled(
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
):
"""Enterprise-only side effect should be invoked after successful register commit."""
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", True, raising=False)
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
account_id="11111111-1111-1111-1111-111111111111"
)
with (
patch("services.account_service.AccountService.create_account") as mock_create_account,
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
):
mock_create_account.return_value = mock_account
result = RegisterService.register(
email="test@example.com",
name="Test User",
password="password123",
language="en-US",
create_workspace_required=False,
)
assert result == mock_account
mock_join_default_workspace.assert_called_once_with(str(mock_account.id))
def test_register_does_not_call_default_workspace_join_when_enterprise_disabled(
self, mock_db_dependencies, mock_external_service_dependencies, monkeypatch
):
"""Enterprise-only side effect should not be invoked when ENTERPRISE_ENABLED is False."""
monkeypatch.setattr(dify_config, "ENTERPRISE_ENABLED", False, raising=False)
mock_external_service_dependencies["feature_service"].get_system_features.return_value.is_allow_register = True
mock_external_service_dependencies["billing_service"].is_email_in_freeze.return_value = False
mock_account = TestAccountAssociatedDataFactory.create_account_mock(
account_id="11111111-1111-1111-1111-111111111111"
)
with (
patch("services.account_service.AccountService.create_account") as mock_create_account,
patch("services.enterprise.enterprise_service.try_join_default_workspace") as mock_join_default_workspace,
):
mock_create_account.return_value = mock_account
RegisterService.register(
email="test@example.com",
name="Test User",
password="password123",
language="en-US",
create_workspace_required=False,
)
mock_join_default_workspace.assert_not_called()
def test_register_with_oauth(self, mock_db_dependencies, mock_external_service_dependencies):
"""Test account registration with OAuth integration."""
# Setup mocks