Compare commits

...

2 Commits

Author SHA1 Message Date
Renzo
c2428361c4 refactor: select in dataset_service (DocumentService class) (#34528)
Some checks are pending
autofix.ci / autofix (push) Waiting to run
Build and Push API & Web / build (api, {{defaultContext}}:api, Dockerfile, DIFY_API_IMAGE_NAME, linux/amd64, ubuntu-latest, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, {{defaultContext}}:api, Dockerfile, DIFY_API_IMAGE_NAME, linux/arm64, ubuntu-24.04-arm, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, {{defaultContext}}, web/Dockerfile, DIFY_WEB_IMAGE_NAME, linux/amd64, ubuntu-latest, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, {{defaultContext}}, web/Dockerfile, DIFY_WEB_IMAGE_NAME, linux/arm64, ubuntu-24.04-arm, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
Main CI Pipeline / Skip Duplicate Checks (push) Waiting to run
Main CI Pipeline / Check Changed Files (push) Blocked by required conditions
Main CI Pipeline / Run API Tests (push) Blocked by required conditions
Main CI Pipeline / Skip API Tests (push) Blocked by required conditions
Main CI Pipeline / API Tests (push) Blocked by required conditions
Main CI Pipeline / Run Web Tests (push) Blocked by required conditions
Main CI Pipeline / Skip Web Tests (push) Blocked by required conditions
Main CI Pipeline / Web Tests (push) Blocked by required conditions
Main CI Pipeline / Run Web Full-Stack E2E (push) Blocked by required conditions
Main CI Pipeline / Skip Web Full-Stack E2E (push) Blocked by required conditions
Main CI Pipeline / Web Full-Stack E2E (push) Blocked by required conditions
Main CI Pipeline / Style Check (push) Blocked by required conditions
Main CI Pipeline / Run VDB Tests (push) Blocked by required conditions
Main CI Pipeline / Skip VDB Tests (push) Blocked by required conditions
Main CI Pipeline / VDB Tests (push) Blocked by required conditions
Main CI Pipeline / Run DB Migration Test (push) Blocked by required conditions
Main CI Pipeline / Skip DB Migration Test (push) Blocked by required conditions
Main CI Pipeline / DB Migration Test (push) Blocked by required conditions
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 22:52:01 +00:00
Renzo
68e4d13f36 refactor: select in annotation_service (#34503)
Co-authored-by: Asuka Minato <i@asukaminato.eu.org>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-04-03 22:47:22 +00:00
4 changed files with 229 additions and 554 deletions

View File

@@ -6,7 +6,7 @@ import pandas as pd
logger = logging.getLogger(__name__)
from typing import TypedDict
from sqlalchemy import or_, select
from sqlalchemy import delete, or_, select, update
from werkzeug.datastructures import FileStorage
from werkzeug.exceptions import NotFound
@@ -51,10 +51,8 @@ class AppAnnotationService:
def up_insert_app_annotation_from_message(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
@@ -66,7 +64,9 @@ class AppAnnotationService:
if args.get("message_id"):
message_id = str(args["message_id"])
message = db.session.query(Message).where(Message.id == message_id, Message.app_id == app.id).first()
message = db.session.scalar(
select(Message).where(Message.id == message_id, Message.app_id == app.id).limit(1)
)
if not message:
raise NotFound("Message Not Exists.")
@@ -95,7 +95,9 @@ class AppAnnotationService:
db.session.add(annotation)
db.session.commit()
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
)
assert current_tenant_id is not None
if annotation_setting:
add_annotation_to_index_task.delay(
@@ -151,10 +153,8 @@ class AppAnnotationService:
def get_annotation_list_by_app_id(cls, app_id: str, page: int, limit: int, keyword: str):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
@@ -193,20 +193,17 @@ class AppAnnotationService:
"""
# get app info
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
annotations = (
db.session.query(MessageAnnotation)
annotations = db.session.scalars(
select(MessageAnnotation)
.where(MessageAnnotation.app_id == app_id)
.order_by(MessageAnnotation.created_at.desc())
.all()
)
).all()
# Sanitize CSV-injectable fields to prevent formula injection
for annotation in annotations:
@@ -223,10 +220,8 @@ class AppAnnotationService:
def insert_app_annotation_directly(cls, args: dict, app_id: str) -> MessageAnnotation:
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
@@ -242,7 +237,9 @@ class AppAnnotationService:
db.session.add(annotation)
db.session.commit()
# if annotation reply is enabled , add annotation to index
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
)
if annotation_setting:
add_annotation_to_index_task.delay(
annotation.id,
@@ -257,16 +254,14 @@ class AppAnnotationService:
def update_app_annotation_directly(cls, args: dict, app_id: str, annotation_id: str):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
annotation = db.session.get(MessageAnnotation, annotation_id)
if not annotation:
raise NotFound("Annotation not found")
@@ -280,8 +275,8 @@ class AppAnnotationService:
db.session.commit()
# if annotation reply is enabled , add annotation to index
app_annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
app_annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
)
if app_annotation_setting:
@@ -299,16 +294,14 @@ class AppAnnotationService:
def delete_app_annotation(cls, app_id: str, annotation_id: str):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
annotation = db.session.get(MessageAnnotation, annotation_id)
if not annotation:
raise NotFound("Annotation not found")
@@ -324,8 +317,8 @@ class AppAnnotationService:
db.session.commit()
# if annotation reply is enabled , delete annotation index
app_annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
app_annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
)
if app_annotation_setting:
@@ -337,22 +330,19 @@ class AppAnnotationService:
def delete_app_annotations_in_batch(cls, app_id: str, annotation_ids: list[str]):
# get app info
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
# Fetch annotations and their settings in a single query
annotations_to_delete = (
db.session.query(MessageAnnotation, AppAnnotationSetting)
annotations_to_delete = db.session.execute(
select(MessageAnnotation, AppAnnotationSetting)
.outerjoin(AppAnnotationSetting, MessageAnnotation.app_id == AppAnnotationSetting.app_id)
.where(MessageAnnotation.id.in_(annotation_ids))
.all()
)
).all()
if not annotations_to_delete:
return {"deleted_count": 0}
@@ -361,9 +351,9 @@ class AppAnnotationService:
annotation_ids_to_delete = [annotation.id for annotation, _ in annotations_to_delete]
# Step 2: Bulk delete hit histories in a single query
db.session.query(AppAnnotationHitHistory).where(
AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete)
).delete(synchronize_session=False)
db.session.execute(
delete(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id.in_(annotation_ids_to_delete))
)
# Step 3: Trigger async tasks for search index deletion
for annotation, annotation_setting in annotations_to_delete:
@@ -373,11 +363,10 @@ class AppAnnotationService:
)
# Step 4: Bulk delete annotations in a single query
deleted_count = (
db.session.query(MessageAnnotation)
.where(MessageAnnotation.id.in_(annotation_ids_to_delete))
.delete(synchronize_session=False)
delete_result = db.session.execute(
delete(MessageAnnotation).where(MessageAnnotation.id.in_(annotation_ids_to_delete))
)
deleted_count = getattr(delete_result, "rowcount", 0)
db.session.commit()
return {"deleted_count": deleted_count}
@@ -398,10 +387,8 @@ class AppAnnotationService:
# get app info
current_user, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
@@ -522,16 +509,14 @@ class AppAnnotationService:
def get_annotation_hit_histories(cls, app_id: str, annotation_id: str, page, limit):
_, current_tenant_id = current_account_with_tenant()
# get app info
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
annotation = db.session.get(MessageAnnotation, annotation_id)
if not annotation:
raise NotFound("Annotation not found")
@@ -551,7 +536,7 @@ class AppAnnotationService:
@classmethod
def get_annotation_by_id(cls, annotation_id: str) -> MessageAnnotation | None:
annotation = db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).first()
annotation = db.session.get(MessageAnnotation, annotation_id)
if not annotation:
return None
@@ -571,8 +556,10 @@ class AppAnnotationService:
score: float,
):
# add hit count to annotation
db.session.query(MessageAnnotation).where(MessageAnnotation.id == annotation_id).update(
{MessageAnnotation.hit_count: MessageAnnotation.hit_count + 1}, synchronize_session=False
db.session.execute(
update(MessageAnnotation)
.where(MessageAnnotation.id == annotation_id)
.values(hit_count=MessageAnnotation.hit_count + 1)
)
annotation_hit_history = AppAnnotationHitHistory(
@@ -593,16 +580,16 @@ class AppAnnotationService:
def get_app_annotation_setting_by_app_id(cls, app_id: str) -> AnnotationSettingDict | AnnotationSettingDisabledDict:
_, current_tenant_id = current_account_with_tenant()
# get app info
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
annotation_setting = db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
)
if annotation_setting:
collection_binding_detail = annotation_setting.collection_binding_detail
if collection_binding_detail:
@@ -630,22 +617,20 @@ class AppAnnotationService:
) -> AnnotationSettingDict:
current_user, current_tenant_id = current_account_with_tenant()
# get app info
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
annotation_setting = (
db.session.query(AppAnnotationSetting)
annotation_setting = db.session.scalar(
select(AppAnnotationSetting)
.where(
AppAnnotationSetting.app_id == app_id,
AppAnnotationSetting.id == annotation_setting_id,
)
.first()
.limit(1)
)
if not annotation_setting:
raise NotFound("App annotation not found")
@@ -678,26 +663,26 @@ class AppAnnotationService:
@classmethod
def clear_all_annotations(cls, app_id: str):
_, current_tenant_id = current_account_with_tenant()
app = (
db.session.query(App)
.where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal")
.first()
app = db.session.scalar(
select(App).where(App.id == app_id, App.tenant_id == current_tenant_id, App.status == "normal").limit(1)
)
if not app:
raise NotFound("App not found")
# if annotation reply is enabled, delete annotation index
app_annotation_setting = (
db.session.query(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).first()
app_annotation_setting = db.session.scalar(
select(AppAnnotationSetting).where(AppAnnotationSetting.app_id == app_id).limit(1)
)
annotations_query = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id)
for annotation in annotations_query.yield_per(100):
annotation_hit_histories_query = db.session.query(AppAnnotationHitHistory).where(
AppAnnotationHitHistory.annotation_id == annotation.id
)
for annotation_hit_history in annotation_hit_histories_query.yield_per(100):
annotations_iter = db.session.scalars(
select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)
).yield_per(100)
for annotation in annotations_iter:
hit_histories_iter = db.session.scalars(
select(AppAnnotationHitHistory).where(AppAnnotationHitHistory.annotation_id == annotation.id)
).yield_per(100)
for annotation_hit_history in hit_histories_iter:
db.session.delete(annotation_hit_history)
# if annotation reply is enabled, delete annotation index

View File

@@ -1400,8 +1400,8 @@ class DocumentService:
@staticmethod
def get_document(dataset_id: str, document_id: str | None = None) -> Document | None:
if document_id:
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
document = db.session.scalar(
select(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).limit(1)
)
return document
else:
@@ -1630,7 +1630,7 @@ class DocumentService:
@staticmethod
def get_document_by_id(document_id: str) -> Document | None:
document = db.session.query(Document).where(Document.id == document_id).first()
document = db.session.get(Document, document_id)
return document
@@ -1695,7 +1695,7 @@ class DocumentService:
@staticmethod
def get_document_file_detail(file_id: str):
file_detail = db.session.query(UploadFile).where(UploadFile.id == file_id).one_or_none()
file_detail = db.session.get(UploadFile, file_id)
return file_detail
@staticmethod
@@ -1769,9 +1769,11 @@ class DocumentService:
document.name = name
db.session.add(document)
if document.data_source_info_dict and "upload_file_id" in document.data_source_info_dict:
db.session.query(UploadFile).where(
UploadFile.id == document.data_source_info_dict["upload_file_id"]
).update({UploadFile.name: name})
db.session.execute(
update(UploadFile)
.where(UploadFile.id == document.data_source_info_dict["upload_file_id"])
.values(name=name)
)
db.session.commit()
@@ -1858,8 +1860,8 @@ class DocumentService:
@staticmethod
def get_documents_position(dataset_id):
document = (
db.session.query(Document).filter_by(dataset_id=dataset_id).order_by(Document.position.desc()).first()
document = db.session.scalar(
select(Document).where(Document.dataset_id == dataset_id).order_by(Document.position.desc()).limit(1)
)
if document:
return document.position + 1
@@ -2016,28 +2018,28 @@ class DocumentService:
if not knowledge_config.data_source.info_list.file_info_list:
raise ValueError("File source info is required")
upload_file_list = knowledge_config.data_source.info_list.file_info_list.file_ids
files = (
db.session.query(UploadFile)
.where(
UploadFile.tenant_id == dataset.tenant_id,
UploadFile.id.in_(upload_file_list),
)
.all()
files = list(
db.session.scalars(
select(UploadFile).where(
UploadFile.tenant_id == dataset.tenant_id,
UploadFile.id.in_(upload_file_list),
)
).all()
)
if len(files) != len(set(upload_file_list)):
raise FileNotExistsError("One or more files not found.")
file_names = [file.name for file in files]
db_documents = (
db.session.query(Document)
.where(
Document.dataset_id == dataset.id,
Document.tenant_id == current_user.current_tenant_id,
Document.data_source_type == DataSourceType.UPLOAD_FILE,
Document.enabled == True,
Document.name.in_(file_names),
)
.all()
db_documents = list(
db.session.scalars(
select(Document).where(
Document.dataset_id == dataset.id,
Document.tenant_id == current_user.current_tenant_id,
Document.data_source_type == DataSourceType.UPLOAD_FILE,
Document.enabled == True,
Document.name.in_(file_names),
)
).all()
)
documents_map = {document.name: document for document in db_documents}
for file in files:
@@ -2083,15 +2085,15 @@ class DocumentService:
raise ValueError("No notion info list found.")
exist_page_ids = []
exist_document = {}
documents = (
db.session.query(Document)
.filter_by(
dataset_id=dataset.id,
tenant_id=current_user.current_tenant_id,
data_source_type=DataSourceType.NOTION_IMPORT,
enabled=True,
)
.all()
documents = list(
db.session.scalars(
select(Document).where(
Document.dataset_id == dataset.id,
Document.tenant_id == current_user.current_tenant_id,
Document.data_source_type == DataSourceType.NOTION_IMPORT,
Document.enabled == True,
)
).all()
)
if documents:
for document in documents:
@@ -2522,14 +2524,15 @@ class DocumentService:
assert isinstance(current_user, Account)
documents_count = (
db.session.query(Document)
.where(
Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
Document.tenant_id == current_user.current_tenant_id,
db.session.scalar(
select(func.count(Document.id)).where(
Document.completed_at.isnot(None),
Document.enabled == True,
Document.archived == False,
Document.tenant_id == current_user.current_tenant_id,
)
)
.count()
or 0
)
return documents_count
@@ -2579,10 +2582,10 @@ class DocumentService:
raise ValueError("No file info list found.")
upload_file_list = document_data.data_source.info_list.file_info_list.file_ids
for file_id in upload_file_list:
file = (
db.session.query(UploadFile)
file = db.session.scalar(
select(UploadFile)
.where(UploadFile.tenant_id == dataset.tenant_id, UploadFile.id == file_id)
.first()
.limit(1)
)
# raise error if file not found
@@ -2599,8 +2602,8 @@ class DocumentService:
notion_info_list = document_data.data_source.info_list.notion_info_list
for notion_info in notion_info_list:
workspace_id = notion_info.workspace_id
data_source_binding = (
db.session.query(DataSourceOauthBinding)
data_source_binding = db.session.scalar(
select(DataSourceOauthBinding)
.where(
sa.and_(
DataSourceOauthBinding.tenant_id == current_user.current_tenant_id,
@@ -2609,7 +2612,7 @@ class DocumentService:
DataSourceOauthBinding.source_info["workspace_id"] == f'"{workspace_id}"',
)
)
.first()
.limit(1)
)
if not data_source_binding:
raise ValueError("Data source binding not found.")
@@ -2654,8 +2657,10 @@ class DocumentService:
db.session.commit()
# update document segment
db.session.query(DocumentSegment).filter_by(document_id=document.id).update(
{DocumentSegment.status: SegmentStatus.RE_SEGMENT}
db.session.execute(
update(DocumentSegment)
.where(DocumentSegment.document_id == document.id)
.values(status=SegmentStatus.RE_SEGMENT)
)
db.session.commit()
# trigger async task

View File

@@ -79,10 +79,7 @@ class TestAppAnnotationServiceUpInsert:
patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = None
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(NotFound):
@@ -100,10 +97,7 @@ class TestAppAnnotationServiceUpInsert:
patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = app
# Act & Assert
with pytest.raises(ValueError):
@@ -121,15 +115,7 @@ class TestAppAnnotationServiceUpInsert:
patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
message_query = MagicMock()
message_query.where.return_value = message_query
message_query.first.return_value = None
mock_db.session.query.side_effect = [app_query, message_query]
mock_db.session.scalar.side_effect = [app, None]
# Act & Assert
with pytest.raises(NotFound):
@@ -152,19 +138,7 @@ class TestAppAnnotationServiceUpInsert:
patch("services.annotation_service.db") as mock_db,
patch("services.annotation_service.add_annotation_to_index_task") as mock_task,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
message_query = MagicMock()
message_query.where.return_value = message_query
message_query.first.return_value = message
setting_query = MagicMock()
setting_query.where.return_value = setting_query
setting_query.first.return_value = setting
mock_db.session.query.side_effect = [app_query, message_query, setting_query]
mock_db.session.scalar.side_effect = [app, message, setting]
# Act
result = AppAnnotationService.up_insert_app_annotation_from_message(args, app.id)
@@ -202,19 +176,7 @@ class TestAppAnnotationServiceUpInsert:
patch("services.annotation_service.MessageAnnotation", return_value=annotation_instance) as mock_cls,
patch("services.annotation_service.add_annotation_to_index_task") as mock_task,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
message_query = MagicMock()
message_query.where.return_value = message_query
message_query.first.return_value = message
setting_query = MagicMock()
setting_query.where.return_value = setting_query
setting_query.first.return_value = None
mock_db.session.query.side_effect = [app_query, message_query, setting_query]
mock_db.session.scalar.side_effect = [app, message, None]
# Act
result = AppAnnotationService.up_insert_app_annotation_from_message(args, app.id)
@@ -245,10 +207,7 @@ class TestAppAnnotationServiceUpInsert:
patch("services.annotation_service.current_account_with_tenant", return_value=(current_user, tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = app
# Act & Assert
with pytest.raises(ValueError):
@@ -270,15 +229,7 @@ class TestAppAnnotationServiceUpInsert:
patch("services.annotation_service.MessageAnnotation", return_value=annotation_instance) as mock_cls,
patch("services.annotation_service.add_annotation_to_index_task") as mock_task,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
setting_query = MagicMock()
setting_query.where.return_value = setting_query
setting_query.first.return_value = setting
mock_db.session.query.side_effect = [app_query, setting_query]
mock_db.session.scalar.side_effect = [app, setting]
# Act
result = AppAnnotationService.up_insert_app_annotation_from_message(args, app.id)
@@ -406,10 +357,7 @@ class TestAppAnnotationServiceListAndExport:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = None
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(NotFound):
@@ -427,10 +375,7 @@ class TestAppAnnotationServiceListAndExport:
patch("services.annotation_service.db") as mock_db,
patch("libs.helper.escape_like_pattern", return_value="safe"),
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = app
mock_db.paginate.return_value = pagination
# Act
@@ -451,10 +396,7 @@ class TestAppAnnotationServiceListAndExport:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = app
mock_db.paginate.return_value = pagination
# Act
@@ -481,16 +423,8 @@ class TestAppAnnotationServiceListAndExport:
patch("services.annotation_service.db") as mock_db,
patch("services.annotation_service.CSVSanitizer.sanitize_value", side_effect=lambda v: f"safe:{v}"),
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
annotation_query = MagicMock()
annotation_query.where.return_value = annotation_query
annotation_query.order_by.return_value = annotation_query
annotation_query.all.return_value = [annotation1, annotation2]
mock_db.session.query.side_effect = [app_query, annotation_query]
mock_db.session.scalar.return_value = app
mock_db.session.scalars.return_value.all.return_value = [annotation1, annotation2]
# Act
result = AppAnnotationService.export_annotation_list_by_app_id(app.id)
@@ -511,10 +445,7 @@ class TestAppAnnotationServiceListAndExport:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = None
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(NotFound):
@@ -534,10 +465,7 @@ class TestAppAnnotationServiceDirectManipulation:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = None
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(NotFound):
@@ -554,10 +482,7 @@ class TestAppAnnotationServiceDirectManipulation:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = app
# Act & Assert
with pytest.raises(ValueError):
@@ -579,15 +504,7 @@ class TestAppAnnotationServiceDirectManipulation:
patch("services.annotation_service.MessageAnnotation", return_value=annotation_instance) as mock_cls,
patch("services.annotation_service.add_annotation_to_index_task") as mock_task,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
setting_query = MagicMock()
setting_query.where.return_value = setting_query
setting_query.first.return_value = setting
mock_db.session.query.side_effect = [app_query, setting_query]
mock_db.session.scalar.side_effect = [app, setting]
# Act
result = AppAnnotationService.insert_app_annotation_directly(args, app.id)
@@ -621,15 +538,8 @@ class TestAppAnnotationServiceDirectManipulation:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
annotation_query = MagicMock()
annotation_query.where.return_value = annotation_query
annotation_query.first.return_value = None
mock_db.session.query.side_effect = [app_query, annotation_query]
mock_db.session.scalar.return_value = app
mock_db.session.get.return_value = None
# Act & Assert
with pytest.raises(NotFound):
@@ -645,10 +555,7 @@ class TestAppAnnotationServiceDirectManipulation:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = None
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(NotFound):
@@ -666,15 +573,8 @@ class TestAppAnnotationServiceDirectManipulation:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
annotation_query = MagicMock()
annotation_query.where.return_value = annotation_query
annotation_query.first.return_value = annotation
mock_db.session.query.side_effect = [app_query, annotation_query]
mock_db.session.scalar.return_value = app
mock_db.session.get.return_value = annotation
# Act & Assert
with pytest.raises(ValueError):
@@ -695,19 +595,8 @@ class TestAppAnnotationServiceDirectManipulation:
patch("services.annotation_service.db") as mock_db,
patch("services.annotation_service.update_annotation_to_index_task") as mock_task,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
annotation_query = MagicMock()
annotation_query.where.return_value = annotation_query
annotation_query.first.return_value = annotation
setting_query = MagicMock()
setting_query.where.return_value = setting_query
setting_query.first.return_value = setting
mock_db.session.query.side_effect = [app_query, annotation_query, setting_query]
mock_db.session.scalar.side_effect = [app, setting]
mock_db.session.get.return_value = annotation
# Act
result = AppAnnotationService.update_app_annotation_directly(args, app.id, annotation.id)
@@ -740,22 +629,11 @@ class TestAppAnnotationServiceDirectManipulation:
patch("services.annotation_service.db") as mock_db,
patch("services.annotation_service.delete_annotation_index_task") as mock_task,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
annotation_query = MagicMock()
annotation_query.where.return_value = annotation_query
annotation_query.first.return_value = annotation
setting_query = MagicMock()
setting_query.where.return_value = setting_query
setting_query.first.return_value = setting
mock_db.session.scalar.side_effect = [app, setting]
mock_db.session.get.return_value = annotation
scalars_result = MagicMock()
scalars_result.all.return_value = [history1, history2]
mock_db.session.query.side_effect = [app_query, annotation_query, setting_query]
mock_db.session.scalars.return_value = scalars_result
# Act
@@ -782,10 +660,7 @@ class TestAppAnnotationServiceDirectManipulation:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = None
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(NotFound):
@@ -801,15 +676,8 @@ class TestAppAnnotationServiceDirectManipulation:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
annotation_query = MagicMock()
annotation_query.where.return_value = annotation_query
annotation_query.first.return_value = None
mock_db.session.query.side_effect = [app_query, annotation_query]
mock_db.session.scalar.return_value = app
mock_db.session.get.return_value = None
# Act & Assert
with pytest.raises(NotFound):
@@ -825,16 +693,8 @@ class TestAppAnnotationServiceDirectManipulation:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
annotations_query = MagicMock()
annotations_query.outerjoin.return_value = annotations_query
annotations_query.where.return_value = annotations_query
annotations_query.all.return_value = []
mock_db.session.query.side_effect = [app_query, annotations_query]
mock_db.session.scalar.return_value = app
mock_db.session.execute.return_value.all.return_value = []
# Act
result = AppAnnotationService.delete_app_annotations_in_batch(app.id, ["ann-1"])
@@ -851,10 +711,7 @@ class TestAppAnnotationServiceDirectManipulation:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = None
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(NotFound):
@@ -874,24 +731,14 @@ class TestAppAnnotationServiceDirectManipulation:
patch("services.annotation_service.db") as mock_db,
patch("services.annotation_service.delete_annotation_index_task") as mock_task,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
mock_db.session.scalar.return_value = app
annotations_query = MagicMock()
annotations_query.outerjoin.return_value = annotations_query
annotations_query.where.return_value = annotations_query
annotations_query.all.return_value = [(annotation1, setting), (annotation2, None)]
hit_history_query = MagicMock()
hit_history_query.where.return_value = hit_history_query
hit_history_query.delete.return_value = None
delete_query = MagicMock()
delete_query.where.return_value = delete_query
delete_query.delete.return_value = 2
mock_db.session.query.side_effect = [app_query, annotations_query, hit_history_query, delete_query]
# First execute().all() for multi-column query, subsequent execute() calls for deletes
execute_result_multi = MagicMock()
execute_result_multi.all.return_value = [(annotation1, setting), (annotation2, None)]
execute_result_delete = MagicMock()
execute_result_delete.rowcount = 2
mock_db.session.execute.side_effect = [execute_result_multi, MagicMock(), execute_result_delete]
# Act
result = AppAnnotationService.delete_app_annotations_in_batch(app.id, ["ann-1", "ann-2"])
@@ -915,10 +762,7 @@ class TestAppAnnotationServiceBatchImport:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = None
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(NotFound):
@@ -941,10 +785,7 @@ class TestAppAnnotationServiceBatchImport:
new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1),
),
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = app
# Act
result = AppAnnotationService.batch_import_app_annotations(app.id, file)
@@ -968,10 +809,7 @@ class TestAppAnnotationServiceBatchImport:
new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1),
),
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = app
# Act
result = AppAnnotationService.batch_import_app_annotations(app.id, file)
@@ -999,10 +837,7 @@ class TestAppAnnotationServiceBatchImport:
new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=2),
),
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = app
# Act
result = AppAnnotationService.batch_import_app_annotations(app.id, file)
@@ -1028,10 +863,7 @@ class TestAppAnnotationServiceBatchImport:
new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=1, ANNOTATION_IMPORT_MIN_RECORDS=1),
),
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = app
# Act
result = AppAnnotationService.batch_import_app_annotations(app.id, file)
@@ -1061,10 +893,7 @@ class TestAppAnnotationServiceBatchImport:
new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1),
),
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = app
# Act
result = AppAnnotationService.batch_import_app_annotations(app.id, file)
@@ -1090,10 +919,7 @@ class TestAppAnnotationServiceBatchImport:
new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1),
),
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = app
# Act
result = AppAnnotationService.batch_import_app_annotations(app.id, file)
@@ -1119,10 +945,7 @@ class TestAppAnnotationServiceBatchImport:
new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1),
),
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = app
# Act
result = AppAnnotationService.batch_import_app_annotations(app.id, file)
@@ -1148,10 +971,7 @@ class TestAppAnnotationServiceBatchImport:
new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1),
),
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = app
# Act
result = AppAnnotationService.batch_import_app_annotations(app.id, file)
@@ -1182,10 +1002,7 @@ class TestAppAnnotationServiceBatchImport:
new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1),
),
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = app
# Act
result = AppAnnotationService.batch_import_app_annotations(app.id, file)
@@ -1218,10 +1035,7 @@ class TestAppAnnotationServiceBatchImport:
new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1),
),
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = app
# Act
result = AppAnnotationService.batch_import_app_annotations(app.id, file)
@@ -1257,10 +1071,7 @@ class TestAppAnnotationServiceBatchImport:
new=SimpleNamespace(ANNOTATION_IMPORT_MAX_RECORDS=5, ANNOTATION_IMPORT_MIN_RECORDS=1),
),
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = app
mock_redis.zadd.side_effect = RuntimeError("boom")
mock_redis.zrem.side_effect = RuntimeError("cleanup-failed")
@@ -1285,10 +1096,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = None
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(NotFound):
@@ -1306,15 +1114,8 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
annotation_query = MagicMock()
annotation_query.where.return_value = annotation_query
annotation_query.first.return_value = annotation
mock_db.session.query.side_effect = [app_query, annotation_query]
mock_db.session.scalar.return_value = app
mock_db.session.get.return_value = annotation
mock_db.paginate.return_value = pagination
# Act
@@ -1334,15 +1135,8 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
annotation_query = MagicMock()
annotation_query.where.return_value = annotation_query
annotation_query.first.return_value = None
mock_db.session.query.side_effect = [app_query, annotation_query]
mock_db.session.scalar.return_value = app
mock_db.session.get.return_value = None
# Act & Assert
with pytest.raises(NotFound):
@@ -1352,10 +1146,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
"""Test get_annotation_by_id returns None when not found."""
# Arrange
with patch("services.annotation_service.db") as mock_db:
query = MagicMock()
query.where.return_value = query
query.first.return_value = None
mock_db.session.query.return_value = query
mock_db.session.get.return_value = None
# Act
result = AppAnnotationService.get_annotation_by_id("ann-1")
@@ -1368,10 +1159,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
# Arrange
annotation = _make_annotation("ann-1")
with patch("services.annotation_service.db") as mock_db:
query = MagicMock()
query.where.return_value = query
query.first.return_value = annotation
mock_db.session.query.return_value = query
mock_db.session.get.return_value = annotation
# Act
result = AppAnnotationService.get_annotation_by_id("ann-1")
@@ -1386,10 +1174,6 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
patch("services.annotation_service.db") as mock_db,
patch("services.annotation_service.AppAnnotationHitHistory") as mock_history_cls,
):
query = MagicMock()
query.where.return_value = query
mock_db.session.query.return_value = query
# Act
AppAnnotationService.add_annotation_history(
annotation_id="ann-1",
@@ -1404,7 +1188,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
)
# Assert
query.update.assert_called_once()
mock_db.session.execute.assert_called_once()
mock_history_cls.assert_called_once()
mock_db.session.add.assert_called_once()
mock_db.session.commit.assert_called_once()
@@ -1420,15 +1204,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
setting_query = MagicMock()
setting_query.where.return_value = setting_query
setting_query.first.return_value = setting
mock_db.session.query.side_effect = [app_query, setting_query]
mock_db.session.scalar.side_effect = [app, setting]
# Act
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id)
@@ -1448,10 +1224,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = None
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(NotFound):
@@ -1468,15 +1241,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
setting_query = MagicMock()
setting_query.where.return_value = setting_query
setting_query.first.return_value = setting
mock_db.session.query.side_effect = [app_query, setting_query]
mock_db.session.scalar.side_effect = [app, setting]
# Act
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id)
@@ -1495,15 +1260,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
setting_query = MagicMock()
setting_query.where.return_value = setting_query
setting_query.first.return_value = None
mock_db.session.query.side_effect = [app_query, setting_query]
mock_db.session.scalar.side_effect = [app, None]
# Act
result = AppAnnotationService.get_app_annotation_setting_by_app_id(app.id)
@@ -1525,15 +1282,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
patch("services.annotation_service.db") as mock_db,
patch("services.annotation_service.naive_utc_now", return_value="now"),
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
setting_query = MagicMock()
setting_query.where.return_value = setting_query
setting_query.first.return_value = setting
mock_db.session.query.side_effect = [app_query, setting_query]
mock_db.session.scalar.side_effect = [app, setting]
# Act
result = AppAnnotationService.update_app_annotation_setting(app.id, setting.id, args)
@@ -1560,15 +1309,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
patch("services.annotation_service.db") as mock_db,
patch("services.annotation_service.naive_utc_now", return_value="now"),
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
setting_query = MagicMock()
setting_query.where.return_value = setting_query
setting_query.first.return_value = setting
mock_db.session.query.side_effect = [app_query, setting_query]
mock_db.session.scalar.side_effect = [app, setting]
# Act
result = AppAnnotationService.update_app_annotation_setting(app.id, setting.id, args)
@@ -1587,10 +1328,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = None
mock_db.session.query.return_value = app_query
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(NotFound):
@@ -1606,15 +1344,7 @@ class TestAppAnnotationServiceHitHistoryAndSettings:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
app_query = MagicMock()
app_query.where.return_value = app_query
app_query.first.return_value = app
setting_query = MagicMock()
setting_query.where.return_value = setting_query
setting_query.first.return_value = None
mock_db.session.query.side_effect = [app_query, setting_query]
mock_db.session.scalar.side_effect = [app, None]
# Act & Assert
with pytest.raises(NotFound):
@@ -1634,25 +1364,21 @@ class TestAppAnnotationServiceClearAll:
annotation2 = _make_annotation("ann-2")
history = MagicMock(spec=AppAnnotationHitHistory)
def query_side_effect(*args: object, **kwargs: object) -> MagicMock:
query = MagicMock()
query.where.return_value = query
if App in args:
query.first.return_value = app
elif AppAnnotationSetting in args:
query.first.return_value = setting
elif MessageAnnotation in args:
query.yield_per.return_value = [annotation1, annotation2]
elif AppAnnotationHitHistory in args:
query.yield_per.return_value = [history]
return query
with (
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
patch("services.annotation_service.delete_annotation_index_task") as mock_task,
):
mock_db.session.query.side_effect = query_side_effect
# scalar calls: app lookup, annotation_setting lookup
mock_db.session.scalar.side_effect = [app, setting]
# scalars calls: first for annotations iteration, then for each annotation's hit histories
annotations_scalars = MagicMock()
annotations_scalars.yield_per.return_value = [annotation1, annotation2]
histories_scalars_1 = MagicMock()
histories_scalars_1.yield_per.return_value = [history]
histories_scalars_2 = MagicMock()
histories_scalars_2.yield_per.return_value = []
mock_db.session.scalars.side_effect = [annotations_scalars, histories_scalars_1, histories_scalars_2]
# Act
result = AppAnnotationService.clear_all_annotations(app.id)
@@ -1675,10 +1401,7 @@ class TestAppAnnotationServiceClearAll:
patch("services.annotation_service.current_account_with_tenant", return_value=(_make_user(), tenant_id)),
patch("services.annotation_service.db") as mock_db,
):
query = MagicMock()
query.where.return_value = query
query.first.return_value = None
mock_db.session.query.return_value = query
mock_db.session.scalar.return_value = None
# Act & Assert
with pytest.raises(NotFound):

View File

@@ -90,13 +90,13 @@ class TestDocumentServiceQueryAndDownloadHelpers:
result = DocumentService.get_document("dataset-1", None)
assert result is None
mock_db.session.query.assert_not_called()
mock_db.session.scalar.assert_not_called()
def test_get_document_queries_by_dataset_and_document_id(self):
document = DatasetServiceUnitDataFactory.create_document_mock()
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.first.return_value = document
mock_db.session.scalar.return_value = document
result = DocumentService.get_document("dataset-1", "doc-1")
@@ -435,7 +435,7 @@ class TestDocumentServiceQueryAndDownloadHelpers:
upload_file = DatasetServiceUnitDataFactory.create_upload_file_mock()
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.one_or_none.return_value = upload_file
mock_db.session.get.return_value = upload_file
result = DocumentService.get_document_file_detail(upload_file.id)
@@ -570,7 +570,7 @@ class TestDocumentServiceMutations:
assert document.name == "New Name"
assert document.doc_metadata[BuiltInField.document_name] == "New Name"
mock_db.session.add.assert_called_once_with(document)
mock_db.session.query.return_value.where.return_value.update.assert_called_once()
mock_db.session.execute.assert_called()
mock_db.session.commit.assert_called_once()
def test_recover_document_raises_when_document_is_not_paused(self):
@@ -624,9 +624,7 @@ class TestDocumentServiceMutations:
document = DatasetServiceUnitDataFactory.create_document_mock(position=7)
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.first.return_value = (
document
)
mock_db.session.scalar.return_value = document
result = DocumentService.get_documents_position("dataset-1")
@@ -634,7 +632,7 @@ class TestDocumentServiceMutations:
def test_get_documents_position_defaults_to_one_when_dataset_is_empty(self):
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.filter_by.return_value.order_by.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
result = DocumentService.get_documents_position("dataset-1")
@@ -869,11 +867,7 @@ class TestDocumentServiceUpdateDocumentWithDatasetId:
patch("services.dataset_service.naive_utc_now", return_value="now"),
patch("services.dataset_service.document_indexing_update_task") as update_task,
):
upload_query = MagicMock()
upload_query.where.return_value.first.return_value = SimpleNamespace(id="file-1", name="upload.txt")
segment_query = MagicMock()
segment_query.filter_by.return_value.update.return_value = 3
mock_db.session.query.side_effect = [upload_query, segment_query]
mock_db.session.scalar.return_value = SimpleNamespace(id="file-1", name="upload.txt")
result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context)
@@ -892,7 +886,7 @@ class TestDocumentServiceUpdateDocumentWithDatasetId:
assert document.created_from == "web"
assert document.doc_form == IndexStructureType.QA_INDEX
assert mock_db.session.commit.call_count == 3
segment_query.filter_by.return_value.update.assert_called_once()
mock_db.session.execute.assert_called()
update_task.delay.assert_called_once_with(document.dataset_id, document.id)
def test_update_document_with_dataset_id_notion_import_requires_binding(self, account_context):
@@ -920,9 +914,7 @@ class TestDocumentServiceUpdateDocumentWithDatasetId:
patch.object(DatasetService, "check_dataset_model_setting"),
patch("services.dataset_service.db") as mock_db,
):
binding_query = MagicMock()
binding_query.where.return_value.first.return_value = None
mock_db.session.query.return_value = binding_query
mock_db.session.scalar.return_value = None
with pytest.raises(ValueError, match="Data source binding not found"):
DocumentService.update_document_with_dataset_id(dataset, document_data, account_context)
@@ -954,10 +946,6 @@ class TestDocumentServiceUpdateDocumentWithDatasetId:
patch("services.dataset_service.naive_utc_now", return_value="now"),
patch("services.dataset_service.document_indexing_update_task") as update_task,
):
segment_query = MagicMock()
segment_query.filter_by.return_value.update.return_value = 2
mock_db.session.query.return_value = segment_query
result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context)
assert result is document
@@ -968,7 +956,7 @@ class TestDocumentServiceUpdateDocumentWithDatasetId:
)
assert document.name == ""
assert document.doc_form == IndexStructureType.PARENT_CHILD_INDEX
segment_query.filter_by.return_value.update.assert_called_once()
mock_db.session.execute.assert_called()
update_task.delay.assert_called_once_with("dataset-1", "doc-1")
@@ -1218,11 +1206,10 @@ class TestDocumentServiceSaveDocumentWithDatasetId:
patch("services.dataset_service.secrets.randbelow", return_value=23),
):
mock_redis.lock.return_value = _make_lock_context()
upload_query = MagicMock()
upload_query.where.return_value.all.return_value = [upload_file_a, upload_file_b]
existing_documents_query = MagicMock()
existing_documents_query.where.return_value.all.return_value = [duplicate_document]
mock_db.session.query.side_effect = [upload_query, existing_documents_query]
mock_db.session.scalars.return_value.all.side_effect = [
[upload_file_a, upload_file_b],
[duplicate_document],
]
documents, batch = DocumentService.save_document_with_dataset_id(
dataset,
@@ -1302,9 +1289,7 @@ class TestDocumentServiceSaveDocumentWithDatasetId:
patch("services.dataset_service.DocumentIndexingTaskProxy") as document_proxy_cls,
):
mock_redis.lock.return_value = _make_lock_context()
notion_documents_query = MagicMock()
notion_documents_query.filter_by.return_value.all.return_value = [existing_keep, existing_remove]
mock_db.session.query.return_value = notion_documents_query
mock_db.session.scalars.return_value.all.return_value = [existing_keep, existing_remove]
documents, _ = DocumentService.save_document_with_dataset_id(
dataset,
@@ -1474,12 +1459,11 @@ class TestDocumentServiceTenantAndUpdateEdges:
def test_get_tenant_documents_count_returns_query_count(self, account_context):
with patch("services.dataset_service.db") as mock_db:
mock_db.session.query.return_value.where.return_value.count.return_value = 12
mock_db.session.scalar.return_value = 12
result = DocumentService.get_tenant_documents_count()
assert result == 12
mock_db.session.query.return_value.where.return_value.count.assert_called_once()
def test_update_document_with_dataset_id_uses_automatic_process_rule_payload(self, account_context):
dataset = SimpleNamespace(id="dataset-1", tenant_id="tenant-1")
@@ -1514,11 +1498,7 @@ class TestDocumentServiceTenantAndUpdateEdges:
):
process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES
process_rule_cls.return_value = created_process_rule
upload_query = MagicMock()
upload_query.where.return_value.first.return_value = SimpleNamespace(id="file-1", name="upload.txt")
segment_query = MagicMock()
segment_query.filter_by.return_value.update.return_value = 1
mock_db.session.query.side_effect = [upload_query, segment_query]
mock_db.session.scalar.return_value = SimpleNamespace(id="file-1", name="upload.txt")
result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context)
@@ -1567,7 +1547,7 @@ class TestDocumentServiceTenantAndUpdateEdges:
patch.object(DatasetService, "check_dataset_model_setting"),
patch("services.dataset_service.db") as mock_db,
):
mock_db.session.query.return_value.where.return_value.first.return_value = None
mock_db.session.scalar.return_value = None
with pytest.raises(FileNotExistsError):
DocumentService.update_document_with_dataset_id(dataset, document_data, account_context)
@@ -1618,11 +1598,7 @@ class TestDocumentServiceTenantAndUpdateEdges:
patch("services.dataset_service.naive_utc_now", return_value="now"),
patch("services.dataset_service.document_indexing_update_task") as update_task,
):
binding_query = MagicMock()
binding_query.where.return_value.first.return_value = SimpleNamespace(id="binding-1")
segment_query = MagicMock()
segment_query.filter_by.return_value.update.return_value = 1
mock_db.session.query.side_effect = [binding_query, segment_query]
mock_db.session.scalar.return_value = SimpleNamespace(id="binding-1")
result = DocumentService.update_document_with_dataset_id(dataset, document_data, account_context)
@@ -1914,11 +1890,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches:
):
mock_redis.lock.return_value = _make_lock_context()
process_rule_cls.return_value = created_process_rule
upload_query = MagicMock()
upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")]
existing_documents_query = MagicMock()
existing_documents_query.where.return_value.all.return_value = []
mock_db.session.query.side_effect = [upload_query, existing_documents_query]
mock_db.session.scalars.return_value.all.side_effect = [[SimpleNamespace(id="file-1", name="file.txt")], []]
documents, batch = DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context)
@@ -1958,11 +1930,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches:
mock_redis.lock.return_value = _make_lock_context()
process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES
process_rule_cls.return_value = created_process_rule
upload_query = MagicMock()
upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")]
existing_documents_query = MagicMock()
existing_documents_query.where.return_value.all.return_value = []
mock_db.session.query.side_effect = [upload_query, existing_documents_query]
mock_db.session.scalars.return_value.all.side_effect = [[SimpleNamespace(id="file-1", name="file.txt")], []]
DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context)
@@ -1996,11 +1964,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches:
mock_redis.lock.return_value = _make_lock_context()
process_rule_cls.AUTOMATIC_RULES = DatasetProcessRule.AUTOMATIC_RULES
process_rule_cls.return_value = created_process_rule
upload_query = MagicMock()
upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")]
existing_documents_query = MagicMock()
existing_documents_query.where.return_value.all.return_value = []
mock_db.session.query.side_effect = [upload_query, existing_documents_query]
mock_db.session.scalars.return_value.all.side_effect = [[SimpleNamespace(id="file-1", name="file.txt")], []]
DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context)
@@ -2024,9 +1988,7 @@ class TestDocumentServiceSaveDocumentAdditionalBranches:
patch("services.dataset_service.secrets.randbelow", return_value=23),
):
mock_redis.lock.return_value = _make_lock_context()
upload_query = MagicMock()
upload_query.where.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")]
mock_db.session.query.return_value = upload_query
mock_db.session.scalars.return_value.all.return_value = [SimpleNamespace(id="file-1", name="file.txt")]
with pytest.raises(FileNotExistsError, match="One or more files not found"):
DocumentService.save_document_with_dataset_id(dataset, knowledge_config, account_context)