Merge remote-tracking branch 'origin/main' into feat/trigger

This commit is contained in:
yessenia
2025-09-25 17:14:24 +08:00
3013 changed files with 148826 additions and 44294 deletions

View File

@@ -38,7 +38,7 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
logger.exception("Delete annotation index failed when annotation deleted.")
end_at = time.perf_counter()
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
except Exception as e:
except Exception:
logger.exception("Annotation deleted index failed")
finally:
db.session.close()

View File

@@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
@@ -39,7 +40,7 @@ def enable_annotation_reply_task(
db.session.close()
return
annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).all()
annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"

View File

@@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
@@ -15,7 +16,7 @@ logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]):
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]):
"""
Clean document when document deleted.
:param document_ids: document ids
@@ -29,12 +30,16 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
start_at = time.perf_counter()
try:
if not doc_form:
raise ValueError("doc_form is required")
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)).all()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
@@ -59,7 +64,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
db.session.commit()
if file_ids:
files = db.session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
for file in files:
try:
storage.delete(file.key)

View File

@@ -79,7 +79,7 @@ def batch_create_segment_to_index_task(
# Skip the first row
df = pd.read_csv(file_path)
content = []
for index, row in df.iterrows():
for _, row in df.iterrows():
if dataset_document.doc_form == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:

View File

@@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
@@ -55,8 +56,8 @@ def clean_dataset_task(
index_struct=index_struct,
collection_binding_id=collection_binding_id,
)
documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all()
segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all()
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
# This ensures all invalid doc_form values are properly handled
@@ -75,7 +76,7 @@ def clean_dataset_task(
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green"))
except Exception as index_cleanup_error:
except Exception:
logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red"))
# Continue with document and segment deletion even if vector cleanup fails
logger.info(
@@ -145,7 +146,7 @@ def clean_dataset_task(
try:
db.session.rollback()
logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow"))
except Exception as rollback_error:
except Exception:
logger.exception("Failed to rollback database session")
logger.exception("Cleaned dataset when dataset deleted failed")

View File

@@ -1,9 +1,9 @@
import logging
import time
from typing import Optional
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
@@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_id: Optional[str]):
def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_id: str | None):
"""
Clean document when document deleted.
:param document_id: document id
@@ -35,7 +35,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
if not dataset:
raise Exception("Document has no dataset")
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]

View File

@@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
@@ -34,7 +35,9 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
document = db.session.query(Document).where(Document.id == document_id).first()
db.session.delete(document)
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)

View File

