mirror of
https://github.com/langgenius/dify.git
synced 2026-02-27 03:45:09 +00:00
test: migrate dataset service update-dataset SQL tests to testcontainers (#32533)
Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,529 @@
|
||||
from unittest.mock import Mock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.dataset import Dataset, ExternalKnowledgeBindings
|
||||
from services.dataset_service import DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
|
||||
|
||||
class DatasetUpdateTestDataFactory:
|
||||
"""Factory class for creating real test data for dataset update integration tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]:
|
||||
"""Create a real account and tenant with the given role."""
|
||||
account = Account(
|
||||
email=f"{uuid4()}@example.com",
|
||||
name=f"user-{uuid4()}",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
db.session.add(account)
|
||||
db.session.commit()
|
||||
|
||||
tenant = Tenant(name=f"tenant-{account.id}", status="normal")
|
||||
db.session.add(tenant)
|
||||
db.session.commit()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=role,
|
||||
current=True,
|
||||
)
|
||||
db.session.add(join)
|
||||
db.session.commit()
|
||||
|
||||
account.current_tenant = tenant
|
||||
return account, tenant
|
||||
|
||||
@staticmethod
|
||||
def create_dataset(
|
||||
tenant_id: str,
|
||||
created_by: str,
|
||||
provider: str = "vendor",
|
||||
name: str = "old_name",
|
||||
description: str = "old_description",
|
||||
indexing_technique: str = "high_quality",
|
||||
retrieval_model: str = "old_model",
|
||||
permission: str = "only_me",
|
||||
embedding_model_provider: str | None = None,
|
||||
embedding_model: str | None = None,
|
||||
collection_binding_id: str | None = None,
|
||||
) -> Dataset:
|
||||
"""Create a real dataset."""
|
||||
dataset = Dataset(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=description,
|
||||
data_source_type="upload_file",
|
||||
indexing_technique=indexing_technique,
|
||||
created_by=created_by,
|
||||
provider=provider,
|
||||
retrieval_model=retrieval_model,
|
||||
permission=permission,
|
||||
embedding_model_provider=embedding_model_provider,
|
||||
embedding_model=embedding_model,
|
||||
collection_binding_id=collection_binding_id,
|
||||
)
|
||||
db.session.add(dataset)
|
||||
db.session.commit()
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_external_binding(
|
||||
tenant_id: str,
|
||||
dataset_id: str,
|
||||
created_by: str,
|
||||
external_knowledge_id: str = "old_knowledge_id",
|
||||
external_knowledge_api_id: str | None = None,
|
||||
) -> ExternalKnowledgeBindings:
|
||||
"""Create a real external knowledge binding."""
|
||||
if external_knowledge_api_id is None:
|
||||
external_knowledge_api_id = str(uuid4())
|
||||
binding = ExternalKnowledgeBindings(
|
||||
tenant_id=tenant_id,
|
||||
dataset_id=dataset_id,
|
||||
created_by=created_by,
|
||||
external_knowledge_id=external_knowledge_id,
|
||||
external_knowledge_api_id=external_knowledge_api_id,
|
||||
)
|
||||
db.session.add(binding)
|
||||
db.session.commit()
|
||||
return binding
|
||||
|
||||
|
||||
class TestDatasetServiceUpdateDataset:
|
||||
"""
|
||||
Comprehensive integration tests for DatasetService.update_dataset method.
|
||||
|
||||
This test suite covers all supported scenarios including:
|
||||
- External dataset updates
|
||||
- Internal dataset updates with different indexing techniques
|
||||
- Embedding model updates
|
||||
- Permission checks
|
||||
- Error conditions and edge cases
|
||||
"""
|
||||
|
||||
# ==================== External Dataset Tests ====================
|
||||
|
||||
def test_update_external_dataset_success(self, db_session_with_containers):
|
||||
"""Test successful update of external dataset."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="external",
|
||||
name="old_name",
|
||||
description="old_description",
|
||||
retrieval_model="old_model",
|
||||
)
|
||||
binding = DatasetUpdateTestDataFactory.create_external_binding(
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
created_by=user.id,
|
||||
)
|
||||
binding_id = binding.id
|
||||
db.session.expunge(binding)
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"description": "new_description",
|
||||
"external_retrieval_model": "new_model",
|
||||
"permission": "only_me",
|
||||
"external_knowledge_id": "new_knowledge_id",
|
||||
"external_knowledge_api_id": str(uuid4()),
|
||||
}
|
||||
|
||||
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
|
||||
db.session.refresh(dataset)
|
||||
updated_binding = db.session.query(ExternalKnowledgeBindings).filter_by(id=binding_id).first()
|
||||
|
||||
assert dataset.name == "new_name"
|
||||
assert dataset.description == "new_description"
|
||||
assert dataset.retrieval_model == "new_model"
|
||||
assert updated_binding is not None
|
||||
assert updated_binding.external_knowledge_id == "new_knowledge_id"
|
||||
assert updated_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"]
|
||||
assert result.id == dataset.id
|
||||
|
||||
def test_update_external_dataset_missing_knowledge_id_error(self, db_session_with_containers):
|
||||
"""Test error when external knowledge id is missing."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="external",
|
||||
)
|
||||
DatasetUpdateTestDataFactory.create_external_binding(
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
created_by=user.id,
|
||||
)
|
||||
|
||||
update_data = {"name": "new_name", "external_knowledge_api_id": str(uuid4())}
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
|
||||
assert "External knowledge id is required" in str(context.value)
|
||||
db.session.rollback()
|
||||
|
||||
def test_update_external_dataset_missing_api_id_error(self, db_session_with_containers):
|
||||
"""Test error when external knowledge api id is missing."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="external",
|
||||
)
|
||||
DatasetUpdateTestDataFactory.create_external_binding(
|
||||
tenant_id=tenant.id,
|
||||
dataset_id=dataset.id,
|
||||
created_by=user.id,
|
||||
)
|
||||
|
||||
update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"}
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
|
||||
assert "External knowledge api id is required" in str(context.value)
|
||||
db.session.rollback()
|
||||
|
||||
def test_update_external_dataset_binding_not_found_error(self, db_session_with_containers):
|
||||
"""Test error when external knowledge binding is not found."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="external",
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"external_knowledge_id": "knowledge_id",
|
||||
"external_knowledge_api_id": str(uuid4()),
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
|
||||
assert "External knowledge binding not found" in str(context.value)
|
||||
db.session.rollback()
|
||||
|
||||
# ==================== Internal Dataset Basic Tests ====================
|
||||
|
||||
def test_update_internal_dataset_basic_success(self, db_session_with_containers):
|
||||
"""Test successful update of internal dataset with basic fields."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
existing_binding_id = str(uuid4())
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id=existing_binding_id,
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"description": "new_description",
|
||||
"indexing_technique": "high_quality",
|
||||
"retrieval_model": "new_model",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
}
|
||||
|
||||
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
db.session.refresh(dataset)
|
||||
|
||||
assert dataset.name == "new_name"
|
||||
assert dataset.description == "new_description"
|
||||
assert dataset.indexing_technique == "high_quality"
|
||||
assert dataset.retrieval_model == "new_model"
|
||||
assert dataset.embedding_model_provider == "openai"
|
||||
assert dataset.embedding_model == "text-embedding-ada-002"
|
||||
assert result.id == dataset.id
|
||||
|
||||
def test_update_internal_dataset_filter_none_values(self, db_session_with_containers):
|
||||
"""Test that None values are filtered out except for description field."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
existing_binding_id = str(uuid4())
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id=existing_binding_id,
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"description": None,
|
||||
"indexing_technique": "high_quality",
|
||||
"retrieval_model": "new_model",
|
||||
"embedding_model_provider": None,
|
||||
"embedding_model": None,
|
||||
}
|
||||
|
||||
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
db.session.refresh(dataset)
|
||||
|
||||
assert dataset.name == "new_name"
|
||||
assert dataset.description is None
|
||||
assert dataset.embedding_model_provider == "openai"
|
||||
assert dataset.embedding_model == "text-embedding-ada-002"
|
||||
assert dataset.retrieval_model == "new_model"
|
||||
assert result.id == dataset.id
|
||||
|
||||
# ==================== Indexing Technique Switch Tests ====================
|
||||
|
||||
def test_update_internal_dataset_indexing_technique_to_economy(self, db_session_with_containers):
|
||||
"""Test updating internal dataset indexing technique to economy."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
existing_binding_id = str(uuid4())
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id=existing_binding_id,
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "economy",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
|
||||
with patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task:
|
||||
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
mock_task.delay.assert_called_once_with(dataset.id, "remove")
|
||||
|
||||
db.session.refresh(dataset)
|
||||
assert dataset.indexing_technique == "economy"
|
||||
assert dataset.embedding_model is None
|
||||
assert dataset.embedding_model_provider is None
|
||||
assert dataset.collection_binding_id is None
|
||||
assert dataset.retrieval_model == "new_model"
|
||||
assert result.id == dataset.id
|
||||
|
||||
def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_session_with_containers):
|
||||
"""Test updating internal dataset indexing technique to high_quality."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="economy",
|
||||
)
|
||||
|
||||
embedding_model = Mock()
|
||||
embedding_model.model = "text-embedding-ada-002"
|
||||
embedding_model.provider = "openai"
|
||||
|
||||
binding = Mock()
|
||||
binding.id = str(uuid4())
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.current_user", user),
|
||||
patch("services.dataset_service.ModelManager") as mock_model_manager,
|
||||
patch(
|
||||
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
|
||||
) as mock_get_binding,
|
||||
patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task,
|
||||
):
|
||||
mock_model_manager.return_value.get_model_instance.return_value = embedding_model
|
||||
mock_get_binding.return_value = binding
|
||||
|
||||
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
|
||||
mock_model_manager.return_value.get_model_instance.assert_called_once_with(
|
||||
tenant_id=tenant.id,
|
||||
provider="openai",
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
mock_get_binding.assert_called_once_with("openai", "text-embedding-ada-002")
|
||||
mock_task.delay.assert_called_once_with(dataset.id, "add")
|
||||
|
||||
db.session.refresh(dataset)
|
||||
assert dataset.indexing_technique == "high_quality"
|
||||
assert dataset.embedding_model == "text-embedding-ada-002"
|
||||
assert dataset.embedding_model_provider == "openai"
|
||||
assert dataset.collection_binding_id == binding.id
|
||||
assert dataset.retrieval_model == "new_model"
|
||||
assert result.id == dataset.id
|
||||
|
||||
# ==================== Embedding Model Update Tests ====================
|
||||
|
||||
def test_update_internal_dataset_keep_existing_embedding_model_when_indexing_technique_unchanged(
|
||||
self, db_session_with_containers
|
||||
):
|
||||
"""Test preserving embedding settings when indexing technique remains unchanged."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
existing_binding_id = str(uuid4())
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id=existing_binding_id,
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"indexing_technique": "high_quality",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
|
||||
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
db.session.refresh(dataset)
|
||||
|
||||
assert dataset.name == "new_name"
|
||||
assert dataset.indexing_technique == "high_quality"
|
||||
assert dataset.embedding_model_provider == "openai"
|
||||
assert dataset.embedding_model == "text-embedding-ada-002"
|
||||
assert dataset.collection_binding_id == existing_binding_id
|
||||
assert dataset.retrieval_model == "new_model"
|
||||
assert result.id == dataset.id
|
||||
|
||||
def test_update_internal_dataset_embedding_model_update(self, db_session_with_containers):
|
||||
"""Test updating internal dataset with new embedding model."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
existing_binding_id = str(uuid4())
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id=existing_binding_id,
|
||||
)
|
||||
|
||||
embedding_model = Mock()
|
||||
embedding_model.model = "text-embedding-3-small"
|
||||
embedding_model.provider = "openai"
|
||||
|
||||
binding = Mock()
|
||||
binding.id = str(uuid4())
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.current_user", user),
|
||||
patch("services.dataset_service.ModelManager") as mock_model_manager,
|
||||
patch(
|
||||
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
|
||||
) as mock_get_binding,
|
||||
patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task,
|
||||
patch("services.dataset_service.regenerate_summary_index_task") as mock_regenerate_task,
|
||||
):
|
||||
mock_model_manager.return_value.get_model_instance.return_value = embedding_model
|
||||
mock_get_binding.return_value = binding
|
||||
|
||||
result = DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
|
||||
mock_model_manager.return_value.get_model_instance.assert_called_once_with(
|
||||
tenant_id=tenant.id,
|
||||
provider="openai",
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model="text-embedding-3-small",
|
||||
)
|
||||
mock_get_binding.assert_called_once_with("openai", "text-embedding-3-small")
|
||||
mock_task.delay.assert_called_once_with(dataset.id, "update")
|
||||
mock_regenerate_task.delay.assert_called_once_with(
|
||||
dataset.id,
|
||||
regenerate_reason="embedding_model_changed",
|
||||
regenerate_vectors_only=True,
|
||||
)
|
||||
|
||||
db.session.refresh(dataset)
|
||||
assert dataset.embedding_model == "text-embedding-3-small"
|
||||
assert dataset.embedding_model_provider == "openai"
|
||||
assert dataset.collection_binding_id == binding.id
|
||||
assert dataset.retrieval_model == "new_model"
|
||||
assert result.id == dataset.id
|
||||
|
||||
# ==================== Error Handling Tests ====================
|
||||
|
||||
def test_update_dataset_not_found_error(self, db_session_with_containers):
|
||||
"""Test error when dataset is not found."""
|
||||
user, _ = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
update_data = {"name": "new_name"}
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset(str(uuid4()), update_data, user)
|
||||
|
||||
assert "Dataset not found" in str(context.value)
|
||||
|
||||
def test_update_dataset_permission_error(self, db_session_with_containers):
|
||||
"""Test error when user doesn't have permission."""
|
||||
owner, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER)
|
||||
outsider, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL)
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=owner.id,
|
||||
provider="vendor",
|
||||
permission="only_me",
|
||||
)
|
||||
|
||||
update_data = {"name": "new_name"}
|
||||
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.update_dataset(dataset.id, update_data, outsider)
|
||||
|
||||
def test_update_internal_dataset_embedding_model_error(self, db_session_with_containers):
|
||||
"""Test error when embedding model is not available."""
|
||||
user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant()
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset(
|
||||
tenant_id=tenant.id,
|
||||
created_by=user.id,
|
||||
provider="vendor",
|
||||
indexing_technique="economy",
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model_provider": "invalid_provider",
|
||||
"embedding_model": "invalid_model",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
|
||||
with (
|
||||
patch("services.dataset_service.current_user", user),
|
||||
patch("services.dataset_service.ModelManager") as mock_model_manager,
|
||||
):
|
||||
mock_model_manager.return_value.get_model_instance.side_effect = Exception("No Embedding Model available")
|
||||
|
||||
with pytest.raises(Exception) as context:
|
||||
DatasetService.update_dataset(dataset.id, update_data, user)
|
||||
|
||||
assert "No Embedding Model available".lower() in str(context.value).lower()
|
||||
@@ -1,661 +0,0 @@
|
||||
import datetime
|
||||
from typing import Any
|
||||
|
||||
# Mock redis_client before importing dataset_service
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from models.account import Account
|
||||
from models.dataset import Dataset, ExternalKnowledgeBindings
|
||||
from services.dataset_service import DatasetService
|
||||
from services.errors.account import NoPermissionError
|
||||
|
||||
|
||||
class DatasetUpdateTestDataFactory:
|
||||
"""Factory class for creating test data and mock objects for dataset update tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_dataset_mock(
|
||||
dataset_id: str = "dataset-123",
|
||||
provider: str = "vendor",
|
||||
name: str = "old_name",
|
||||
description: str = "old_description",
|
||||
indexing_technique: str = "high_quality",
|
||||
retrieval_model: str = "old_model",
|
||||
embedding_model_provider: str | None = None,
|
||||
embedding_model: str | None = None,
|
||||
collection_binding_id: str | None = None,
|
||||
**kwargs,
|
||||
) -> Mock:
|
||||
"""Create a mock dataset with specified attributes."""
|
||||
dataset = Mock(spec=Dataset)
|
||||
dataset.id = dataset_id
|
||||
dataset.provider = provider
|
||||
dataset.name = name
|
||||
dataset.description = description
|
||||
dataset.indexing_technique = indexing_technique
|
||||
dataset.retrieval_model = retrieval_model
|
||||
dataset.embedding_model_provider = embedding_model_provider
|
||||
dataset.embedding_model = embedding_model
|
||||
dataset.collection_binding_id = collection_binding_id
|
||||
for key, value in kwargs.items():
|
||||
setattr(dataset, key, value)
|
||||
return dataset
|
||||
|
||||
@staticmethod
|
||||
def create_user_mock(user_id: str = "user-789") -> Mock:
|
||||
"""Create a mock user."""
|
||||
user = Mock()
|
||||
user.id = user_id
|
||||
return user
|
||||
|
||||
@staticmethod
|
||||
def create_external_binding_mock(
|
||||
external_knowledge_id: str = "old_knowledge_id", external_knowledge_api_id: str = "old_api_id"
|
||||
) -> Mock:
|
||||
"""Create a mock external knowledge binding."""
|
||||
binding = Mock(spec=ExternalKnowledgeBindings)
|
||||
binding.external_knowledge_id = external_knowledge_id
|
||||
binding.external_knowledge_api_id = external_knowledge_api_id
|
||||
return binding
|
||||
|
||||
@staticmethod
|
||||
def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock:
|
||||
"""Create a mock embedding model."""
|
||||
embedding_model = Mock()
|
||||
embedding_model.model = model
|
||||
embedding_model.provider = provider
|
||||
return embedding_model
|
||||
|
||||
@staticmethod
|
||||
def create_collection_binding_mock(binding_id: str = "binding-456") -> Mock:
|
||||
"""Create a mock collection binding."""
|
||||
binding = Mock()
|
||||
binding.id = binding_id
|
||||
return binding
|
||||
|
||||
@staticmethod
|
||||
def create_current_user_mock(tenant_id: str = "tenant-123") -> Mock:
|
||||
"""Create a mock current user."""
|
||||
current_user = create_autospec(Account, instance=True)
|
||||
current_user.current_tenant_id = tenant_id
|
||||
return current_user
|
||||
|
||||
|
||||
class TestDatasetServiceUpdateDataset:
|
||||
"""
|
||||
Comprehensive unit tests for DatasetService.update_dataset method.
|
||||
|
||||
This test suite covers all supported scenarios including:
|
||||
- External dataset updates
|
||||
- Internal dataset updates with different indexing techniques
|
||||
- Embedding model updates
|
||||
- Permission checks
|
||||
- Error conditions and edge cases
|
||||
"""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_dataset_service_dependencies(self):
|
||||
"""Common mock setup for dataset service dependencies."""
|
||||
with (
|
||||
patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset,
|
||||
patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm,
|
||||
patch("extensions.ext_database.db.session") as mock_db,
|
||||
patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now,
|
||||
patch("services.dataset_service.DatasetService._has_dataset_same_name") as has_dataset_same_name,
|
||||
):
|
||||
current_time = datetime.datetime(2023, 1, 1, 12, 0, 0)
|
||||
mock_naive_utc_now.return_value = current_time
|
||||
|
||||
yield {
|
||||
"get_dataset": mock_get_dataset,
|
||||
"check_permission": mock_check_perm,
|
||||
"db_session": mock_db,
|
||||
"naive_utc_now": mock_naive_utc_now,
|
||||
"current_time": current_time,
|
||||
"has_dataset_same_name": has_dataset_same_name,
|
||||
}
|
||||
|
||||
@pytest.fixture
|
||||
def mock_external_provider_dependencies(self):
|
||||
"""Mock setup for external provider tests."""
|
||||
with patch("services.dataset_service.Session") as mock_session:
|
||||
from extensions.ext_database import db
|
||||
|
||||
with patch.object(db.__class__, "engine", new_callable=Mock):
|
||||
session_mock = Mock()
|
||||
mock_session.return_value.__enter__.return_value = session_mock
|
||||
yield session_mock
|
||||
|
||||
@pytest.fixture
|
||||
def mock_internal_provider_dependencies(self):
|
||||
"""Mock setup for internal provider tests."""
|
||||
with (
|
||||
patch("services.dataset_service.ModelManager") as mock_model_manager,
|
||||
patch(
|
||||
"services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding"
|
||||
) as mock_get_binding,
|
||||
patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task,
|
||||
patch("services.dataset_service.regenerate_summary_index_task") as mock_regenerate_task,
|
||||
patch(
|
||||
"services.dataset_service.current_user", create_autospec(Account, instance=True)
|
||||
) as mock_current_user,
|
||||
):
|
||||
mock_current_user.current_tenant_id = "tenant-123"
|
||||
yield {
|
||||
"model_manager": mock_model_manager,
|
||||
"get_binding": mock_get_binding,
|
||||
"task": mock_task,
|
||||
"regenerate_task": mock_regenerate_task,
|
||||
"current_user": mock_current_user,
|
||||
}
|
||||
|
||||
def _assert_database_update_called(self, mock_db, dataset_id: str, expected_updates: dict[str, Any]):
|
||||
"""Helper method to verify database update calls."""
|
||||
mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_updates)
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
def _assert_external_dataset_update(self, mock_dataset, mock_binding, update_data: dict[str, Any]):
|
||||
"""Helper method to verify external dataset updates."""
|
||||
assert mock_dataset.name == update_data.get("name", mock_dataset.name)
|
||||
assert mock_dataset.description == update_data.get("description", mock_dataset.description)
|
||||
assert mock_dataset.retrieval_model == update_data.get("external_retrieval_model", mock_dataset.retrieval_model)
|
||||
|
||||
if "external_knowledge_id" in update_data:
|
||||
assert mock_binding.external_knowledge_id == update_data["external_knowledge_id"]
|
||||
if "external_knowledge_api_id" in update_data:
|
||||
assert mock_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"]
|
||||
|
||||
# ==================== External Dataset Tests ====================
|
||||
|
||||
def test_update_external_dataset_success(
|
||||
self, mock_dataset_service_dependencies, mock_external_provider_dependencies
|
||||
):
|
||||
"""Test successful update of external dataset."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
||||
provider="external", name="old_name", description="old_description", retrieval_model="old_model"
|
||||
)
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
binding = DatasetUpdateTestDataFactory.create_external_binding_mock()
|
||||
|
||||
# Mock external knowledge binding query
|
||||
mock_external_provider_dependencies.query.return_value.filter_by.return_value.first.return_value = binding
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"description": "new_description",
|
||||
"external_retrieval_model": "new_model",
|
||||
"permission": "only_me",
|
||||
"external_knowledge_id": "new_knowledge_id",
|
||||
"external_knowledge_api_id": "new_api_id",
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||
|
||||
# Verify dataset and binding updates
|
||||
self._assert_external_dataset_update(dataset, binding, update_data)
|
||||
|
||||
# Verify database operations
|
||||
mock_db = mock_dataset_service_dependencies["db_session"]
|
||||
mock_db.add.assert_any_call(dataset)
|
||||
mock_db.add.assert_any_call(binding)
|
||||
mock_db.commit.assert_called_once()
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
def test_update_external_dataset_missing_knowledge_id_error(self, mock_dataset_service_dependencies):
|
||||
"""Test error when external knowledge id is missing."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
assert "External knowledge id is required" in str(context.value)
|
||||
|
||||
def test_update_external_dataset_missing_api_id_error(self, mock_dataset_service_dependencies):
|
||||
"""Test error when external knowledge api id is missing."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
assert "External knowledge api id is required" in str(context.value)
|
||||
|
||||
def test_update_external_dataset_binding_not_found_error(
|
||||
self, mock_dataset_service_dependencies, mock_external_provider_dependencies
|
||||
):
|
||||
"""Test error when external knowledge binding is not found."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Mock external knowledge binding query returning None
|
||||
mock_external_provider_dependencies.query.return_value.filter_by.return_value.first.return_value = None
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"external_knowledge_id": "knowledge_id",
|
||||
"external_knowledge_api_id": "api_id",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
assert "External knowledge binding not found" in str(context.value)
|
||||
|
||||
# ==================== Internal Dataset Basic Tests ====================
|
||||
|
||||
def test_update_internal_dataset_basic_success(self, mock_dataset_service_dependencies):
|
||||
"""Test successful update of internal dataset with basic fields."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id="binding-123",
|
||||
)
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"description": "new_description",
|
||||
"indexing_technique": "high_quality",
|
||||
"retrieval_model": "new_model",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify permission check was called
|
||||
mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user)
|
||||
|
||||
# Verify database update was called with correct filtered data
|
||||
expected_filtered_data = {
|
||||
"name": "new_name",
|
||||
"description": "new_description",
|
||||
"indexing_technique": "high_quality",
|
||||
"retrieval_model": "new_model",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
||||
)
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
def test_update_internal_dataset_filter_none_values(self, mock_dataset_service_dependencies):
|
||||
"""Test that None values are filtered out except for description field."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="high_quality")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"description": None, # Should be included
|
||||
"indexing_technique": "high_quality",
|
||||
"retrieval_model": "new_model",
|
||||
"embedding_model_provider": None, # Should be filtered out
|
||||
"embedding_model": None, # Should be filtered out
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify database update was called with filtered data
|
||||
expected_filtered_data = {
|
||||
"name": "new_name",
|
||||
"description": None, # Description should be included even if None
|
||||
"indexing_technique": "high_quality",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
actual_call_args = mock_dataset_service_dependencies[
|
||||
"db_session"
|
||||
].query.return_value.filter_by.return_value.update.call_args[0][0]
|
||||
# Remove timestamp for comparison as it's dynamic
|
||||
del actual_call_args["updated_at"]
|
||||
del expected_filtered_data["updated_at"]
|
||||
|
||||
assert actual_call_args == expected_filtered_data
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
# ==================== Indexing Technique Switch Tests ====================
|
||||
|
||||
def test_update_internal_dataset_indexing_technique_to_economy(
|
||||
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
|
||||
):
|
||||
"""Test updating internal dataset indexing technique to economy."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="high_quality")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify database update was called with embedding model fields cleared
|
||||
expected_filtered_data = {
|
||||
"indexing_technique": "economy",
|
||||
"embedding_model": None,
|
||||
"embedding_model_provider": None,
|
||||
"collection_binding_id": None,
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
||||
)
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
def test_update_internal_dataset_indexing_technique_to_high_quality(
|
||||
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
|
||||
):
|
||||
"""Test updating internal dataset indexing technique to high_quality."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Mock embedding model
|
||||
embedding_model = DatasetUpdateTestDataFactory.create_embedding_model_mock()
|
||||
mock_internal_provider_dependencies[
|
||||
"model_manager"
|
||||
].return_value.get_model_instance.return_value = embedding_model
|
||||
|
||||
# Mock collection binding
|
||||
binding = DatasetUpdateTestDataFactory.create_collection_binding_mock()
|
||||
mock_internal_provider_dependencies["get_binding"].return_value = binding
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify embedding model was validated
|
||||
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once_with(
|
||||
tenant_id=mock_internal_provider_dependencies["current_user"].current_tenant_id,
|
||||
provider="openai",
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model="text-embedding-ada-002",
|
||||
)
|
||||
|
||||
# Verify collection binding was retrieved
|
||||
mock_internal_provider_dependencies["get_binding"].assert_called_once_with("openai", "text-embedding-ada-002")
|
||||
|
||||
# Verify database update was called with correct data
|
||||
expected_filtered_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"embedding_model_provider": "openai",
|
||||
"collection_binding_id": "binding-456",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
||||
)
|
||||
|
||||
# Verify vector index task was triggered
|
||||
mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "add")
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
# ==================== Embedding Model Update Tests ====================
|
||||
|
||||
def test_update_internal_dataset_keep_existing_embedding_model(self, mock_dataset_service_dependencies):
|
||||
"""Test updating internal dataset without changing embedding model."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id="binding-123",
|
||||
)
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify database update was called with existing embedding model preserved
|
||||
expected_filtered_data = {
|
||||
"name": "new_name",
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"collection_binding_id": "binding-123",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
||||
)
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
def test_update_internal_dataset_embedding_model_update(
|
||||
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
|
||||
):
|
||||
"""Test updating internal dataset with new embedding model."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
)
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Mock embedding model
|
||||
embedding_model = DatasetUpdateTestDataFactory.create_embedding_model_mock("text-embedding-3-small")
|
||||
mock_internal_provider_dependencies[
|
||||
"model_manager"
|
||||
].return_value.get_model_instance.return_value = embedding_model
|
||||
|
||||
# Mock collection binding
|
||||
binding = DatasetUpdateTestDataFactory.create_collection_binding_mock("binding-789")
|
||||
mock_internal_provider_dependencies["get_binding"].return_value = binding
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify embedding model was validated
|
||||
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once_with(
|
||||
tenant_id=mock_internal_provider_dependencies["current_user"].current_tenant_id,
|
||||
provider="openai",
|
||||
model_type=ModelType.TEXT_EMBEDDING,
|
||||
model="text-embedding-3-small",
|
||||
)
|
||||
|
||||
# Verify collection binding was retrieved
|
||||
mock_internal_provider_dependencies["get_binding"].assert_called_once_with("openai", "text-embedding-3-small")
|
||||
|
||||
# Verify database update was called with correct data
|
||||
expected_filtered_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model": "text-embedding-3-small",
|
||||
"embedding_model_provider": "openai",
|
||||
"collection_binding_id": "binding-789",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
||||
)
|
||||
|
||||
# Verify vector index task was triggered
|
||||
mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "update")
|
||||
|
||||
# Verify regenerate summary index task was triggered (when embedding_model changes)
|
||||
mock_internal_provider_dependencies["regenerate_task"].delay.assert_called_once_with(
|
||||
"dataset-123",
|
||||
regenerate_reason="embedding_model_changed",
|
||||
regenerate_vectors_only=True,
|
||||
)
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
def test_update_internal_dataset_no_indexing_technique_change(self, mock_dataset_service_dependencies):
|
||||
"""Test updating internal dataset without changing indexing technique."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(
|
||||
provider="vendor",
|
||||
indexing_technique="high_quality",
|
||||
embedding_model_provider="openai",
|
||||
embedding_model="text-embedding-ada-002",
|
||||
collection_binding_id="binding-123",
|
||||
)
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
update_data = {
|
||||
"name": "new_name",
|
||||
"indexing_technique": "high_quality", # Same as current
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
result = DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
# Verify database update was called with correct data
|
||||
expected_filtered_data = {
|
||||
"name": "new_name",
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model_provider": "openai",
|
||||
"embedding_model": "text-embedding-ada-002",
|
||||
"collection_binding_id": "binding-123",
|
||||
"retrieval_model": "new_model",
|
||||
"updated_by": user.id,
|
||||
"updated_at": mock_dataset_service_dependencies["current_time"],
|
||||
}
|
||||
|
||||
self._assert_database_update_called(
|
||||
mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data
|
||||
)
|
||||
|
||||
# Verify return value
|
||||
assert result == dataset
|
||||
|
||||
# ==================== Error Handling Tests ====================
|
||||
|
||||
def test_update_dataset_not_found_error(self, mock_dataset_service_dependencies):
|
||||
"""Test error when dataset is not found."""
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = None
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
update_data = {"name": "new_name"}
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(ValueError) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
assert "Dataset not found" in str(context.value)
|
||||
|
||||
def test_update_dataset_permission_error(self, mock_dataset_service_dependencies):
|
||||
"""Test error when user doesn't have permission."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock()
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
mock_dataset_service_dependencies["check_permission"].side_effect = NoPermissionError("No permission")
|
||||
|
||||
update_data = {"name": "new_name"}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(NoPermissionError):
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
def test_update_internal_dataset_embedding_model_error(
|
||||
self, mock_dataset_service_dependencies, mock_internal_provider_dependencies
|
||||
):
|
||||
"""Test error when embedding model is not available."""
|
||||
dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy")
|
||||
mock_dataset_service_dependencies["get_dataset"].return_value = dataset
|
||||
|
||||
user = DatasetUpdateTestDataFactory.create_user_mock()
|
||||
|
||||
# Mock model manager to raise error
|
||||
mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.side_effect = Exception(
|
||||
"No Embedding Model available"
|
||||
)
|
||||
|
||||
update_data = {
|
||||
"indexing_technique": "high_quality",
|
||||
"embedding_model_provider": "invalid_provider",
|
||||
"embedding_model": "invalid_model",
|
||||
"retrieval_model": "new_model",
|
||||
}
|
||||
|
||||
mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False
|
||||
|
||||
with pytest.raises(Exception) as context:
|
||||
DatasetService.update_dataset("dataset-123", update_data, user)
|
||||
|
||||
assert "No Embedding Model available".lower() in str(context.value).lower()
|
||||
Reference in New Issue
Block a user