From 5cb1b53b47cd44b442aa4e018cea15472f1a7fe9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=9C=A8=E4=B9=8B=E6=9C=AC=E6=BE=AA?= Date: Fri, 27 Feb 2026 06:10:15 +0800 Subject: [PATCH] test: migrate dataset service update-dataset SQL tests to testcontainers (#32533) Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com> Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> --- .../test_dataset_service_update_dataset.py | 529 ++++++++++++++ .../test_dataset_service_update_dataset.py | 661 ------------------ 2 files changed, 529 insertions(+), 661 deletions(-) create mode 100644 api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py delete mode 100644 api/tests/unit_tests/services/test_dataset_service_update_dataset.py diff --git a/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py new file mode 100644 index 0000000000..608fc76bd2 --- /dev/null +++ b/api/tests/test_containers_integration_tests/services/test_dataset_service_update_dataset.py @@ -0,0 +1,529 @@ +from unittest.mock import Mock, patch +from uuid import uuid4 + +import pytest + +from core.model_runtime.entities.model_entities import ModelType +from extensions.ext_database import db +from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole +from models.dataset import Dataset, ExternalKnowledgeBindings +from services.dataset_service import DatasetService +from services.errors.account import NoPermissionError + + +class DatasetUpdateTestDataFactory: + """Factory class for creating real test data for dataset update integration tests.""" + + @staticmethod + def create_account_with_tenant(role: TenantAccountRole = TenantAccountRole.OWNER) -> tuple[Account, Tenant]: + """Create a real account and tenant with the given role.""" + account = Account( + email=f"{uuid4()}@example.com", + name=f"user-{uuid4()}", + interface_language="en-US", + status="active", + ) + db.session.add(account) + db.session.commit() + + tenant = Tenant(name=f"tenant-{account.id}", status="normal") + db.session.add(tenant) + db.session.commit() + + join = TenantAccountJoin( + tenant_id=tenant.id, + account_id=account.id, + role=role, + current=True, + ) + db.session.add(join) + db.session.commit() + + account.current_tenant = tenant + return account, tenant + + @staticmethod + def create_dataset( + tenant_id: str, + created_by: str, + provider: str = "vendor", + name: str = "old_name", + description: str = "old_description", + indexing_technique: str = "high_quality", + retrieval_model: str = "old_model", + permission: str = "only_me", + embedding_model_provider: str | None = None, + embedding_model: str | None = None, + collection_binding_id: str | None = None, + ) -> Dataset: + """Create a real dataset.""" + dataset = Dataset( + tenant_id=tenant_id, + name=name, + description=description, + data_source_type="upload_file", + indexing_technique=indexing_technique, + created_by=created_by, + provider=provider, + retrieval_model=retrieval_model, + permission=permission, + embedding_model_provider=embedding_model_provider, + embedding_model=embedding_model, + collection_binding_id=collection_binding_id, + ) + db.session.add(dataset) + db.session.commit() + return dataset + + @staticmethod + def create_external_binding( + tenant_id: str, + dataset_id: str, + created_by: str, + external_knowledge_id: str = "old_knowledge_id", + external_knowledge_api_id: str | None = None, + ) -> ExternalKnowledgeBindings: + """Create a real external knowledge binding.""" + if external_knowledge_api_id is None: + external_knowledge_api_id = str(uuid4()) + binding = ExternalKnowledgeBindings( + tenant_id=tenant_id, + dataset_id=dataset_id, + created_by=created_by, + external_knowledge_id=external_knowledge_id, + external_knowledge_api_id=external_knowledge_api_id, + ) + db.session.add(binding) + db.session.commit() + return binding + + +class TestDatasetServiceUpdateDataset: + """ + Comprehensive integration tests for DatasetService.update_dataset method. + + This test suite covers all supported scenarios including: + - External dataset updates + - Internal dataset updates with different indexing techniques + - Embedding model updates + - Permission checks + - Error conditions and edge cases + """ + + # ==================== External Dataset Tests ==================== + + def test_update_external_dataset_success(self, db_session_with_containers): + """Test successful update of external dataset.""" + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + dataset = DatasetUpdateTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=user.id, + provider="external", + name="old_name", + description="old_description", + retrieval_model="old_model", + ) + binding = DatasetUpdateTestDataFactory.create_external_binding( + tenant_id=tenant.id, + dataset_id=dataset.id, + created_by=user.id, + ) + binding_id = binding.id + db.session.expunge(binding) + + update_data = { + "name": "new_name", + "description": "new_description", + "external_retrieval_model": "new_model", + "permission": "only_me", + "external_knowledge_id": "new_knowledge_id", + "external_knowledge_api_id": str(uuid4()), + } + + result = DatasetService.update_dataset(dataset.id, update_data, user) + + db.session.refresh(dataset) + updated_binding = db.session.query(ExternalKnowledgeBindings).filter_by(id=binding_id).first() + + assert dataset.name == "new_name" + assert dataset.description == "new_description" + assert dataset.retrieval_model == "new_model" + assert updated_binding is not None + assert updated_binding.external_knowledge_id == "new_knowledge_id" + assert updated_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"] + assert result.id == dataset.id + + def test_update_external_dataset_missing_knowledge_id_error(self, db_session_with_containers): + """Test error when external knowledge id is missing.""" + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + dataset = DatasetUpdateTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=user.id, + provider="external", + ) + DatasetUpdateTestDataFactory.create_external_binding( + tenant_id=tenant.id, + dataset_id=dataset.id, + created_by=user.id, + ) + + update_data = {"name": "new_name", "external_knowledge_api_id": str(uuid4())} + + with pytest.raises(ValueError) as context: + DatasetService.update_dataset(dataset.id, update_data, user) + + assert "External knowledge id is required" in str(context.value) + db.session.rollback() + + def test_update_external_dataset_missing_api_id_error(self, db_session_with_containers): + """Test error when external knowledge api id is missing.""" + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + dataset = DatasetUpdateTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=user.id, + provider="external", + ) + DatasetUpdateTestDataFactory.create_external_binding( + tenant_id=tenant.id, + dataset_id=dataset.id, + created_by=user.id, + ) + + update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"} + + with pytest.raises(ValueError) as context: + DatasetService.update_dataset(dataset.id, update_data, user) + + assert "External knowledge api id is required" in str(context.value) + db.session.rollback() + + def test_update_external_dataset_binding_not_found_error(self, db_session_with_containers): + """Test error when external knowledge binding is not found.""" + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + dataset = DatasetUpdateTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=user.id, + provider="external", + ) + + update_data = { + "name": "new_name", + "external_knowledge_id": "knowledge_id", + "external_knowledge_api_id": str(uuid4()), + } + + with pytest.raises(ValueError) as context: + DatasetService.update_dataset(dataset.id, update_data, user) + + assert "External knowledge binding not found" in str(context.value) + db.session.rollback() + + # ==================== Internal Dataset Basic Tests ==================== + + def test_update_internal_dataset_basic_success(self, db_session_with_containers): + """Test successful update of internal dataset with basic fields.""" + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + existing_binding_id = str(uuid4()) + dataset = DatasetUpdateTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=user.id, + provider="vendor", + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embedding-ada-002", + collection_binding_id=existing_binding_id, + ) + + update_data = { + "name": "new_name", + "description": "new_description", + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + } + + result = DatasetService.update_dataset(dataset.id, update_data, user) + db.session.refresh(dataset) + + assert dataset.name == "new_name" + assert dataset.description == "new_description" + assert dataset.indexing_technique == "high_quality" + assert dataset.retrieval_model == "new_model" + assert dataset.embedding_model_provider == "openai" + assert dataset.embedding_model == "text-embedding-ada-002" + assert result.id == dataset.id + + def test_update_internal_dataset_filter_none_values(self, db_session_with_containers): + """Test that None values are filtered out except for description field.""" + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + existing_binding_id = str(uuid4()) + dataset = DatasetUpdateTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=user.id, + provider="vendor", + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embedding-ada-002", + collection_binding_id=existing_binding_id, + ) + + update_data = { + "name": "new_name", + "description": None, + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + "embedding_model_provider": None, + "embedding_model": None, + } + + result = DatasetService.update_dataset(dataset.id, update_data, user) + db.session.refresh(dataset) + + assert dataset.name == "new_name" + assert dataset.description is None + assert dataset.embedding_model_provider == "openai" + assert dataset.embedding_model == "text-embedding-ada-002" + assert dataset.retrieval_model == "new_model" + assert result.id == dataset.id + + # ==================== Indexing Technique Switch Tests ==================== + + def test_update_internal_dataset_indexing_technique_to_economy(self, db_session_with_containers): + """Test updating internal dataset indexing technique to economy.""" + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + existing_binding_id = str(uuid4()) + dataset = DatasetUpdateTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=user.id, + provider="vendor", + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embedding-ada-002", + collection_binding_id=existing_binding_id, + ) + + update_data = { + "indexing_technique": "economy", + "retrieval_model": "new_model", + } + + with patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task: + result = DatasetService.update_dataset(dataset.id, update_data, user) + mock_task.delay.assert_called_once_with(dataset.id, "remove") + + db.session.refresh(dataset) + assert dataset.indexing_technique == "economy" + assert dataset.embedding_model is None + assert dataset.embedding_model_provider is None + assert dataset.collection_binding_id is None + assert dataset.retrieval_model == "new_model" + assert result.id == dataset.id + + def test_update_internal_dataset_indexing_technique_to_high_quality(self, db_session_with_containers): + """Test updating internal dataset indexing technique to high_quality.""" + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + dataset = DatasetUpdateTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=user.id, + provider="vendor", + indexing_technique="economy", + ) + + embedding_model = Mock() + embedding_model.model = "text-embedding-ada-002" + embedding_model.provider = "openai" + + binding = Mock() + binding.id = str(uuid4()) + + update_data = { + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-ada-002", + "retrieval_model": "new_model", + } + + with ( + patch("services.dataset_service.current_user", user), + patch("services.dataset_service.ModelManager") as mock_model_manager, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" + ) as mock_get_binding, + patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task, + ): + mock_model_manager.return_value.get_model_instance.return_value = embedding_model + mock_get_binding.return_value = binding + + result = DatasetService.update_dataset(dataset.id, update_data, user) + + mock_model_manager.return_value.get_model_instance.assert_called_once_with( + tenant_id=tenant.id, + provider="openai", + model_type=ModelType.TEXT_EMBEDDING, + model="text-embedding-ada-002", + ) + mock_get_binding.assert_called_once_with("openai", "text-embedding-ada-002") + mock_task.delay.assert_called_once_with(dataset.id, "add") + + db.session.refresh(dataset) + assert dataset.indexing_technique == "high_quality" + assert dataset.embedding_model == "text-embedding-ada-002" + assert dataset.embedding_model_provider == "openai" + assert dataset.collection_binding_id == binding.id + assert dataset.retrieval_model == "new_model" + assert result.id == dataset.id + + # ==================== Embedding Model Update Tests ==================== + + def test_update_internal_dataset_keep_existing_embedding_model_when_indexing_technique_unchanged( + self, db_session_with_containers + ): + """Test preserving embedding settings when indexing technique remains unchanged.""" + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + existing_binding_id = str(uuid4()) + dataset = DatasetUpdateTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=user.id, + provider="vendor", + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embedding-ada-002", + collection_binding_id=existing_binding_id, + ) + + update_data = { + "name": "new_name", + "indexing_technique": "high_quality", + "retrieval_model": "new_model", + } + + result = DatasetService.update_dataset(dataset.id, update_data, user) + db.session.refresh(dataset) + + assert dataset.name == "new_name" + assert dataset.indexing_technique == "high_quality" + assert dataset.embedding_model_provider == "openai" + assert dataset.embedding_model == "text-embedding-ada-002" + assert dataset.collection_binding_id == existing_binding_id + assert dataset.retrieval_model == "new_model" + assert result.id == dataset.id + + def test_update_internal_dataset_embedding_model_update(self, db_session_with_containers): + """Test updating internal dataset with new embedding model.""" + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + existing_binding_id = str(uuid4()) + dataset = DatasetUpdateTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=user.id, + provider="vendor", + indexing_technique="high_quality", + embedding_model_provider="openai", + embedding_model="text-embedding-ada-002", + collection_binding_id=existing_binding_id, + ) + + embedding_model = Mock() + embedding_model.model = "text-embedding-3-small" + embedding_model.provider = "openai" + + binding = Mock() + binding.id = str(uuid4()) + + update_data = { + "indexing_technique": "high_quality", + "embedding_model_provider": "openai", + "embedding_model": "text-embedding-3-small", + "retrieval_model": "new_model", + } + + with ( + patch("services.dataset_service.current_user", user), + patch("services.dataset_service.ModelManager") as mock_model_manager, + patch( + "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" + ) as mock_get_binding, + patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task, + patch("services.dataset_service.regenerate_summary_index_task") as mock_regenerate_task, + ): + mock_model_manager.return_value.get_model_instance.return_value = embedding_model + mock_get_binding.return_value = binding + + result = DatasetService.update_dataset(dataset.id, update_data, user) + + mock_model_manager.return_value.get_model_instance.assert_called_once_with( + tenant_id=tenant.id, + provider="openai", + model_type=ModelType.TEXT_EMBEDDING, + model="text-embedding-3-small", + ) + mock_get_binding.assert_called_once_with("openai", "text-embedding-3-small") + mock_task.delay.assert_called_once_with(dataset.id, "update") + mock_regenerate_task.delay.assert_called_once_with( + dataset.id, + regenerate_reason="embedding_model_changed", + regenerate_vectors_only=True, + ) + + db.session.refresh(dataset) + assert dataset.embedding_model == "text-embedding-3-small" + assert dataset.embedding_model_provider == "openai" + assert dataset.collection_binding_id == binding.id + assert dataset.retrieval_model == "new_model" + assert result.id == dataset.id + + # ==================== Error Handling Tests ==================== + + def test_update_dataset_not_found_error(self, db_session_with_containers): + """Test error when dataset is not found.""" + user, _ = DatasetUpdateTestDataFactory.create_account_with_tenant() + update_data = {"name": "new_name"} + + with pytest.raises(ValueError) as context: + DatasetService.update_dataset(str(uuid4()), update_data, user) + + assert "Dataset not found" in str(context.value) + + def test_update_dataset_permission_error(self, db_session_with_containers): + """Test error when user doesn't have permission.""" + owner, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.OWNER) + outsider, _ = DatasetUpdateTestDataFactory.create_account_with_tenant(role=TenantAccountRole.NORMAL) + dataset = DatasetUpdateTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=owner.id, + provider="vendor", + permission="only_me", + ) + + update_data = {"name": "new_name"} + + with pytest.raises(NoPermissionError): + DatasetService.update_dataset(dataset.id, update_data, outsider) + + def test_update_internal_dataset_embedding_model_error(self, db_session_with_containers): + """Test error when embedding model is not available.""" + user, tenant = DatasetUpdateTestDataFactory.create_account_with_tenant() + dataset = DatasetUpdateTestDataFactory.create_dataset( + tenant_id=tenant.id, + created_by=user.id, + provider="vendor", + indexing_technique="economy", + ) + + update_data = { + "indexing_technique": "high_quality", + "embedding_model_provider": "invalid_provider", + "embedding_model": "invalid_model", + "retrieval_model": "new_model", + } + + with ( + patch("services.dataset_service.current_user", user), + patch("services.dataset_service.ModelManager") as mock_model_manager, + ): + mock_model_manager.return_value.get_model_instance.side_effect = Exception("No Embedding Model available") + + with pytest.raises(Exception) as context: + DatasetService.update_dataset(dataset.id, update_data, user) + + assert "No Embedding Model available".lower() in str(context.value).lower() diff --git a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py b/api/tests/unit_tests/services/test_dataset_service_update_dataset.py deleted file mode 100644 index 08818945e3..0000000000 --- a/api/tests/unit_tests/services/test_dataset_service_update_dataset.py +++ /dev/null @@ -1,661 +0,0 @@ -import datetime -from typing import Any - -# Mock redis_client before importing dataset_service -from unittest.mock import Mock, create_autospec, patch - -import pytest - -from core.model_runtime.entities.model_entities import ModelType -from models.account import Account -from models.dataset import Dataset, ExternalKnowledgeBindings -from services.dataset_service import DatasetService -from services.errors.account import NoPermissionError - - -class DatasetUpdateTestDataFactory: - """Factory class for creating test data and mock objects for dataset update tests.""" - - @staticmethod - def create_dataset_mock( - dataset_id: str = "dataset-123", - provider: str = "vendor", - name: str = "old_name", - description: str = "old_description", - indexing_technique: str = "high_quality", - retrieval_model: str = "old_model", - embedding_model_provider: str | None = None, - embedding_model: str | None = None, - collection_binding_id: str | None = None, - **kwargs, - ) -> Mock: - """Create a mock dataset with specified attributes.""" - dataset = Mock(spec=Dataset) - dataset.id = dataset_id - dataset.provider = provider - dataset.name = name - dataset.description = description - dataset.indexing_technique = indexing_technique - dataset.retrieval_model = retrieval_model - dataset.embedding_model_provider = embedding_model_provider - dataset.embedding_model = embedding_model - dataset.collection_binding_id = collection_binding_id - for key, value in kwargs.items(): - setattr(dataset, key, value) - return dataset - - @staticmethod - def create_user_mock(user_id: str = "user-789") -> Mock: - """Create a mock user.""" - user = Mock() - user.id = user_id - return user - - @staticmethod - def create_external_binding_mock( - external_knowledge_id: str = "old_knowledge_id", external_knowledge_api_id: str = "old_api_id" - ) -> Mock: - """Create a mock external knowledge binding.""" - binding = Mock(spec=ExternalKnowledgeBindings) - binding.external_knowledge_id = external_knowledge_id - binding.external_knowledge_api_id = external_knowledge_api_id - return binding - - @staticmethod - def create_embedding_model_mock(model: str = "text-embedding-ada-002", provider: str = "openai") -> Mock: - """Create a mock embedding model.""" - embedding_model = Mock() - embedding_model.model = model - embedding_model.provider = provider - return embedding_model - - @staticmethod - def create_collection_binding_mock(binding_id: str = "binding-456") -> Mock: - """Create a mock collection binding.""" - binding = Mock() - binding.id = binding_id - return binding - - @staticmethod - def create_current_user_mock(tenant_id: str = "tenant-123") -> Mock: - """Create a mock current user.""" - current_user = create_autospec(Account, instance=True) - current_user.current_tenant_id = tenant_id - return current_user - - -class TestDatasetServiceUpdateDataset: - """ - Comprehensive unit tests for DatasetService.update_dataset method. - - This test suite covers all supported scenarios including: - - External dataset updates - - Internal dataset updates with different indexing techniques - - Embedding model updates - - Permission checks - - Error conditions and edge cases - """ - - @pytest.fixture - def mock_dataset_service_dependencies(self): - """Common mock setup for dataset service dependencies.""" - with ( - patch("services.dataset_service.DatasetService.get_dataset") as mock_get_dataset, - patch("services.dataset_service.DatasetService.check_dataset_permission") as mock_check_perm, - patch("extensions.ext_database.db.session") as mock_db, - patch("services.dataset_service.naive_utc_now") as mock_naive_utc_now, - patch("services.dataset_service.DatasetService._has_dataset_same_name") as has_dataset_same_name, - ): - current_time = datetime.datetime(2023, 1, 1, 12, 0, 0) - mock_naive_utc_now.return_value = current_time - - yield { - "get_dataset": mock_get_dataset, - "check_permission": mock_check_perm, - "db_session": mock_db, - "naive_utc_now": mock_naive_utc_now, - "current_time": current_time, - "has_dataset_same_name": has_dataset_same_name, - } - - @pytest.fixture - def mock_external_provider_dependencies(self): - """Mock setup for external provider tests.""" - with patch("services.dataset_service.Session") as mock_session: - from extensions.ext_database import db - - with patch.object(db.__class__, "engine", new_callable=Mock): - session_mock = Mock() - mock_session.return_value.__enter__.return_value = session_mock - yield session_mock - - @pytest.fixture - def mock_internal_provider_dependencies(self): - """Mock setup for internal provider tests.""" - with ( - patch("services.dataset_service.ModelManager") as mock_model_manager, - patch( - "services.dataset_service.DatasetCollectionBindingService.get_dataset_collection_binding" - ) as mock_get_binding, - patch("services.dataset_service.deal_dataset_vector_index_task") as mock_task, - patch("services.dataset_service.regenerate_summary_index_task") as mock_regenerate_task, - patch( - "services.dataset_service.current_user", create_autospec(Account, instance=True) - ) as mock_current_user, - ): - mock_current_user.current_tenant_id = "tenant-123" - yield { - "model_manager": mock_model_manager, - "get_binding": mock_get_binding, - "task": mock_task, - "regenerate_task": mock_regenerate_task, - "current_user": mock_current_user, - } - - def _assert_database_update_called(self, mock_db, dataset_id: str, expected_updates: dict[str, Any]): - """Helper method to verify database update calls.""" - mock_db.query.return_value.filter_by.return_value.update.assert_called_once_with(expected_updates) - mock_db.commit.assert_called_once() - - def _assert_external_dataset_update(self, mock_dataset, mock_binding, update_data: dict[str, Any]): - """Helper method to verify external dataset updates.""" - assert mock_dataset.name == update_data.get("name", mock_dataset.name) - assert mock_dataset.description == update_data.get("description", mock_dataset.description) - assert mock_dataset.retrieval_model == update_data.get("external_retrieval_model", mock_dataset.retrieval_model) - - if "external_knowledge_id" in update_data: - assert mock_binding.external_knowledge_id == update_data["external_knowledge_id"] - if "external_knowledge_api_id" in update_data: - assert mock_binding.external_knowledge_api_id == update_data["external_knowledge_api_id"] - - # ==================== External Dataset Tests ==================== - - def test_update_external_dataset_success( - self, mock_dataset_service_dependencies, mock_external_provider_dependencies - ): - """Test successful update of external dataset.""" - dataset = DatasetUpdateTestDataFactory.create_dataset_mock( - provider="external", name="old_name", description="old_description", retrieval_model="old_model" - ) - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - user = DatasetUpdateTestDataFactory.create_user_mock() - binding = DatasetUpdateTestDataFactory.create_external_binding_mock() - - # Mock external knowledge binding query - mock_external_provider_dependencies.query.return_value.filter_by.return_value.first.return_value = binding - - update_data = { - "name": "new_name", - "description": "new_description", - "external_retrieval_model": "new_model", - "permission": "only_me", - "external_knowledge_id": "new_knowledge_id", - "external_knowledge_api_id": "new_api_id", - } - - mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False - result = DatasetService.update_dataset("dataset-123", update_data, user) - - mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) - - # Verify dataset and binding updates - self._assert_external_dataset_update(dataset, binding, update_data) - - # Verify database operations - mock_db = mock_dataset_service_dependencies["db_session"] - mock_db.add.assert_any_call(dataset) - mock_db.add.assert_any_call(binding) - mock_db.commit.assert_called_once() - - # Verify return value - assert result == dataset - - def test_update_external_dataset_missing_knowledge_id_error(self, mock_dataset_service_dependencies): - """Test error when external knowledge id is missing.""" - dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external") - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - user = DatasetUpdateTestDataFactory.create_user_mock() - update_data = {"name": "new_name", "external_knowledge_api_id": "api_id"} - mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False - - with pytest.raises(ValueError) as context: - DatasetService.update_dataset("dataset-123", update_data, user) - - assert "External knowledge id is required" in str(context.value) - - def test_update_external_dataset_missing_api_id_error(self, mock_dataset_service_dependencies): - """Test error when external knowledge api id is missing.""" - dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external") - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - user = DatasetUpdateTestDataFactory.create_user_mock() - update_data = {"name": "new_name", "external_knowledge_id": "knowledge_id"} - mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False - - with pytest.raises(ValueError) as context: - DatasetService.update_dataset("dataset-123", update_data, user) - - assert "External knowledge api id is required" in str(context.value) - - def test_update_external_dataset_binding_not_found_error( - self, mock_dataset_service_dependencies, mock_external_provider_dependencies - ): - """Test error when external knowledge binding is not found.""" - dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="external") - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - user = DatasetUpdateTestDataFactory.create_user_mock() - - # Mock external knowledge binding query returning None - mock_external_provider_dependencies.query.return_value.filter_by.return_value.first.return_value = None - - update_data = { - "name": "new_name", - "external_knowledge_id": "knowledge_id", - "external_knowledge_api_id": "api_id", - } - mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False - - with pytest.raises(ValueError) as context: - DatasetService.update_dataset("dataset-123", update_data, user) - - assert "External knowledge binding not found" in str(context.value) - - # ==================== Internal Dataset Basic Tests ==================== - - def test_update_internal_dataset_basic_success(self, mock_dataset_service_dependencies): - """Test successful update of internal dataset with basic fields.""" - dataset = DatasetUpdateTestDataFactory.create_dataset_mock( - provider="vendor", - indexing_technique="high_quality", - embedding_model_provider="openai", - embedding_model="text-embedding-ada-002", - collection_binding_id="binding-123", - ) - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - user = DatasetUpdateTestDataFactory.create_user_mock() - - update_data = { - "name": "new_name", - "description": "new_description", - "indexing_technique": "high_quality", - "retrieval_model": "new_model", - "embedding_model_provider": "openai", - "embedding_model": "text-embedding-ada-002", - } - - mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False - result = DatasetService.update_dataset("dataset-123", update_data, user) - - # Verify permission check was called - mock_dataset_service_dependencies["check_permission"].assert_called_once_with(dataset, user) - - # Verify database update was called with correct filtered data - expected_filtered_data = { - "name": "new_name", - "description": "new_description", - "indexing_technique": "high_quality", - "retrieval_model": "new_model", - "embedding_model_provider": "openai", - "embedding_model": "text-embedding-ada-002", - "updated_by": user.id, - "updated_at": mock_dataset_service_dependencies["current_time"], - } - - self._assert_database_update_called( - mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data - ) - - # Verify return value - assert result == dataset - - def test_update_internal_dataset_filter_none_values(self, mock_dataset_service_dependencies): - """Test that None values are filtered out except for description field.""" - dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="high_quality") - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - user = DatasetUpdateTestDataFactory.create_user_mock() - - update_data = { - "name": "new_name", - "description": None, # Should be included - "indexing_technique": "high_quality", - "retrieval_model": "new_model", - "embedding_model_provider": None, # Should be filtered out - "embedding_model": None, # Should be filtered out - } - - mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False - - result = DatasetService.update_dataset("dataset-123", update_data, user) - - # Verify database update was called with filtered data - expected_filtered_data = { - "name": "new_name", - "description": None, # Description should be included even if None - "indexing_technique": "high_quality", - "retrieval_model": "new_model", - "updated_by": user.id, - "updated_at": mock_dataset_service_dependencies["current_time"], - } - - actual_call_args = mock_dataset_service_dependencies[ - "db_session" - ].query.return_value.filter_by.return_value.update.call_args[0][0] - # Remove timestamp for comparison as it's dynamic - del actual_call_args["updated_at"] - del expected_filtered_data["updated_at"] - - assert actual_call_args == expected_filtered_data - - # Verify return value - assert result == dataset - - # ==================== Indexing Technique Switch Tests ==================== - - def test_update_internal_dataset_indexing_technique_to_economy( - self, mock_dataset_service_dependencies, mock_internal_provider_dependencies - ): - """Test updating internal dataset indexing technique to economy.""" - dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="high_quality") - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - user = DatasetUpdateTestDataFactory.create_user_mock() - - update_data = {"indexing_technique": "economy", "retrieval_model": "new_model"} - mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False - - result = DatasetService.update_dataset("dataset-123", update_data, user) - - # Verify database update was called with embedding model fields cleared - expected_filtered_data = { - "indexing_technique": "economy", - "embedding_model": None, - "embedding_model_provider": None, - "collection_binding_id": None, - "retrieval_model": "new_model", - "updated_by": user.id, - "updated_at": mock_dataset_service_dependencies["current_time"], - } - - self._assert_database_update_called( - mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data - ) - - # Verify return value - assert result == dataset - - def test_update_internal_dataset_indexing_technique_to_high_quality( - self, mock_dataset_service_dependencies, mock_internal_provider_dependencies - ): - """Test updating internal dataset indexing technique to high_quality.""" - dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy") - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - user = DatasetUpdateTestDataFactory.create_user_mock() - - # Mock embedding model - embedding_model = DatasetUpdateTestDataFactory.create_embedding_model_mock() - mock_internal_provider_dependencies[ - "model_manager" - ].return_value.get_model_instance.return_value = embedding_model - - # Mock collection binding - binding = DatasetUpdateTestDataFactory.create_collection_binding_mock() - mock_internal_provider_dependencies["get_binding"].return_value = binding - - update_data = { - "indexing_technique": "high_quality", - "embedding_model_provider": "openai", - "embedding_model": "text-embedding-ada-002", - "retrieval_model": "new_model", - } - mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False - - result = DatasetService.update_dataset("dataset-123", update_data, user) - - # Verify embedding model was validated - mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once_with( - tenant_id=mock_internal_provider_dependencies["current_user"].current_tenant_id, - provider="openai", - model_type=ModelType.TEXT_EMBEDDING, - model="text-embedding-ada-002", - ) - - # Verify collection binding was retrieved - mock_internal_provider_dependencies["get_binding"].assert_called_once_with("openai", "text-embedding-ada-002") - - # Verify database update was called with correct data - expected_filtered_data = { - "indexing_technique": "high_quality", - "embedding_model": "text-embedding-ada-002", - "embedding_model_provider": "openai", - "collection_binding_id": "binding-456", - "retrieval_model": "new_model", - "updated_by": user.id, - "updated_at": mock_dataset_service_dependencies["current_time"], - } - - self._assert_database_update_called( - mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data - ) - - # Verify vector index task was triggered - mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "add") - - # Verify return value - assert result == dataset - - # ==================== Embedding Model Update Tests ==================== - - def test_update_internal_dataset_keep_existing_embedding_model(self, mock_dataset_service_dependencies): - """Test updating internal dataset without changing embedding model.""" - dataset = DatasetUpdateTestDataFactory.create_dataset_mock( - provider="vendor", - indexing_technique="high_quality", - embedding_model_provider="openai", - embedding_model="text-embedding-ada-002", - collection_binding_id="binding-123", - ) - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - user = DatasetUpdateTestDataFactory.create_user_mock() - - update_data = {"name": "new_name", "indexing_technique": "high_quality", "retrieval_model": "new_model"} - mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False - - result = DatasetService.update_dataset("dataset-123", update_data, user) - - # Verify database update was called with existing embedding model preserved - expected_filtered_data = { - "name": "new_name", - "indexing_technique": "high_quality", - "embedding_model_provider": "openai", - "embedding_model": "text-embedding-ada-002", - "collection_binding_id": "binding-123", - "retrieval_model": "new_model", - "updated_by": user.id, - "updated_at": mock_dataset_service_dependencies["current_time"], - } - - self._assert_database_update_called( - mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data - ) - - # Verify return value - assert result == dataset - - def test_update_internal_dataset_embedding_model_update( - self, mock_dataset_service_dependencies, mock_internal_provider_dependencies - ): - """Test updating internal dataset with new embedding model.""" - dataset = DatasetUpdateTestDataFactory.create_dataset_mock( - provider="vendor", - indexing_technique="high_quality", - embedding_model_provider="openai", - embedding_model="text-embedding-ada-002", - ) - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - user = DatasetUpdateTestDataFactory.create_user_mock() - - # Mock embedding model - embedding_model = DatasetUpdateTestDataFactory.create_embedding_model_mock("text-embedding-3-small") - mock_internal_provider_dependencies[ - "model_manager" - ].return_value.get_model_instance.return_value = embedding_model - - # Mock collection binding - binding = DatasetUpdateTestDataFactory.create_collection_binding_mock("binding-789") - mock_internal_provider_dependencies["get_binding"].return_value = binding - - update_data = { - "indexing_technique": "high_quality", - "embedding_model_provider": "openai", - "embedding_model": "text-embedding-3-small", - "retrieval_model": "new_model", - } - mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False - - result = DatasetService.update_dataset("dataset-123", update_data, user) - - # Verify embedding model was validated - mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.assert_called_once_with( - tenant_id=mock_internal_provider_dependencies["current_user"].current_tenant_id, - provider="openai", - model_type=ModelType.TEXT_EMBEDDING, - model="text-embedding-3-small", - ) - - # Verify collection binding was retrieved - mock_internal_provider_dependencies["get_binding"].assert_called_once_with("openai", "text-embedding-3-small") - - # Verify database update was called with correct data - expected_filtered_data = { - "indexing_technique": "high_quality", - "embedding_model": "text-embedding-3-small", - "embedding_model_provider": "openai", - "collection_binding_id": "binding-789", - "retrieval_model": "new_model", - "updated_by": user.id, - "updated_at": mock_dataset_service_dependencies["current_time"], - } - - self._assert_database_update_called( - mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data - ) - - # Verify vector index task was triggered - mock_internal_provider_dependencies["task"].delay.assert_called_once_with("dataset-123", "update") - - # Verify regenerate summary index task was triggered (when embedding_model changes) - mock_internal_provider_dependencies["regenerate_task"].delay.assert_called_once_with( - "dataset-123", - regenerate_reason="embedding_model_changed", - regenerate_vectors_only=True, - ) - - # Verify return value - assert result == dataset - - def test_update_internal_dataset_no_indexing_technique_change(self, mock_dataset_service_dependencies): - """Test updating internal dataset without changing indexing technique.""" - dataset = DatasetUpdateTestDataFactory.create_dataset_mock( - provider="vendor", - indexing_technique="high_quality", - embedding_model_provider="openai", - embedding_model="text-embedding-ada-002", - collection_binding_id="binding-123", - ) - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - user = DatasetUpdateTestDataFactory.create_user_mock() - - update_data = { - "name": "new_name", - "indexing_technique": "high_quality", # Same as current - "retrieval_model": "new_model", - } - mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False - - result = DatasetService.update_dataset("dataset-123", update_data, user) - - # Verify database update was called with correct data - expected_filtered_data = { - "name": "new_name", - "indexing_technique": "high_quality", - "embedding_model_provider": "openai", - "embedding_model": "text-embedding-ada-002", - "collection_binding_id": "binding-123", - "retrieval_model": "new_model", - "updated_by": user.id, - "updated_at": mock_dataset_service_dependencies["current_time"], - } - - self._assert_database_update_called( - mock_dataset_service_dependencies["db_session"], "dataset-123", expected_filtered_data - ) - - # Verify return value - assert result == dataset - - # ==================== Error Handling Tests ==================== - - def test_update_dataset_not_found_error(self, mock_dataset_service_dependencies): - """Test error when dataset is not found.""" - mock_dataset_service_dependencies["get_dataset"].return_value = None - - user = DatasetUpdateTestDataFactory.create_user_mock() - update_data = {"name": "new_name"} - mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False - - with pytest.raises(ValueError) as context: - DatasetService.update_dataset("dataset-123", update_data, user) - - assert "Dataset not found" in str(context.value) - - def test_update_dataset_permission_error(self, mock_dataset_service_dependencies): - """Test error when user doesn't have permission.""" - dataset = DatasetUpdateTestDataFactory.create_dataset_mock() - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - user = DatasetUpdateTestDataFactory.create_user_mock() - mock_dataset_service_dependencies["check_permission"].side_effect = NoPermissionError("No permission") - - update_data = {"name": "new_name"} - - mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False - - with pytest.raises(NoPermissionError): - DatasetService.update_dataset("dataset-123", update_data, user) - - def test_update_internal_dataset_embedding_model_error( - self, mock_dataset_service_dependencies, mock_internal_provider_dependencies - ): - """Test error when embedding model is not available.""" - dataset = DatasetUpdateTestDataFactory.create_dataset_mock(provider="vendor", indexing_technique="economy") - mock_dataset_service_dependencies["get_dataset"].return_value = dataset - - user = DatasetUpdateTestDataFactory.create_user_mock() - - # Mock model manager to raise error - mock_internal_provider_dependencies["model_manager"].return_value.get_model_instance.side_effect = Exception( - "No Embedding Model available" - ) - - update_data = { - "indexing_technique": "high_quality", - "embedding_model_provider": "invalid_provider", - "embedding_model": "invalid_model", - "retrieval_model": "new_model", - } - - mock_dataset_service_dependencies["has_dataset_same_name"].return_value = False - - with pytest.raises(Exception) as context: - DatasetService.update_dataset("dataset-123", update_data, user) - - assert "No Embedding Model available".lower() in str(context.value).lower()