Compare commits

..

1 Commits
main ... 34028

Author SHA1 Message Date
Asuka Minato
a84ad7cf0f https://github.com/langgenius/dify/issues/34028#issuecomment-4125972296 2026-03-28 08:24:15 +09:00
32 changed files with 2616 additions and 406 deletions

View File

@@ -50,6 +50,17 @@ jobs:
run: |
diff -u /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
- name: Check if line counts match
id: line_count_check
run: |
base_lines=$(wc -l < /tmp/pyrefly_base.txt)
pr_lines=$(wc -l < /tmp/pyrefly_pr.txt)
if [ "$base_lines" -eq "$pr_lines" ]; then
echo "same=true" >> $GITHUB_OUTPUT
else
echo "same=false" >> $GITHUB_OUTPUT
fi
- name: Save PR number
run: |
echo ${{ github.event.pull_request.number }} > pr_number.txt
@@ -63,7 +74,7 @@ jobs:
pr_number.txt
- name: Comment PR with pyrefly diff
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
if: ${{ github.event.pull_request.head.repo.full_name == github.repository && steps.line_count_check.outputs.same == 'false' }}
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
with:
github-token: ${{ secrets.GITHUB_TOKEN }}

View File

@@ -49,7 +49,7 @@ jobs:
- name: Run Type Checks
if: steps.changed-files.outputs.any_changed == 'true'
run: make type-check-core
run: make type-check
- name: Dotenv check
if: steps.changed-files.outputs.any_changed == 'true'

View File

@@ -22,8 +22,8 @@ jobs:
strategy:
fail-fast: false
matrix:
shardIndex: [1, 2, 3, 4]
shardTotal: [4]
shardIndex: [1, 2, 3, 4, 5, 6]
shardTotal: [6]
defaults:
run:
shell: bash
@@ -66,6 +66,7 @@ jobs:
- name: Checkout code
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
with:
fetch-depth: 0
persist-credentials: false
- name: Setup web environment

View File

@@ -74,12 +74,6 @@ type-check:
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
@echo "✅ Type checks complete"
type-check-core:
@echo "📝 Running core type checks (basedpyright + mypy)..."
@./dev/basedpyright-check $(PATH_TO_CHECK)
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
@echo "✅ Core type checks complete"
test:
@echo "🧪 Running backend unit tests..."
@if [ -n "$(TARGET_TESTS)" ]; then \
@@ -139,7 +133,6 @@ help:
@echo " make check - Check code with ruff"
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
@echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)"
@echo " make type-check-core - Run core type checks (basedpyright, mypy)"
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
@echo ""
@echo "Docker Build Targets:"

View File

@@ -127,8 +127,7 @@ ALIYUN_OSS_AUTH_VERSION=v1
ALIYUN_OSS_REGION=your-region
# Don't start with '/'. OSS doesn't support leading slash in object names.
ALIYUN_OSS_PATH=your-path
# Optional CloudBox ID for Aliyun OSS, DO NOT enable it if you are not using CloudBox.
#ALIYUN_CLOUDBOX_ID=your-cloudbox-id
ALIYUN_CLOUDBOX_ID=your-cloudbox-id
# Google Storage configuration
GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name

View File

@@ -8,7 +8,6 @@ Go admin-api caller.
from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field
from sqlalchemy import select
from sqlalchemy.orm import Session
from controllers.common.schema import register_schema_model
@@ -88,7 +87,7 @@ class EnterpriseAppDSLExport(Resource):
"""Export an app's DSL as YAML."""
include_secret = request.args.get("include_secret", "false").lower() == "true"
app_model = db.session.get(App, app_id)
app_model = db.session.query(App).filter_by(id=app_id).first()
if not app_model:
return {"message": "app not found"}, 404
@@ -105,7 +104,7 @@ def _get_active_account(email: str) -> Account | None:
Workspace membership is already validated by the Go admin-api caller.
"""
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
account = db.session.query(Account).filter_by(email=email).first()
if account is None or account.status != AccountStatus.ACTIVE:
return None
return account

View File

@@ -18,7 +18,7 @@ from graphon.model_runtime.entities import (
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
from graphon.model_runtime.entities.model_entities import ModelFeature
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from sqlalchemy import func, select
from sqlalchemy import select
from core.agent.entities import AgentEntity, AgentToolEntity
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
@@ -104,14 +104,11 @@ class BaseAgentRunner(AppRunner):
)
# get how many agent thoughts have been created
self.agent_thought_count = (
db.session.scalar(
select(func.count())
.select_from(MessageAgentThought)
.where(
MessageAgentThought.message_id == self.message.id,
)
db.session.query(MessageAgentThought)
.where(
MessageAgentThought.message_id == self.message.id,
)
or 0
.count()
)
db.session.close()

View File

@@ -1,7 +1,7 @@
import logging
from collections.abc import Sequence
from sqlalchemy import select, update
from sqlalchemy import select
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.app_invoke_entities import InvokeFrom
@@ -70,21 +70,23 @@ class DatasetIndexToolCallbackHandler:
)
child_chunk = db.session.scalar(child_chunk_stmt)
if child_chunk:
db.session.execute(
update(DocumentSegment)
_ = (
db.session.query(DocumentSegment)
.where(DocumentSegment.id == child_chunk.segment_id)
.values(hit_count=DocumentSegment.hit_count + 1)
.update(
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
)
)
else:
conditions = [DocumentSegment.index_node_id == document.metadata["doc_id"]]
query = db.session.query(DocumentSegment).where(
DocumentSegment.index_node_id == document.metadata["doc_id"]
)
if "dataset_id" in document.metadata:
conditions.append(DocumentSegment.dataset_id == document.metadata["dataset_id"])
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
# add hit count to document segment
db.session.execute(
update(DocumentSegment).where(*conditions).values(hit_count=DocumentSegment.hit_count + 1)
)
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
db.session.commit()

View File

@@ -19,7 +19,7 @@ def encrypt_token(tenant_id: str, token: str):
from extensions.ext_database import db
from models.account import Tenant
if not (tenant := db.session.get(Tenant, tenant_id)):
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
raise ValueError(f"Tenant with id {tenant_id} not found")
assert tenant.encrypt_public_key is not None
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)

View File

@@ -10,7 +10,6 @@ from graphon.model_runtime.entities.llm_entities import LLMResult
from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from graphon.model_runtime.entities.model_entities import ModelType
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from sqlalchemy import select
from core.app.app_config.entities import ModelConfig
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
@@ -411,8 +410,8 @@ class LLMGenerator:
model_config: ModelConfig,
ideal_output: str | None,
):
last_run: Message | None = db.session.scalar(
select(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).limit(1)
last_run: Message | None = (
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
)
if not last_run:
return LLMGenerator.__instruction_modify_common(

View File

@@ -227,7 +227,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
get app
"""
try:
app = db.session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1))
app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first()
except Exception:
raise ValueError("app not found")

View File

@@ -1,4 +1,4 @@
from sqlalchemy import delete, select
from sqlalchemy import select
from core.tools.__base.tool_provider import ToolProviderController
from core.tools.builtin_tool.provider import BuiltinToolProviderController
@@ -31,7 +31,7 @@ class ToolLabelManager:
raise ValueError("Unsupported tool type")
# delete old labels
db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id))
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete()
# insert new labels
for label in labels:

View File

