test: migrate dataset service update-dataset SQL tests to testcontainers (#32533)

Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
木之本澪
2026-02-27 06:10:15 +08:00
committed by GitHub
parent b48f36a4e5
commit 5cb1b53b47
2 changed files with 529 additions and 661 deletions

View File

@@ -0,0 +1,529 @@
from unittest.mock import Mock, patch
from uuid import uuid4
import pytest
from core.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, ExternalKnowledgeBindings
from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
class DatasetUpdateTestDataFactory:
"""Factory class for creating real test data for dataset update integration tests."""
@staticmethod
def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]:
"""Create a real account and tenant with the given role."""
account = Account(
email=f"{uuid4()}@example.com",
name=f"user-{uuid4()}",
interface_language="en-US",
status="active",
)
db.session.add(account)
db.session.commit()
tenant = Tenant(name=f"tenant-{account.id}", status="normal")
db.session.add(tenant)
db.session.commit()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=role,
current=True,
)
db.session.add(join)
db.session.commit()
account.current_tenant = tenant
return account, tenant
@staticmethod
def create_dataset(
tenant_id: str,
created_by: str,
provider: str = "vendor",
name: str = "old_name",
description: str = "old_description",
indexing_technique: str = "high_quality",
retrieval_model: str = "old_model",
permission: str = "only_me",
embedding_model_provider: str | None = None,
embedding_model: str | None = None,
collection_binding_id: str | None = None,
) -> Dataset:
"""Create a real dataset."""
dataset = Dataset(
tenant_id=tenant_id,
name=name,
description=description,
data_source_type="upload_file",
indexing_technique=indexing_technique,
created_by=created_by,
provider=provider,
retrieval_model=retrieval_model,
permission=permission,
embedding_model_provider=embedding_model_provider,
embedding_model=embedding_model,
collection_binding_id=collection_binding_id,
)
db.session.add(dataset)
db.session.commit()
return dataset
@staticmethod
def create_external_binding(
tenant_id: str,
dataset_id: str,
created_by: str,
external_knowledge_id: str = "old_knowledge_id",
external_knowledge_api_id: str | None = None,
) -> ExternalKnowledgeBindings:
"""Create a real external knowledge binding."""
if external_knowledge_api_id is None:
external_knowledge_api_id = str(uuid4())
binding = ExternalKnowledgeBindings(
tenant_id=tenant_id,
dataset_id=dataset_id,
created_by=created_by,
external_knowledge_id=external_knowledge_id,
external_knowledge_api_id=external_knowledge_api_id,
)
db.session.add(binding)
db.session.commit()
return binding
class TestDatasetServiceUpdateDataset:
"""
Comprehensive integration tests for DatasetService.update_dataset method.
This test suite covers all supported scenarios including:
- External dataset updates
- Internal dataset updates with different indexing techniques
- Embedding model updates
- Permission checks
- Error conditions and edge cases
"""
# ==================== External Dataset Tests ====================
def test_update_external_dataset_success(self, db_session_with_containers):
"""Test successful update of external dataset."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
dataset = DatasetUpdateTestDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=user.id,
provider="external",
name="old_name",
description="old_description",
retrieval_model="old_model",
)
binding = DatasetUpdateTestDataFactory.create_external_binding(
tenant_id=tenant.id,
dataset_id=dataset.id,
created_by=user.id,
)
binding_id = binding.id
db.session.expunge(binding)
update_data = {
"name": "new_name",
"description": "new_description",
"external_retrieval_model": "new_model",
"permission": "only_me",
"external_knowledge_id": "new_knowledge_id",
"external_knowledge_api_id": str(uuid4()),
}
result = DatasetService.update_dataset(dataset.id, update_data, user)
db.session.refresh(dataset)
updated_binding = db.session.query(ExternalKnowledgeBindings).filter_by(id=binding_id).first()
assert dataset.name == "new_name"
assert dataset.description == "new_description"
assert dataset.retrieval_model == "new_model"
assert updated_binding is not None
assert updated_binding.external_knowledge_id == "new_knowledge_id"
assert updated_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"]
assert result.id == dataset.id
def test_update_external_dataset_missing_knowledge_id_error(self, db_session_with_containers):
"""Test error when external knowledge id is missing."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
dataset = DatasetUpdateTestDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=user.id,
provider="external",
)
DatasetUpdateTestDataFactory.create_external_binding(
tenant_id=tenant.id,
dataset_id=dataset.id,
created_by=user.id,
)
update_data = {"name": "new_name", "external_knowledge_api_id": str(uuid4())}
with pytest.raises(ValueError) as context:
DatasetService.update_dataset(dataset.id, update_data, user)
assert "External knowledge id is required" in str(context.value)
db.session.rollback()
def test_update_external_dataset_missing_api_id_error(self, db_session_with_containers):
"""Test error when external knowledge api id is missing."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
dataset = DatasetUpdateTestDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=user.id,
provider="external",
)
DatasetUpdateTestDataFactory.create_external_binding(
tenant_id=tenant.id,
dataset_id=dataset.id,
created_by=user.id,
)
update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"}
with pytest.raises(ValueError) as context:
DatasetService.update_dataset(dataset.id, update_data, user)
assert "External knowledge api id is required" in str(context.value)
db.session.rollback()
def test_update_external_dataset_binding_not_found_error(self, db_session_with_containers):
"""Test error when external knowledge binding is not found."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
dataset = DatasetUpdateTestDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=user.id,
provider="external",
)
update_data = {
"name": "new_name",
"external_knowledge_id": "knowledge_id",
"external_knowledge_api_id": str(uuid4()),
}
with pytest.raises(ValueError) as context:
DatasetService.update_dataset(dataset.id, update_data, user)
assert "External knowledge binding not found" in str(context.value)
db.session.rollback()
# ==================== Internal Dataset Basic Tests ====================
def test_update_internal_dataset_basic_success(self, db_session_with_containers):
"""Test successful update of internal dataset with basic fields."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
existing_binding_id = str(uuid4())
dataset = DatasetUpdateTestDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="high_quality",
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id=existing_binding_id,
)
update_data = {
"name": "new_name",
"description": "new_description",
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
}
result = DatasetService.update_dataset(dataset.id, update_data, user)
db.session.refresh(dataset)
assert dataset.name == "new_name"
assert dataset.description == "new_description"
assert dataset.indexing_technique == "high_quality"
assert dataset.retrieval_model == "new_model"
assert dataset.embedding_model_provider == "openai"
assert dataset.embedding_model == "text-embedding-ada-002"
assert result.id == dataset.id
def test_update_internal_dataset_filter_none_values(self, db_session_with_containers):
"""Test that None values are filtered out except for description field."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
existing_binding_id = str(uuid4())
dataset = DatasetUpdateTestDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="high_quality",
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id=existing_binding_id,
)
update_data = {
"name": "new_name",
"description": None,
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"embedding_model_provider": None,
"embedding_model": None,
}
result = DatasetService.update_dataset(dataset.id, update_data, user)
db.session.refresh(dataset)
assert dataset.name == "new_name"
assert dataset.description is None
assert dataset.embedding_model_provider == "openai"
assert dataset.embedding_model == "text-embedding-ada-002"
assert dataset.retrieval_model == "new_model"
assert result.id == dataset.id
# ==================== Indexing Technique Switch Tests ====================
def test_update_internal_dataset_indexing_technique_to_economy(self, db_session_with_containers):
"""Test updating internal dataset indexing technique to economy."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
existing_binding_id = str(uuid4())
dataset = DatasetUpdateTestDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="high_quality",
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id=existing_binding_id,
)
update_data = {
"indexing_technique": "economy",
"retrieval_model": "new_model",
}
with patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task:
result = DatasetService.update_dataset(dataset.id, update_data, user)
mock_task.delay.assert_called_once_with(dataset.id, "remove")
db.session.refresh(dataset)
assert dataset.indexing_technique == "economy"
assert dataset.embedding_model is None
assert dataset.embedding_model_provider is None
assert dataset.collection_binding_id is None
assert dataset.retrieval_model == "new_model"
assert result.id == dataset.id
def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_session_with_containers):
"""Test updating internal dataset indexing technique to high_quality."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
dataset = DatasetUpdateTestDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="economy",
)
embedding_model = Mock()
embedding_model.model = "text-embedding-ada-002"
embedding_model.provider = "openai"
binding = Mock()
binding.id = str(uuid4())
update_data = {
"indexing_technique": "high_quality",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"retrieval_model": "new_model",
}
with (
patch("services.dataset_service.current_user", user),
patch("services.dataset_service.ModelManager") as mock_model_manager,
patch(
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
) as mock_get_binding,
patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task,
):
mock_model_manager.return_value.get_model_instance.return_value = embedding_model
mock_get_binding.return_value = binding
result = DatasetService.update_dataset(dataset.id, update_data, user)
mock_model_manager.return_value.get_model_instance.assert_called_once_with(
tenant_id=tenant.id,
provider="openai",
model_type=ModelType.TEXT_EMBEDDING,
model="text-embedding-ada-002",
)
mock_get_binding.assert_called_once_with("openai", "text-embedding-ada-002")
mock_task.delay.assert_called_once_with(dataset.id, "add")
db.session.refresh(dataset)
assert dataset.indexing_technique == "high_quality"
assert dataset.embedding_model == "text-embedding-ada-002"
assert dataset.embedding_model_provider == "openai"
assert dataset.collection_binding_id == binding.id
assert dataset.retrieval_model == "new_model"
assert result.id == dataset.id
# ==================== Embedding Model Update Tests ====================
def test_update_internal_dataset_keep_existing_embedding_model_when_indexing_technique_unchanged(
self, db_session_with_containers
):
"""Test preserving embedding settings when indexing technique remains unchanged."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
existing_binding_id = str(uuid4())
dataset = DatasetUpdateTestDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="high_quality",
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id=existing_binding_id,
)
update_data = {
"name": "new_name",
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
}
result = DatasetService.update_dataset(dataset.id, update_data, user)
db.session.refresh(dataset)
assert dataset.name == "new_name"
assert dataset.indexing_technique == "high_quality"
assert dataset.embedding_model_provider == "openai"
assert dataset.embedding_model == "text-embedding-ada-002"
assert dataset.collection_binding_id == existing_binding_id
assert dataset.retrieval_model == "new_model"
assert result.id == dataset.id
def test_update_internal_dataset_embedding_model_update(self, db_session_with_containers):
"""Test updating internal dataset with new embedding model."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
existing_binding_id = str(uuid4())
dataset = DatasetUpdateTestDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="high_quality",
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id=existing_binding_id,
)
embedding_model = Mock()
embedding_model.model = "text-embedding-3-small"
embedding_model.provider = "openai"
binding = Mock()
binding.id = str(uuid4())
update_data = {
"indexing_technique": "high_quality",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-3-small",
"retrieval_model": "new_model",
}
with (
patch("services.dataset_service.current_user", user),
patch("services.dataset_service.ModelManager") as mock_model_manager,
patch(
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
) as mock_get_binding,
patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task,
patch("services.dataset_service.regenerate_summary_index_task") as mock_regenerate_task,
):
mock_model_manager.return_value.get_model_instance.return_value = embedding_model
mock_get_binding.return_value = binding
result = DatasetService.update_dataset(dataset.id, update_data, user)
mock_model_manager.return_value.get_model_instance.assert_called_once_with(
tenant_id=tenant.id,
provider="openai",
model_type=ModelType.TEXT_EMBEDDING,
model="text-embedding-3-small",
)
mock_get_binding.assert_called_once_with("openai", "text-embedding-3-small")
mock_task.delay.assert_called_once_with(dataset.id, "update")
mock_regenerate_task.delay.assert_called_once_with(
dataset.id,
regenerate_reason="embedding_model_changed",
regenerate_vectors_only=True,
)
db.session.refresh(dataset)
assert dataset.embedding_model == "text-embedding-3-small"
assert dataset.embedding_model_provider == "openai"
assert dataset.collection_binding_id == binding.id
assert dataset.retrieval_model == "new_model"
assert result.id == dataset.id
# ==================== Error Handling Tests ====================
def test_update_dataset_not_found_error(self, db_session_with_containers):
"""Test error when dataset is not found."""
user, _ = DatasetUpdateTestDataFactory.create_account_with_tenant()
update_data = {"name": "new_name"}
with pytest.raises(ValueError) as context:
DatasetService.update_dataset(str(uuid4()), update_data, user)
assert "Dataset not found" in str(context.value)
def test_update_dataset_permission_error(self, db_session_with_containers):
"""Test error when user doesn't have permission."""
owner, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
outsider, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
dataset = DatasetUpdateTestDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=owner.id,
provider="vendor",
permission="only_me",
)
update_data = {"name": "new_name"}
with pytest.raises(NoPermissionError):
DatasetService.update_dataset(dataset.id, update_data, outsider)
def test_update_internal_dataset_embedding_model_error(self, db_session_with_containers):
"""Test error when embedding model is not available."""
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
dataset = DatasetUpdateTestDataFactory.create_dataset(
tenant_id=tenant.id,
created_by=user.id,
provider="vendor",
indexing_technique="economy",
)
update_data = {
"indexing_technique": "high_quality",
"embedding_model_provider": "invalid_provider",
"embedding_model": "invalid_model",
"retrieval_model": "new_model",
}
with (
patch("services.dataset_service.current_user", user),
patch("services.dataset_service.ModelManager") as mock_model_manager,
):
mock_model_manager.return_value.get_model_instance.side_effect = Exception("No Embedding Model available")
with pytest.raises(Exception) as context:
DatasetService.update_dataset(dataset.id, update_data, user)
assert "No Embedding Model available".lower() in str(context.value).lower()

