mirror of
https://github.com/langgenius/dify.git
synced 2026-03-28 19:26:47 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a84ad7cf0f |
13
.github/workflows/pyrefly-diff.yml
vendored
13
.github/workflows/pyrefly-diff.yml
vendored
@@ -50,6 +50,17 @@ jobs:
|
||||
run: |
|
||||
diff -u /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
|
||||
|
||||
- name: Check if line counts match
|
||||
id: line_count_check
|
||||
run: |
|
||||
base_lines=$(wc -l < /tmp/pyrefly_base.txt)
|
||||
pr_lines=$(wc -l < /tmp/pyrefly_pr.txt)
|
||||
if [ "$base_lines" -eq "$pr_lines" ]; then
|
||||
echo "same=true" >> $GITHUB_OUTPUT
|
||||
else
|
||||
echo "same=false" >> $GITHUB_OUTPUT
|
||||
fi
|
||||
|
||||
- name: Save PR number
|
||||
run: |
|
||||
echo ${{ github.event.pull_request.number }} > pr_number.txt
|
||||
@@ -63,7 +74,7 @@ jobs:
|
||||
pr_number.txt
|
||||
|
||||
- name: Comment PR with pyrefly diff
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository }}
|
||||
if: ${{ github.event.pull_request.head.repo.full_name == github.repository && steps.line_count_check.outputs.same == 'false' }}
|
||||
uses: actions/github-script@ed597411d8f924073f98dfc5c65a23a2325f34cd # v8.0.0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
|
||||
2
.github/workflows/style.yml
vendored
2
.github/workflows/style.yml
vendored
@@ -49,7 +49,7 @@ jobs:
|
||||
|
||||
- name: Run Type Checks
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
run: make type-check-core
|
||||
run: make type-check
|
||||
|
||||
- name: Dotenv check
|
||||
if: steps.changed-files.outputs.any_changed == 'true'
|
||||
|
||||
5
.github/workflows/web-tests.yml
vendored
5
.github/workflows/web-tests.yml
vendored
@@ -22,8 +22,8 @@ jobs:
|
||||
strategy:
|
||||
fail-fast: false
|
||||
matrix:
|
||||
shardIndex: [1, 2, 3, 4]
|
||||
shardTotal: [4]
|
||||
shardIndex: [1, 2, 3, 4, 5, 6]
|
||||
shardTotal: [6]
|
||||
defaults:
|
||||
run:
|
||||
shell: bash
|
||||
@@ -66,6 +66,7 @@ jobs:
|
||||
- name: Checkout code
|
||||
uses: actions/checkout@de0fac2e4500dabe0009e67214ff5f5447ce83dd # v6.0.2
|
||||
with:
|
||||
fetch-depth: 0
|
||||
persist-credentials: false
|
||||
|
||||
- name: Setup web environment
|
||||
|
||||
7
Makefile
7
Makefile
@@ -74,12 +74,6 @@ type-check:
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "✅ Type checks complete"
|
||||
|
||||
type-check-core:
|
||||
@echo "📝 Running core type checks (basedpyright + mypy)..."
|
||||
@./dev/basedpyright-check $(PATH_TO_CHECK)
|
||||
@uv --directory api run mypy --exclude-gitignore --exclude 'tests/' --exclude 'migrations/' --check-untyped-defs --disable-error-code=import-untyped .
|
||||
@echo "✅ Core type checks complete"
|
||||
|
||||
test:
|
||||
@echo "🧪 Running backend unit tests..."
|
||||
@if [ -n "$(TARGET_TESTS)" ]; then \
|
||||
@@ -139,7 +133,6 @@ help:
|
||||
@echo " make check - Check code with ruff"
|
||||
@echo " make lint - Format, fix, and lint code (ruff, imports, dotenv)"
|
||||
@echo " make type-check - Run type checks (basedpyright, pyrefly, mypy)"
|
||||
@echo " make type-check-core - Run core type checks (basedpyright, mypy)"
|
||||
@echo " make test - Run backend unit tests (or TARGET_TESTS=./api/tests/<target_tests>)"
|
||||
@echo ""
|
||||
@echo "Docker Build Targets:"
|
||||
|
||||
@@ -127,8 +127,7 @@ ALIYUN_OSS_AUTH_VERSION=v1
|
||||
ALIYUN_OSS_REGION=your-region
|
||||
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
||||
ALIYUN_OSS_PATH=your-path
|
||||
# Optional CloudBox ID for Aliyun OSS, DO NOT enable it if you are not using CloudBox.
|
||||
#ALIYUN_CLOUDBOX_ID=your-cloudbox-id
|
||||
ALIYUN_CLOUDBOX_ID=your-cloudbox-id
|
||||
|
||||
# Google Storage configuration
|
||||
GOOGLE_STORAGE_BUCKET_NAME=your-bucket-name
|
||||
|
||||
@@ -8,7 +8,6 @@ Go admin-api caller.
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from controllers.common.schema import register_schema_model
|
||||
@@ -88,7 +87,7 @@ class EnterpriseAppDSLExport(Resource):
|
||||
"""Export an app's DSL as YAML."""
|
||||
include_secret = request.args.get("include_secret", "false").lower() == "true"
|
||||
|
||||
app_model = db.session.get(App, app_id)
|
||||
app_model = db.session.query(App).filter_by(id=app_id).first()
|
||||
if not app_model:
|
||||
return {"message": "app not found"}, 404
|
||||
|
||||
@@ -105,7 +104,7 @@ def _get_active_account(email: str) -> Account | None:
|
||||
|
||||
Workspace membership is already validated by the Go admin-api caller.
|
||||
"""
|
||||
account = db.session.scalar(select(Account).where(Account.email == email).limit(1))
|
||||
account = db.session.query(Account).filter_by(email=email).first()
|
||||
if account is None or account.status != AccountStatus.ACTIVE:
|
||||
return None
|
||||
return account
|
||||
|
||||
@@ -18,7 +18,7 @@ from graphon.model_runtime.entities import (
|
||||
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.entities import AgentEntity, AgentToolEntity
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
@@ -104,14 +104,11 @@ class BaseAgentRunner(AppRunner):
|
||||
)
|
||||
# get how many agent thoughts have been created
|
||||
self.agent_thought_count = (
|
||||
db.session.scalar(
|
||||
select(func.count())
|
||||
.select_from(MessageAgentThought)
|
||||
.where(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
)
|
||||
db.session.query(MessageAgentThought)
|
||||
.where(
|
||||
MessageAgentThought.message_id == self.message.id,
|
||||
)
|
||||
or 0
|
||||
.count()
|
||||
)
|
||||
db.session.close()
|
||||
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@@ -70,21 +70,23 @@ class DatasetIndexToolCallbackHandler:
|
||||
)
|
||||
child_chunk = db.session.scalar(child_chunk_stmt)
|
||||
if child_chunk:
|
||||
db.session.execute(
|
||||
update(DocumentSegment)
|
||||
_ = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.id == child_chunk.segment_id)
|
||||
.values(hit_count=DocumentSegment.hit_count + 1)
|
||||
.update(
|
||||
{DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False
|
||||
)
|
||||
)
|
||||
else:
|
||||
conditions = [DocumentSegment.index_node_id == document.metadata["doc_id"]]
|
||||
query = db.session.query(DocumentSegment).where(
|
||||
DocumentSegment.index_node_id == document.metadata["doc_id"]
|
||||
)
|
||||
|
||||
if "dataset_id" in document.metadata:
|
||||
conditions.append(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
query = query.where(DocumentSegment.dataset_id == document.metadata["dataset_id"])
|
||||
|
||||
# add hit count to document segment
|
||||
db.session.execute(
|
||||
update(DocumentSegment).where(*conditions).values(hit_count=DocumentSegment.hit_count + 1)
|
||||
)
|
||||
query.update({DocumentSegment.hit_count: DocumentSegment.hit_count + 1}, synchronize_session=False)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@@ -19,7 +19,7 @@ def encrypt_token(tenant_id: str, token: str):
|
||||
from extensions.ext_database import db
|
||||
from models.account import Tenant
|
||||
|
||||
if not (tenant := db.session.get(Tenant, tenant_id)):
|
||||
if not (tenant := db.session.query(Tenant).where(Tenant.id == tenant_id).first()):
|
||||
raise ValueError(f"Tenant with id {tenant_id} not found")
|
||||
assert tenant.encrypt_public_key is not None
|
||||
encrypted_token = rsa.encrypt(token, tenant.encrypt_public_key)
|
||||
|
||||
@@ -10,7 +10,6 @@ from graphon.model_runtime.entities.llm_entities import LLMResult
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
from graphon.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
@@ -411,8 +410,8 @@ class LLMGenerator:
|
||||
model_config: ModelConfig,
|
||||
ideal_output: str | None,
|
||||
):
|
||||
last_run: Message | None = db.session.scalar(
|
||||
select(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).limit(1)
|
||||
last_run: Message | None = (
|
||||
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
|
||||
)
|
||||
if not last_run:
|
||||
return LLMGenerator.__instruction_modify_common(
|
||||
|
||||
@@ -227,7 +227,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
get app
|
||||
"""
|
||||
try:
|
||||
app = db.session.scalar(select(App).where(App.id == app_id, App.tenant_id == tenant_id).limit(1))
|
||||
app = db.session.query(App).where(App.id == app_id).where(App.tenant_id == tenant_id).first()
|
||||
except Exception:
|
||||
raise ValueError("app not found")
|
||||
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from sqlalchemy import delete, select
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.tools.__base.tool_provider import ToolProviderController
|
||||
from core.tools.builtin_tool.provider import BuiltinToolProviderController
|
||||
@@ -31,7 +31,7 @@ class ToolLabelManager:
|
||||
raise ValueError("Unsupported tool type")
|
||||
|
||||
# delete old labels
|
||||
db.session.execute(delete(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id))
|
||||
db.session.query(ToolLabelBinding).where(ToolLabelBinding.tool_id == provider_id).delete()
|
||||
|
||||
# insert new labels
|
||||
for label in labels:
|
||||
|
||||
@@ -255,11 +255,11 @@ class ToolManager:
|
||||
if builtin_provider is None:
|
||||
raise ToolProviderNotFoundError(f"no default provider for {provider_id}")
|
||||
else:
|
||||
builtin_provider = db.session.scalar(
|
||||
select(BuiltinToolProvider)
|
||||
builtin_provider = (
|
||||
db.session.query(BuiltinToolProvider)
|
||||
.where(BuiltinToolProvider.tenant_id == tenant_id, (BuiltinToolProvider.provider == provider_id))
|
||||
.order_by(BuiltinToolProvider.is_default.desc(), BuiltinToolProvider.created_at.asc())
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
|
||||
if builtin_provider is None:
|
||||
@@ -818,13 +818,13 @@ class ToolManager:
|
||||
|
||||
:return: the provider controller, the credentials
|
||||
"""
|
||||
provider: ApiToolProvider | None = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.id == provider_id,
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider is None:
|
||||
@@ -872,13 +872,13 @@ class ToolManager:
|
||||
get api provider
|
||||
"""
|
||||
provider_name = provider
|
||||
provider_obj: ApiToolProvider | None = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
provider_obj: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.where(
|
||||
ApiToolProvider.tenant_id == tenant_id,
|
||||
ApiToolProvider.name == provider,
|
||||
)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
|
||||
if provider_obj is None:
|
||||
@@ -964,10 +964,10 @@ class ToolManager:
|
||||
@classmethod
|
||||
def generate_workflow_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
|
||||
try:
|
||||
workflow_provider: WorkflowToolProvider | None = db.session.scalar(
|
||||
select(WorkflowToolProvider)
|
||||
workflow_provider: WorkflowToolProvider | None = (
|
||||
db.session.query(WorkflowToolProvider)
|
||||
.where(WorkflowToolProvider.tenant_id == tenant_id, WorkflowToolProvider.id == provider_id)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
|
||||
if workflow_provider is None:
|
||||
@@ -981,10 +981,10 @@ class ToolManager:
|
||||
@classmethod
|
||||
def generate_api_tool_icon_url(cls, tenant_id: str, provider_id: str) -> EmojiIconDict:
|
||||
try:
|
||||
api_provider: ApiToolProvider | None = db.session.scalar(
|
||||
select(ApiToolProvider)
|
||||
api_provider: ApiToolProvider | None = (
|
||||
db.session.query(ApiToolProvider)
|
||||
.where(ApiToolProvider.tenant_id == tenant_id, ApiToolProvider.id == provider_id)
|
||||
.limit(1)
|
||||
.first()
|
||||
)
|
||||
|
||||
if api_provider is None:
|
||||
|
||||
@@ -110,7 +110,7 @@ class DatasetMultiRetrieverTool(DatasetRetrieverBaseTool):
|
||||
context_list: list[RetrievalSourceMetadata] = []
|
||||
resource_number = 1
|
||||
for segment in sorted_segments:
|
||||
dataset = db.session.get(Dataset, segment.dataset_id)
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
document_stmt = select(Document).where(
|
||||
Document.id == segment.document_id,
|
||||
Document.enabled == True,
|
||||
|
||||
@@ -205,7 +205,7 @@ class DatasetRetrieverTool(DatasetRetrieverBaseTool):
|
||||
if self.return_resource:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = db.session.get(Dataset, segment.dataset_id)
|
||||
dataset = db.session.query(Dataset).filter_by(id=segment.dataset_id).first()
|
||||
dataset_document_stmt = select(DatasetDocument).where(
|
||||
DatasetDocument.id == segment.document_id,
|
||||
DatasetDocument.enabled == True,
|
||||
|
||||
@@ -35,13 +35,15 @@ class ApiKeyAuthService:
|
||||
|
||||
@staticmethod
|
||||
def get_auth_credentials(tenant_id: str, category: str, provider: str):
|
||||
data_source_api_key_bindings = db.session.scalar(
|
||||
select(DataSourceApiKeyAuthBinding).where(
|
||||
data_source_api_key_bindings = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.where(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.category == category,
|
||||
DataSourceApiKeyAuthBinding.provider == provider,
|
||||
DataSourceApiKeyAuthBinding.disabled.is_(False),
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if not data_source_api_key_bindings:
|
||||
return None
|
||||
@@ -52,11 +54,10 @@ class ApiKeyAuthService:
|
||||
|
||||
@staticmethod
|
||||
def delete_provider_auth(tenant_id: str, binding_id: str):
|
||||
data_source_api_key_binding = db.session.scalar(
|
||||
select(DataSourceApiKeyAuthBinding).where(
|
||||
DataSourceApiKeyAuthBinding.tenant_id == tenant_id,
|
||||
DataSourceApiKeyAuthBinding.id == binding_id,
|
||||
)
|
||||
data_source_api_key_binding = (
|
||||
db.session.query(DataSourceApiKeyAuthBinding)
|
||||
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
|
||||
.first()
|
||||
)
|
||||
if data_source_api_key_binding:
|
||||
db.session.delete(data_source_api_key_binding)
|
||||
|
||||
@@ -1,6 +1,4 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
from faker import Faker
|
||||
@@ -536,283 +534,3 @@ class TestWorkspaceService:
|
||||
# Verify database state
|
||||
db_session_with_containers.refresh(tenant)
|
||||
assert tenant.id is not None
|
||||
|
||||
def test_get_tenant_info_should_raise_assertion_when_join_missing(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""TenantAccountJoin must exist; missing join should raise AssertionError."""
|
||||
fake = Faker()
|
||||
account = Account(email=fake.email(), name=fake.name(), interface_language="en-US", status="active")
|
||||
db_session_with_containers.add(account)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
tenant = Tenant(name=fake.company(), status="normal", plan="basic")
|
||||
db_session_with_containers.add(tenant)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
# No TenantAccountJoin created
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
with pytest.raises(AssertionError, match="TenantAccountJoin not found"):
|
||||
WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""replace_webapp_logo should be None when custom_config_dict does not have the key."""
|
||||
import json
|
||||
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
tenant.custom_config = json.dumps({})
|
||||
db_session_with_containers.commit()
|
||||
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
|
||||
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert result["custom_config"]["replace_webapp_logo"] is None
|
||||
|
||||
def test_get_tenant_info_should_use_files_url_for_logo_url(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""The logo URL should use dify_config.FILES_URL as the base."""
|
||||
import json
|
||||
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
tenant.custom_config = json.dumps({"replace_webapp_logo": True})
|
||||
db_session_with_containers.commit()
|
||||
|
||||
custom_base = "https://cdn.mycompany.io"
|
||||
mock_external_service_dependencies["dify_config"].FILES_URL = custom_base
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = True
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = True
|
||||
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base)
|
||||
|
||||
def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
mock_external_service_dependencies["dify_config"].EDITION = "SELF_HOSTED"
|
||||
mock_external_service_dependencies["feature_service"].get_features.return_value.can_replace_logo = False
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
with patch("services.workspace_service.current_user", account):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert "next_credit_reset_date" not in result
|
||||
assert "trial_credits" not in result
|
||||
assert "trial_credits_used" not in result
|
||||
|
||||
def test_get_tenant_info_cloud_credit_reset_date(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""next_credit_reset_date should be present in CLOUD edition."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
|
||||
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
|
||||
feature.can_replace_logo = False
|
||||
feature.next_credit_reset_date = "2025-02-01"
|
||||
feature.billing.subscription.plan = "professional"
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
with (
|
||||
patch("services.workspace_service.current_user", account),
|
||||
patch("services.credit_pool_service.CreditPoolService.get_pool", return_value=None),
|
||||
):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert result["next_credit_reset_date"] == "2025-02-01"
|
||||
|
||||
def test_get_tenant_info_cloud_paid_pool_not_full(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""trial_credits come from paid pool when plan is not sandbox and pool is not full."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
|
||||
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
|
||||
feature.can_replace_logo = False
|
||||
feature.next_credit_reset_date = "2025-02-01"
|
||||
feature.billing.subscription.plan = "professional"
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
paid_pool = MagicMock(quota_limit=1000, quota_used=200)
|
||||
|
||||
with (
|
||||
patch("services.workspace_service.current_user", account),
|
||||
patch("services.credit_pool_service.CreditPoolService.get_pool", return_value=paid_pool),
|
||||
):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 1000
|
||||
assert result["trial_credits_used"] == 200
|
||||
|
||||
def test_get_tenant_info_cloud_paid_pool_unlimited(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""quota_limit == -1 means unlimited; service should use paid pool."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
|
||||
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
|
||||
feature.can_replace_logo = False
|
||||
feature.next_credit_reset_date = "2025-02-01"
|
||||
feature.billing.subscription.plan = "professional"
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
paid_pool = MagicMock(quota_limit=-1, quota_used=999)
|
||||
|
||||
with (
|
||||
patch("services.workspace_service.current_user", account),
|
||||
patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, None]),
|
||||
):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == -1
|
||||
assert result["trial_credits_used"] == 999
|
||||
|
||||
def test_get_tenant_info_cloud_fall_back_to_trial_when_paid_full(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""When paid pool is exhausted, switch to trial pool."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
|
||||
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
|
||||
feature.can_replace_logo = False
|
||||
feature.next_credit_reset_date = "2025-02-01"
|
||||
feature.billing.subscription.plan = "professional"
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
paid_pool = MagicMock(quota_limit=500, quota_used=500)
|
||||
trial_pool = MagicMock(quota_limit=100, quota_used=10)
|
||||
|
||||
with (
|
||||
patch("services.workspace_service.current_user", account),
|
||||
patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, trial_pool]),
|
||||
):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 100
|
||||
assert result["trial_credits_used"] == 10
|
||||
|
||||
def test_get_tenant_info_cloud_fall_back_to_trial_when_paid_none(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""When paid_pool is None, fall back to trial pool."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
|
||||
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
|
||||
feature.can_replace_logo = False
|
||||
feature.next_credit_reset_date = "2025-02-01"
|
||||
feature.billing.subscription.plan = "professional"
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
trial_pool = MagicMock(quota_limit=50, quota_used=5)
|
||||
|
||||
with (
|
||||
patch("services.workspace_service.current_user", account),
|
||||
patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[None, trial_pool]),
|
||||
):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 50
|
||||
assert result["trial_credits_used"] == 5
|
||||
|
||||
def test_get_tenant_info_cloud_sandbox_uses_trial_pool(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""When plan is SANDBOX, skip paid pool and use trial pool."""
|
||||
from enums.cloud_plan import CloudPlan
|
||||
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
|
||||
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
|
||||
feature.can_replace_logo = False
|
||||
feature.next_credit_reset_date = "2025-02-01"
|
||||
feature.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
paid_pool = MagicMock(quota_limit=1000, quota_used=0)
|
||||
trial_pool = MagicMock(quota_limit=200, quota_used=20)
|
||||
|
||||
with (
|
||||
patch("services.workspace_service.current_user", account),
|
||||
patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[paid_pool, trial_pool]),
|
||||
):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 200
|
||||
assert result["trial_credits_used"] == 20
|
||||
|
||||
def test_get_tenant_info_cloud_both_pools_none(
|
||||
self, db_session_with_containers: Session, mock_external_service_dependencies
|
||||
):
|
||||
"""When both paid and trial pools are absent, trial_credits should not be set."""
|
||||
fake = Faker()
|
||||
account, tenant = self._create_test_account_and_tenant(
|
||||
db_session_with_containers, mock_external_service_dependencies
|
||||
)
|
||||
|
||||
mock_external_service_dependencies["dify_config"].EDITION = "CLOUD"
|
||||
feature = mock_external_service_dependencies["feature_service"].get_features.return_value
|
||||
feature.can_replace_logo = False
|
||||
feature.next_credit_reset_date = "2025-02-01"
|
||||
feature.billing.subscription.plan = "professional"
|
||||
mock_external_service_dependencies["tenant_service"].has_roles.return_value = False
|
||||
|
||||
with (
|
||||
patch("services.workspace_service.current_user", account),
|
||||
patch("services.credit_pool_service.CreditPoolService.get_pool", side_effect=[None, None]),
|
||||
):
|
||||
result = WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
assert result is not None
|
||||
assert "trial_credits" not in result
|
||||
assert "trial_credits_used" not in result
|
||||
|
||||
@@ -64,18 +64,18 @@ class TestGetActiveAccount:
|
||||
def test_returns_active_account(self, mock_db):
|
||||
mock_account = MagicMock()
|
||||
mock_account.status = "active"
|
||||
mock_db.session.scalar.return_value = mock_account
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
|
||||
|
||||
result = _get_active_account("user@example.com")
|
||||
|
||||
assert result is mock_account
|
||||
mock_db.session.scalar.assert_called_once()
|
||||
mock_db.session.query.return_value.filter_by.assert_called_once_with(email="user@example.com")
|
||||
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_returns_none_for_inactive_account(self, mock_db):
|
||||
mock_account = MagicMock()
|
||||
mock_account.status = "banned"
|
||||
mock_db.session.scalar.return_value = mock_account
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_account
|
||||
|
||||
result = _get_active_account("banned@example.com")
|
||||
|
||||
@@ -83,7 +83,7 @@ class TestGetActiveAccount:
|
||||
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_returns_none_for_nonexistent_email(self, mock_db):
|
||||
mock_db.session.scalar.return_value = None
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
result = _get_active_account("missing@example.com")
|
||||
|
||||
@@ -205,7 +205,7 @@ class TestEnterpriseAppDSLExport:
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_export_success_returns_200(self, mock_db, mock_dsl_cls, api_instance, app: Flask):
|
||||
mock_app = MagicMock()
|
||||
mock_db.session.get.return_value = mock_app
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app
|
||||
mock_dsl_cls.export_dsl.return_value = "version: 0.6.0\nkind: app\n"
|
||||
|
||||
unwrapped = inspect.unwrap(api_instance.get)
|
||||
@@ -221,7 +221,7 @@ class TestEnterpriseAppDSLExport:
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_export_with_secret(self, mock_db, mock_dsl_cls, api_instance, app: Flask):
|
||||
mock_app = MagicMock()
|
||||
mock_db.session.get.return_value = mock_app
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = mock_app
|
||||
mock_dsl_cls.export_dsl.return_value = "yaml-data"
|
||||
|
||||
unwrapped = inspect.unwrap(api_instance.get)
|
||||
@@ -234,7 +234,7 @@ class TestEnterpriseAppDSLExport:
|
||||
|
||||
@patch("controllers.inner_api.app.dsl.db")
|
||||
def test_export_app_not_found_returns_404(self, mock_db, api_instance, app: Flask):
|
||||
mock_db.session.get.return_value = None
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
unwrapped = inspect.unwrap(api_instance.get)
|
||||
with app.test_request_context("?include_secret=false"):
|
||||
|
||||
@@ -621,7 +621,7 @@ class TestConvertDatasetRetrieverTool:
|
||||
class TestBaseAgentRunnerInit:
|
||||
def test_init_sets_stream_tool_call_and_files(self, mocker):
|
||||
session = mocker.MagicMock()
|
||||
session.scalar.return_value = 2
|
||||
session.query.return_value.where.return_value.count.return_value = 2
|
||||
mocker.patch.object(module.db, "session", session)
|
||||
|
||||
mocker.patch.object(BaseAgentRunner, "organize_agent_history", return_value=[])
|
||||
|
||||
@@ -114,9 +114,13 @@ class TestOnToolEnd:
|
||||
document = mocker.Mock()
|
||||
document.metadata = {"document_id": "doc-1", "doc_id": "node-1"}
|
||||
|
||||
mock_query = mocker.Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
|
||||
handler.on_tool_end([document])
|
||||
|
||||
mock_db.session.execute.assert_called_once()
|
||||
mock_query.update.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_on_tool_end_non_parent_child_index(self, handler, mocker):
|
||||
@@ -134,9 +138,13 @@ class TestOnToolEnd:
|
||||
"dataset_id": "dataset-1",
|
||||
}
|
||||
|
||||
mock_query = mocker.Mock()
|
||||
mock_db.session.query.return_value = mock_query
|
||||
mock_query.where.return_value = mock_query
|
||||
|
||||
handler.on_tool_end([document])
|
||||
|
||||
mock_db.session.execute.assert_called_once()
|
||||
mock_query.update.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_on_tool_end_empty_documents(self, handler):
|
||||
|
||||
@@ -38,13 +38,13 @@ class TestObfuscatedToken:
|
||||
|
||||
|
||||
class TestEncryptToken:
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_successful_encryption(self, mock_encrypt, mock_query):
|
||||
"""Test successful token encryption"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_data"
|
||||
|
||||
result = encrypt_token("tenant-123", "test_token")
|
||||
@@ -52,10 +52,10 @@ class TestEncryptToken:
|
||||
assert result == base64.b64encode(b"encrypted_data").decode()
|
||||
mock_encrypt.assert_called_with("test_token", "mock_public_key")
|
||||
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("models.engine.db.session.query")
|
||||
def test_tenant_not_found(self, mock_query):
|
||||
"""Test error when tenant doesn't exist"""
|
||||
mock_query.return_value = None
|
||||
mock_query.return_value.where.return_value.first.return_value = None
|
||||
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
encrypt_token("invalid-tenant", "test_token")
|
||||
@@ -119,7 +119,7 @@ class TestGetDecryptDecoding:
|
||||
|
||||
|
||||
class TestEncryptDecryptIntegration:
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
@patch("libs.rsa.decrypt")
|
||||
def test_should_encrypt_and_decrypt_consistently(self, mock_decrypt, mock_encrypt, mock_query):
|
||||
@@ -127,7 +127,7 @@ class TestEncryptDecryptIntegration:
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
|
||||
# Setup mock encryption/decryption
|
||||
original_token = "test_token_123"
|
||||
@@ -146,14 +146,14 @@ class TestEncryptDecryptIntegration:
|
||||
class TestSecurity:
|
||||
"""Critical security tests for encryption system"""
|
||||
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_cross_tenant_isolation(self, mock_encrypt, mock_query):
|
||||
"""Ensure tokens encrypted for one tenant cannot be used by another"""
|
||||
# Setup mock tenant
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "tenant1_public_key"
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_for_tenant1"
|
||||
|
||||
# Encrypt token for tenant1
|
||||
@@ -181,12 +181,12 @@ class TestSecurity:
|
||||
with pytest.raises(Exception, match="Decryption error"):
|
||||
decrypt_token("tenant-123", tampered)
|
||||
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_encryption_randomness(self, mock_encrypt, mock_query):
|
||||
"""Ensure same plaintext produces different ciphertext"""
|
||||
mock_tenant = MagicMock(encrypt_public_key="key")
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
|
||||
# Different outputs for same input
|
||||
mock_encrypt.side_effect = [b"enc1", b"enc2", b"enc3"]
|
||||
@@ -205,13 +205,13 @@ class TestEdgeCases:
|
||||
# Test empty string (which is a valid str type)
|
||||
assert obfuscated_token("") == ""
|
||||
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_empty_token_encryption(self, mock_encrypt, mock_query):
|
||||
"""Test encryption of empty token"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_empty"
|
||||
|
||||
result = encrypt_token("tenant-123", "")
|
||||
@@ -219,13 +219,13 @@ class TestEdgeCases:
|
||||
assert result == base64.b64encode(b"encrypted_empty").decode()
|
||||
mock_encrypt.assert_called_with("", "mock_public_key")
|
||||
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_special_characters_in_token(self, mock_encrypt, mock_query):
|
||||
"""Test tokens containing special/unicode characters"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
mock_encrypt.return_value = b"encrypted_special"
|
||||
|
||||
# Test various special characters
|
||||
@@ -242,13 +242,13 @@ class TestEdgeCases:
|
||||
assert result == base64.b64encode(b"encrypted_special").decode()
|
||||
mock_encrypt.assert_called_with(token, "mock_public_key")
|
||||
|
||||
@patch("extensions.ext_database.db.session.get")
|
||||
@patch("models.engine.db.session.query")
|
||||
@patch("libs.rsa.encrypt")
|
||||
def test_should_handle_rsa_size_limits(self, mock_encrypt, mock_query):
|
||||
"""Test behavior when token exceeds RSA encryption limits"""
|
||||
mock_tenant = MagicMock()
|
||||
mock_tenant.encrypt_public_key = "mock_public_key"
|
||||
mock_query.return_value = mock_tenant
|
||||
mock_query.return_value.where.return_value.first.return_value = mock_tenant
|
||||
|
||||
# RSA 2048-bit can only encrypt ~245 bytes
|
||||
# The actual limit depends on padding scheme
|
||||
|
||||
@@ -314,8 +314,8 @@ class TestLLMGenerator:
|
||||
assert "An unexpected error occurred" in result["error"]
|
||||
|
||||
def test_instruction_modify_legacy_no_last_run(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
|
||||
# Mock __instruction_modify_common call via invoke_llm
|
||||
mock_response = MagicMock()
|
||||
@@ -328,12 +328,12 @@ class TestLLMGenerator:
|
||||
assert result == {"modified": "prompt"}
|
||||
|
||||
def test_instruction_modify_legacy_with_last_run(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
last_run = MagicMock()
|
||||
last_run.query = "q"
|
||||
last_run.answer = "a"
|
||||
last_run.error = "e"
|
||||
mock_scalar.return_value = last_run
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = last_run
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"modified": "prompt"}'
|
||||
@@ -483,8 +483,8 @@ class TestLLMGenerator:
|
||||
|
||||
def test_instruction_modify_common_placeholders(self, mock_model_instance, model_config_entity):
|
||||
# Testing placeholders replacement via instruction_modify_legacy for convenience
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = '{"ok": true}'
|
||||
@@ -504,8 +504,8 @@ class TestLLMGenerator:
|
||||
assert "current_val" in user_msg_dict["instruction"]
|
||||
|
||||
def test_instruction_modify_common_no_braces(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "No braces here"
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
@@ -516,8 +516,8 @@ class TestLLMGenerator:
|
||||
assert "Could not find a valid JSON object" in result["error"]
|
||||
|
||||
def test_instruction_modify_common_not_dict(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "[1, 2, 3]"
|
||||
mock_model_instance.invoke_llm.return_value = mock_response
|
||||
@@ -556,8 +556,8 @@ class TestLLMGenerator:
|
||||
)
|
||||
|
||||
def test_instruction_modify_common_invoke_error(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
mock_model_instance.invoke_llm.side_effect = InvokeError("Invoke Failed")
|
||||
|
||||
result = LLMGenerator.instruction_modify_legacy(
|
||||
@@ -566,8 +566,8 @@ class TestLLMGenerator:
|
||||
assert "Failed to generate code" in result["error"]
|
||||
|
||||
def test_instruction_modify_common_exception(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
mock_model_instance.invoke_llm.side_effect = Exception("Random error")
|
||||
|
||||
result = LLMGenerator.instruction_modify_legacy(
|
||||
@@ -576,8 +576,8 @@ class TestLLMGenerator:
|
||||
assert "An unexpected error occurred" in result["error"]
|
||||
|
||||
def test_instruction_modify_common_json_error(self, mock_model_instance, model_config_entity):
|
||||
with patch("extensions.ext_database.db.session.scalar") as mock_scalar:
|
||||
mock_scalar.return_value = None
|
||||
with patch("extensions.ext_database.db.session.query") as mock_query:
|
||||
mock_query.return_value.where.return_value.order_by.return_value.first.return_value = None
|
||||
|
||||
mock_response = MagicMock()
|
||||
mock_response.message.get_text_content.return_value = "No JSON here"
|
||||
|
||||
@@ -332,21 +332,27 @@ class TestPluginAppBackwardsInvocation:
|
||||
PluginAppBackwardsInvocation._get_user("uid")
|
||||
|
||||
def test_get_app_returns_app(self, mocker):
|
||||
query_chain = MagicMock()
|
||||
query_chain.where.return_value = query_chain
|
||||
app_obj = MagicMock(id="app")
|
||||
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=app_obj)))
|
||||
query_chain.first.return_value = app_obj
|
||||
db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain)))
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", db)
|
||||
|
||||
assert PluginAppBackwardsInvocation._get_app("app", "tenant") is app_obj
|
||||
|
||||
def test_get_app_raises_when_missing(self, mocker):
|
||||
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(return_value=None)))
|
||||
query_chain = MagicMock()
|
||||
query_chain.where.return_value = query_chain
|
||||
query_chain.first.return_value = None
|
||||
db = SimpleNamespace(session=MagicMock(query=MagicMock(return_value=query_chain)))
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", db)
|
||||
|
||||
with pytest.raises(ValueError, match="app not found"):
|
||||
PluginAppBackwardsInvocation._get_app("app", "tenant")
|
||||
|
||||
def test_get_app_raises_when_query_fails(self, mocker):
|
||||
db = SimpleNamespace(session=MagicMock(scalar=MagicMock(side_effect=RuntimeError("db down"))))
|
||||
db = SimpleNamespace(session=MagicMock(query=MagicMock(side_effect=RuntimeError("db down"))))
|
||||
mocker.patch("core.plugin.backwards_invocation.app.db", db)
|
||||
|
||||
with pytest.raises(ValueError, match="app not found"):
|
||||
|
||||
@@ -38,9 +38,11 @@ def test_tool_label_manager_filter_tool_labels():
|
||||
def test_tool_label_manager_update_tool_labels_db():
|
||||
controller = _api_controller("api-1")
|
||||
with patch("core.tools.tool_label_manager.db") as mock_db:
|
||||
delete_query = mock_db.session.query.return_value.where.return_value
|
||||
delete_query.delete.return_value = None
|
||||
ToolLabelManager.update_tool_labels(controller, ["search", "search", "invalid"])
|
||||
|
||||
mock_db.session.execute.assert_called_once()
|
||||
delete_query.delete.assert_called_once()
|
||||
# only one valid unique label should be inserted.
|
||||
assert mock_db.session.add.call_count == 1
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@@ -220,7 +220,9 @@ def test_get_tool_runtime_builtin_with_credentials_decrypts_and_forks():
|
||||
with patch.object(ToolManager, "get_builtin_provider", return_value=controller):
|
||||
with patch("core.helper.credential_utils.check_credential_policy_compliance"):
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.scalar.return_value = builtin_provider
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = (
|
||||
builtin_provider
|
||||
)
|
||||
encrypter = Mock()
|
||||
encrypter.decrypt.return_value = {"api_key": "secret"}
|
||||
cache = Mock()
|
||||
@@ -272,7 +274,7 @@ def test_get_tool_runtime_builtin_refreshes_expired_oauth_credentials(
|
||||
)
|
||||
refreshed = SimpleNamespace(credentials={"token": "new"}, expires_at=123456)
|
||||
|
||||
mock_db.session.scalar.return_value = builtin_provider
|
||||
mock_db.session.query.return_value.where.return_value.order_by.return_value.first.return_value = builtin_provider
|
||||
encrypter = Mock()
|
||||
encrypter.decrypt.return_value = {"token": "old"}
|
||||
encrypter.encrypt.return_value = {"token": "encrypted"}
|
||||
@@ -696,10 +698,12 @@ def test_get_api_provider_controller_returns_controller_and_credentials():
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="disclaimer",
|
||||
)
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.first.return_value = provider
|
||||
controller = Mock()
|
||||
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.scalar.return_value = provider
|
||||
mock_db.session.query.return_value = db_query
|
||||
with patch(
|
||||
"core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller
|
||||
) as mock_from_db:
|
||||
@@ -726,10 +730,12 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels():
|
||||
privacy_policy="privacy",
|
||||
custom_disclaimer="disclaimer",
|
||||
)
|
||||
db_query = Mock()
|
||||
db_query.where.return_value.first.return_value = provider
|
||||
controller = Mock()
|
||||
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.scalar.return_value = provider
|
||||
mock_db.session.query.return_value = db_query
|
||||
with patch("core.tools.tool_manager.ApiToolProviderController.from_db", return_value=controller):
|
||||
encrypter = Mock()
|
||||
encrypter.decrypt.return_value = {"api_key_value": "secret"}
|
||||
@@ -744,7 +750,7 @@ def test_user_get_api_provider_masks_credentials_and_adds_labels():
|
||||
|
||||
def test_get_api_provider_controller_not_found_raises():
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.scalar.return_value = None
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
with pytest.raises(ToolProviderNotFoundError, match="api provider missing not found"):
|
||||
ToolManager.get_api_provider_controller("tenant-1", "missing")
|
||||
|
||||
@@ -803,14 +809,14 @@ def test_generate_tool_icon_urls_for_workflow_and_api():
|
||||
workflow_provider = SimpleNamespace(icon='{"background": "#222", "content": "W"}')
|
||||
api_provider = SimpleNamespace(icon='{"background": "#333", "content": "A"}')
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.scalar.side_effect = [workflow_provider, api_provider]
|
||||
mock_db.session.query.return_value.where.return_value.first.side_effect = [workflow_provider, api_provider]
|
||||
assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "wf-1") == {"background": "#222", "content": "W"}
|
||||
assert ToolManager.generate_api_tool_icon_url("tenant-1", "api-1") == {"background": "#333", "content": "A"}
|
||||
|
||||
|
||||
def test_generate_tool_icon_urls_missing_workflow_and_api_use_default():
|
||||
with patch("core.tools.tool_manager.db") as mock_db:
|
||||
mock_db.session.scalar.return_value = None
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
assert ToolManager.generate_workflow_tool_icon_url("tenant-1", "missing")["background"] == "#252525"
|
||||
assert ToolManager.generate_api_tool_icon_url("tenant-1", "missing")["background"] == "#252525"
|
||||
|
||||
|
||||
@@ -263,7 +263,7 @@ def test_single_dataset_retriever_non_economy_run_sorts_context_and_resources():
|
||||
)
|
||||
db_session = Mock()
|
||||
db_session.scalar.side_effect = [dataset, lookup_doc_low, lookup_doc_high]
|
||||
db_session.get.return_value = dataset
|
||||
db_session.query.return_value.filter_by.return_value.first.return_value = dataset
|
||||
|
||||
tool = SingleDatasetRetrieverTool(
|
||||
tenant_id="tenant-1",
|
||||
@@ -444,7 +444,7 @@ def test_multi_dataset_retriever_run_orders_segments_and_returns_resources():
|
||||
)
|
||||
db_session = Mock()
|
||||
db_session.scalars.return_value.all.return_value = [segment_for_node_2, segment_for_node_1]
|
||||
db_session.get.side_effect = [
|
||||
db_session.query.return_value.filter_by.return_value.first.side_effect = [
|
||||
SimpleNamespace(id="dataset-2", name="Dataset Two"),
|
||||
SimpleNamespace(id="dataset-1", name="Dataset One"),
|
||||
]
|
||||
|
||||
558
api/tests/unit_tests/services/test_metadata_service.py
Normal file
558
api/tests/unit_tests/services/test_metadata_service.py
Normal file
@@ -0,0 +1,558 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import UTC, datetime
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from core.rag.index_processor.constant.built_in_field import BuiltInField, MetadataDataSource
|
||||
from models.dataset import Dataset
|
||||
from services.entities.knowledge_entities.knowledge_entities import (
|
||||
DocumentMetadataOperation,
|
||||
MetadataArgs,
|
||||
MetadataDetail,
|
||||
MetadataOperationData,
|
||||
)
|
||||
from services.metadata_service import MetadataService
|
||||
|
||||
|
||||
@dataclass
|
||||
class _DocumentStub:
|
||||
id: str
|
||||
name: str
|
||||
uploader: str
|
||||
upload_date: datetime
|
||||
last_update_date: datetime
|
||||
data_source_type: str
|
||||
doc_metadata: dict[str, object] | None
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db(mocker: MockerFixture) -> MagicMock:
|
||||
mocked_db = mocker.patch("services.metadata_service.db")
|
||||
mocked_db.session = MagicMock()
|
||||
return mocked_db
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_redis_client(mocker: MockerFixture) -> MagicMock:
|
||||
return mocker.patch("services.metadata_service.redis_client")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_account(mocker: MockerFixture) -> MagicMock:
|
||||
mock_user = SimpleNamespace(id="user-1")
|
||||
return mocker.patch("services.metadata_service.current_account_with_tenant", return_value=(mock_user, "tenant-1"))
|
||||
|
||||
|
||||
def _build_document(document_id: str, doc_metadata: dict[str, object] | None = None) -> _DocumentStub:
|
||||
now = datetime(2025, 1, 1, 10, 30, tzinfo=UTC)
|
||||
return _DocumentStub(
|
||||
id=document_id,
|
||||
name=f"doc-{document_id}",
|
||||
uploader="qa@example.com",
|
||||
upload_date=now,
|
||||
last_update_date=now,
|
||||
data_source_type="upload_file",
|
||||
doc_metadata=doc_metadata,
|
||||
)
|
||||
|
||||
|
||||
def _dataset(**kwargs: Any) -> Dataset:
|
||||
return cast(Dataset, SimpleNamespace(**kwargs))
|
||||
|
||||
|
||||
def test_create_metadata_should_raise_value_error_when_name_exceeds_limit() -> None:
|
||||
# Arrange
|
||||
metadata_args = MetadataArgs(type="string", name="x" * 256)
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="cannot exceed 255"):
|
||||
MetadataService.create_metadata("dataset-1", metadata_args)
|
||||
|
||||
|
||||
def test_create_metadata_should_raise_value_error_when_metadata_name_already_exists(
|
||||
mock_db: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
metadata_args = MetadataArgs(type="string", name="priority")
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
MetadataService.create_metadata("dataset-1", metadata_args)
|
||||
|
||||
# Assert
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_create_metadata_should_raise_value_error_when_name_collides_with_builtin(
|
||||
mock_db: MagicMock, mock_current_account: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
metadata_args = MetadataArgs(type="string", name=BuiltInField.document_name)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Built-in fields"):
|
||||
MetadataService.create_metadata("dataset-1", metadata_args)
|
||||
|
||||
|
||||
def test_create_metadata_should_persist_metadata_when_input_is_valid(
|
||||
mock_db: MagicMock, mock_current_account: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
metadata_args = MetadataArgs(type="number", name="score")
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
# Act
|
||||
result = MetadataService.create_metadata("dataset-1", metadata_args)
|
||||
|
||||
# Assert
|
||||
assert result.tenant_id == "tenant-1"
|
||||
assert result.dataset_id == "dataset-1"
|
||||
assert result.type == "number"
|
||||
assert result.name == "score"
|
||||
assert result.created_by == "user-1"
|
||||
mock_db.session.add.assert_called_once_with(result)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_metadata_name_should_raise_value_error_when_name_exceeds_limit() -> None:
|
||||
# Arrange
|
||||
too_long_name = "x" * 256
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="cannot exceed 255"):
|
||||
MetadataService.update_metadata_name("dataset-1", "metadata-1", too_long_name)
|
||||
|
||||
|
||||
def test_update_metadata_name_should_raise_value_error_when_duplicate_name_exists(
|
||||
mock_db: MagicMock, mock_current_account: MagicMock
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
MetadataService.update_metadata_name("dataset-1", "metadata-1", "duplicate")
|
||||
|
||||
# Assert
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_metadata_name_should_raise_value_error_when_name_collides_with_builtin(
|
||||
mock_db: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Built-in fields"):
|
||||
MetadataService.update_metadata_name("dataset-1", "metadata-1", BuiltInField.source)
|
||||
|
||||
# Assert
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_metadata_name_should_update_bound_documents_and_return_metadata(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
fixed_now = datetime(2025, 2, 1, 0, 0, tzinfo=UTC)
|
||||
mocker.patch("services.metadata_service.naive_utc_now", return_value=fixed_now)
|
||||
|
||||
metadata = SimpleNamespace(id="metadata-1", name="old_name", updated_by=None, updated_at=None)
|
||||
bindings = [SimpleNamespace(document_id="doc-1"), SimpleNamespace(document_id="doc-2")]
|
||||
query_duplicate = MagicMock()
|
||||
query_duplicate.filter_by.return_value.first.return_value = None
|
||||
query_metadata = MagicMock()
|
||||
query_metadata.filter_by.return_value.first.return_value = metadata
|
||||
query_bindings = MagicMock()
|
||||
query_bindings.filter_by.return_value.all.return_value = bindings
|
||||
mock_db.session.query.side_effect = [query_duplicate, query_metadata, query_bindings]
|
||||
|
||||
doc_1 = _build_document("1", {"old_name": "value", "other": "keep"})
|
||||
doc_2 = _build_document("2", None)
|
||||
mock_get_documents = mocker.patch("services.metadata_service.DocumentService.get_document_by_ids")
|
||||
mock_get_documents.return_value = [doc_1, doc_2]
|
||||
|
||||
# Act
|
||||
result = MetadataService.update_metadata_name("dataset-1", "metadata-1", "new_name")
|
||||
|
||||
# Assert
|
||||
assert result is metadata
|
||||
assert metadata.name == "new_name"
|
||||
assert metadata.updated_by == "user-1"
|
||||
assert metadata.updated_at == fixed_now
|
||||
assert doc_1.doc_metadata == {"other": "keep", "new_name": "value"}
|
||||
assert doc_2.doc_metadata == {"new_name": None}
|
||||
mock_get_documents.assert_called_once_with(["doc-1", "doc-2"])
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_metadata_name_should_return_none_when_metadata_does_not_exist(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
mock_logger = mocker.patch("services.metadata_service.logger")
|
||||
|
||||
query_duplicate = MagicMock()
|
||||
query_duplicate.filter_by.return_value.first.return_value = None
|
||||
query_metadata = MagicMock()
|
||||
query_metadata.filter_by.return_value.first.return_value = None
|
||||
mock_db.session.query.side_effect = [query_duplicate, query_metadata]
|
||||
|
||||
# Act
|
||||
result = MetadataService.update_metadata_name("dataset-1", "missing-id", "new_name")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_logger.exception.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_delete_metadata_should_remove_metadata_and_related_document_fields(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
metadata = SimpleNamespace(id="metadata-1", name="obsolete")
|
||||
bindings = [SimpleNamespace(document_id="doc-1")]
|
||||
query_metadata = MagicMock()
|
||||
query_metadata.filter_by.return_value.first.return_value = metadata
|
||||
query_bindings = MagicMock()
|
||||
query_bindings.filter_by.return_value.all.return_value = bindings
|
||||
mock_db.session.query.side_effect = [query_metadata, query_bindings]
|
||||
|
||||
document = _build_document("1", {"obsolete": "legacy", "remaining": "value"})
|
||||
mocker.patch("services.metadata_service.DocumentService.get_document_by_ids", return_value=[document])
|
||||
|
||||
# Act
|
||||
result = MetadataService.delete_metadata("dataset-1", "metadata-1")
|
||||
|
||||
# Assert
|
||||
assert result is metadata
|
||||
assert document.doc_metadata == {"remaining": "value"}
|
||||
mock_db.session.delete.assert_called_once_with(metadata)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
|
||||
|
||||
def test_delete_metadata_should_return_none_when_metadata_is_missing(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = None
|
||||
mock_logger = mocker.patch("services.metadata_service.logger")
|
||||
|
||||
# Act
|
||||
result = MetadataService.delete_metadata("dataset-1", "missing-id")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
mock_logger.exception.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
|
||||
|
||||
def test_get_built_in_fields_should_return_all_expected_fields() -> None:
|
||||
# Arrange
|
||||
expected_names = {
|
||||
BuiltInField.document_name,
|
||||
BuiltInField.uploader,
|
||||
BuiltInField.upload_date,
|
||||
BuiltInField.last_update_date,
|
||||
BuiltInField.source,
|
||||
}
|
||||
|
||||
# Act
|
||||
result = MetadataService.get_built_in_fields()
|
||||
|
||||
# Assert
|
||||
assert {item["name"] for item in result} == expected_names
|
||||
assert [item["type"] for item in result] == ["string", "string", "time", "time", "string"]
|
||||
|
||||
|
||||
def test_enable_built_in_field_should_return_immediately_when_already_enabled(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
|
||||
get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id")
|
||||
|
||||
# Act
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
|
||||
# Assert
|
||||
get_docs.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_enable_built_in_field_should_populate_documents_and_enable_flag(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
|
||||
doc_1 = _build_document("1", {"custom": "value"})
|
||||
doc_2 = _build_document("2", None)
|
||||
mocker.patch(
|
||||
"services.metadata_service.DocumentService.get_working_documents_by_dataset_id",
|
||||
return_value=[doc_1, doc_2],
|
||||
)
|
||||
|
||||
# Act
|
||||
MetadataService.enable_built_in_field(dataset)
|
||||
|
||||
# Assert
|
||||
assert dataset.built_in_field_enabled is True
|
||||
assert doc_1.doc_metadata is not None
|
||||
assert doc_1.doc_metadata[BuiltInField.document_name] == "doc-1"
|
||||
assert doc_1.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file
|
||||
assert doc_2.doc_metadata is not None
|
||||
assert doc_2.doc_metadata[BuiltInField.uploader] == "qa@example.com"
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
|
||||
|
||||
def test_disable_built_in_field_should_return_immediately_when_already_disabled(
|
||||
mock_db: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
|
||||
get_docs = mocker.patch("services.metadata_service.DocumentService.get_working_documents_by_dataset_id")
|
||||
|
||||
# Act
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
|
||||
# Assert
|
||||
get_docs.assert_not_called()
|
||||
mock_db.session.commit.assert_not_called()
|
||||
|
||||
|
||||
def test_disable_built_in_field_should_remove_builtin_keys_and_disable_flag(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
|
||||
document = _build_document(
|
||||
"1",
|
||||
{
|
||||
BuiltInField.document_name: "doc",
|
||||
BuiltInField.uploader: "user",
|
||||
BuiltInField.upload_date: 1.0,
|
||||
BuiltInField.last_update_date: 2.0,
|
||||
BuiltInField.source: MetadataDataSource.upload_file,
|
||||
"custom": "keep",
|
||||
},
|
||||
)
|
||||
mocker.patch(
|
||||
"services.metadata_service.DocumentService.get_working_documents_by_dataset_id",
|
||||
return_value=[document],
|
||||
)
|
||||
|
||||
# Act
|
||||
MetadataService.disable_built_in_field(dataset)
|
||||
|
||||
# Assert
|
||||
assert dataset.built_in_field_enabled is False
|
||||
assert document.doc_metadata == {"custom": "keep"}
|
||||
mock_db.session.commit.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("dataset_metadata_lock_dataset-1")
|
||||
|
||||
|
||||
def test_update_documents_metadata_should_replace_metadata_and_create_bindings_on_full_update(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
|
||||
document = _build_document("1", {"legacy": "value"})
|
||||
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document)
|
||||
delete_chain = mock_db.session.query.return_value.filter_by.return_value
|
||||
delete_chain.delete.return_value = 1
|
||||
operation = DocumentMetadataOperation(
|
||||
document_id="1",
|
||||
metadata_list=[MetadataDetail(id="meta-1", name="priority", value="high")],
|
||||
partial_update=False,
|
||||
)
|
||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
# Act
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
# Assert
|
||||
assert document.doc_metadata == {"priority": "high"}
|
||||
delete_chain.delete.assert_called_once()
|
||||
assert mock_db.session.commit.call_count == 1
|
||||
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1")
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_documents_metadata_should_skip_existing_binding_and_preserve_existing_fields_on_partial_update(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mock_current_account: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=True)
|
||||
document = _build_document("1", {"existing": "value"})
|
||||
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=document)
|
||||
mock_db.session.query.return_value.filter_by.return_value.first.return_value = object()
|
||||
operation = DocumentMetadataOperation(
|
||||
document_id="1",
|
||||
metadata_list=[MetadataDetail(id="meta-1", name="new_key", value="new_value")],
|
||||
partial_update=True,
|
||||
)
|
||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
# Act
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
# Assert
|
||||
assert document.doc_metadata is not None
|
||||
assert document.doc_metadata["existing"] == "value"
|
||||
assert document.doc_metadata["new_key"] == "new_value"
|
||||
assert document.doc_metadata[BuiltInField.source] == MetadataDataSource.upload_file
|
||||
assert mock_db.session.commit.call_count == 1
|
||||
assert mock_db.session.add.call_count == 1
|
||||
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_1")
|
||||
mock_current_account.assert_called_once()
|
||||
|
||||
|
||||
def test_update_documents_metadata_should_raise_and_rollback_when_document_not_found(
|
||||
mock_db: MagicMock,
|
||||
mock_redis_client: MagicMock,
|
||||
mocker: MockerFixture,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False)
|
||||
mocker.patch("services.metadata_service.DocumentService.get_document", return_value=None)
|
||||
operation = DocumentMetadataOperation(document_id="404", metadata_list=[], partial_update=True)
|
||||
metadata_args = MetadataOperationData(operation_data=[operation])
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="Document not found"):
|
||||
MetadataService.update_documents_metadata(dataset, metadata_args)
|
||||
|
||||
# Assert
|
||||
mock_db.session.rollback.assert_called_once()
|
||||
mock_redis_client.delete.assert_called_once_with("document_metadata_lock_404")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("dataset_id", "document_id", "expected_key"),
|
||||
[
|
||||
("dataset-1", None, "dataset_metadata_lock_dataset-1"),
|
||||
(None, "doc-1", "document_metadata_lock_doc-1"),
|
||||
],
|
||||
)
|
||||
def test_knowledge_base_metadata_lock_check_should_set_lock_when_not_already_locked(
|
||||
dataset_id: str | None,
|
||||
document_id: str | None,
|
||||
expected_key: str,
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = None
|
||||
|
||||
# Act
|
||||
MetadataService.knowledge_base_metadata_lock_check(dataset_id, document_id)
|
||||
|
||||
# Assert
|
||||
mock_redis_client.set.assert_called_once_with(expected_key, 1, ex=3600)
|
||||
|
||||
|
||||
def test_knowledge_base_metadata_lock_check_should_raise_when_dataset_lock_exists(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = 1
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="knowledge base metadata operation is running"):
|
||||
MetadataService.knowledge_base_metadata_lock_check("dataset-1", None)
|
||||
|
||||
|
||||
def test_knowledge_base_metadata_lock_check_should_raise_when_document_lock_exists(
|
||||
mock_redis_client: MagicMock,
|
||||
) -> None:
|
||||
# Arrange
|
||||
mock_redis_client.get.return_value = 1
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(ValueError, match="document metadata operation is running"):
|
||||
MetadataService.knowledge_base_metadata_lock_check(None, "doc-1")
|
||||
|
||||
|
||||
def test_get_dataset_metadatas_should_exclude_builtin_and_include_binding_counts(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
dataset = _dataset(
|
||||
id="dataset-1",
|
||||
built_in_field_enabled=True,
|
||||
doc_metadata=[
|
||||
{"id": "meta-1", "name": "priority", "type": "string"},
|
||||
{"id": "built-in", "name": "ignored", "type": "string"},
|
||||
{"id": "meta-2", "name": "score", "type": "number"},
|
||||
],
|
||||
)
|
||||
count_chain = mock_db.session.query.return_value.filter_by.return_value
|
||||
count_chain.count.side_effect = [3, 1]
|
||||
|
||||
# Act
|
||||
result = MetadataService.get_dataset_metadatas(dataset)
|
||||
|
||||
# Assert
|
||||
assert result["built_in_field_enabled"] is True
|
||||
assert result["doc_metadata"] == [
|
||||
{"id": "meta-1", "name": "priority", "type": "string", "count": 3},
|
||||
{"id": "meta-2", "name": "score", "type": "number", "count": 1},
|
||||
]
|
||||
|
||||
|
||||
def test_get_dataset_metadatas_should_return_empty_list_when_no_metadata(mock_db: MagicMock) -> None:
|
||||
# Arrange
|
||||
dataset = _dataset(id="dataset-1", built_in_field_enabled=False, doc_metadata=None)
|
||||
|
||||
# Act
|
||||
result = MetadataService.get_dataset_metadatas(dataset)
|
||||
|
||||
# Assert
|
||||
assert result == {"doc_metadata": [], "built_in_field_enabled": False}
|
||||
mock_db.session.query.assert_not_called()
|
||||
1336
api/tests/unit_tests/services/test_tag_service.py
Normal file
1336
api/tests/unit_tests/services/test_tag_service.py
Normal file
File diff suppressed because it is too large
Load Diff
576
api/tests/unit_tests/services/test_workspace_service.py
Normal file
576
api/tests/unit_tests/services/test_workspace_service.py
Normal file
@@ -0,0 +1,576 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from types import SimpleNamespace
|
||||
from typing import Any, cast
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from pytest_mock import MockerFixture
|
||||
|
||||
from models.account import Tenant
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Constants used throughout the tests
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
TENANT_ID = "tenant-abc"
|
||||
ACCOUNT_ID = "account-xyz"
|
||||
FILES_BASE_URL = "https://files.example.com"
|
||||
|
||||
DB_PATH = "services.workspace_service.db"
|
||||
FEATURE_SERVICE_PATH = "services.workspace_service.FeatureService.get_features"
|
||||
TENANT_SERVICE_PATH = "services.workspace_service.TenantService.has_roles"
|
||||
DIFY_CONFIG_PATH = "services.workspace_service.dify_config"
|
||||
CURRENT_USER_PATH = "services.workspace_service.current_user"
|
||||
CREDIT_POOL_SERVICE_PATH = "services.credit_pool_service.CreditPoolService.get_pool"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers / factories
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _make_tenant(
|
||||
tenant_id: str = TENANT_ID,
|
||||
name: str = "My Workspace",
|
||||
plan: str = "sandbox",
|
||||
status: str = "active",
|
||||
custom_config: dict | None = None,
|
||||
) -> Tenant:
|
||||
"""Create a minimal Tenant-like namespace."""
|
||||
return cast(
|
||||
Tenant,
|
||||
SimpleNamespace(
|
||||
id=tenant_id,
|
||||
name=name,
|
||||
plan=plan,
|
||||
status=status,
|
||||
created_at="2024-01-01T00:00:00Z",
|
||||
custom_config_dict=custom_config or {},
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
def _make_feature(
|
||||
can_replace_logo: bool = False,
|
||||
next_credit_reset_date: str | None = None,
|
||||
billing_plan: str = "sandbox",
|
||||
) -> MagicMock:
|
||||
"""Create a feature namespace matching what FeatureService.get_features returns."""
|
||||
feature = MagicMock()
|
||||
feature.can_replace_logo = can_replace_logo
|
||||
feature.next_credit_reset_date = next_credit_reset_date
|
||||
feature.billing.subscription.plan = billing_plan
|
||||
return feature
|
||||
|
||||
|
||||
def _make_pool(quota_limit: int, quota_used: int) -> MagicMock:
|
||||
pool = MagicMock()
|
||||
pool.quota_limit = quota_limit
|
||||
pool.quota_used = quota_used
|
||||
return pool
|
||||
|
||||
|
||||
def _make_tenant_account_join(role: str = "normal") -> SimpleNamespace:
|
||||
return SimpleNamespace(role=role)
|
||||
|
||||
|
||||
def _tenant_info(result: object) -> dict[str, Any] | None:
|
||||
return cast(dict[str, Any] | None, result)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Shared fixtures
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_current_user() -> SimpleNamespace:
|
||||
"""Return a lightweight current_user stand-in."""
|
||||
return SimpleNamespace(id=ACCOUNT_ID)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict:
|
||||
"""
|
||||
Patch the common external boundaries used by WorkspaceService.get_tenant_info.
|
||||
|
||||
Returns a dict of named mocks so individual tests can customise them.
|
||||
"""
|
||||
mocker.patch(CURRENT_USER_PATH, mock_current_user)
|
||||
|
||||
mock_db_session = mocker.patch(f"{DB_PATH}.session")
|
||||
mock_query_chain = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query_chain
|
||||
mock_query_chain.where.return_value = mock_query_chain
|
||||
mock_query_chain.first.return_value = _make_tenant_account_join(role="owner")
|
||||
|
||||
mock_feature = mocker.patch(FEATURE_SERVICE_PATH, return_value=_make_feature())
|
||||
mock_has_roles = mocker.patch(TENANT_SERVICE_PATH, return_value=False)
|
||||
mock_config = mocker.patch(DIFY_CONFIG_PATH)
|
||||
mock_config.EDITION = "SELF_HOSTED"
|
||||
mock_config.FILES_URL = FILES_BASE_URL
|
||||
|
||||
return {
|
||||
"db_session": mock_db_session,
|
||||
"query_chain": mock_query_chain,
|
||||
"get_features": mock_feature,
|
||||
"has_roles": mock_has_roles,
|
||||
"config": mock_config,
|
||||
}
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 1. None Tenant Handling
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_return_none_when_tenant_is_none() -> None:
|
||||
"""get_tenant_info should short-circuit and return None for a falsy tenant."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
tenant = None
|
||||
|
||||
# Act
|
||||
result = WorkspaceService.get_tenant_info(cast(Tenant, tenant))
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_get_tenant_info_should_return_none_when_tenant_is_falsy() -> None:
|
||||
"""get_tenant_info treats any falsy value as absent (e.g. empty string, 0)."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange / Act / Assert
|
||||
assert WorkspaceService.get_tenant_info("") is None # type: ignore[arg-type]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 2. Basic Tenant Info — happy path
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_return_base_fields(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""get_tenant_info should always return the six base scalar fields."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["id"] == TENANT_ID
|
||||
assert result["name"] == "My Workspace"
|
||||
assert result["plan"] == "sandbox"
|
||||
assert result["status"] == "active"
|
||||
assert result["created_at"] == "2024-01-01T00:00:00Z"
|
||||
assert result["trial_end_reason"] is None
|
||||
|
||||
|
||||
def test_get_tenant_info_should_populate_role_from_tenant_account_join(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""The 'role' field should be taken from TenantAccountJoin, not the default."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["query_chain"].first.return_value = _make_tenant_account_join(role="admin")
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["role"] == "admin"
|
||||
|
||||
|
||||
def test_get_tenant_info_should_raise_assertion_when_tenant_account_join_missing(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""
|
||||
The service asserts that TenantAccountJoin exists.
|
||||
Missing join should raise AssertionError.
|
||||
"""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["query_chain"].first.return_value = None
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act + Assert
|
||||
with pytest.raises(AssertionError, match="TenantAccountJoin not found"):
|
||||
WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 3. Logo Customisation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_include_custom_config_when_logo_allowed_and_admin(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""custom_config block should appear for OWNER/ADMIN when can_replace_logo is True."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
|
||||
basic_mocks["has_roles"].return_value = True
|
||||
tenant = _make_tenant(
|
||||
custom_config={
|
||||
"replace_webapp_logo": True,
|
||||
"remove_webapp_brand": True,
|
||||
}
|
||||
)
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "custom_config" in result
|
||||
assert result["custom_config"]["remove_webapp_brand"] is True
|
||||
expected_logo_url = f"{FILES_BASE_URL}/files/workspaces/{TENANT_ID}/webapp-logo"
|
||||
assert result["custom_config"]["replace_webapp_logo"] == expected_logo_url
|
||||
|
||||
|
||||
def test_get_tenant_info_should_set_replace_webapp_logo_to_none_when_flag_absent(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""replace_webapp_logo should be None when custom_config_dict does not have the key."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
|
||||
basic_mocks["has_roles"].return_value = True
|
||||
tenant = _make_tenant(custom_config={}) # no replace_webapp_logo key
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["custom_config"]["replace_webapp_logo"] is None
|
||||
|
||||
|
||||
def test_get_tenant_info_should_not_include_custom_config_when_logo_not_allowed(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""custom_config should be absent when can_replace_logo is False."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=False)
|
||||
basic_mocks["has_roles"].return_value = True
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "custom_config" not in result
|
||||
|
||||
|
||||
def test_get_tenant_info_should_not_include_custom_config_when_user_not_admin(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""custom_config block is gated on OWNER or ADMIN role."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
|
||||
basic_mocks["has_roles"].return_value = False # regular member
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "custom_config" not in result
|
||||
|
||||
|
||||
def test_get_tenant_info_should_use_files_url_for_logo_url(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""The logo URL should use dify_config.FILES_URL as the base."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
custom_base = "https://cdn.mycompany.io"
|
||||
basic_mocks["config"].FILES_URL = custom_base
|
||||
basic_mocks["get_features"].return_value = _make_feature(can_replace_logo=True)
|
||||
basic_mocks["has_roles"].return_value = True
|
||||
tenant = _make_tenant(custom_config={"replace_webapp_logo": True})
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["custom_config"]["replace_webapp_logo"].startswith(custom_base)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 4. Cloud-Edition Credit Features
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
CLOUD_BILLING_PLAN_NON_SANDBOX = "professional" # any plan that is not SANDBOX
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def cloud_mocks(mocker: MockerFixture, mock_current_user: SimpleNamespace) -> dict:
|
||||
"""Patches for CLOUD edition tests, billing plan = professional by default."""
|
||||
mocker.patch(CURRENT_USER_PATH, mock_current_user)
|
||||
|
||||
mock_db_session = mocker.patch(f"{DB_PATH}.session")
|
||||
mock_query_chain = MagicMock()
|
||||
mock_db_session.query.return_value = mock_query_chain
|
||||
mock_query_chain.where.return_value = mock_query_chain
|
||||
mock_query_chain.first.return_value = _make_tenant_account_join(role="owner")
|
||||
|
||||
mock_feature = mocker.patch(
|
||||
FEATURE_SERVICE_PATH,
|
||||
return_value=_make_feature(
|
||||
can_replace_logo=False,
|
||||
next_credit_reset_date="2025-02-01",
|
||||
billing_plan=CLOUD_BILLING_PLAN_NON_SANDBOX,
|
||||
),
|
||||
)
|
||||
mocker.patch(TENANT_SERVICE_PATH, return_value=False)
|
||||
mock_config = mocker.patch(DIFY_CONFIG_PATH)
|
||||
mock_config.EDITION = "CLOUD"
|
||||
mock_config.FILES_URL = FILES_BASE_URL
|
||||
|
||||
return {
|
||||
"db_session": mock_db_session,
|
||||
"query_chain": mock_query_chain,
|
||||
"get_features": mock_feature,
|
||||
"config": mock_config,
|
||||
}
|
||||
|
||||
|
||||
def test_get_tenant_info_should_add_next_credit_reset_date_in_cloud_edition(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""next_credit_reset_date should be present in CLOUD edition."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
mocker.patch(
|
||||
CREDIT_POOL_SERVICE_PATH,
|
||||
side_effect=[None, None], # both paid and trial pools absent
|
||||
)
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["next_credit_reset_date"] == "2025-02-01"
|
||||
|
||||
|
||||
def test_get_tenant_info_should_use_paid_pool_when_plan_is_not_sandbox_and_pool_not_full(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""trial_credits/trial_credits_used come from the paid pool when conditions are met."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
paid_pool = _make_pool(quota_limit=1000, quota_used=200)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, return_value=paid_pool)
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 1000
|
||||
assert result["trial_credits_used"] == 200
|
||||
|
||||
|
||||
def test_get_tenant_info_should_use_paid_pool_when_quota_limit_is_infinite(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""quota_limit == -1 means unlimited; service should still use the paid pool."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
paid_pool = _make_pool(quota_limit=-1, quota_used=999)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, None])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == -1
|
||||
assert result["trial_credits_used"] == 999
|
||||
|
||||
|
||||
def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_full(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""When paid pool is exhausted (used >= limit), switch to trial pool."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
paid_pool = _make_pool(quota_limit=500, quota_used=500) # exactly full
|
||||
trial_pool = _make_pool(quota_limit=100, quota_used=10)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 100
|
||||
assert result["trial_credits_used"] == 10
|
||||
|
||||
|
||||
def test_get_tenant_info_should_fall_back_to_trial_pool_when_paid_pool_is_none(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""When paid_pool is None, fall back to trial pool."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
trial_pool = _make_pool(quota_limit=50, quota_used=5)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, trial_pool])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 50
|
||||
assert result["trial_credits_used"] == 5
|
||||
|
||||
|
||||
def test_get_tenant_info_should_fall_back_to_trial_pool_for_sandbox_plan(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""
|
||||
When the subscription plan IS SANDBOX, the paid pool branch is skipped
|
||||
entirely and we fall back to the trial pool.
|
||||
"""
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange — override billing plan to SANDBOX
|
||||
cloud_mocks["get_features"].return_value = _make_feature(
|
||||
next_credit_reset_date="2025-02-01",
|
||||
billing_plan=CloudPlan.SANDBOX,
|
||||
)
|
||||
paid_pool = _make_pool(quota_limit=1000, quota_used=0)
|
||||
trial_pool = _make_pool(quota_limit=200, quota_used=20)
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[paid_pool, trial_pool])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert result["trial_credits"] == 200
|
||||
assert result["trial_credits_used"] == 20
|
||||
|
||||
|
||||
def test_get_tenant_info_should_omit_trial_credits_when_both_pools_are_none(
|
||||
mocker: MockerFixture,
|
||||
cloud_mocks: dict,
|
||||
) -> None:
|
||||
"""When both paid and trial pools are absent, trial_credits should not be set."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
mocker.patch(CREDIT_POOL_SERVICE_PATH, side_effect=[None, None])
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "trial_credits" not in result
|
||||
assert "trial_credits_used" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 5. Self-hosted / Non-Cloud Edition
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_not_include_cloud_fields_in_self_hosted(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""next_credit_reset_date and trial_credits should NOT appear in SELF_HOSTED mode."""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange (basic_mocks already sets EDITION = "SELF_HOSTED")
|
||||
tenant = _make_tenant()
|
||||
|
||||
# Act
|
||||
result = _tenant_info(WorkspaceService.get_tenant_info(tenant))
|
||||
|
||||
# Assert
|
||||
assert result is not None
|
||||
assert "next_credit_reset_date" not in result
|
||||
assert "trial_credits" not in result
|
||||
assert "trial_credits_used" not in result
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# 6. DB query integrity
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_get_tenant_info_should_query_tenant_account_join_with_correct_ids(
|
||||
mocker: MockerFixture,
|
||||
basic_mocks: dict,
|
||||
) -> None:
|
||||
"""
|
||||
The DB query for TenantAccountJoin must be scoped to the correct
|
||||
tenant_id and current_user.id.
|
||||
"""
|
||||
from services.workspace_service import WorkspaceService
|
||||
|
||||
# Arrange
|
||||
tenant = _make_tenant(tenant_id="my-special-tenant")
|
||||
mock_current_user = mocker.patch(CURRENT_USER_PATH)
|
||||
mock_current_user.id = "special-user-id"
|
||||
|
||||
# Act
|
||||
WorkspaceService.get_tenant_info(tenant)
|
||||
|
||||
# Assert — db.session.query was invoked (at least once)
|
||||
basic_mocks["db_session"].query.assert_called()
|
||||
@@ -488,8 +488,7 @@ ALIYUN_OSS_REGION=ap-southeast-1
|
||||
ALIYUN_OSS_AUTH_VERSION=v4
|
||||
# Don't start with '/'. OSS doesn't support leading slash in object names.
|
||||
ALIYUN_OSS_PATH=your-path
|
||||
# Optional CloudBox ID for Aliyun OSS, DO NOT enable it if you are not using CloudBox.
|
||||
#ALIYUN_CLOUDBOX_ID=your-cloudbox-id
|
||||
ALIYUN_CLOUDBOX_ID=your-cloudbox-id
|
||||
|
||||
# Tencent COS Configuration
|
||||
#
|
||||
|
||||
@@ -275,7 +275,6 @@ services:
|
||||
# Use the shared environment variables.
|
||||
<<: *shared-api-worker-env
|
||||
DB_DATABASE: ${DB_PLUGIN_DATABASE:-dify_plugin}
|
||||
DB_SSL_MODE: ${DB_SSL_MODE:-disable}
|
||||
SERVER_PORT: ${PLUGIN_DAEMON_PORT:-5002}
|
||||
SERVER_KEY: ${PLUGIN_DAEMON_KEY:-lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi}
|
||||
MAX_PLUGIN_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
|
||||
|
||||
@@ -146,6 +146,7 @@ x-shared-env: &shared-api-worker-env
|
||||
ALIYUN_OSS_REGION: ${ALIYUN_OSS_REGION:-ap-southeast-1}
|
||||
ALIYUN_OSS_AUTH_VERSION: ${ALIYUN_OSS_AUTH_VERSION:-v4}
|
||||
ALIYUN_OSS_PATH: ${ALIYUN_OSS_PATH:-your-path}
|
||||
ALIYUN_CLOUDBOX_ID: ${ALIYUN_CLOUDBOX_ID:-your-cloudbox-id}
|
||||
TENCENT_COS_BUCKET_NAME: ${TENCENT_COS_BUCKET_NAME:-your-bucket-name}
|
||||
TENCENT_COS_SECRET_KEY: ${TENCENT_COS_SECRET_KEY:-your-secret-key}
|
||||
TENCENT_COS_SECRET_ID: ${TENCENT_COS_SECRET_ID:-your-secret-id}
|
||||
@@ -984,7 +985,6 @@ services:
|
||||
# Use the shared environment variables.
|
||||
<<: *shared-api-worker-env
|
||||
DB_DATABASE: ${DB_PLUGIN_DATABASE:-dify_plugin}
|
||||
DB_SSL_MODE: ${DB_SSL_MODE:-disable}
|
||||
SERVER_PORT: ${PLUGIN_DAEMON_PORT:-5002}
|
||||
SERVER_KEY: ${PLUGIN_DAEMON_KEY:-lYkiYYT6owG+71oLerGzA7GXCgOT++6ovaezWAjpCjf+Sjc3ZtU+qUEi}
|
||||
MAX_PLUGIN_PACKAGE_SIZE: ${PLUGIN_MAX_PACKAGE_SIZE:-52428800}
|
||||
|
||||
Reference in New Issue
Block a user