@@ -255,11 +255,11 @@ class ToolManager:
if builtin_provider is None:
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
else:
builtin_provider = db.session.scalar(
select(BuiltinToolProvider)
builtin_provider = (
db.session.query(BuiltinToolProvider)
.where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
.limit(1)
.first()
)
if builtin_provider is None:
@@ -818,13 +818,13 @@ class ToolManager:
:return: the provider controller, the credentials
"""
provider: ApiToolProvider | None = db.session.scalar(
select(ApiToolProvider)
provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.where(
ApiToolProvider.id == provider_id,
ApiToolProvider.tenant_id == tenant_id,
)
.limit(1)
.first()
)
if provider is None:
@@ -872,13 +872,13 @@ class ToolManager:
get api provider
"""
provider_name = provider
provider_obj: ApiToolProvider | None = db.session.scalar(
select(ApiToolProvider)
provider_obj: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.where(
ApiToolProvider.tenant_id == tenant_id,
ApiToolProvider.name == provider,
)
.limit(1)
.first()
)
if provider_obj is None:
@@ -964,10 +964,10 @@ class ToolManager:
@classmethod
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
try:
workflow_provider: WorkflowToolProvider | None = db.session.scalar(
select(WorkflowToolProvider)
workflow_provider: WorkflowToolProvider | None = (
db.session.query(WorkflowToolProvider)
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
.limit(1)
.first()
)
if workflow_provider is None:
@@ -981,10 +981,10 @@ class ToolManager:
@classmethod
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
try:
api_provider: ApiToolProvider | None = db.session.scalar(
select(ApiToolProvider)
api_provider: ApiToolProvider | None = (
db.session.query(ApiToolProvider)
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
.limit(1)
.first()
)
if api_provider is None:

View File

@@ -110,7 +110,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
context_list: list[RetrievalSourceMetadata] = []
resource_number = 1
for segment in sorted_segments:
dataset = db.session.get(Dataset, segment.dataset_id)
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
document_stmt = select(Document).where(
Document.id == segment.document_id,
Document.enabled == True,

View File

@@ -205,7 +205,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
if self.return_resource:
for record in records:
segment = record.segment
dataset = db.session.get(Dataset, segment.dataset_id)
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
dataset_document_stmt = select(DatasetDocument).where(
DatasetDocument.id == segment.document_id,
DatasetDocument.enabled == True,

View File

@@ -35,13 +35,15 @@ class ApiKeyAuthService:
@staticmethod
def get_auth_credentials(tenant_id: str, category: str, provider: str):
data_source_api_key_bindings = db.session.scalar(
select(DataSourceApiKeyAuthBinding).where(
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.where(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.category == category,
DataSourceApiKeyAuthBinding.provider == provider,
DataSourceApiKeyAuthBinding.disabled.is_(False),
)
.first()
)
if not data_source_api_key_bindings:
return None
@@ -52,11 +54,10 @@ class ApiKeyAuthService:
@staticmethod
def delete_provider_auth(tenant_id: str, binding_id: str):
data_source_api_key_binding = db.session.scalar(
select(DataSourceApiKeyAuthBinding).where(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
DataSourceApiKeyAuthBinding.id == binding_id,
)
data_source_api_key_binding = (
db.session.query(DataSourceApiKeyAuthBinding)
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
.first()
)
if data_source_api_key_binding:
db.session.delete(data_source_api_key_binding)

View File

@@ -1,6 +1,4 @@
from __future__ import annotations
from unittest.mock import MagicMock, patch
from unittest.mock import patch
import pytest
from faker import Faker
@@ -536,283 +534,3 @@ class TestWorkspaceService:
# Verify database state
db_session_with_containers.refresh(tenant)
assert tenant.id is not None
def test_get_tenant_info_should_raise_assertion_when_join_missing(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""TenantAccountJoin must exist; missing join should raise AssertionError."""
fake = Faker()
account = Account(email=fake.email(), name=fake.name(), interface_language="en-US", status="active")
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(name=fake.company(), status="normal", plan="basic")
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
# No TenantAccountJoin created
with patch("services.workspace_service.current_user", account):
with pytest.raises(AssertionError, match="TenantAccountJoin not found"):
WorkspaceService.get_tenant_info(tenant)
def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""replace_webapp_logo should be None when custom_config_dict does not have the key."""
import json
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
tenant.custom_config = json.dumps({})
db_session_with_containers.commit()
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
with patch("services.workspace_service.current_user", account):
result = WorkspaceService.get_tenant_info(tenant)
assert result is not None
assert result["custom_config"]["replace_webapp_logo"] is None
def test_get_tenant_info_should_use_files_url_for_logo_url(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""The logo URL should use dify_config.FILES_URL as the base."""
import json
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
tenant.custom_config = json.dumps({"replace_webapp_logo": True})
db_session_with_containers.commit()
custom_base = "https://cdn.mycompany.io"
mock_external_service_dependencies["dify_config"].FILES_URL = custom_base
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
with patch("services.workspace_service.current_user", account):
result = WorkspaceService.get_tenant_info(tenant)
assert result is not None
assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base)
def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
mock_external_service_dependencies["dify_config"].EDITION = "SELF_HOSTED"
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = False
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
with patch("services.workspace_service.current_user", account):
result = WorkspaceService.get_tenant_info(tenant)
assert result is not None
assert "next_credit_reset_date" not in result
assert "trial_credits" not in result
assert "trial_credits_used" not in result
def test_get_tenant_info_cloud_credit_reset_date(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""next_credit_reset_date should be present in CLOUD edition."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
feature.can_replace_logo = False
feature.next_credit_reset_date = "2025-02-01"
feature.billing.subscription.plan = "professional"
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
with (
patch("services.workspace_service.current_user", account),
patch("services.credit_pool_service.CreditPoolService.get_pool", return_value=None),
):
result = WorkspaceService.get_tenant_info(tenant)
assert result is not None
assert result["next_credit_reset_date"] == "2025-02-01"
def test_get_tenant_info_cloud_paid_pool_not_full(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""trial_credits come from paid pool when plan is not sandbox and pool is not full."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
feature.can_replace_logo = False
feature.next_credit_reset_date = "2025-02-01"
feature.billing.subscription.plan = "professional"
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
paid_pool = MagicMock(quota_limit=1000, quota_used=200)
with (
patch("services.workspace_service.current_user", account),
patch("services.credit_pool_service.CreditPoolService.get_pool", return_value=paid_pool),
):
result = WorkspaceService.get_tenant_info(tenant)
assert result is not None
assert result["trial_credits"] == 1000
assert result["trial_credits_used"] == 200
def test_get_tenant_info_cloud_paid_pool_unlimited(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""quota_limit == -1 means unlimited; service should use paid pool."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
feature.can_replace_logo = False
feature.next_credit_reset_date = "2025-02-01"
feature.billing.subscription.plan = "professional"
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
paid_pool = MagicMock(quota_limit=-1, quota_used=999)
with (
patch("services.workspace_service.current_user", account),
patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, None]),
):
result = WorkspaceService.get_tenant_info(tenant)
assert result is not None
assert result["trial_credits"] == -1
assert result["trial_credits_used"] == 999
def test_get_tenant_info_cloud_fall_back_to_trial_when_paid_full(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""When paid pool is exhausted, switch to trial pool."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
feature.can_replace_logo = False
feature.next_credit_reset_date = "2025-02-01"
feature.billing.subscription.plan = "professional"
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
paid_pool = MagicMock(quota_limit=500, quota_used=500)
trial_pool = MagicMock(quota_limit=100, quota_used=10)
with (
patch("services.workspace_service.current_user", account),
patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, trial_pool]),
):
result = WorkspaceService.get_tenant_info(tenant)
assert result is not None
assert result["trial_credits"] == 100
assert result["trial_credits_used"] == 10
def test_get_tenant_info_cloud_fall_back_to_trial_when_paid_none(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""When paid_pool is None, fall back to trial pool."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
feature.can_replace_logo = False
feature.next_credit_reset_date = "2025-02-01"
feature.billing.subscription.plan = "professional"
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
trial_pool = MagicMock(quota_limit=50, quota_used=5)
with (
patch("services.workspace_service.current_user", account),
patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[None, trial_pool]),
):
result = WorkspaceService.get_tenant_info(tenant)
assert result is not None
assert result["trial_credits"] == 50
assert result["trial_credits_used"] == 5
def test_get_tenant_info_cloud_sandbox_uses_trial_pool(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""When plan is SANDBOX, skip paid pool and use trial pool."""
from enums.cloud_plan import CloudPlan
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
feature.can_replace_logo = False
feature.next_credit_reset_date = "2025-02-01"
feature.billing.subscription.plan = CloudPlan.SANDBOX
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
paid_pool = MagicMock(quota_limit=1000, quota_used=0)
trial_pool = MagicMock(quota_limit=200, quota_used=20)
with (
patch("services.workspace_service.current_user", account),
patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, trial_pool]),
):
result = WorkspaceService.get_tenant_info(tenant)
assert result is not None
assert result["trial_credits"] == 200
assert result["trial_credits_used"] == 20
def test_get_tenant_info_cloud_both_pools_none(
self, db_session_with_containers: Session, mock_external_service_dependencies
):
"""When both paid and trial pools are absent, trial_credits should not be set."""
fake = Faker()
account, tenant = self._create_test_account_and_tenant(
db_session_with_containers, mock_external_service_dependencies
)
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
feature.can_replace_logo = False
feature.next_credit_reset_date = "2025-02-01"
feature.billing.subscription.plan = "professional"
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
with (
patch("services.workspace_service.current_user", account),
patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[None, None]),
):
result = WorkspaceService.get_tenant_info(tenant)
assert result is not None
assert "trial_credits" not in result
assert "trial_credits_used" not in result

View File

@@ -64,18 +64,18 @@ class TestGetActiveAccount:
def test_returns_active_account(self, mock_db):
mock_account = MagicMock()
mock_account.status = "active"
mock_db.session.scalar.return_value = mock_account
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
result = _get_active_account("user@example.com")
assert result is mock_account
mock_db.session.scalar.assert_called_once()
mock_db.session.query.return_value.filter_by.assert_called_once_with(email="user@example.com")
@patch("controllers.inner_api.app.dsl.db")
def test_returns_none_for_inactive_account(self, mock_db):
mock_account = MagicMock()
mock_account.status = "banned"
mock_db.session.scalar.return_value = mock_account
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
result = _get_active_account("banned@example.com")
@@ -83,7 +83,7 @@ class TestGetActiveAccount:
@patch("controllers.inner_api.app.dsl.db")
def test_returns_none_for_nonexistent_email(self, mock_db):
mock_db.session.scalar.return_value = None
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
result = _get_active_account("missing@example.com")
@@ -205,7 +205,7 @@ class TestEnterpriseAppDSLExport:
@patch("controllers.inner_api.app.dsl.db")
def test_export_success_returns_200(self, mock_db, mock_dsl_cls, api_instance, app: Flask):
mock_app = MagicMock()
mock_db.session.get.return_value = mock_app
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app
mock_dsl_cls.export_dsl.return_value = "version: 0.6.0\nkind: app\n"
unwrapped = inspect.unwrap(api_instance.get)
@@ -221,7 +221,7 @@ class TestEnterpriseAppDSLExport:
@patch("controllers.inner_api.app.dsl.db")
def test_export_with_secret(self, mock_db, mock_dsl_cls, api_instance, app: Flask):
mock_app = MagicMock()
mock_db.session.get.return_value = mock_app
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app
mock_dsl_cls.export_dsl.return_value = "yaml-data"
unwrapped = inspect.unwrap(api_instance.get)
@@ -234,7 +234,7 @@ class TestEnterpriseAppDSLExport:
@patch("controllers.inner_api.app.dsl.db")
def test_export_app_not_found_returns_404(self, mock_db, api_instance, app: Flask):
mock_db.session.get.return_value = None
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
unwrapped = inspect.unwrap(api_instance.get)
with app.test_request_context("?include_secret=false"):

View File

@@ -621,7 +621,7 @@ class TestConvertDatasetRetrieverTool:
class TestBaseAgentRunnerInit:
def test_init_sets_stream_tool_call_and_files(self, mocker):
session = mocker.MagicMock()
session.scalar.return_value = 2
session.query.return_value.where.return_value.count.return_value = 2
mocker.patch.object(module.db, "session", session)
mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[])

View File

@@ -114,9 +114,13 @@ class TestOnToolEnd:
document = mocker.Mock()
document.metadata = {"document_id": "doc-1", "doc_id": "node-1"}
mock_query = mocker.Mock()
mock_db.session.query.return_value = mock_query
mock_query.where.return_value = mock_query
handler.on_tool_end([document])
mock_db.session.execute.assert_called_once()
mock_query.update.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_on_tool_end_non_parent_child_index(self, handler, mocker):
@@ -134,9 +138,13 @@ class TestOnToolEnd:
"dataset_id": "dataset-1",
}
mock_query = mocker.Mock()
mock_db.session.query.return_value = mock_query
mock_query.where.return_value = mock_query
handler.on_tool_end([document])
mock_db.session.execute.assert_called_once()
mock_query.update.assert_called_once()
mock_db.session.commit.assert_called_once()
def test_on_tool_end_empty_documents(self, handler):

View File

@@ -38,13 +38,13 @@ class TestObfuscatedToken:
class TestEncryptToken:
@patch("extensions.ext_database.db.session.get")
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
def test_successful_encryption(self, mock_encrypt, mock_query):
"""Test successful token encryption"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value = mock_tenant
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_data"
result = encrypt_token("tenant-123", "test_token")
@@ -52,10 +52,10 @@ class TestEncryptToken:
assert result == base64.b64encode(b"encrypted_data").decode()
mock_encrypt.assert_called_with("test_token", "mock_public_key")
@patch("extensions.ext_database.db.session.get")
@patch("models.engine.db.session.query")
def test_tenant_not_found(self, mock_query):
"""Test error when tenant doesn't exist"""
mock_query.return_value = None
mock_query.return_value.where.return_value.first.return_value = None
with pytest.raises(ValueError) as exc_info:
encrypt_token("invalid-tenant", "test_token")
@@ -119,7 +119,7 @@ class TestGetDecryptDecoding:
class TestEncryptDecryptIntegration:
@patch("extensions.ext_database.db.session.get")
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
@patch("libs.rsa.decrypt")
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query):
@@ -127,7 +127,7 @@ class TestEncryptDecryptIntegration:
# Setup mock tenant
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value = mock_tenant
mock_query.return_value.where.return_value.first.return_value = mock_tenant
# Setup mock encryption/decryption
original_token = "test_token_123"
@@ -146,14 +146,14 @@ class TestEncryptDecryptIntegration:
class TestSecurity:
"""Critical security tests for encryption system"""
@patch("extensions.ext_database.db.session.get")
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
def test_cross_tenant_isolation(self, mock_encrypt, mock_query):
"""Ensure tokens encrypted for one tenant cannot be used by another"""
# Setup mock tenant
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "tenant1_public_key"
mock_query.return_value = mock_tenant
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_for_tenant1"
# Encrypt token for tenant1
@@ -181,12 +181,12 @@ class TestSecurity:
with pytest.raises(Exception, match="Decryption error"):
decrypt_token("tenant-123", tampered)
@patch("extensions.ext_database.db.session.get")
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
def test_encryption_randomness(self, mock_encrypt, mock_query):
"""Ensure same plaintext produces different ciphertext"""
mock_tenant = MagicMock(encrypt_public_key="key")
mock_query.return_value = mock_tenant
mock_query.return_value.where.return_value.first.return_value = mock_tenant
# Different outputs for same input
mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"]
@@ -205,13 +205,13 @@ class TestEdgeCases:
# Test empty string (which is a valid str type)
assert obfuscated_token("") == ""
@patch("extensions.ext_database.db.session.get")
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query):
"""Test encryption of empty token"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value = mock_tenant
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_empty"
result = encrypt_token("tenant-123", "")
@@ -219,13 +219,13 @@ class TestEdgeCases:
assert result == base64.b64encode(b"encrypted_empty").decode()
mock_encrypt.assert_called_with("", "mock_public_key")
@patch("extensions.ext_database.db.session.get")
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query):
"""Test tokens containing special/unicode characters"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value = mock_tenant
mock_query.return_value.where.return_value.first.return_value = mock_tenant
mock_encrypt.return_value = b"encrypted_special"
# Test various special characters
@@ -242,13 +242,13 @@ class TestEdgeCases:
assert result == base64.b64encode(b"encrypted_special").decode()
mock_encrypt.assert_called_with(token, "mock_public_key")
@patch("extensions.ext_database.db.session.get")
@patch("models.engine.db.session.query")
@patch("libs.rsa.encrypt")
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query):
"""Test behavior when token exceeds RSA encryption limits"""
mock_tenant = MagicMock()
mock_tenant.encrypt_public_key = "mock_public_key"
mock_query.return_value = mock_tenant
mock_query.return_value.where.return_value.first.return_value = mock_tenant
# RSA 2048-bit can only encrypt ~245 bytes
# The actual limit depends on padding scheme

View File

@@ -314,8 +314,8 @@ class TestLLMGenerator:
assert "An unexpected error occurred" in result["error"]
def test_instruction_modify_legacy_no_last_run(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
mock_scalar.return_value = None
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
# Mock __instruction_modify_common call via invoke_llm
mock_response = MagicMock()
@@ -328,12 +328,12 @@ class TestLLMGenerator:
assert result == {"modified": "prompt"}
def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
with patch("extensions.ext_database.db.session.query") as mock_query:
last_run = MagicMock()
last_run.query = "q"
last_run.answer = "a"
last_run.error = "e"
mock_scalar.return_value = last_run
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = last_run
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"modified": "prompt"}'
@@ -483,8 +483,8 @@ class TestLLMGenerator:
def test_instruction_modify_common_placeholders(self, mock_model_instance, model_config_entity):
# Testing placeholders replacement via instruction_modify_legacy for convenience
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
mock_scalar.return_value = None
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = '{"ok": true}'
@@ -504,8 +504,8 @@ class TestLLMGenerator:
assert "current_val" in user_msg_dict["instruction"]
def test_instruction_modify_common_no_braces(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
mock_scalar.return_value = None
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "No braces here"
mock_model_instance.invoke_llm.return_value = mock_response
@@ -516,8 +516,8 @@ class TestLLMGenerator:
assert "Could not find a valid JSON object" in result["error"]
def test_instruction_modify_common_not_dict(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
mock_scalar.return_value = None
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "[1, 2, 3]"
mock_model_instance.invoke_llm.return_value = mock_response
@@ -556,8 +556,8 @@ class TestLLMGenerator:
)
def test_instruction_modify_common_invoke_error(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
mock_scalar.return_value = None
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke Failed")
result = LLMGenerator.instruction_modify_legacy(
@@ -566,8 +566,8 @@ class TestLLMGenerator:
assert "Failed to generate code" in result["error"]
def test_instruction_modify_common_exception(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
mock_scalar.return_value = None
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
result = LLMGenerator.instruction_modify_legacy(
@@ -576,8 +576,8 @@ class TestLLMGenerator:
assert "An unexpected error occurred" in result["error"]
def test_instruction_modify_common_json_error(self, mock_model_instance, model_config_entity):
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
mock_scalar.return_value = None
with patch("extensions.ext_database.db.session.query") as mock_query:
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
mock_response = MagicMock()
mock_response.message.get_text_content.return_value = "No JSON here"

View File

@@ -332,21 +332,27 @@ class TestPluginAppBackwardsInvocation:
PluginAppBackwardsInvocation._get_user("uid")
def test_get_app_returns_app(self, mocker):
query_chain = MagicMock()
query_chain.where.return_value = query_chain
app_obj = MagicMock(id="app")
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=app_obj)))
query_chain.first.return_value = app_obj
db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain)))
mocker.patch("core.plugin.backwards_invocation.app.db", db)
assert PluginAppBackwardsInvocation._get_app("app", "tenant") is app_obj
def test_get_app_raises_when_missing(self, mocker):
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=None)))
query_chain = MagicMock()
query_chain.where.return_value = query_chain
query_chain.first.return_value = None
db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain)))
mocker.patch("core.plugin.backwards_invocation.app.db", db)
with pytest.raises(ValueError, match="app not found"):
PluginAppBackwardsInvocation._get_app("app", "tenant")
def test_get_app_raises_when_query_fails(self, mocker):
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(side_effect=RuntimeError("db down"))))
db = SimpleNamespace(session=MagicMock(query=MagicMock(side_effect=RuntimeError("db down"))))
mocker.patch("core.plugin.backwards_invocation.app.db", db)
with pytest.raises(ValueError, match="app not found"):

