diff --git a/api/tests/test_containers_integration_tests/models/test_dataset_models.py b/api/tests/test_containers_integration_tests/models/test_dataset_models.py new file mode 100644 index 0000000000..d2c3e1e58e --- /dev/null +++ b/api/tests/test_containers_integration_tests/models/test_dataset_models.py @@ -0,0 +1,271 @@ +""" +Integration tests for Dataset and Document model properties using testcontainers. + +These tests validate database-backed model properties (total_documents, word_count, etc.) +without mocking SQLAlchemy queries, ensuring real query behavior against PostgreSQL. +""" + +from collections.abc import Generator +from uuid import uuid4 + +import pytest +from sqlalchemy.orm import Session + +from models.dataset import Dataset, Document, DocumentSegment + + +class TestDatasetDocumentProperties: + """Integration tests for Dataset and Document model properties.""" + + @pytest.fixture(autouse=True) + def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]: + """Automatically rollback session changes after each test.""" + yield + db_session_with_containers.rollback() + + def test_dataset_with_documents_relationship(self, db_session_with_containers: Session) -> None: + """Test dataset can track its documents.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + dataset = Dataset( + tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + for i in range(3): + doc = Document( + tenant_id=tenant_id, + dataset_id=dataset.id, + position=i + 1, + data_source_type="upload_file", + batch="batch_001", + name=f"doc_{i}.pdf", + created_from="web", + created_by=created_by, + ) + db_session_with_containers.add(doc) + db_session_with_containers.flush() + + assert dataset.total_documents == 3 + + def test_dataset_available_documents_count(self, db_session_with_containers: Session) -> None: + """Test dataset can count available documents.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + dataset = Dataset( + tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + doc_available = Document( + tenant_id=tenant_id, + dataset_id=dataset.id, + position=1, + data_source_type="upload_file", + batch="batch_001", + name="available.pdf", + created_from="web", + created_by=created_by, + indexing_status="completed", + enabled=True, + archived=False, + ) + doc_pending = Document( + tenant_id=tenant_id, + dataset_id=dataset.id, + position=2, + data_source_type="upload_file", + batch="batch_001", + name="pending.pdf", + created_from="web", + created_by=created_by, + indexing_status="waiting", + enabled=True, + archived=False, + ) + doc_disabled = Document( + tenant_id=tenant_id, + dataset_id=dataset.id, + position=3, + data_source_type="upload_file", + batch="batch_001", + name="disabled.pdf", + created_from="web", + created_by=created_by, + indexing_status="completed", + enabled=False, + archived=False, + ) + db_session_with_containers.add_all([doc_available, doc_pending, doc_disabled]) + db_session_with_containers.flush() + + assert dataset.total_available_documents == 1 + + def test_dataset_word_count_aggregation(self, db_session_with_containers: Session) -> None: + """Test dataset can aggregate word count from documents.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + dataset = Dataset( + tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + for i, wc in enumerate([2000, 3000]): + doc = Document( + tenant_id=tenant_id, + dataset_id=dataset.id, + position=i + 1, + data_source_type="upload_file", + batch="batch_001", + name=f"doc_{i}.pdf", + created_from="web", + created_by=created_by, + word_count=wc, + ) + db_session_with_containers.add(doc) + db_session_with_containers.flush() + + assert dataset.word_count == 5000 + + def test_dataset_available_segment_count(self, db_session_with_containers: Session) -> None: + """Test Dataset.available_segment_count counts completed and enabled segments.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + dataset = Dataset( + tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + doc = Document( + tenant_id=tenant_id, + dataset_id=dataset.id, + position=1, + data_source_type="upload_file", + batch="batch_001", + name="doc.pdf", + created_from="web", + created_by=created_by, + ) + db_session_with_containers.add(doc) + db_session_with_containers.flush() + + for i in range(2): + seg = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset.id, + document_id=doc.id, + position=i + 1, + content=f"segment {i}", + word_count=100, + tokens=50, + status="completed", + enabled=True, + created_by=created_by, + ) + db_session_with_containers.add(seg) + + seg_waiting = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset.id, + document_id=doc.id, + position=3, + content="waiting segment", + word_count=100, + tokens=50, + status="waiting", + enabled=True, + created_by=created_by, + ) + db_session_with_containers.add(seg_waiting) + db_session_with_containers.flush() + + assert dataset.available_segment_count == 2 + + def test_document_segment_count_property(self, db_session_with_containers: Session) -> None: + """Test document can count its segments.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + dataset = Dataset( + tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + doc = Document( + tenant_id=tenant_id, + dataset_id=dataset.id, + position=1, + data_source_type="upload_file", + batch="batch_001", + name="doc.pdf", + created_from="web", + created_by=created_by, + ) + db_session_with_containers.add(doc) + db_session_with_containers.flush() + + for i in range(3): + seg = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset.id, + document_id=doc.id, + position=i + 1, + content=f"segment {i}", + word_count=100, + tokens=50, + created_by=created_by, + ) + db_session_with_containers.add(seg) + db_session_with_containers.flush() + + assert doc.segment_count == 3 + + def test_document_hit_count_aggregation(self, db_session_with_containers: Session) -> None: + """Test document can aggregate hit count from segments.""" + tenant_id = str(uuid4()) + created_by = str(uuid4()) + + dataset = Dataset( + tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by + ) + db_session_with_containers.add(dataset) + db_session_with_containers.flush() + + doc = Document( + tenant_id=tenant_id, + dataset_id=dataset.id, + position=1, + data_source_type="upload_file", + batch="batch_001", + name="doc.pdf", + created_from="web", + created_by=created_by, + ) + db_session_with_containers.add(doc) + db_session_with_containers.flush() + + for i, hits in enumerate([10, 15]): + seg = DocumentSegment( + tenant_id=tenant_id, + dataset_id=dataset.id, + document_id=doc.id, + position=i + 1, + content=f"segment {i}", + word_count=100, + tokens=50, + hit_count=hits, + created_by=created_by, + ) + db_session_with_containers.add(seg) + db_session_with_containers.flush() + + assert doc.hit_count == 25 diff --git a/api/tests/unit_tests/models/test_dataset_models.py b/api/tests/unit_tests/models/test_dataset_models.py index 2322c556e2..c0e912fa1e 100644 --- a/api/tests/unit_tests/models/test_dataset_models.py +++ b/api/tests/unit_tests/models/test_dataset_models.py @@ -12,7 +12,7 @@ This test suite covers: import json import pickle from datetime import UTC, datetime -from unittest.mock import MagicMock, patch +from unittest.mock import patch from uuid import uuid4 from models.dataset import ( @@ -954,156 +954,6 @@ class TestChildChunk: assert child_chunk.index_node_hash == index_node_hash -class TestDatasetDocumentCascadeDeletes: - """Test suite for Dataset-Document cascade delete operations.""" - - def test_dataset_with_documents_relationship(self): - """Test dataset can track its documents.""" - # Arrange - dataset_id = str(uuid4()) - dataset = Dataset( - tenant_id=str(uuid4()), - name="Test Dataset", - data_source_type="upload_file", - created_by=str(uuid4()), - ) - dataset.id = dataset_id - - # Mock the database session query - mock_query = MagicMock() - mock_query.where.return_value.scalar.return_value = 3 - - with patch("models.dataset.db.session.query", return_value=mock_query): - # Act - total_docs = dataset.total_documents - - # Assert - assert total_docs == 3 - - def test_dataset_available_documents_count(self): - """Test dataset can count available documents.""" - # Arrange - dataset_id = str(uuid4()) - dataset = Dataset( - tenant_id=str(uuid4()), - name="Test Dataset", - data_source_type="upload_file", - created_by=str(uuid4()), - ) - dataset.id = dataset_id - - # Mock the database session query - mock_query = MagicMock() - mock_query.where.return_value.scalar.return_value = 2 - - with patch("models.dataset.db.session.query", return_value=mock_query): - # Act - available_docs = dataset.total_available_documents - - # Assert - assert available_docs == 2 - - def test_dataset_word_count_aggregation(self): - """Test dataset can aggregate word count from documents.""" - # Arrange - dataset_id = str(uuid4()) - dataset = Dataset( - tenant_id=str(uuid4()), - name="Test Dataset", - data_source_type="upload_file", - created_by=str(uuid4()), - ) - dataset.id = dataset_id - - # Mock the database session query - mock_query = MagicMock() - mock_query.with_entities.return_value.where.return_value.scalar.return_value = 5000 - - with patch("models.dataset.db.session.query", return_value=mock_query): - # Act - total_words = dataset.word_count - - # Assert - assert total_words == 5000 - - def test_dataset_available_segment_count(self): - """Test dataset can count available segments.""" - # Arrange - dataset_id = str(uuid4()) - dataset = Dataset( - tenant_id=str(uuid4()), - name="Test Dataset", - data_source_type="upload_file", - created_by=str(uuid4()), - ) - dataset.id = dataset_id - - # Mock the database session query - mock_query = MagicMock() - mock_query.where.return_value.scalar.return_value = 15 - - with patch("models.dataset.db.session.query", return_value=mock_query): - # Act - segment_count = dataset.available_segment_count - - # Assert - assert segment_count == 15 - - def test_document_segment_count_property(self): - """Test document can count its segments.""" - # Arrange - document_id = str(uuid4()) - document = Document( - tenant_id=str(uuid4()), - dataset_id=str(uuid4()), - position=1, - data_source_type="upload_file", - batch="batch_001", - name="test.pdf", - created_from="web", - created_by=str(uuid4()), - ) - document.id = document_id - - # Mock the database session query - mock_query = MagicMock() - mock_query.where.return_value.count.return_value = 10 - - with patch("models.dataset.db.session.query", return_value=mock_query): - # Act - segment_count = document.segment_count - - # Assert - assert segment_count == 10 - - def test_document_hit_count_aggregation(self): - """Test document can aggregate hit count from segments.""" - # Arrange - document_id = str(uuid4()) - document = Document( - tenant_id=str(uuid4()), - dataset_id=str(uuid4()), - position=1, - data_source_type="upload_file", - batch="batch_001", - name="test.pdf", - created_from="web", - created_by=str(uuid4()), - ) - document.id = document_id - - # Mock the database session query - mock_query = MagicMock() - mock_query.with_entities.return_value.where.return_value.scalar.return_value = 25 - - with patch("models.dataset.db.session.query", return_value=mock_query): - # Act - hit_count = document.hit_count - - # Assert - assert hit_count == 25 - - class TestDocumentSegmentNavigation: """Test suite for DocumentSegment navigation properties."""