@@ -1,6 +1,5 @@
import logging
import time
from typing import Optional
import click
from celery import shared_task
@@ -16,7 +15,7 @@ logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] = None):
def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = None):
"""
Async create segment to index
:param segment_id:

View File

@@ -0,0 +1,171 @@
import logging
import time
import click
from celery import shared_task # type: ignore
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from core.rag.models.document import ChildDocument, Document
from extensions.ext_database import db
from models.dataset import Dataset, DocumentSegment
from models.dataset import Document as DatasetDocument
@shared_task(queue="dataset")
def deal_dataset_index_update_task(dataset_id: str, action: str):
"""
Async deal dataset from index
:param dataset_id: dataset_id
:param action: action
Usage: deal_dataset_index_update_task.delay(dataset_id, action)
"""
logging.info(click.style("Start deal dataset index update: {}".format(dataset_id), fg="green"))
start_at = time.perf_counter()
try:
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if action == "upgrade":
dataset_documents = (
db.session.query(DatasetDocument)
.where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
for dataset_document in dataset_documents:
try:
# add from vector index
segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
documents.append(document)
# save vector index
# clean keywords
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
index_processor.load(dataset, documents, with_keywords=False)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
elif action == "update":
dataset_documents = (
db.session.query(DatasetDocument)
.where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
# add new index
if dataset_documents:
# update document status
dataset_documents_ids = [doc.id for doc in dataset_documents]
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
{"indexing_status": "indexing"}, synchronize_session=False
)
db.session.commit()
# clean index
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
for dataset_document in dataset_documents:
# update from vector index
try:
segments = (
db.session.query(DocumentSegment)
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
.order_by(DocumentSegment.position.asc())
.all()
)
if segments:
documents = []
for segment in segments:
document = Document(
page_content=segment.content,
metadata={
"doc_id": segment.index_node_id,
"doc_hash": segment.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
child_chunks = segment.get_child_chunks()
if child_chunks:
child_documents = []
for child_chunk in child_chunks:
child_document = ChildDocument(
page_content=child_chunk.content,
metadata={
"doc_id": child_chunk.index_node_id,
"doc_hash": child_chunk.index_node_hash,
"document_id": segment.document_id,
"dataset_id": segment.dataset_id,
},
)
child_documents.append(child_document)
document.children = child_documents
documents.append(document)
# save vector index
index_processor.load(dataset, documents, with_keywords=False)
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "completed"}, synchronize_session=False
)
db.session.commit()
except Exception as e:
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
)
db.session.commit()
else:
# clean collection
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
end_at = time.perf_counter()
logging.info(
click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")
)
except Exception:
logging.exception("Deal dataset vector index failed")
finally:
db.session.close()

View File

@@ -4,6 +4,7 @@ from typing import Literal
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -36,16 +37,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
if action == "remove":
index_processor.clean(dataset, None, with_keywords=False)
elif action == "add":
dataset_documents = (
db.session.query(DatasetDocument)
.where(
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
).all()
if dataset_documents:
dataset_documents_ids = [doc.id for doc in dataset_documents]
@@ -89,16 +88,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
)
db.session.commit()
elif action == "update":
dataset_documents = (
db.session.query(DatasetDocument)
.where(
dataset_documents = db.session.scalars(
select(DatasetDocument).where(
DatasetDocument.dataset_id == dataset_id,
DatasetDocument.indexing_status == "completed",
DatasetDocument.enabled == True,
DatasetDocument.archived == False,
)
.all()
)
).all()
# add new index
if dataset_documents:
# update document status

View File

@@ -15,7 +15,7 @@ def delete_account_task(account_id):
account = db.session.query(Account).where(Account.id == account_id).first()
try:
BillingService.delete_account(account_id)
except Exception as e:
except Exception:
logger.exception("Failed to delete account %s from billing service.", account_id)
raise

View File

@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
@shared_task(queue="conversation")
def delete_conversation_related_data(conversation_id: str) -> None:
def delete_conversation_related_data(conversation_id: str):
"""
Delete related data conversation in correct order from datatbase to respect foreign key constraints

View File