View File

@@ -38,9 +38,11 @@ def test_tool_label_manager_filter_tool_labels():
def test_tool_label_manager_update_tool_labels_db():
controller = _api_controller("api-1")
with patch("core.tools.tool_label_manager.db") as mock_db:
delete_query = mock_db.session.query.return_value.where.return_value
delete_query.delete.return_value = None
ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"])
mock_db.session.execute.assert_called_once()
delete_query.delete.assert_called_once()
# only one valid unique label should be inserted.
assert mock_db.session.add.call_count == 1
mock_db.session.commit.assert_called_once()

View File

@@ -220,7 +220,9 @@ def test_get_tool_runtime_builtin_with_credentials_decrypts_and_forks():
with patch.object(ToolManager, "get_builtin_provider", return_value=controller):
with patch("core.helper.credential_utils.check_credential_policy_compliance"):
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.session.scalar.return_value = builtin_provider
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
builtin_provider
)
encrypter = Mock()
encrypter.decrypt.return_value = {"api_key": "secret"}
cache = Mock()
@@ -272,7 +274,7 @@ def test_get_tool_runtime_builtin_refreshes_expired_oauth_credentials(
)
refreshed = SimpleNamespace(credentials={"token": "new"}, expires_at=123456)
mock_db.session.scalar.return_value = builtin_provider
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = builtin_provider
encrypter = Mock()
encrypter.decrypt.return_value = {"token": "old"}
encrypter.encrypt.return_value = {"token": "encrypted"}
@@ -696,10 +698,12 @@ def test_get_api_provider_controller_returns_controller_and_credentials():
privacy_policy="privacy",
custom_disclaimer="disclaimer",
)
db_query = Mock()
db_query.where.return_value.first.return_value = provider
controller = Mock()
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.session.scalar.return_value = provider
mock_db.session.query.return_value = db_query
with patch(
"core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller
) as mock_from_db:
@@ -726,10 +730,12 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels():
privacy_policy="privacy",
custom_disclaimer="disclaimer",
)
db_query = Mock()
db_query.where.return_value.first.return_value = provider
controller = Mock()
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.session.scalar.return_value = provider
mock_db.session.query.return_value = db_query
with patch("core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller):
encrypter = Mock()
encrypter.decrypt.return_value = {"api_key_value": "secret"}
@@ -744,7 +750,7 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels():
def test_get_api_provider_controller_not_found_raises():
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.session.scalar.return_value = None
mock_db.session.query.return_value.where.return_value.first.return_value = None
with pytest.raises(ToolProviderNotFoundError, match="api provider missing not found"):
ToolManager.get_api_provider_controller("tenant-1", "missing")
@@ -803,14 +809,14 @@ def test_generate_tool_icon_urls_for_workflow_and_api():
workflow_provider = SimpleNamespace(icon='{"background": "#222", "content": "W"}')
api_provider = SimpleNamespace(icon='{"background": "#333", "content": "A"}')
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.session.scalar.side_effect = [workflow_provider, api_provider]
mock_db.session.query.return_value.where.return_value.first.side_effect = [workflow_provider, api_provider]
assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "wf-1") == {"background": "#222", "content": "W"}
assert ToolManager.generate_api_tool_icon_url("tenant-1", "api-1") == {"background": "#333", "content": "A"}
def test_generate_tool_icon_urls_missing_workflow_and_api_use_default():
with patch("core.tools.tool_manager.db") as mock_db:
mock_db.session.scalar.return_value = None
mock_db.session.query.return_value.where.return_value.first.return_value = None
assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "missing")["background"] == "#252525"
assert ToolManager.generate_api_tool_icon_url("tenant-1", "missing")["background"] == "#252525"