View File

@@ -1,661 +0,0 @@
import datetime
from typing import Any
# Mock redis_client before importing dataset_service
from unittest.mock import Mock, create_autospec, patch
import pytest
from core.model_runtime.entities.model_entities import ModelType
from models.account import Account
from models.dataset import Dataset, ExternalKnowledgeBindings
from services.dataset_service import DatasetService
from services.errors.account import NoPermissionError
class DatasetUpdateTestDataFactory:
"""Factory class for creating test data and mock objects for dataset update tests."""
@staticmethod
def create_dataset_mock(
dataset_id: str = "dataset-123",
provider: str = "vendor",
name: str = "old_name",
description: str = "old_description",
indexing_technique: str = "high_quality",
retrieval_model: str = "old_model",
embedding_model_provider: str | None = None,
embedding_model: str | None = None,
collection_binding_id: str | None = None,
**kwargs,
) -> Mock:
"""Create a mock dataset with specified attributes."""
dataset = Mock(spec=Dataset)
dataset.id = dataset_id
dataset.provider = provider
dataset.name = name
dataset.description = description
dataset.indexing_technique = indexing_technique
dataset.retrieval_model = retrieval_model
dataset.embedding_model_provider = embedding_model_provider
dataset.embedding_model = embedding_model
dataset.collection_binding_id = collection_binding_id
for key, value in kwargs.items():
setattr(dataset, key, value)
return dataset
@staticmethod
def create_user_mock(user_id: str = "user-789") -> Mock:
"""Create a mock user."""
user = Mock()
user.id = user_id
return user
@staticmethod
def create_external_binding_mock(
external_knowledge_id: str = "old_knowledge_id", external_knowledge_api_id: str = "old_api_id"
) -> Mock:
"""Create a mock external knowledge binding."""
binding = Mock(spec=ExternalKnowledgeBindings)
binding.external_knowledge_id = external_knowledge_id
binding.external_knowledge_api_id = external_knowledge_api_id
return binding
@staticmethod
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
"""Create a mock embedding model."""
embedding_model = Mock()
embedding_model.model = model
embedding_model.provider = provider
return embedding_model
@staticmethod
def create_collection_binding_mock(binding_id: str = "binding-456") -> Mock:
"""Create a mock collection binding."""
binding = Mock()
binding.id = binding_id
return binding
@staticmethod
def create_current_user_mock(tenant_id: str = "tenant-123") -> Mock:
"""Create a mock current user."""
current_user = create_autospec(Account, instance=True)
current_user.current_tenant_id = tenant_id
return current_user
class TestDatasetServiceUpdateDataset:
"""
Comprehensive unit tests for DatasetService.update_dataset method.
This test suite covers all supported scenarios including:
- External dataset updates
- Internal dataset updates with different indexing techniques
- Embedding model updates
- Permission checks
- Error conditions and edge cases
"""
@pytest.fixture
def mock_dataset_service_dependencies(self):
"""Common mock setup for dataset service dependencies."""
with (
patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
patch("extensions.ext_database.db.session") as mock_db,
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
patch("services.dataset_service.DatasetService._has_dataset_same_name") as has_dataset_same_name,
):
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
mock_naive_utc_now.return_value = current_time
yield {
"get_dataset": mock_get_dataset,
"check_permission": mock_check_perm,
"db_session": mock_db,
"naive_utc_now": mock_naive_utc_now,
"current_time": current_time,
"has_dataset_same_name": has_dataset_same_name,
}
@pytest.fixture
def mock_external_provider_dependencies(self):
"""Mock setup for external provider tests."""
with patch("services.dataset_service.Session") as mock_session:
from extensions.ext_database import db
with patch.object(db.__class__, "engine", new_callable=Mock):
session_mock = Mock()
mock_session.return_value.__enter__.return_value = session_mock
yield session_mock
@pytest.fixture
def mock_internal_provider_dependencies(self):
"""Mock setup for internal provider tests."""
with (
patch("services.dataset_service.ModelManager") as mock_model_manager,
patch(
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
) as mock_get_binding,
patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task,
patch("services.dataset_service.regenerate_summary_index_task") as mock_regenerate_task,
patch(
"services.dataset_service.current_user", create_autospec(Account, instance=True)
) as mock_current_user,
):
mock_current_user.current_tenant_id = "tenant-123"
yield {
"model_manager": mock_model_manager,
"get_binding": mock_get_binding,
"task": mock_task,
"regenerate_task": mock_regenerate_task,
"current_user": mock_current_user,
}
def _assert_database_update_called(self, mock_db, dataset_id: str, expected_updates: dict[str, Any]):
"""Helper method to verify database update calls."""
mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_updates)
mock_db.commit.assert_called_once()
def _assert_external_dataset_update(self, mock_dataset, mock_binding, update_data: dict[str, Any]):
"""Helper method to verify external dataset updates."""
assert mock_dataset.name == update_data.get("name", mock_dataset.name)
assert mock_dataset.description == update_data.get("description", mock_dataset.description)
assert mock_dataset.retrieval_model == update_data.get("external_retrieval_model", mock_dataset.retrieval_model)
if "external_knowledge_id" in update_data:
assert mock_binding.external_knowledge_id == update_data["external_knowledge_id"]
if "external_knowledge_api_id" in update_data:
assert mock_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"]
# ==================== External Dataset Tests ====================
def test_update_external_dataset_success(
self, mock_dataset_service_dependencies, mock_external_provider_dependencies
):
"""Test successful update of external dataset."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
provider="external", name="old_name", description="old_description", retrieval_model="old_model"
)
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
binding = DatasetUpdateTestDataFactory.create_external_binding_mock()
# Mock external knowledge binding query
mock_external_provider_dependencies.query.return_value.filter_by.return_value.first.return_value = binding
update_data = {
"name": "new_name",
"description": "new_description",
"external_retrieval_model": "new_model",
"permission": "only_me",
"external_knowledge_id": "new_knowledge_id",
"external_knowledge_api_id": "new_api_id",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
# Verify dataset and binding updates
self._assert_external_dataset_update(dataset, binding, update_data)
# Verify database operations
mock_db = mock_dataset_service_dependencies["db_session"]
mock_db.add.assert_any_call(dataset)
mock_db.add.assert_any_call(binding)
mock_db.commit.assert_called_once()
# Verify return value
assert result == dataset
def test_update_external_dataset_missing_knowledge_id_error(self, mock_dataset_service_dependencies):
"""Test error when external knowledge id is missing."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, user)
assert "External knowledge id is required" in str(context.value)
def test_update_external_dataset_missing_api_id_error(self, mock_dataset_service_dependencies):
"""Test error when external knowledge api id is missing."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, user)
assert "External knowledge api id is required" in str(context.value)
def test_update_external_dataset_binding_not_found_error(
self, mock_dataset_service_dependencies, mock_external_provider_dependencies
):
"""Test error when external knowledge binding is not found."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
# Mock external knowledge binding query returning None
mock_external_provider_dependencies.query.return_value.filter_by.return_value.first.return_value = None
update_data = {
"name": "new_name",
"external_knowledge_id": "knowledge_id",
"external_knowledge_api_id": "api_id",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, user)
assert "External knowledge binding not found" in str(context.value)
# ==================== Internal Dataset Basic Tests ====================
def test_update_internal_dataset_basic_success(self, mock_dataset_service_dependencies):
"""Test successful update of internal dataset with basic fields."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
provider="vendor",
indexing_technique="high_quality",
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id="binding-123",
)
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {
"name": "new_name",
"description": "new_description",
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify permission check was called
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
# Verify database update was called with correct filtered data
expected_filtered_data = {
"name": "new_name",
"description": "new_description",
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
)
# Verify return value
assert result == dataset
def test_update_internal_dataset_filter_none_values(self, mock_dataset_service_dependencies):
"""Test that None values are filtered out except for description field."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="high_quality")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {
"name": "new_name",
"description": None, # Should be included
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"embedding_model_provider": None, # Should be filtered out
"embedding_model": None, # Should be filtered out
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify database update was called with filtered data
expected_filtered_data = {
"name": "new_name",
"description": None, # Description should be included even if None
"indexing_technique": "high_quality",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
actual_call_args = mock_dataset_service_dependencies[
"db_session"
].query.return_value.filter_by.return_value.update.call_args[0][0]
# Remove timestamp for comparison as it's dynamic
del actual_call_args["updated_at"]
del expected_filtered_data["updated_at"]
assert actual_call_args == expected_filtered_data
# Verify return value
assert result == dataset
# ==================== Indexing Technique Switch Tests ====================
def test_update_internal_dataset_indexing_technique_to_economy(
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
):
"""Test updating internal dataset indexing technique to economy."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="high_quality")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify database update was called with embedding model fields cleared
expected_filtered_data = {
"indexing_technique": "economy",
"embedding_model": None,
"embedding_model_provider": None,
"collection_binding_id": None,
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
)
# Verify return value
assert result == dataset
def test_update_internal_dataset_indexing_technique_to_high_quality(
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
):
"""Test updating internal dataset indexing technique to high_quality."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
# Mock embedding model
embedding_model = DatasetUpdateTestDataFactory.create_embedding_model_mock()
mock_internal_provider_dependencies[
"model_manager"
].return_value.get_model_instance.return_value = embedding_model
# Mock collection binding
binding = DatasetUpdateTestDataFactory.create_collection_binding_mock()
mock_internal_provider_dependencies["get_binding"].return_value = binding
update_data = {
"indexing_technique": "high_quality",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"retrieval_model": "new_model",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify embedding model was validated
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once_with(
tenant_id=mock_internal_provider_dependencies["current_user"].current_tenant_id,
provider="openai",
model_type=ModelType.TEXT_EMBEDDING,
model="text-embedding-ada-002",
)
# Verify collection binding was retrieved
mock_internal_provider_dependencies["get_binding"].assert_called_once_with("openai", "text-embedding-ada-002")
# Verify database update was called with correct data
expected_filtered_data = {
"indexing_technique": "high_quality",
"embedding_model": "text-embedding-ada-002",
"embedding_model_provider": "openai",
"collection_binding_id": "binding-456",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
)
# Verify vector index task was triggered
mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "add")
# Verify return value
assert result == dataset
# ==================== Embedding Model Update Tests ====================
def test_update_internal_dataset_keep_existing_embedding_model(self, mock_dataset_service_dependencies):
"""Test updating internal dataset without changing embedding model."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
provider="vendor",
indexing_technique="high_quality",
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id="binding-123",
)
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify database update was called with existing embedding model preserved
expected_filtered_data = {
"name": "new_name",
"indexing_technique": "high_quality",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"collection_binding_id": "binding-123",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
)
# Verify return value
assert result == dataset
def test_update_internal_dataset_embedding_model_update(
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
):
"""Test updating internal dataset with new embedding model."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
provider="vendor",
indexing_technique="high_quality",
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
)
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
# Mock embedding model
embedding_model = DatasetUpdateTestDataFactory.create_embedding_model_mock("text-embedding-3-small")
mock_internal_provider_dependencies[
"model_manager"
].return_value.get_model_instance.return_value = embedding_model
# Mock collection binding
binding = DatasetUpdateTestDataFactory.create_collection_binding_mock("binding-789")
mock_internal_provider_dependencies["get_binding"].return_value = binding
update_data = {
"indexing_technique": "high_quality",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-3-small",
"retrieval_model": "new_model",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify embedding model was validated
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once_with(
tenant_id=mock_internal_provider_dependencies["current_user"].current_tenant_id,
provider="openai",
model_type=ModelType.TEXT_EMBEDDING,
model="text-embedding-3-small",
)
# Verify collection binding was retrieved
mock_internal_provider_dependencies["get_binding"].assert_called_once_with("openai", "text-embedding-3-small")
# Verify database update was called with correct data
expected_filtered_data = {
"indexing_technique": "high_quality",
"embedding_model": "text-embedding-3-small",
"embedding_model_provider": "openai",
"collection_binding_id": "binding-789",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
)
# Verify vector index task was triggered
mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "update")
# Verify regenerate summary index task was triggered (when embedding_model changes)
mock_internal_provider_dependencies["regenerate_task"].delay.assert_called_once_with(
"dataset-123",
regenerate_reason="embedding_model_changed",
regenerate_vectors_only=True,
)
# Verify return value
assert result == dataset
def test_update_internal_dataset_no_indexing_technique_change(self, mock_dataset_service_dependencies):
"""Test updating internal dataset without changing indexing technique."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
provider="vendor",
indexing_technique="high_quality",
embedding_model_provider="openai",
embedding_model="text-embedding-ada-002",
collection_binding_id="binding-123",
)
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {
"name": "new_name",
"indexing_technique": "high_quality", # Same as current
"retrieval_model": "new_model",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
result = DatasetService.update_dataset("dataset-123", update_data, user)
# Verify database update was called with correct data
expected_filtered_data = {
"name": "new_name",
"indexing_technique": "high_quality",
"embedding_model_provider": "openai",
"embedding_model": "text-embedding-ada-002",
"collection_binding_id": "binding-123",
"retrieval_model": "new_model",
"updated_by": user.id,
"updated_at": mock_dataset_service_dependencies["current_time"],
}
self._assert_database_update_called(
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
)
# Verify return value
assert result == dataset
# ==================== Error Handling Tests ====================
def test_update_dataset_not_found_error(self, mock_dataset_service_dependencies):
"""Test error when dataset is not found."""
mock_dataset_service_dependencies["get_dataset"].return_value = None
user = DatasetUpdateTestDataFactory.create_user_mock()
update_data = {"name": "new_name"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(ValueError) as context:
DatasetService.update_dataset("dataset-123", update_data, user)
assert "Dataset not found" in str(context.value)
def test_update_dataset_permission_error(self, mock_dataset_service_dependencies):
"""Test error when user doesn't have permission."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock()
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
mock_dataset_service_dependencies["check_permission"].side_effect = NoPermissionError("No permission")
update_data = {"name": "new_name"}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(NoPermissionError):
DatasetService.update_dataset("dataset-123", update_data, user)
def test_update_internal_dataset_embedding_model_error(
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
):
"""Test error when embedding model is not available."""
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy")
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
user = DatasetUpdateTestDataFactory.create_user_mock()
# Mock model manager to raise error
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.side_effect = Exception(
"No Embedding Model available"
)
update_data = {
"indexing_technique": "high_quality",
"embedding_model_provider": "invalid_provider",
"embedding_model": "invalid_model",
"retrieval_model": "new_model",
}
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
with pytest.raises(Exception) as context:
DatasetService.update_dataset("dataset-123", update_data, user)
assert "No Embedding Model available".lower() in str(context.value).lower()