@@ -12,7 +12,9 @@ logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, document_id: str):
def delete_segment_from_index_task(
index_node_ids: list, dataset_id: str, document_id: str, child_node_ids: list | None = None
):
"""
Async Remove segment from index
:param index_node_ids:
@@ -26,6 +28,7 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume
try:
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
return
dataset_document = db.session.query(Document).where(Document.id == document_id).first()
@@ -33,11 +36,19 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume
return
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
logging.info("Document not in valid state for index operations, skipping")
return
doc_form = dataset_document.doc_form
index_type = dataset_document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
# Proceed with index cleanup using the index_node_ids directly
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset,
index_node_ids,
with_keywords=True,
delete_child_chunks=True,
precomputed_child_node_ids=child_node_ids,
)
end_at = time.perf_counter()
logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))

View File

@@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
@@ -44,15 +45,13 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
.all()
)
).all()
if not segments:
db.session.close()

View File

@@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.extractor.notion_extractor import NotionExtractor
@@ -46,6 +47,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
data_source_binding = (
db.session.query(DataSourceOauthBinding)
.where(
@@ -85,7 +87,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index

View File

@@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -45,7 +46,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]

View File

@@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from configs import dify_config
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
@@ -27,73 +28,77 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
documents = []
start_at = time.perf_counter()
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset is None:
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
db.session.close()
return
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
if features.billing.subscription.plan == "sandbox" and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
if 0 < vector_space.limit <= vector_space.size:
raise ValueError(
"Your total number of documents plus the number of uploads have over the limit of "
"your subscription."
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset is None:
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
db.session.close()
return
# check document limit
features = FeatureService.get_features(dataset.tenant_id)
try:
if features.billing.enabled:
vector_space = features.vector_space
count = len(document_ids)
if features.billing.subscription.plan == "sandbox" and count > 1:
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
if count > batch_upload_limit:
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
current = int(getattr(vector_space, "size", 0) or 0)
limit = int(getattr(vector_space, "limit", 0) or 0)
if limit > 0 and (current + count) > limit:
raise ValueError(
"Your total number of documents plus the number of uploads have exceeded the limit of "
"your subscription."
)
except Exception as e:
for document_id in document_ids:
document = (
db.session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
except Exception as e:
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
db.session.add(document)
db.session.commit()
return
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()
# clean old data
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
return
finally:
db.session.close()
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
if document:
# clean old data
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
for segment in segments:
db.session.delete(segment)
db.session.commit()
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
documents.append(document)
db.session.add(document)
db.session.commit()
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()

View File

@@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.constant.index_type import IndexType
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -45,15 +46,13 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
# sync index processor
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
segments = (
db.session.query(DocumentSegment)
.where(
segments = db.session.scalars(
select(DocumentSegment).where(
DocumentSegment.id.in_(segment_ids),
DocumentSegment.dataset_id == dataset_id,
DocumentSegment.document_id == document_id,
)
.all()
)
).all()
if not segments:
logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
db.session.close()

View File

@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_deletion_success_task(to: str, language: str = "en-US") -> None:
def send_deletion_success_task(to: str, language: str = "en-US"):
"""
Send account deletion success email with internationalization support.
@@ -46,7 +46,7 @@ def send_deletion_success_task(to: str, language: str = "en-US") -> None:
@shared_task(queue="mail")
def send_account_deletion_verification_code(to: str, code: str, language: str = "en-US") -> None:
def send_account_deletion_verification_code(to: str, code: str, language: str = "en-US"):
"""
Send account deletion verification code email with internationalization support.

View File

@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_change_mail_task(language: str, to: str, code: str, phase: str) -> None:
def send_change_mail_task(language: str, to: str, code: str, phase: str):
"""
Send change email notification with internationalization support.
@@ -43,7 +43,7 @@ def send_change_mail_task(language: str, to: str, code: str, phase: str) -> None
@shared_task(queue="mail")
def send_change_mail_completed_notification_task(language: str, to: str) -> None:
def send_change_mail_completed_notification_task(language: str, to: str):
"""
Send change email completed notification with internationalization support.

View File

@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_email_code_login_mail_task(language: str, to: str, code: str) -> None:
def send_email_code_login_mail_task(language: str, to: str, code: str):
"""
Send email code login email with internationalization support.

View File

