test: migrate Dataset/Document property tests to testcontainers (#32487)

Co-authored-by: KinomotoMio <200703522+KinomotoMio@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
木之本澪
2026-02-24 00:23:48 +08:00
committed by GitHub
parent f76ee7cfa4
commit 737575d637
2 changed files with 272 additions and 151 deletions

View File

@@ -0,0 +1,271 @@
"""
Integration tests for Dataset and Document model properties using testcontainers.
These tests validate database-backed model properties (total_documents, word_count, etc.)
without mocking SQLAlchemy queries, ensuring real query behavior against PostgreSQL.
"""
from collections.abc import Generator
from uuid import uuid4
import pytest
from sqlalchemy.orm import Session
from models.dataset import Dataset, Document, DocumentSegment
class TestDatasetDocumentProperties:
"""Integration tests for Dataset and Document model properties."""
@pytest.fixture(autouse=True)
def _auto_rollback(self, db_session_with_containers: Session) -> Generator[None, None, None]:
"""Automatically rollback session changes after each test."""
yield
db_session_with_containers.rollback()
def test_dataset_with_documents_relationship(self, db_session_with_containers: Session) -> None:
"""Test dataset can track its documents."""
tenant_id = str(uuid4())
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
for i in range(3):
doc = Document(
tenant_id=tenant_id,
dataset_id=dataset.id,
position=i + 1,
data_source_type="upload_file",
batch="batch_001",
name=f"doc_{i}.pdf",
created_from="web",
created_by=created_by,
)
db_session_with_containers.add(doc)
db_session_with_containers.flush()
assert dataset.total_documents == 3
def test_dataset_available_documents_count(self, db_session_with_containers: Session) -> None:
"""Test dataset can count available documents."""
tenant_id = str(uuid4())
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
doc_available = Document(
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
batch="batch_001",
name="available.pdf",
created_from="web",
created_by=created_by,
indexing_status="completed",
enabled=True,
archived=False,
)
doc_pending = Document(
tenant_id=tenant_id,
dataset_id=dataset.id,
position=2,
data_source_type="upload_file",
batch="batch_001",
name="pending.pdf",
created_from="web",
created_by=created_by,
indexing_status="waiting",
enabled=True,
archived=False,
)
doc_disabled = Document(
tenant_id=tenant_id,
dataset_id=dataset.id,
position=3,
data_source_type="upload_file",
batch="batch_001",
name="disabled.pdf",
created_from="web",
created_by=created_by,
indexing_status="completed",
enabled=False,
archived=False,
)
db_session_with_containers.add_all([doc_available, doc_pending, doc_disabled])
db_session_with_containers.flush()
assert dataset.total_available_documents == 1
def test_dataset_word_count_aggregation(self, db_session_with_containers: Session) -> None:
"""Test dataset can aggregate word count from documents."""
tenant_id = str(uuid4())
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
for i, wc in enumerate([2000, 3000]):
doc = Document(
tenant_id=tenant_id,
dataset_id=dataset.id,
position=i + 1,
data_source_type="upload_file",
batch="batch_001",
name=f"doc_{i}.pdf",
created_from="web",
created_by=created_by,
word_count=wc,
)
db_session_with_containers.add(doc)
db_session_with_containers.flush()
assert dataset.word_count == 5000
def test_dataset_available_segment_count(self, db_session_with_containers: Session) -> None:
"""Test Dataset.available_segment_count counts completed and enabled segments."""
tenant_id = str(uuid4())
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
doc = Document(
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
batch="batch_001",
name="doc.pdf",
created_from="web",
created_by=created_by,
)
db_session_with_containers.add(doc)
db_session_with_containers.flush()
for i in range(2):
seg = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset.id,
document_id=doc.id,
position=i + 1,
content=f"segment {i}",
word_count=100,
tokens=50,
status="completed",
enabled=True,
created_by=created_by,
)
db_session_with_containers.add(seg)
seg_waiting = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset.id,
document_id=doc.id,
position=3,
content="waiting segment",
word_count=100,
tokens=50,
status="waiting",
enabled=True,
created_by=created_by,
)
db_session_with_containers.add(seg_waiting)
db_session_with_containers.flush()
assert dataset.available_segment_count == 2
def test_document_segment_count_property(self, db_session_with_containers: Session) -> None:
"""Test document can count its segments."""
tenant_id = str(uuid4())
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
doc = Document(
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
batch="batch_001",
name="doc.pdf",
created_from="web",
created_by=created_by,
)
db_session_with_containers.add(doc)
db_session_with_containers.flush()
for i in range(3):
seg = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset.id,
document_id=doc.id,
position=i + 1,
content=f"segment {i}",
word_count=100,
tokens=50,
created_by=created_by,
)
db_session_with_containers.add(seg)
db_session_with_containers.flush()
assert doc.segment_count == 3
def test_document_hit_count_aggregation(self, db_session_with_containers: Session) -> None:
"""Test document can aggregate hit count from segments."""
tenant_id = str(uuid4())
created_by = str(uuid4())
dataset = Dataset(
tenant_id=tenant_id, name="Test Dataset", data_source_type="upload_file", created_by=created_by
)
db_session_with_containers.add(dataset)
db_session_with_containers.flush()
doc = Document(
tenant_id=tenant_id,
dataset_id=dataset.id,
position=1,
data_source_type="upload_file",
batch="batch_001",
name="doc.pdf",
created_from="web",
created_by=created_by,
)
db_session_with_containers.add(doc)
db_session_with_containers.flush()
for i, hits in enumerate([10, 15]):
seg = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset.id,
document_id=doc.id,
position=i + 1,
content=f"segment {i}",
word_count=100,
tokens=50,
hit_count=hits,
created_by=created_by,
)
db_session_with_containers.add(seg)
db_session_with_containers.flush()
assert doc.hit_count == 25