View File

@@ -263,7 +263,7 @@ def test_single_dataset_retriever_non_economy_run_sorts_context_and_resources():
)
db_session = Mock()
db_session.scalar.side_effect = [dataset, lookup_doc_low, lookup_doc_high]
db_session.get.return_value = dataset
db_session.query.return_value.filter_by.return_value.first.return_value = dataset
tool = SingleDatasetRetrieverTool(
tenant_id="tenant-1",
@@ -444,7 +444,7 @@ def test_multi_dataset_retriever_run_orders_segments_and_returns_resources():
)
db_session = Mock()
db_session.scalars.return_value.all.return_value = [segment_for_node_2, segment_for_node_1]
db_session.get.side_effect = [
db_session.query.return_value.filter_by.return_value.first.side_effect = [
SimpleNamespace(id="dataset-2", name="Dataset Two"),
SimpleNamespace(id="dataset-1", name="Dataset One"),
]

View File

@@ -0,0 +1,558 @@
from __future__ import annotations
from dataclasses import dataclass
from datetime import UTC, datetime
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
from models.dataset import Dataset
from services.entities.knowledge_entities.knowledge_entities import (
DocumentMetadataOperation,
MetadataArgs,
MetadataDetail,
MetadataOperationData,
)
from services.metadata_service import MetadataService
@dataclass
class _DocumentStub:
id: str
name: str
uploader: str
upload_date: datetime
last_update_date: datetime
data_source_type: str
doc_metadata: dict[str, object] | None
@pytest.fixture
def mock_db(mocker: MockerFixture) -> MagicMock:
mocked_db = mocker.patch("services.metadata_service.db")
mocked_db.session = MagicMock()
return mocked_db
@pytest.fixture
def mock_redis_client(mocker: MockerFixture) -> MagicMock:
return mocker.patch("services.metadata_service.redis_client")
@pytest.fixture
def mock_current_account(mocker: MockerFixture) -> MagicMock:
mock_user = SimpleNamespace(id="user-1")
return mocker.patch("services.metadata_service.current_account_with_tenant", return_value=(mock_user, "tenant-1"))
def _build_document(document_id: str, doc_metadata: dict[str, object] | None = None) -> _DocumentStub:
now = datetime(2025, 1, 1, 10, 30, tzinfo=UTC)
return _DocumentStub(
id=document_id,
name=f"doc-{document_id}",
uploader="qa@example.com",
upload_date=now,
last_update_date=now,
data_source_type="upload_file",
doc_metadata=doc_metadata,
)
def _dataset(**kwargs: Any) -> Dataset:
return cast(Dataset, SimpleNamespace(**kwargs))
def test_create_metadata_should_raise_value_error_when_name_exceeds_limit() -> None:
# Arrange
metadata_args = MetadataArgs(type="string", name="x" * 256)
# Act + Assert
with pytest.raises(ValueError, match="cannot exceed 255"):
MetadataService.create_metadata("dataset-1", metadata_args)
def test_create_metadata_should_raise_value_error_when_metadata_name_already_exists(
mock_db: MagicMock,
mock_current_account: MagicMock,
) -> None:
# Arrange
metadata_args = MetadataArgs(type="string", name="priority")
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
# Act + Assert
with pytest.raises(ValueError, match="already exists"):
MetadataService.create_metadata("dataset-1", metadata_args)
# Assert
mock_current_account.assert_called_once()
def test_create_metadata_should_raise_value_error_when_name_collides_with_builtin(
mock_db: MagicMock, mock_current_account: MagicMock
) -> None:
# Arrange
metadata_args = MetadataArgs(type="string", name=BuiltInField.document_name)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="Built-in fields"):
MetadataService.create_metadata("dataset-1", metadata_args)
def test_create_metadata_should_persist_metadata_when_input_is_valid(
mock_db: MagicMock, mock_current_account: MagicMock
) -> None:
# Arrange
metadata_args = MetadataArgs(type="number", name="score")
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
# Act
result = MetadataService.create_metadata("dataset-1", metadata_args)
# Assert
assert result.tenant_id == "tenant-1"
assert result.dataset_id == "dataset-1"
assert result.type == "number"
assert result.name == "score"
assert result.created_by == "user-1"
mock_db.session.add.assert_called_once_with(result)
mock_db.session.commit.assert_called_once()
mock_current_account.assert_called_once()
def test_update_metadata_name_should_raise_value_error_when_name_exceeds_limit() -> None:
# Arrange
too_long_name = "x" * 256
# Act + Assert
with pytest.raises(ValueError, match="cannot exceed 255"):
MetadataService.update_metadata_name("dataset-1", "metadata-1", too_long_name)
def test_update_metadata_name_should_raise_value_error_when_duplicate_name_exists(
mock_db: MagicMock, mock_current_account: MagicMock
) -> None:
# Arrange
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
# Act + Assert
with pytest.raises(ValueError, match="already exists"):
MetadataService.update_metadata_name("dataset-1", "metadata-1", "duplicate")
# Assert
mock_current_account.assert_called_once()
def test_update_metadata_name_should_raise_value_error_when_name_collides_with_builtin(
mock_db: MagicMock,
mock_current_account: MagicMock,
) -> None:
# Arrange
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
# Act + Assert
with pytest.raises(ValueError, match="Built-in fields"):
MetadataService.update_metadata_name("dataset-1", "metadata-1", BuiltInField.source)
# Assert
mock_current_account.assert_called_once()
def test_update_metadata_name_should_update_bound_documents_and_return_metadata(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mock_current_account: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
fixed_now = datetime(2025, 2, 1, 0, 0, tzinfo=UTC)
mocker.patch("services.metadata_service.naive_utc_now", return_value=fixed_now)
metadata = SimpleNamespace(id="metadata-1", name="old_name", updated_by=None, updated_at=None)
bindings = [SimpleNamespace(document_id="doc-1"), SimpleNamespace(document_id="doc-2")]
query_duplicate = MagicMock()
query_duplicate.filter_by.return_value.first.return_value = None
query_metadata = MagicMock()
query_metadata.filter_by.return_value.first.return_value = metadata
query_bindings = MagicMock()
query_bindings.filter_by.return_value.all.return_value = bindings
mock_db.session.query.side_effect = [query_duplicate, query_metadata, query_bindings]
doc_1 = _build_document("1", {"old_name": "value", "other": "keep"})
doc_2 = _build_document("2", None)
mock_get_documents = mocker.patch("services.metadata_service.DocumentService.get_document_by_ids")
mock_get_documents.return_value = [doc_1, doc_2]
# Act
result = MetadataService.update_metadata_name("dataset-1", "metadata-1", "new_name")
# Assert
assert result is metadata
assert metadata.name == "new_name"
assert metadata.updated_by == "user-1"
assert metadata.updated_at == fixed_now
assert doc_1.doc_metadata == {"other": "keep", "new_name": "value"}
assert doc_2.doc_metadata == {"new_name": None}
mock_get_documents.assert_called_once_with(["doc-1", "doc-2"])
mock_db.session.commit.assert_called_once()
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
mock_current_account.assert_called_once()
def test_update_metadata_name_should_return_none_when_metadata_does_not_exist(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mock_current_account: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
mock_logger = mocker.patch("services.metadata_service.logger")
query_duplicate = MagicMock()
query_duplicate.filter_by.return_value.first.return_value = None
query_metadata = MagicMock()
query_metadata.filter_by.return_value.first.return_value = None
mock_db.session.query.side_effect = [query_duplicate, query_metadata]
# Act
result = MetadataService.update_metadata_name("dataset-1", "missing-id", "new_name")
# Assert
assert result is None
mock_logger.exception.assert_called_once()
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
mock_current_account.assert_called_once()
def test_delete_metadata_should_remove_metadata_and_related_document_fields(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
metadata = SimpleNamespace(id="metadata-1", name="obsolete")
bindings = [SimpleNamespace(document_id="doc-1")]
query_metadata = MagicMock()
query_metadata.filter_by.return_value.first.return_value = metadata
query_bindings = MagicMock()
query_bindings.filter_by.return_value.all.return_value = bindings
mock_db.session.query.side_effect = [query_metadata, query_bindings]
document = _build_document("1", {"obsolete": "legacy", "remaining": "value"})
mocker.patch("services.metadata_service.DocumentService.get_document_by_ids", return_value=[document])
# Act
result = MetadataService.delete_metadata("dataset-1", "metadata-1")
# Assert
assert result is metadata
assert document.doc_metadata == {"remaining": "value"}
mock_db.session.delete.assert_called_once_with(metadata)
mock_db.session.commit.assert_called_once()
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
def test_delete_metadata_should_return_none_when_metadata_is_missing(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
mock_logger = mocker.patch("services.metadata_service.logger")
# Act
result = MetadataService.delete_metadata("dataset-1", "missing-id")
# Assert
assert result is None
mock_logger.exception.assert_called_once()
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
def test_get_built_in_fields_should_return_all_expected_fields() -> None:
# Arrange
expected_names = {
BuiltInField.document_name,
BuiltInField.uploader,
BuiltInField.upload_date,
BuiltInField.last_update_date,
BuiltInField.source,
}
# Act
result = MetadataService.get_built_in_fields()
# Assert
assert {item["name"] for item in result} == expected_names
assert [item["type"] for item in result] == ["string", "string", "time", "time", "string"]
def test_enable_built_in_field_should_return_immediately_when_already_enabled(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id")
# Act
MetadataService.enable_built_in_field(dataset)
# Assert
get_docs.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_enable_built_in_field_should_populate_documents_and_enable_flag(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
doc_1 = _build_document("1", {"custom": "value"})
doc_2 = _build_document("2", None)
mocker.patch(
"services.metadata_service.DocumentService.get_working_documents_by_dataset_id",
return_value=[doc_1, doc_2],
)
# Act
MetadataService.enable_built_in_field(dataset)
# Assert
assert dataset.built_in_field_enabled is True
assert doc_1.doc_metadata is not None
assert doc_1.doc_metadata[BuiltInField.document_name] == "doc-1"
assert doc_1.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file
assert doc_2.doc_metadata is not None
assert doc_2.doc_metadata[BuiltInField.uploader] == "qa@example.com"
mock_db.session.commit.assert_called_once()
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
def test_disable_built_in_field_should_return_immediately_when_already_disabled(
mock_db: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id")
# Act
MetadataService.disable_built_in_field(dataset)
# Assert
get_docs.assert_not_called()
mock_db.session.commit.assert_not_called()
def test_disable_built_in_field_should_remove_builtin_keys_and_disable_flag(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
document = _build_document(
"1",
{
BuiltInField.document_name: "doc",
BuiltInField.uploader: "user",
BuiltInField.upload_date: 1.0,
BuiltInField.last_update_date: 2.0,
BuiltInField.source: MetadataDataSource.upload_file,
"custom": "keep",
},
)
mocker.patch(
"services.metadata_service.DocumentService.get_working_documents_by_dataset_id",
return_value=[document],
)
# Act
MetadataService.disable_built_in_field(dataset)
# Assert
assert dataset.built_in_field_enabled is False
assert document.doc_metadata == {"custom": "keep"}
mock_db.session.commit.assert_called_once()
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
def test_update_documents_metadata_should_replace_metadata_and_create_bindings_on_full_update(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mock_current_account: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
document = _build_document("1", {"legacy": "value"})
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document)
delete_chain = mock_db.session.query.return_value.filter_by.return_value
delete_chain.delete.return_value = 1
operation = DocumentMetadataOperation(
document_id="1",
metadata_list=[MetadataDetail(id="meta-1", name="priority", value="high")],
partial_update=False,
)
metadata_args = MetadataOperationData(operation_data=[operation])
# Act
MetadataService.update_documents_metadata(dataset, metadata_args)
# Assert
assert document.doc_metadata == {"priority": "high"}
delete_chain.delete.assert_called_once()
assert mock_db.session.commit.call_count == 1
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1")
mock_current_account.assert_called_once()
def test_update_documents_metadata_should_skip_existing_binding_and_preserve_existing_fields_on_partial_update(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mock_current_account: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
document = _build_document("1", {"existing": "value"})
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document)
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
operation = DocumentMetadataOperation(
document_id="1",
metadata_list=[MetadataDetail(id="meta-1", name="new_key", value="new_value")],
partial_update=True,
)
metadata_args = MetadataOperationData(operation_data=[operation])
# Act
MetadataService.update_documents_metadata(dataset, metadata_args)
# Assert
assert document.doc_metadata is not None
assert document.doc_metadata["existing"] == "value"
assert document.doc_metadata["new_key"] == "new_value"
assert document.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file
assert mock_db.session.commit.call_count == 1
assert mock_db.session.add.call_count == 1
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1")
mock_current_account.assert_called_once()
def test_update_documents_metadata_should_raise_and_rollback_when_document_not_found(
mock_db: MagicMock,
mock_redis_client: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=None)
operation = DocumentMetadataOperation(document_id="404", metadata_list=[], partial_update=True)
metadata_args = MetadataOperationData(operation_data=[operation])
# Act + Assert
with pytest.raises(ValueError, match="Document not found"):
MetadataService.update_documents_metadata(dataset, metadata_args)
# Assert
mock_db.session.rollback.assert_called_once()
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_404")
@pytest.mark.parametrize(
("dataset_id", "document_id", "expected_key"),
[
("dataset-1", None, "dataset_metadata_lock_dataset-1"),
(None, "doc-1", "document_metadata_lock_doc-1"),
],
)
def test_knowledge_base_metadata_lock_check_should_set_lock_when_not_already_locked(
dataset_id: str | None,
document_id: str | None,
expected_key: str,
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = None
# Act
MetadataService.knowledge_base_metadata_lock_check(dataset_id, document_id)
# Assert
mock_redis_client.set.assert_called_once_with(expected_key, 1, ex=3600)
def test_knowledge_base_metadata_lock_check_should_raise_when_dataset_lock_exists(
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = 1
# Act + Assert
with pytest.raises(ValueError, match="knowledge base metadata operation is running"):
MetadataService.knowledge_base_metadata_lock_check("dataset-1", None)
def test_knowledge_base_metadata_lock_check_should_raise_when_document_lock_exists(
mock_redis_client: MagicMock,
) -> None:
# Arrange
mock_redis_client.get.return_value = 1
# Act + Assert
with pytest.raises(ValueError, match="document metadata operation is running"):
MetadataService.knowledge_base_metadata_lock_check(None, "doc-1")
def test_get_dataset_metadatas_should_exclude_builtin_and_include_binding_counts(mock_db: MagicMock) -> None:
# Arrange
dataset = _dataset(
id="dataset-1",
built_in_field_enabled=True,
doc_metadata=[
{"id": "meta-1", "name": "priority", "type": "string"},
{"id": "built-in", "name": "ignored", "type": "string"},
{"id": "meta-2", "name": "score", "type": "number"},
],
)
count_chain = mock_db.session.query.return_value.filter_by.return_value
count_chain.count.side_effect = [3, 1]
# Act
result = MetadataService.get_dataset_metadatas(dataset)
# Assert
assert result["built_in_field_enabled"] is True
assert result["doc_metadata"] == [
{"id": "meta-1", "name": "priority", "type": "string", "count": 3},
{"id": "meta-2", "name": "score", "type": "number", "count": 1},
]
def test_get_dataset_metadatas_should_return_empty_list_when_no_metadata(mock_db: MagicMock) -> None:
# Arrange
dataset = _dataset(id="dataset-1", built_in_field_enabled=False, doc_metadata=None)
# Act
result = MetadataService.get_dataset_metadatas(dataset)
# Assert
assert result == {"doc_metadata": [], "built_in_field_enabled": False}
mock_db.session.query.assert_not_called()

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,576 @@
from __future__ import annotations
from types import SimpleNamespace
from typing import Any, cast
from unittest.mock import MagicMock
import pytest
from pytest_mock import MockerFixture
from models.account import Tenant
# ---------------------------------------------------------------------------
# Constants used throughout the tests
# ---------------------------------------------------------------------------
TENANT_ID = "tenant-abc"
ACCOUNT_ID = "account-xyz"
FILES_BASE_URL = "https://files.example.com"
DB_PATH = "services.workspace_service.db"
FEATURE_SERVICE_PATH = "services.workspace_service.FeatureService.get_features"
TENANT_SERVICE_PATH = "services.workspace_service.TenantService.has_roles"
DIFY_CONFIG_PATH = "services.workspace_service.dify_config"
CURRENT_USER_PATH = "services.workspace_service.current_user"
CREDIT_POOL_SERVICE_PATH = "services.credit_pool_service.CreditPoolService.get_pool"
# ---------------------------------------------------------------------------
# Helpers / factories
# ---------------------------------------------------------------------------
def _make_tenant(
tenant_id: str = TENANT_ID,
name: str = "My Workspace",
plan: str = "sandbox",
status: str = "active",
custom_config: dict | None = None,
) -> Tenant:
"""Create a minimal Tenant-like namespace."""
return cast(
Tenant,
SimpleNamespace(
id=tenant_id,
name=name,
plan=plan,
status=status,
created_at="2024-01-01T00:00:00Z",
custom_config_dict=custom_config or {},
),
)
def _make_feature(
can_replace_logo: bool = False,
next_credit_reset_date: str | None = None,
billing_plan: str = "sandbox",
) -> MagicMock:
"""Create a feature namespace matching what FeatureService.get_features returns."""
feature = MagicMock()
feature.can_replace_logo = can_replace_logo
feature.next_credit_reset_date = next_credit_reset_date
feature.billing.subscription.plan = billing_plan
return feature
def _make_pool(quota_limit: int, quota_used: int) -> MagicMock:
pool = MagicMock()
pool.quota_limit = quota_limit
pool.quota_used = quota_used
return pool
def _make_tenant_account_join(role: str = "normal") -> SimpleNamespace:
return SimpleNamespace(role=role)
def _tenant_info(result: object) -> dict[str, Any] | None:
return cast(dict[str, Any] | None, result)
# ---------------------------------------------------------------------------
# Shared fixtures
# ---------------------------------------------------------------------------
@pytest.fixture
def mock_current_user() -> SimpleNamespace:
"""Return a lightweight current_user stand-in."""
return SimpleNamespace(id=ACCOUNT_ID)
@pytest.fixture
def basic_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict:
"""
Patch the common external boundaries used by WorkspaceService.get_tenant_info.
Returns a dict of named mocks so individual tests can customise them.
"""
mocker.patch(CURRENT_USER_PATH, mock_current_user)
mock_db_session = mocker.patch(f"{DB_PATH}.session")
mock_query_chain = MagicMock()
mock_db_session.query.return_value = mock_query_chain
mock_query_chain.where.return_value = mock_query_chain
mock_query_chain.first.return_value = _make_tenant_account_join(role="owner")
mock_feature = mocker.patch(FEATURE_SERVICE_PATH, return_value=_make_feature())
mock_has_roles = mocker.patch(TENANT_SERVICE_PATH, return_value=False)
mock_config = mocker.patch(DIFY_CONFIG_PATH)
mock_config.EDITION = "SELF_HOSTED"
mock_config.FILES_URL = FILES_BASE_URL
return {
"db_session": mock_db_session,
"query_chain": mock_query_chain,
"get_features": mock_feature,
"has_roles": mock_has_roles,
"config": mock_config,
}
# ---------------------------------------------------------------------------
# 1. None Tenant Handling
# ---------------------------------------------------------------------------
def test_get_tenant_info_should_return_none_when_tenant_is_none() -> None:
"""get_tenant_info should short-circuit and return None for a falsy tenant."""
from services.workspace_service import WorkspaceService
# Arrange
tenant = None
# Act
result = WorkspaceService.get_tenant_info(cast(Tenant, tenant))
# Assert
assert result is None
def test_get_tenant_info_should_return_none_when_tenant_is_falsy() -> None:
"""get_tenant_info treats any falsy value as absent (e.g. empty string, 0)."""
from services.workspace_service import WorkspaceService
# Arrange / Act / Assert
assert WorkspaceService.get_tenant_info("") is None # type: ignore[arg-type]
# ---------------------------------------------------------------------------
# 2. Basic Tenant Info — happy path
# ---------------------------------------------------------------------------
def test_get_tenant_info_should_return_base_fields(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""get_tenant_info should always return the six base scalar fields."""
from services.workspace_service import WorkspaceService
# Arrange
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["id"] == TENANT_ID
assert result["name"] == "My Workspace"
assert result["plan"] == "sandbox"
assert result["status"] == "active"
assert result["created_at"] == "2024-01-01T00:00:00Z"
assert result["trial_end_reason"] is None
def test_get_tenant_info_should_populate_role_from_tenant_account_join(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""The 'role' field should be taken from TenantAccountJoin, not the default."""
from services.workspace_service import WorkspaceService
# Arrange
basic_mocks["query_chain"].first.return_value = _make_tenant_account_join(role="admin")
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["role"] == "admin"
def test_get_tenant_info_should_raise_assertion_when_tenant_account_join_missing(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""
The service asserts that TenantAccountJoin exists.
Missing join should raise AssertionError.
"""
from services.workspace_service import WorkspaceService
# Arrange
basic_mocks["query_chain"].first.return_value = None
tenant = _make_tenant()
# Act + Assert
with pytest.raises(AssertionError, match="TenantAccountJoin not found"):
WorkspaceService.get_tenant_info(tenant)
# ---------------------------------------------------------------------------
# 3. Logo Customisation
# ---------------------------------------------------------------------------
def test_get_tenant_info_should_include_custom_config_when_logo_allowed_and_admin(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""custom_config block should appear for OWNER/ADMIN when can_replace_logo is True."""
from services.workspace_service import WorkspaceService
# Arrange
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
basic_mocks["has_roles"].return_value = True
tenant = _make_tenant(
custom_config={
"replace_webapp_logo": True,
"remove_webapp_brand": True,
}
)
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert "custom_config" in result
assert result["custom_config"]["remove_webapp_brand"] is True
expected_logo_url = f"{FILES_BASE_URL}/files/workspaces/{TENANT_ID}/webapp-logo"
assert result["custom_config"]["replace_webapp_logo"] == expected_logo_url
def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""replace_webapp_logo should be None when custom_config_dict does not have the key."""
from services.workspace_service import WorkspaceService
# Arrange
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
basic_mocks["has_roles"].return_value = True
tenant = _make_tenant(custom_config={}) # no replace_webapp_logo key
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["custom_config"]["replace_webapp_logo"] is None
def test_get_tenant_info_should_not_include_custom_config_when_logo_not_allowed(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""custom_config should be absent when can_replace_logo is False."""
from services.workspace_service import WorkspaceService
# Arrange
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=False)
basic_mocks["has_roles"].return_value = True
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert "custom_config" not in result
def test_get_tenant_info_should_not_include_custom_config_when_user_not_admin(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""custom_config block is gated on OWNER or ADMIN role."""
from services.workspace_service import WorkspaceService
# Arrange
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
basic_mocks["has_roles"].return_value = False # regular member
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert "custom_config" not in result
def test_get_tenant_info_should_use_files_url_for_logo_url(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""The logo URL should use dify_config.FILES_URL as the base."""
from services.workspace_service import WorkspaceService
# Arrange
custom_base = "https://cdn.mycompany.io"
basic_mocks["config"].FILES_URL = custom_base
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
basic_mocks["has_roles"].return_value = True
tenant = _make_tenant(custom_config={"replace_webapp_logo": True})
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base)
# ---------------------------------------------------------------------------
# 4. Cloud-Edition Credit Features
# ---------------------------------------------------------------------------
CLOUD_BILLING_PLAN_NON_SANDBOX = "professional" # any plan that is not SANDBOX
@pytest.fixture
def cloud_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict:
"""Patches for CLOUD edition tests, billing plan = professional by default."""
mocker.patch(CURRENT_USER_PATH, mock_current_user)
mock_db_session = mocker.patch(f"{DB_PATH}.session")
mock_query_chain = MagicMock()
mock_db_session.query.return_value = mock_query_chain
mock_query_chain.where.return_value = mock_query_chain
mock_query_chain.first.return_value = _make_tenant_account_join(role="owner")
mock_feature = mocker.patch(
FEATURE_SERVICE_PATH,
return_value=_make_feature(
can_replace_logo=False,
next_credit_reset_date="2025-02-01",
billing_plan=CLOUD_BILLING_PLAN_NON_SANDBOX,
),
)
mocker.patch(TENANT_SERVICE_PATH, return_value=False)
mock_config = mocker.patch(DIFY_CONFIG_PATH)
mock_config.EDITION = "CLOUD"
mock_config.FILES_URL = FILES_BASE_URL
return {
"db_session": mock_db_session,
"query_chain": mock_query_chain,
"get_features": mock_feature,
"config": mock_config,
}
def test_get_tenant_info_should_add_next_credit_reset_date_in_cloud_edition(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""next_credit_reset_date should be present in CLOUD edition."""
from services.workspace_service import WorkspaceService
# Arrange
mocker.patch(
CREDIT_POOL_SERVICE_PATH,
side_effect=[None, None], # both paid and trial pools absent
)
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["next_credit_reset_date"] == "2025-02-01"
def test_get_tenant_info_should_use_paid_pool_when_plan_is_not_sandbox_and_pool_not_full(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""trial_credits/trial_credits_used come from the paid pool when conditions are met."""
from services.workspace_service import WorkspaceService
# Arrange
paid_pool = _make_pool(quota_limit=1000, quota_used=200)
mocker.patch(CREDIT_POOL_SERVICE_PATH, return_value=paid_pool)
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["trial_credits"] == 1000
assert result["trial_credits_used"] == 200
def test_get_tenant_info_should_use_paid_pool_when_quota_limit_is_infinite(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""quota_limit == -1 means unlimited; service should still use the paid pool."""
from services.workspace_service import WorkspaceService
# Arrange
paid_pool = _make_pool(quota_limit=-1, quota_used=999)
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, None])
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["trial_credits"] == -1
assert result["trial_credits_used"] == 999
def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_full(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""When paid pool is exhausted (used >= limit), switch to trial pool."""
from services.workspace_service import WorkspaceService
# Arrange
paid_pool = _make_pool(quota_limit=500, quota_used=500) # exactly full
trial_pool = _make_pool(quota_limit=100, quota_used=10)
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool])
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["trial_credits"] == 100
assert result["trial_credits_used"] == 10
def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_none(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""When paid_pool is None, fall back to trial pool."""
from services.workspace_service import WorkspaceService
# Arrange
trial_pool = _make_pool(quota_limit=50, quota_used=5)
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, trial_pool])
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["trial_credits"] == 50
assert result["trial_credits_used"] == 5
def test_get_tenant_info_should_fall_back_to_trial_pool_for_sandbox_plan(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""
When the subscription plan IS SANDBOX, the paid pool branch is skipped
entirely and we fall back to the trial pool.
"""
from enums.cloud_plan import CloudPlan
from services.workspace_service import WorkspaceService
# Arrange — override billing plan to SANDBOX
cloud_mocks["get_features"].return_value = _make_feature(
next_credit_reset_date="2025-02-01",
billing_plan=CloudPlan.SANDBOX,
)
paid_pool = _make_pool(quota_limit=1000, quota_used=0)
trial_pool = _make_pool(quota_limit=200, quota_used=20)
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool])
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert result["trial_credits"] == 200
assert result["trial_credits_used"] == 20
def test_get_tenant_info_should_omit_trial_credits_when_both_pools_are_none(
mocker: MockerFixture,
cloud_mocks: dict,
) -> None:
"""When both paid and trial pools are absent, trial_credits should not be set."""
from services.workspace_service import WorkspaceService
# Arrange
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, None])
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert "trial_credits" not in result
assert "trial_credits_used" not in result
# ---------------------------------------------------------------------------
# 5. Self-hosted / Non-Cloud Edition
# ---------------------------------------------------------------------------
def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode."""
from services.workspace_service import WorkspaceService
# Arrange (basic_mocks already sets EDITION = "SELF_HOSTED")
tenant = _make_tenant()
# Act
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
# Assert
assert result is not None
assert "next_credit_reset_date" not in result
assert "trial_credits" not in result
assert "trial_credits_used" not in result
# ---------------------------------------------------------------------------
# 6. DB query integrity
# ---------------------------------------------------------------------------
def test_get_tenant_info_should_query_tenant_account_join_with_correct_ids(
mocker: MockerFixture,
basic_mocks: dict,
) -> None:
"""
The DB query for TenantAccountJoin must be scoped to the correct
tenant_id and current_user.id.
"""
from services.workspace_service import WorkspaceService
# Arrange
tenant = _make_tenant(tenant_id="my-special-tenant")
mock_current_user = mocker.patch(CURRENT_USER_PATH)
mock_current_user.id = "special-user-id"
# Act
WorkspaceService.get_tenant_info(tenant)
# Assert — db.session.query was invoked (at least once)
basic_mocks["db_session"].query.assert_called()

View File

@@ -488,8 +488,7 @@ ALIYUN_OSS_REGION=ap-southeast-1
ALIYUN_OSS_AUTH_VERSION=v4
# Don't start with '/'. OSS doesn't support leading slash in object names.
ALIYUN_OSS_PATH=your-path
# Optional CloudBox ID for Aliyun OSS, DO NOT enable it if you are not using CloudBox.
#ALIYUN_CLOUDBOX_ID=your-cloudbox-id
ALIYUN_CLOUDBOX_ID=your-cloudbox-id
# Tencent COS Configuration
#

View File

@@ -275,7 +275,6 @@ services:
# Use the shared environment variables.
<<: *shared-api-worker-env
DB_DATABASE: ${DB_PLUGIN_DATABASE:-dify_plugin}
DB_SSL_MODE: ${DB_SSL_MODE:-disable}
SERVER_PORT: ${PLUGIN_DAEMON_PORT:-5002}
SERVER_KEY: ${PLUGIN_DAEMON_KEY:-lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi}
MAX_PLUGIN_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}

View File

@@ -146,6 +146,7 @@ x-shared-env: &shared-api-worker-env
ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-ap-southeast-1}
ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4}
ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-your-path}
ALIYUN_CLOUDBOX_ID: ${ALIYUN_CLOUDBOX_ID:-your-cloudbox-id}
TENCENT_COS_BUCKET_NAME: ${TENCENT_COS_BUCKET_NAME:-your-bucket-name}
TENCENT_COS_SECRET_KEY: ${TENCENT_COS_SECRET_KEY:-your-secret-key}
TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id}
@@ -984,7 +985,6 @@ services:
# Use the shared environment variables.
<<: *shared-api-worker-env
DB_DATABASE: ${DB_PLUGIN_DATABASE:-dify_plugin}
DB_SSL_MODE: ${DB_SSL_MODE:-disable}
SERVER_PORT: ${PLUGIN_DAEMON_PORT:-5002}
SERVER_KEY: ${PLUGIN_DAEMON_KEY:-lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi}
MAX_PLUGIN_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}