@@ -1,17 +1,46 @@
import logging
import time
from collections.abc import Mapping
from typing import Any
import click
from celery import shared_task
from flask import render_template_string
from jinja2.runtime import Context
from jinja2.sandbox import ImmutableSandboxedEnvironment
from configs import dify_config
from configs.feature import TemplateMode
from extensions.ext_mail import mail
from libs.email_i18n import get_email_i18n_service
logger = logging.getLogger(__name__)
class SandboxedEnvironment(ImmutableSandboxedEnvironment):
def __init__(self, timeout: int, *args: Any, **kwargs: Any):
self._timeout_time = time.time() + timeout
super().__init__(*args, **kwargs)
def call(self, context: Context, obj: Any, *args: Any, **kwargs: Any) -> Any:
if time.time() > self._timeout_time:
raise TimeoutError("Template rendering timeout")
return super().call(context, obj, *args, **kwargs)
def _render_template_with_strategy(body: str, substitutions: Mapping[str, str]) -> str:
mode = dify_config.MAIL_TEMPLATING_MODE
timeout = dify_config.MAIL_TEMPLATING_TIMEOUT
if mode == TemplateMode.UNSAFE:
return render_template_string(body, **substitutions)
if mode == TemplateMode.SANDBOX:
tmpl = SandboxedEnvironment(timeout=timeout).from_string(body)
return tmpl.render(substitutions)
if mode == TemplateMode.DISABLED:
return body
raise ValueError(f"Unsupported mail templating mode: {mode}")
@shared_task(queue="mail")
def send_inner_email_task(to: list[str], subject: str, body: str, substitutions: Mapping[str, str]):
if not mail.is_inited():
@@ -21,7 +50,7 @@ def send_inner_email_task(to: list[str], subject: str, body: str, substitutions:
start_at = time.perf_counter()
try:
html_content = render_template_string(body, **substitutions)
html_content = _render_template_with_strategy(body, substitutions)
email_service = get_email_i18n_service()
email_service.send_raw_email(to=to, subject=subject, html_content=html_content)

View File

@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str) -> None:
def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str):
"""
Send invite member email with internationalization support.

View File

@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_owner_transfer_confirm_task(language: str, to: str, code: str, workspace: str) -> None:
def send_owner_transfer_confirm_task(language: str, to: str, code: str, workspace: str):
"""
Send owner transfer confirmation email with internationalization support.
@@ -52,7 +52,7 @@ def send_owner_transfer_confirm_task(language: str, to: str, code: str, workspac
@shared_task(queue="mail")
def send_old_owner_transfer_notify_email_task(language: str, to: str, workspace: str, new_owner_email: str) -> None:
def send_old_owner_transfer_notify_email_task(language: str, to: str, workspace: str, new_owner_email: str):
"""
Send old owner transfer notification email with internationalization support.
@@ -93,7 +93,7 @@ def send_old_owner_transfer_notify_email_task(language: str, to: str, workspace:
@shared_task(queue="mail")
def send_new_owner_transfer_notify_email_task(language: str, to: str, workspace: str) -> None:
def send_new_owner_transfer_notify_email_task(language: str, to: str, workspace: str):
"""
Send new owner transfer notification email with internationalization support.

View File

@@ -0,0 +1,87 @@
import logging
import time
import click
from celery import shared_task
from configs import dify_config
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_email_register_mail_task(language: str, to: str, code: str) -> None:
"""
Send email register email with internationalization support.
Args:
language: Language code for email localization
to: Recipient email address
code: Email register code
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start email register mail to {to}", fg="green"))
start_at = time.perf_counter()
try:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.EMAIL_REGISTER,
language_code=language,
to=to,
template_context={
"to": to,
"code": code,
},
)
end_at = time.perf_counter()
logger.info(
click.style(f"Send email register mail to {to} succeeded: latency: {end_at - start_at}", fg="green")
)
except Exception:
logger.exception("Send email register mail to %s failed", to)
@shared_task(queue="mail")
def send_email_register_mail_task_when_account_exist(language: str, to: str, account_name: str) -> None:
"""
Send email register email with internationalization support when account exist.
Args:
language: Language code for email localization
to: Recipient email address
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start email register mail to {to}", fg="green"))
start_at = time.perf_counter()
try:
login_url = f"{dify_config.CONSOLE_WEB_URL}/signin"
reset_password_url = f"{dify_config.CONSOLE_WEB_URL}/reset-password"
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.EMAIL_REGISTER_WHEN_ACCOUNT_EXIST,
language_code=language,
to=to,
template_context={
"to": to,
"login_url": login_url,
"reset_password_url": reset_password_url,
"account_name": account_name,
},
)
end_at = time.perf_counter()
logger.info(
click.style(f"Send email register mail to {to} succeeded: latency: {end_at - start_at}", fg="green")
)
except Exception:
logger.exception("Send email register mail to %s failed", to)

View File