View File

@@ -12,7 +12,7 @@ This test suite covers:
import json
import pickle
from datetime import UTC, datetime
from unittest.mock import MagicMock, patch
from unittest.mock import patch
from uuid import uuid4
from models.dataset import (
@@ -954,156 +954,6 @@ class TestChildChunk:
assert child_chunk.index_node_hash == index_node_hash
class TestDatasetDocumentCascadeDeletes:
"""Test suite for Dataset-Document cascade delete operations."""
def test_dataset_with_documents_relationship(self):
"""Test dataset can track its documents."""
# Arrange
dataset_id = str(uuid4())
dataset = Dataset(
tenant_id=str(uuid4()),
name="Test Dataset",
data_source_type="upload_file",
created_by=str(uuid4()),
)
dataset.id = dataset_id
# Mock the database session query
mock_query = MagicMock()
mock_query.where.return_value.scalar.return_value = 3
with patch("models.dataset.db.session.query", return_value=mock_query):
# Act
total_docs = dataset.total_documents
# Assert
assert total_docs == 3
def test_dataset_available_documents_count(self):
"""Test dataset can count available documents."""
# Arrange
dataset_id = str(uuid4())
dataset = Dataset(
tenant_id=str(uuid4()),
name="Test Dataset",
data_source_type="upload_file",
created_by=str(uuid4()),
)
dataset.id = dataset_id
# Mock the database session query
mock_query = MagicMock()
mock_query.where.return_value.scalar.return_value = 2
with patch("models.dataset.db.session.query", return_value=mock_query):
# Act
available_docs = dataset.total_available_documents
# Assert
assert available_docs == 2
def test_dataset_word_count_aggregation(self):
"""Test dataset can aggregate word count from documents."""
# Arrange
dataset_id = str(uuid4())
dataset = Dataset(
tenant_id=str(uuid4()),
name="Test Dataset",
data_source_type="upload_file",
created_by=str(uuid4()),
)
dataset.id = dataset_id
# Mock the database session query
mock_query = MagicMock()
mock_query.with_entities.return_value.where.return_value.scalar.return_value = 5000
with patch("models.dataset.db.session.query", return_value=mock_query):
# Act
total_words = dataset.word_count
# Assert
assert total_words == 5000
def test_dataset_available_segment_count(self):
"""Test dataset can count available segments."""
# Arrange
dataset_id = str(uuid4())
dataset = Dataset(
tenant_id=str(uuid4()),
name="Test Dataset",
data_source_type="upload_file",
created_by=str(uuid4()),
)
dataset.id = dataset_id
# Mock the database session query
mock_query = MagicMock()
mock_query.where.return_value.scalar.return_value = 15
with patch("models.dataset.db.session.query", return_value=mock_query):
# Act
segment_count = dataset.available_segment_count
# Assert
assert segment_count == 15
def test_document_segment_count_property(self):
"""Test document can count its segments."""
# Arrange
document_id = str(uuid4())
document = Document(
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
position=1,
data_source_type="upload_file",
batch="batch_001",
name="test.pdf",
created_from="web",
created_by=str(uuid4()),
)
document.id = document_id
# Mock the database session query
mock_query = MagicMock()
mock_query.where.return_value.count.return_value = 10
with patch("models.dataset.db.session.query", return_value=mock_query):
# Act
segment_count = document.segment_count
# Assert
assert segment_count == 10
def test_document_hit_count_aggregation(self):
"""Test document can aggregate hit count from segments."""
# Arrange
document_id = str(uuid4())
document = Document(
tenant_id=str(uuid4()),
dataset_id=str(uuid4()),
position=1,
data_source_type="upload_file",
batch="batch_001",
name="test.pdf",
created_from="web",
created_by=str(uuid4()),
)
document.id = document_id
# Mock the database session query
mock_query = MagicMock()
mock_query.with_entities.return_value.where.return_value.scalar.return_value = 25
with patch("models.dataset.db.session.query", return_value=mock_query):
# Act
hit_count = document.hit_count
# Assert
assert hit_count == 25
class TestDocumentSegmentNavigation:
"""Test suite for DocumentSegment navigation properties."""