@@ -4,6 +4,7 @@ import time
import click
from celery import shared_task
from configs import dify_config
from extensions.ext_mail import mail
from libs.email_i18n import EmailType, get_email_i18n_service
@@ -11,7 +12,7 @@ logger = logging.getLogger(__name__)
@shared_task(queue="mail")
def send_reset_password_mail_task(language: str, to: str, code: str) -> None:
def send_reset_password_mail_task(language: str, to: str, code: str):
"""
Send reset password email with internationalization support.
@@ -44,3 +45,47 @@ def send_reset_password_mail_task(language: str, to: str, code: str) -> None:
)
except Exception:
logger.exception("Send password reset mail to %s failed", to)
@shared_task(queue="mail")
def send_reset_password_mail_task_when_account_not_exist(language: str, to: str, is_allow_register: bool) -> None:
"""
Send reset password email with internationalization support when account not exist.
Args:
language: Language code for email localization
to: Recipient email address
"""
if not mail.is_inited():
return
logger.info(click.style(f"Start password reset mail to {to}", fg="green"))
start_at = time.perf_counter()
try:
if is_allow_register:
sign_up_url = f"{dify_config.CONSOLE_WEB_URL}/signup"
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST,
language_code=language,
to=to,
template_context={
"to": to,
"sign_up_url": sign_up_url,
},
)
else:
email_service = get_email_i18n_service()
email_service.send_email(
email_type=EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER,
language_code=language,
to=to,
)
end_at = time.perf_counter()
logger.info(
click.style(f"Send password reset mail to {to} succeeded: latency: {end_at - start_at}", fg="green")
)
except Exception:
logger.exception("Send password reset mail to %s failed", to)

View File

@@ -1,3 +1,4 @@
import operator
import traceback
import typing
@@ -118,7 +119,7 @@ def process_tenant_plugin_autoupgrade_check_task(
current_version = version
latest_version = manifest.latest_version
def fix_only_checker(latest_version, current_version):
def fix_only_checker(latest_version: str, current_version: str):
latest_version_tuple = tuple(int(val) for val in latest_version.split("."))
current_version_tuple = tuple(int(val) for val in current_version.split("."))
@@ -130,8 +131,7 @@ def process_tenant_plugin_autoupgrade_check_task(
return False
version_checker = {
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: lambda latest_version,
current_version: latest_version != current_version,
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: operator.ne,
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker,
}
@@ -146,7 +146,7 @@ def process_tenant_plugin_autoupgrade_check_task(
fg="green",
)
)
task_start_resp = manager.upgrade_plugin(
_ = manager.upgrade_plugin(
tenant_id,
original_unique_identifier,
new_unique_identifier,

View File

@@ -0,0 +1,175 @@
import contextvars
import json
import logging
import time
import uuid
from collections.abc import Mapping
from concurrent.futures import ThreadPoolExecutor
from typing import Any
import click
from celery import shared_task # type: ignore
from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.repositories.factory import DifyCoreRepositoryFactory
from extensions.ext_database import db
from models.account import Account, Tenant
from models.dataset import Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.file_service import FileService
@shared_task(queue="priority_pipeline")
def priority_rag_pipeline_run_task(
rag_pipeline_invoke_entities_file_id: str,
tenant_id: str,
):
"""
Async Run rag pipeline
:param rag_pipeline_invoke_entities: Rag pipeline invoke entities
rag_pipeline_invoke_entities include:
:param pipeline_id: Pipeline ID
:param user_id: User ID
:param tenant_id: Tenant ID
:param workflow_id: Workflow ID
:param invoke_from: Invoke source (debugger, published, etc.)
:param streaming: Whether to stream results
:param datasource_type: Type of datasource
:param datasource_info: Datasource information dict
:param batch: Batch identifier
:param document_id: Document ID (optional)
:param start_node_id: Starting node ID
:param inputs: Input parameters dict
:param workflow_execution_id: Workflow execution ID
:param workflow_thread_pool_id: Thread pool ID for workflow execution
"""
# run with threading, thread pool size is 10
try:
start_at = time.perf_counter()
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(
rag_pipeline_invoke_entities_file_id
)
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
# Get Flask app object for thread context
flask_app = current_app._get_current_object() # type: ignore
with ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
# Submit task to thread pool with Flask app
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app)
futures.append(future)
# Wait for all tasks to complete
for future in futures:
try:
future.result() # This will raise any exceptions that occurred in the thread
except Exception:
logging.exception("Error in pipeline task")
end_at = time.perf_counter()
logging.info(
click.style(
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
)
)
except Exception:
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
raise
finally:
file_service = FileService(db.engine)
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
db.session.close()
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app):
"""Run a single RAG pipeline task within Flask app context."""
# Create Flask application context for this thread
with flask_app.app_context():
try:
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
user_id = rag_pipeline_invoke_entity_model.user_id
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
workflow_id = rag_pipeline_invoke_entity_model.workflow_id
streaming = rag_pipeline_invoke_entity_model.streaming
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
with Session(db.engine, expire_on_commit=False) as session:
# Load required entities
account = session.query(Account).where(Account.id == user_id).first()
if not account:
raise ValueError(f"Account {user_id} not found")
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
account.current_tenant = tenant
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError(f"Pipeline {pipeline_id} not found")
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
if workflow_execution_id is None:
workflow_execution_id = str(uuid.uuid4())
# Create application generate entity from dict
entity = RagPipelineGenerateEntity(**application_generate_entity)
# Create workflow repositories
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
)
workflow_node_execution_repository = (
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
)
# Set the user directly in g for preserve_flask_contexts
g._login_user = account
# Copy context for passing to pipeline generator
context = contextvars.copy_context()
# Direct execution without creating another thread
# Since we're already in a thread pool, no need for nested threading
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
pipeline_generator = PipelineGenerator()
# Using protected method intentionally for async execution
pipeline_generator._generate( # type: ignore[attr-defined]
flask_app=flask_app,
context=context,
pipeline=pipeline,
workflow_id=workflow_id,
user=account,
application_generate_entity=entity,
invoke_from=InvokeFrom.PUBLISHED,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
except Exception:
logging.exception("Error in priority pipeline task")
raise

View File

@@ -0,0 +1,196 @@
import contextvars
import json
import logging
import time
import uuid
from collections.abc import Mapping
from concurrent.futures import ThreadPoolExecutor
from typing import Any
import click
from celery import shared_task # type: ignore
from flask import current_app, g
from sqlalchemy.orm import Session, sessionmaker
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
from core.repositories.factory import DifyCoreRepositoryFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import Account, Tenant
from models.dataset import Pipeline
from models.enums import WorkflowRunTriggeredFrom
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.file_service import FileService
@shared_task(queue="pipeline")
def rag_pipeline_run_task(
rag_pipeline_invoke_entities_file_id: str,
tenant_id: str,
):
"""
Async Run rag pipeline
:param rag_pipeline_invoke_entities: Rag pipeline invoke entities
rag_pipeline_invoke_entities include:
:param pipeline_id: Pipeline ID
:param user_id: User ID
:param tenant_id: Tenant ID
:param workflow_id: Workflow ID
:param invoke_from: Invoke source (debugger, published, etc.)
:param streaming: Whether to stream results
:param datasource_type: Type of datasource
:param datasource_info: Datasource information dict
:param batch: Batch identifier
:param document_id: Document ID (optional)
:param start_node_id: Starting node ID
:param inputs: Input parameters dict
:param workflow_execution_id: Workflow execution ID
:param workflow_thread_pool_id: Thread pool ID for workflow execution
"""
# run with threading, thread pool size is 10
try:
start_at = time.perf_counter()
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(
rag_pipeline_invoke_entities_file_id
)
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
# Get Flask app object for thread context
flask_app = current_app._get_current_object() # type: ignore
with ThreadPoolExecutor(max_workers=10) as executor:
futures = []
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
# Submit task to thread pool with Flask app
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app)
futures.append(future)
# Wait for all tasks to complete
for future in futures:
try:
future.result() # This will raise any exceptions that occurred in the thread
except Exception:
logging.exception("Error in pipeline task")
end_at = time.perf_counter()
logging.info(
click.style(
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
)
)
except Exception:
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
raise
finally:
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{tenant_id}"
tenant_pipeline_task_key = f"tenant_pipeline_task:{tenant_id}"
# Check if there are waiting tasks in the queue
# Use rpop to get the next task from the queue (FIFO order)
next_file_id = redis_client.rpop(tenant_self_pipeline_task_queue)
if next_file_id:
# Process the next waiting task
# Keep the flag set to indicate a task is running
redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1)
rag_pipeline_run_task.delay( # type: ignore
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
if isinstance(next_file_id, bytes)
else next_file_id,
tenant_id=tenant_id,
)
else:
# No more waiting tasks, clear the flag
redis_client.delete(tenant_pipeline_task_key)
file_service = FileService(db.engine)
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
db.session.close()
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app):
"""Run a single RAG pipeline task within Flask app context."""
# Create Flask application context for this thread
with flask_app.app_context():
try:
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
user_id = rag_pipeline_invoke_entity_model.user_id
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
workflow_id = rag_pipeline_invoke_entity_model.workflow_id
streaming = rag_pipeline_invoke_entity_model.streaming
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
with Session(db.engine) as session:
# Load required entities
account = session.query(Account).where(Account.id == user_id).first()
if not account:
raise ValueError(f"Account {user_id} not found")
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
if not tenant:
raise ValueError(f"Tenant {tenant_id} not found")
account.current_tenant = tenant
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
if not pipeline:
raise ValueError(f"Pipeline {pipeline_id} not found")
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
if not workflow:
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
if workflow_execution_id is None:
workflow_execution_id = str(uuid.uuid4())
# Create application generate entity from dict
entity = RagPipelineGenerateEntity(**application_generate_entity)
# Create workflow repositories
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
)
workflow_node_execution_repository = (
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
session_factory=session_factory,
user=account,
app_id=entity.app_config.app_id,
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
)
)
# Set the user directly in g for preserve_flask_contexts
g._login_user = account
# Copy context for passing to pipeline generator
context = contextvars.copy_context()
# Direct execution without creating another thread
# Since we're already in a thread pool, no need for nested threading
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
pipeline_generator = PipelineGenerator()
# Using protected method intentionally for async execution
pipeline_generator._generate( # type: ignore[attr-defined]
flask_app=flask_app,
context=context,
pipeline=pipeline,
workflow_id=workflow_id,
user=account,
application_generate_entity=entity,
invoke_from=InvokeFrom.PUBLISHED,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
workflow_thread_pool_id=workflow_thread_pool_id,
)
except Exception:
logging.exception("Error in pipeline task")
raise

View File

@@ -355,6 +355,11 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
"""
Delete draft variables for an app in batches.
This function now handles cleanup of associated Offload data including:
- WorkflowDraftVariableFile records
- UploadFile records
- Object storage files
Args:
app_id: The ID of the app whose draft variables should be deleted
batch_size: Number of records to delete per batch
@@ -366,22 +371,31 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
raise ValueError("batch_size must be positive")
total_deleted = 0
total_files_deleted = 0
while True:
with db.engine.begin() as conn:
# Get a batch of draft variable IDs
# Get a batch of draft variable IDs along with their file_ids
query_sql = """
SELECT id FROM workflow_draft_variables
SELECT id, file_id FROM workflow_draft_variables
WHERE app_id = :app_id
LIMIT :batch_size
"""
result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
draft_var_ids = [row[0] for row in result]
if not draft_var_ids:
rows = list(result)
if not rows:
break
# Delete the batch
draft_var_ids = [row[0] for row in rows]
file_ids = [row[1] for row in rows if row[1] is not None]
# Clean up associated Offload data first
if file_ids:
files_deleted = _delete_draft_variable_offload_data(conn, file_ids)
total_files_deleted += files_deleted
# Delete the draft variables
delete_sql = """
DELETE FROM workflow_draft_variables
WHERE id IN :ids
@@ -392,10 +406,85 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
logger.info(click.style(f"Deleted {batch_deleted} draft variables (batch) for app {app_id}", fg="green"))
logger.info(click.style(f"Deleted {total_deleted} total draft variables for app {app_id}", fg="green"))
logger.info(
click.style(
f"Deleted {total_deleted} total draft variables for app {app_id}. "
f"Cleaned up {total_files_deleted} total associated files.",
fg="green",
)
)
return total_deleted
def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
"""
Delete Offload data associated with WorkflowDraftVariable file_ids.
This function:
1. Finds WorkflowDraftVariableFile records by file_ids
2. Deletes associated files from object storage
3. Deletes UploadFile records
4. Deletes WorkflowDraftVariableFile records
Args:
conn: Database connection
file_ids: List of WorkflowDraftVariableFile IDs
Returns:
Number of files cleaned up
"""
from extensions.ext_storage import storage
if not file_ids:
return 0
files_deleted = 0
try:
# Get WorkflowDraftVariableFile records and their associated UploadFile keys
query_sql = """
SELECT wdvf.id, uf.key, uf.id as upload_file_id
FROM workflow_draft_variable_files wdvf
JOIN upload_files uf ON wdvf.upload_file_id = uf.id
WHERE wdvf.id IN :file_ids
"""
result = conn.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)})
file_records = list(result)
# Delete from object storage and collect upload file IDs
upload_file_ids = []
for _, storage_key, upload_file_id in file_records:
try:
storage.delete(storage_key)
upload_file_ids.append(upload_file_id)
files_deleted += 1
except Exception:
logging.exception("Failed to delete storage object %s", storage_key)
# Continue with database cleanup even if storage deletion fails
upload_file_ids.append(upload_file_id)
# Delete UploadFile records
if upload_file_ids:
delete_upload_files_sql = """
DELETE FROM upload_files
WHERE id IN :upload_file_ids
"""
conn.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)})
# Delete WorkflowDraftVariableFile records
delete_variable_files_sql = """
DELETE FROM workflow_draft_variable_files
WHERE id IN :file_ids
"""
conn.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)})
except Exception:
logging.exception("Error deleting draft variable offload data:")
# Don't raise, as we want to continue with the main deletion process
return files_deleted
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
while True:
with db.engine.begin() as conn:

View File

@@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
@@ -45,7 +46,7 @@ def remove_document_from_index_task(document_id: str):
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
if index_node_ids:
try:

View File

@@ -3,26 +3,30 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from libs.datetime_utils import naive_utc_now
from models.account import Account, Tenant
from models.dataset import Dataset, Document, DocumentSegment
from services.feature_service import FeatureService
from services.rag_pipeline.rag_pipeline import RagPipelineService
logger = logging.getLogger(__name__)
@shared_task(queue="dataset")
def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_id: str):
"""
Async process document
:param dataset_id:
:param document_ids:
:param user_id:
Usage: retry_document_indexing_task.delay(dataset_id, document_ids)
Usage: retry_document_indexing_task.delay(dataset_id, document_ids, user_id)
"""
start_at = time.perf_counter()
try:
@@ -30,11 +34,19 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
if not dataset:
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
return
tenant_id = dataset.tenant_id
user = db.session.query(Account).where(Account.id == user_id).first()
if not user:
logger.info(click.style(f"User not found: {user_id}", fg="red"))
return
tenant = db.session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
if not tenant:
raise ValueError("Tenant not found")
user.current_tenant = tenant
for document_id in document_ids:
retry_indexing_cache_key = f"document_{document_id}_is_retried"
# check document limit
features = FeatureService.get_features(tenant_id)
features = FeatureService.get_features(tenant.id)
try:
if features.billing.enabled:
vector_space = features.vector_space
@@ -69,7 +81,9 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
@@ -84,8 +98,12 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
db.session.add(document)
db.session.commit()
indexing_runner = IndexingRunner()
indexing_runner.run([document])
if dataset.runtime_mode == "rag_pipeline":
rag_pipeline_service = RagPipelineService()
rag_pipeline_service.retry_error_document(dataset, document, user)
else:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
redis_client.delete(retry_indexing_cache_key)
except Exception as ex:
document.indexing_status = "error"

View File

@@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import select
from core.indexing_runner import IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -63,7 +64,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
# clean old data
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index

View File

@@ -0,0 +1,27 @@
"""
Celery tasks for asynchronous workflow execution storage operations.
These tasks provide asynchronous storage capabilities for workflow execution data,
improving performance by offloading storage operations to background workers.
"""
import logging
from celery import shared_task # type: ignore[import-untyped]
from sqlalchemy.orm import Session
from extensions.ext_database import db
_logger = logging.getLogger(__name__)
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
@shared_task(queue="workflow_draft_var", bind=True, max_retries=3, default_retry_delay=60)
def save_workflow_execution_task(
self,
deletions: list[DraftVarFileDeletion],
):
with Session(bind=db.engine) as session, session.begin():
srv = WorkflowDraftVariableService(session=session)
srv.delete_workflow_draft_variable_file(deletions=deletions)

View File

@@ -120,7 +120,7 @@ def _create_workflow_run_from_execution(
return workflow_run
def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution) -> None:
def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution):
"""
Update a WorkflowRun database model from a WorkflowExecution domain entity.
"""

View File

@@ -140,9 +140,7 @@ def _create_node_execution_from_domain(
return node_execution
def _update_node_execution_from_domain(
node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution
) -> None:
def _update_node_execution_from_domain(node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution):
"""
Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
"""