mirror of
https://github.com/langgenius/dify.git
synced 2026-02-26 11:25:10 +00:00
Merge remote-tracking branch 'origin/main' into feat/trigger
This commit is contained in:
@@ -38,7 +38,7 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
|
||||
logger.exception("Delete annotation index failed when annotation deleted.")
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Annotation deleted index failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.datasource.vdb.vector_factory import Vector
|
||||
from core.rag.models.document import Document
|
||||
@@ -39,7 +40,7 @@ def enable_annotation_reply_task(
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
annotations = db.session.query(MessageAnnotation).where(MessageAnnotation.app_id == app_id).all()
|
||||
annotations = db.session.scalars(select(MessageAnnotation).where(MessageAnnotation.app_id == app_id)).all()
|
||||
enable_app_annotation_key = f"enable_app_annotation_{str(app_id)}"
|
||||
enable_app_annotation_job_key = f"enable_app_annotation_job_{str(job_id)}"
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
@@ -15,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str, file_ids: list[str]):
|
||||
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]):
|
||||
"""
|
||||
Clean document when document deleted.
|
||||
:param document_ids: document ids
|
||||
@@ -29,12 +30,16 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
if not doc_form:
|
||||
raise ValueError("doc_form is required")
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids)).all()
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
|
||||
).all()
|
||||
# check segment is exist
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
@@ -59,7 +64,7 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
|
||||
|
||||
db.session.commit()
|
||||
if file_ids:
|
||||
files = db.session.query(UploadFile).where(UploadFile.id.in_(file_ids)).all()
|
||||
files = db.session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
|
||||
for file in files:
|
||||
try:
|
||||
storage.delete(file.key)
|
||||
|
||||
@@ -79,7 +79,7 @@ def batch_create_segment_to_index_task(
|
||||
# Skip the first row
|
||||
df = pd.read_csv(file_path)
|
||||
content = []
|
||||
for index, row in df.iterrows():
|
||||
for _, row in df.iterrows():
|
||||
if dataset_document.doc_form == "qa_model":
|
||||
data = {"content": row.iloc[0], "answer": row.iloc[1]}
|
||||
else:
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
@@ -55,8 +56,8 @@ def clean_dataset_task(
|
||||
index_struct=index_struct,
|
||||
collection_binding_id=collection_binding_id,
|
||||
)
|
||||
documents = db.session.query(Document).where(Document.dataset_id == dataset_id).all()
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id).all()
|
||||
documents = db.session.scalars(select(Document).where(Document.dataset_id == dataset_id)).all()
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.dataset_id == dataset_id)).all()
|
||||
|
||||
# Enhanced validation: Check if doc_form is None, empty string, or contains only whitespace
|
||||
# This ensures all invalid doc_form values are properly handled
|
||||
@@ -75,7 +76,7 @@ def clean_dataset_task(
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=True)
|
||||
logger.info(click.style(f"Successfully cleaned vector database for dataset: {dataset_id}", fg="green"))
|
||||
except Exception as index_cleanup_error:
|
||||
except Exception:
|
||||
logger.exception(click.style(f"Failed to clean vector database for dataset {dataset_id}", fg="red"))
|
||||
# Continue with document and segment deletion even if vector cleanup fails
|
||||
logger.info(
|
||||
@@ -145,7 +146,7 @@ def clean_dataset_task(
|
||||
try:
|
||||
db.session.rollback()
|
||||
logger.info(click.style(f"Rolled back database session for dataset: {dataset_id}", fg="yellow"))
|
||||
except Exception as rollback_error:
|
||||
except Exception:
|
||||
logger.exception("Failed to rollback database session")
|
||||
|
||||
logger.exception("Cleaned dataset when dataset deleted failed")
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.tools.utils.web_reader_tool import get_image_upload_file_ids
|
||||
@@ -16,7 +16,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_id: Optional[str]):
|
||||
def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_id: str | None):
|
||||
"""
|
||||
Clean document when document deleted.
|
||||
:param document_id: document id
|
||||
@@ -35,7 +35,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
|
||||
if not dataset:
|
||||
raise Exception("Document has no dataset")
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
# check segment is exist
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
@@ -34,7 +35,9 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
|
||||
document = db.session.query(Document).where(Document.id == document_id).first()
|
||||
db.session.delete(document)
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
@@ -1,6 +1,5 @@
|
||||
import logging
|
||||
import time
|
||||
from typing import Optional
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
@@ -16,7 +15,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def create_segment_to_index_task(segment_id: str, keywords: Optional[list[str]] = None):
|
||||
def create_segment_to_index_task(segment_id: str, keywords: list[str] | None = None):
|
||||
"""
|
||||
Async create segment to index
|
||||
:param segment_id:
|
||||
|
||||
171
api/tasks/deal_dataset_index_update_task.py
Normal file
171
api/tasks/deal_dataset_index_update_task.py
Normal file
@@ -0,0 +1,171 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from core.rag.models.document import ChildDocument, Document
|
||||
from extensions.ext_database import db
|
||||
from models.dataset import Dataset, DocumentSegment
|
||||
from models.dataset import Document as DatasetDocument
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def deal_dataset_index_update_task(dataset_id: str, action: str):
|
||||
"""
|
||||
Async deal dataset from index
|
||||
:param dataset_id: dataset_id
|
||||
:param action: action
|
||||
Usage: deal_dataset_index_update_task.delay(dataset_id, action)
|
||||
"""
|
||||
logging.info(click.style("Start deal dataset index update: {}".format(dataset_id), fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
dataset = db.session.query(Dataset).filter_by(id=dataset_id).first()
|
||||
|
||||
if not dataset:
|
||||
raise Exception("Dataset not found")
|
||||
index_type = dataset.doc_form or IndexType.PARAGRAPH_INDEX
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
if action == "upgrade":
|
||||
dataset_documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
|
||||
if dataset_documents:
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
try:
|
||||
# add from vector index
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
if segments:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
# clean keywords
|
||||
index_processor.clean(dataset, None, with_keywords=True, delete_child_chunks=False)
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
elif action == "update":
|
||||
dataset_documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
# add new index
|
||||
if dataset_documents:
|
||||
# update document status
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id.in_(dataset_documents_ids)).update(
|
||||
{"indexing_status": "indexing"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
|
||||
# clean index
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
|
||||
for dataset_document in dataset_documents:
|
||||
# update from vector index
|
||||
try:
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(DocumentSegment.document_id == dataset_document.id, DocumentSegment.enabled == True)
|
||||
.order_by(DocumentSegment.position.asc())
|
||||
.all()
|
||||
)
|
||||
if segments:
|
||||
documents = []
|
||||
for segment in segments:
|
||||
document = Document(
|
||||
page_content=segment.content,
|
||||
metadata={
|
||||
"doc_id": segment.index_node_id,
|
||||
"doc_hash": segment.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
if dataset_document.doc_form == IndexType.PARENT_CHILD_INDEX:
|
||||
child_chunks = segment.get_child_chunks()
|
||||
if child_chunks:
|
||||
child_documents = []
|
||||
for child_chunk in child_chunks:
|
||||
child_document = ChildDocument(
|
||||
page_content=child_chunk.content,
|
||||
metadata={
|
||||
"doc_id": child_chunk.index_node_id,
|
||||
"doc_hash": child_chunk.index_node_hash,
|
||||
"document_id": segment.document_id,
|
||||
"dataset_id": segment.dataset_id,
|
||||
},
|
||||
)
|
||||
child_documents.append(child_document)
|
||||
document.children = child_documents
|
||||
documents.append(document)
|
||||
# save vector index
|
||||
index_processor.load(dataset, documents, with_keywords=False)
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "completed"}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
except Exception as e:
|
||||
db.session.query(DatasetDocument).where(DatasetDocument.id == dataset_document.id).update(
|
||||
{"indexing_status": "error", "error": str(e)}, synchronize_session=False
|
||||
)
|
||||
db.session.commit()
|
||||
else:
|
||||
# clean collection
|
||||
index_processor.clean(dataset, None, with_keywords=False, delete_child_chunks=False)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style("Deal dataset vector index: {} latency: {}".format(dataset_id, end_at - start_at), fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Deal dataset vector index failed")
|
||||
finally:
|
||||
db.session.close()
|
||||
@@ -4,6 +4,7 @@ from typing import Literal
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
@@ -36,16 +37,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
|
||||
if action == "remove":
|
||||
index_processor.clean(dataset, None, with_keywords=False)
|
||||
elif action == "add":
|
||||
dataset_documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(
|
||||
dataset_documents = db.session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if dataset_documents:
|
||||
dataset_documents_ids = [doc.id for doc in dataset_documents]
|
||||
@@ -89,16 +88,14 @@ def deal_dataset_vector_index_task(dataset_id: str, action: Literal["remove", "a
|
||||
)
|
||||
db.session.commit()
|
||||
elif action == "update":
|
||||
dataset_documents = (
|
||||
db.session.query(DatasetDocument)
|
||||
.where(
|
||||
dataset_documents = db.session.scalars(
|
||||
select(DatasetDocument).where(
|
||||
DatasetDocument.dataset_id == dataset_id,
|
||||
DatasetDocument.indexing_status == "completed",
|
||||
DatasetDocument.enabled == True,
|
||||
DatasetDocument.archived == False,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
# add new index
|
||||
if dataset_documents:
|
||||
# update document status
|
||||
|
||||
@@ -15,7 +15,7 @@ def delete_account_task(account_id):
|
||||
account = db.session.query(Account).where(Account.id == account_id).first()
|
||||
try:
|
||||
BillingService.delete_account(account_id)
|
||||
except Exception as e:
|
||||
except Exception:
|
||||
logger.exception("Failed to delete account %s from billing service.", account_id)
|
||||
raise
|
||||
|
||||
|
||||
@@ -14,7 +14,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="conversation")
|
||||
def delete_conversation_related_data(conversation_id: str) -> None:
|
||||
def delete_conversation_related_data(conversation_id: str):
|
||||
"""
|
||||
Delete related data conversation in correct order from datatbase to respect foreign key constraints
|
||||
|
||||
|
||||
@@ -12,7 +12,9 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, document_id: str):
|
||||
def delete_segment_from_index_task(
|
||||
index_node_ids: list, dataset_id: str, document_id: str, child_node_ids: list | None = None
|
||||
):
|
||||
"""
|
||||
Async Remove segment from index
|
||||
:param index_node_ids:
|
||||
@@ -26,6 +28,7 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume
|
||||
try:
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
logging.warning("Dataset %s not found, skipping index cleanup", dataset_id)
|
||||
return
|
||||
|
||||
dataset_document = db.session.query(Document).where(Document.id == document_id).first()
|
||||
@@ -33,11 +36,19 @@ def delete_segment_from_index_task(index_node_ids: list, dataset_id: str, docume
|
||||
return
|
||||
|
||||
if not dataset_document.enabled or dataset_document.archived or dataset_document.indexing_status != "completed":
|
||||
logging.info("Document not in valid state for index operations, skipping")
|
||||
return
|
||||
doc_form = dataset_document.doc_form
|
||||
|
||||
index_type = dataset_document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
# Proceed with index cleanup using the index_node_ids directly
|
||||
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
|
||||
index_processor.clean(
|
||||
dataset,
|
||||
index_node_ids,
|
||||
with_keywords=True,
|
||||
delete_child_chunks=True,
|
||||
precomputed_child_node_ids=child_node_ids,
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(click.style(f"Segment deleted from index latency: {end_at - start_at}", fg="green"))
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
@@ -44,15 +45,13 @@ def disable_segments_from_index_task(segment_ids: list, dataset_id: str, documen
|
||||
# sync index processor
|
||||
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
|
||||
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
|
||||
if not segments:
|
||||
db.session.close()
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.extractor.notion_extractor import NotionExtractor
|
||||
@@ -46,6 +47,7 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
page_id = data_source_info["notion_page_id"]
|
||||
page_type = data_source_info["type"]
|
||||
page_edited_time = data_source_info["last_edited_time"]
|
||||
|
||||
data_source_binding = (
|
||||
db.session.query(DataSourceOauthBinding)
|
||||
.where(
|
||||
@@ -85,7 +87,9 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
@@ -45,7 +46,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
|
||||
@@ -27,73 +28,77 @@ def duplicate_document_indexing_task(dataset_id: str, document_ids: list):
|
||||
documents = []
|
||||
start_at = time.perf_counter()
|
||||
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if dataset is None:
|
||||
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
count = len(document_ids)
|
||||
if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
if 0 < vector_space.limit <= vector_space.size:
|
||||
raise ValueError(
|
||||
"Your total number of documents plus the number of uploads have over the limit of "
|
||||
"your subscription."
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if dataset is None:
|
||||
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
db.session.close()
|
||||
return
|
||||
|
||||
# check document limit
|
||||
features = FeatureService.get_features(dataset.tenant_id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
count = len(document_ids)
|
||||
if features.billing.subscription.plan == "sandbox" and count > 1:
|
||||
raise ValueError("Your current plan does not support batch upload, please upgrade your plan.")
|
||||
batch_upload_limit = int(dify_config.BATCH_UPLOAD_LIMIT)
|
||||
if count > batch_upload_limit:
|
||||
raise ValueError(f"You have reached the batch upload limit of {batch_upload_limit}.")
|
||||
current = int(getattr(vector_space, "size", 0) or 0)
|
||||
limit = int(getattr(vector_space, "limit", 0) or 0)
|
||||
if limit > 0 and (current + count) > limit:
|
||||
raise ValueError(
|
||||
"Your total number of documents plus the number of uploads have exceeded the limit of "
|
||||
"your subscription."
|
||||
)
|
||||
except Exception as e:
|
||||
for document_id in document_ids:
|
||||
document = (
|
||||
db.session.query(Document)
|
||||
.where(Document.id == document_id, Document.dataset_id == dataset_id)
|
||||
.first()
|
||||
)
|
||||
except Exception as e:
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
return
|
||||
|
||||
for document_id in document_ids:
|
||||
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
|
||||
|
||||
document = (
|
||||
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
|
||||
if document:
|
||||
document.indexing_status = "error"
|
||||
document.error = str(e)
|
||||
document.stopped_at = naive_utc_now()
|
||||
# clean old data
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
return
|
||||
finally:
|
||||
db.session.close()
|
||||
|
||||
for document_id in document_ids:
|
||||
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
|
||||
|
||||
document = (
|
||||
db.session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
|
||||
)
|
||||
|
||||
if document:
|
||||
# clean old data
|
||||
index_type = document.doc_form
|
||||
index_processor = IndexProcessorFactory(index_type).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
|
||||
# delete from vector index
|
||||
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
|
||||
|
||||
for segment in segments:
|
||||
db.session.delete(segment)
|
||||
db.session.commit()
|
||||
|
||||
document.indexing_status = "parsing"
|
||||
document.processing_started_at = naive_utc_now()
|
||||
documents.append(document)
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
try:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run(documents)
|
||||
end_at = time.perf_counter()
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.constant.index_type import IndexType
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
@@ -45,15 +46,13 @@ def enable_segments_to_index_task(segment_ids: list, dataset_id: str, document_i
|
||||
# sync index processor
|
||||
index_processor = IndexProcessorFactory(dataset_document.doc_form).init_index_processor()
|
||||
|
||||
segments = (
|
||||
db.session.query(DocumentSegment)
|
||||
.where(
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(
|
||||
DocumentSegment.id.in_(segment_ids),
|
||||
DocumentSegment.dataset_id == dataset_id,
|
||||
DocumentSegment.document_id == document_id,
|
||||
)
|
||||
.all()
|
||||
)
|
||||
).all()
|
||||
if not segments:
|
||||
logger.info(click.style(f"Segments not found: {segment_ids}", fg="cyan"))
|
||||
db.session.close()
|
||||
|
||||
@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_deletion_success_task(to: str, language: str = "en-US") -> None:
|
||||
def send_deletion_success_task(to: str, language: str = "en-US"):
|
||||
"""
|
||||
Send account deletion success email with internationalization support.
|
||||
|
||||
@@ -46,7 +46,7 @@ def send_deletion_success_task(to: str, language: str = "en-US") -> None:
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_account_deletion_verification_code(to: str, code: str, language: str = "en-US") -> None:
|
||||
def send_account_deletion_verification_code(to: str, code: str, language: str = "en-US"):
|
||||
"""
|
||||
Send account deletion verification code email with internationalization support.
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_change_mail_task(language: str, to: str, code: str, phase: str) -> None:
|
||||
def send_change_mail_task(language: str, to: str, code: str, phase: str):
|
||||
"""
|
||||
Send change email notification with internationalization support.
|
||||
|
||||
@@ -43,7 +43,7 @@ def send_change_mail_task(language: str, to: str, code: str, phase: str) -> None
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_change_mail_completed_notification_task(language: str, to: str) -> None:
|
||||
def send_change_mail_completed_notification_task(language: str, to: str):
|
||||
"""
|
||||
Send change email completed notification with internationalization support.
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_email_code_login_mail_task(language: str, to: str, code: str) -> None:
|
||||
def send_email_code_login_mail_task(language: str, to: str, code: str):
|
||||
"""
|
||||
Send email code login email with internationalization support.
|
||||
|
||||
|
||||
@@ -1,17 +1,46 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from flask import render_template_string
|
||||
from jinja2.runtime import Context
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
|
||||
from configs import dify_config
|
||||
from configs.feature import TemplateMode
|
||||
from extensions.ext_mail import mail
|
||||
from libs.email_i18n import get_email_i18n_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SandboxedEnvironment(ImmutableSandboxedEnvironment):
|
||||
def __init__(self, timeout: int, *args: Any, **kwargs: Any):
|
||||
self._timeout_time = time.time() + timeout
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def call(self, context: Context, obj: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
if time.time() > self._timeout_time:
|
||||
raise TimeoutError("Template rendering timeout")
|
||||
return super().call(context, obj, *args, **kwargs)
|
||||
|
||||
|
||||
def _render_template_with_strategy(body: str, substitutions: Mapping[str, str]) -> str:
|
||||
mode = dify_config.MAIL_TEMPLATING_MODE
|
||||
timeout = dify_config.MAIL_TEMPLATING_TIMEOUT
|
||||
if mode == TemplateMode.UNSAFE:
|
||||
return render_template_string(body, **substitutions)
|
||||
if mode == TemplateMode.SANDBOX:
|
||||
tmpl = SandboxedEnvironment(timeout=timeout).from_string(body)
|
||||
return tmpl.render(substitutions)
|
||||
if mode == TemplateMode.DISABLED:
|
||||
return body
|
||||
raise ValueError(f"Unsupported mail templating mode: {mode}")
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_inner_email_task(to: list[str], subject: str, body: str, substitutions: Mapping[str, str]):
|
||||
if not mail.is_inited():
|
||||
@@ -21,7 +50,7 @@ def send_inner_email_task(to: list[str], subject: str, body: str, substitutions:
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
html_content = render_template_string(body, **substitutions)
|
||||
html_content = _render_template_with_strategy(body, substitutions)
|
||||
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_raw_email(to=to, subject=subject, html_content=html_content)
|
||||
|
||||
@@ -12,7 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str) -> None:
|
||||
def send_invite_member_mail_task(language: str, to: str, token: str, inviter_name: str, workspace_name: str):
|
||||
"""
|
||||
Send invite member email with internationalization support.
|
||||
|
||||
|
||||
@@ -11,7 +11,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_owner_transfer_confirm_task(language: str, to: str, code: str, workspace: str) -> None:
|
||||
def send_owner_transfer_confirm_task(language: str, to: str, code: str, workspace: str):
|
||||
"""
|
||||
Send owner transfer confirmation email with internationalization support.
|
||||
|
||||
@@ -52,7 +52,7 @@ def send_owner_transfer_confirm_task(language: str, to: str, code: str, workspac
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_old_owner_transfer_notify_email_task(language: str, to: str, workspace: str, new_owner_email: str) -> None:
|
||||
def send_old_owner_transfer_notify_email_task(language: str, to: str, workspace: str, new_owner_email: str):
|
||||
"""
|
||||
Send old owner transfer notification email with internationalization support.
|
||||
|
||||
@@ -93,7 +93,7 @@ def send_old_owner_transfer_notify_email_task(language: str, to: str, workspace:
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_new_owner_transfer_notify_email_task(language: str, to: str, workspace: str) -> None:
|
||||
def send_new_owner_transfer_notify_email_task(language: str, to: str, workspace: str):
|
||||
"""
|
||||
Send new owner transfer notification email with internationalization support.
|
||||
|
||||
|
||||
87
api/tasks/mail_register_task.py
Normal file
87
api/tasks/mail_register_task.py
Normal file
@@ -0,0 +1,87 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_mail import mail
|
||||
from libs.email_i18n import EmailType, get_email_i18n_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_email_register_mail_task(language: str, to: str, code: str) -> None:
|
||||
"""
|
||||
Send email register email with internationalization support.
|
||||
|
||||
Args:
|
||||
language: Language code for email localization
|
||||
to: Recipient email address
|
||||
code: Email register code
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start email register mail to {to}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.EMAIL_REGISTER,
|
||||
language_code=language,
|
||||
to=to,
|
||||
template_context={
|
||||
"to": to,
|
||||
"code": code,
|
||||
},
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(f"Send email register mail to {to} succeeded: latency: {end_at - start_at}", fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Send email register mail to %s failed", to)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_email_register_mail_task_when_account_exist(language: str, to: str, account_name: str) -> None:
|
||||
"""
|
||||
Send email register email with internationalization support when account exist.
|
||||
|
||||
Args:
|
||||
language: Language code for email localization
|
||||
to: Recipient email address
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start email register mail to {to}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
login_url = f"{dify_config.CONSOLE_WEB_URL}/signin"
|
||||
reset_password_url = f"{dify_config.CONSOLE_WEB_URL}/reset-password"
|
||||
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.EMAIL_REGISTER_WHEN_ACCOUNT_EXIST,
|
||||
language_code=language,
|
||||
to=to,
|
||||
template_context={
|
||||
"to": to,
|
||||
"login_url": login_url,
|
||||
"reset_password_url": reset_password_url,
|
||||
"account_name": account_name,
|
||||
},
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(f"Send email register mail to {to} succeeded: latency: {end_at - start_at}", fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Send email register mail to %s failed", to)
|
||||
@@ -4,6 +4,7 @@ import time
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_mail import mail
|
||||
from libs.email_i18n import EmailType, get_email_i18n_service
|
||||
|
||||
@@ -11,7 +12,7 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_reset_password_mail_task(language: str, to: str, code: str) -> None:
|
||||
def send_reset_password_mail_task(language: str, to: str, code: str):
|
||||
"""
|
||||
Send reset password email with internationalization support.
|
||||
|
||||
@@ -44,3 +45,47 @@ def send_reset_password_mail_task(language: str, to: str, code: str) -> None:
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Send password reset mail to %s failed", to)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_reset_password_mail_task_when_account_not_exist(language: str, to: str, is_allow_register: bool) -> None:
|
||||
"""
|
||||
Send reset password email with internationalization support when account not exist.
|
||||
|
||||
Args:
|
||||
language: Language code for email localization
|
||||
to: Recipient email address
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start password reset mail to {to}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
if is_allow_register:
|
||||
sign_up_url = f"{dify_config.CONSOLE_WEB_URL}/signup"
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST,
|
||||
language_code=language,
|
||||
to=to,
|
||||
template_context={
|
||||
"to": to,
|
||||
"sign_up_url": sign_up_url,
|
||||
},
|
||||
)
|
||||
else:
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.RESET_PASSWORD_WHEN_ACCOUNT_NOT_EXIST_NO_REGISTER,
|
||||
language_code=language,
|
||||
to=to,
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(f"Send password reset mail to {to} succeeded: latency: {end_at - start_at}", fg="green")
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("Send password reset mail to %s failed", to)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import operator
|
||||
import traceback
|
||||
import typing
|
||||
|
||||
@@ -118,7 +119,7 @@ def process_tenant_plugin_autoupgrade_check_task(
|
||||
current_version = version
|
||||
latest_version = manifest.latest_version
|
||||
|
||||
def fix_only_checker(latest_version, current_version):
|
||||
def fix_only_checker(latest_version: str, current_version: str):
|
||||
latest_version_tuple = tuple(int(val) for val in latest_version.split("."))
|
||||
current_version_tuple = tuple(int(val) for val in current_version.split("."))
|
||||
|
||||
@@ -130,8 +131,7 @@ def process_tenant_plugin_autoupgrade_check_task(
|
||||
return False
|
||||
|
||||
version_checker = {
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: lambda latest_version,
|
||||
current_version: latest_version != current_version,
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.LATEST: operator.ne,
|
||||
TenantPluginAutoUpgradeStrategy.StrategySetting.FIX_ONLY: fix_only_checker,
|
||||
}
|
||||
|
||||
@@ -146,7 +146,7 @@ def process_tenant_plugin_autoupgrade_check_task(
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
task_start_resp = manager.upgrade_plugin(
|
||||
_ = manager.upgrade_plugin(
|
||||
tenant_id,
|
||||
original_unique_identifier,
|
||||
new_unique_identifier,
|
||||
|
||||
175
api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py
Normal file
175
api/tasks/rag_pipeline/priority_rag_pipeline_run_task.py
Normal file
@@ -0,0 +1,175 @@
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from flask import current_app, g
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from models.account import Account, Tenant
|
||||
from models.dataset import Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@shared_task(queue="priority_pipeline")
|
||||
def priority_rag_pipeline_run_task(
|
||||
rag_pipeline_invoke_entities_file_id: str,
|
||||
tenant_id: str,
|
||||
):
|
||||
"""
|
||||
Async Run rag pipeline
|
||||
:param rag_pipeline_invoke_entities: Rag pipeline invoke entities
|
||||
rag_pipeline_invoke_entities include:
|
||||
:param pipeline_id: Pipeline ID
|
||||
:param user_id: User ID
|
||||
:param tenant_id: Tenant ID
|
||||
:param workflow_id: Workflow ID
|
||||
:param invoke_from: Invoke source (debugger, published, etc.)
|
||||
:param streaming: Whether to stream results
|
||||
:param datasource_type: Type of datasource
|
||||
:param datasource_info: Datasource information dict
|
||||
:param batch: Batch identifier
|
||||
:param document_id: Document ID (optional)
|
||||
:param start_node_id: Starting node ID
|
||||
:param inputs: Input parameters dict
|
||||
:param workflow_execution_id: Workflow execution ID
|
||||
:param workflow_thread_pool_id: Thread pool ID for workflow execution
|
||||
"""
|
||||
# run with threading, thread pool size is 10
|
||||
|
||||
try:
|
||||
start_at = time.perf_counter()
|
||||
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(
|
||||
rag_pipeline_invoke_entities_file_id
|
||||
)
|
||||
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
|
||||
|
||||
# Get Flask app object for thread context
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = []
|
||||
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
|
||||
# Submit task to thread pool with Flask app
|
||||
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app)
|
||||
futures.append(future)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for future in futures:
|
||||
try:
|
||||
future.result() # This will raise any exceptions that occurred in the thread
|
||||
except Exception:
|
||||
logging.exception("Error in pipeline task")
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(
|
||||
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
|
||||
raise
|
||||
finally:
|
||||
file_service = FileService(db.engine)
|
||||
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
|
||||
db.session.close()
|
||||
|
||||
|
||||
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app):
|
||||
"""Run a single RAG pipeline task within Flask app context."""
|
||||
# Create Flask application context for this thread
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
|
||||
user_id = rag_pipeline_invoke_entity_model.user_id
|
||||
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
|
||||
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
|
||||
workflow_id = rag_pipeline_invoke_entity_model.workflow_id
|
||||
streaming = rag_pipeline_invoke_entity_model.streaming
|
||||
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
|
||||
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
|
||||
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Load required entities
|
||||
account = session.query(Account).where(Account.id == user_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {user_id} not found")
|
||||
|
||||
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant {tenant_id} not found")
|
||||
account.current_tenant = tenant
|
||||
|
||||
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError(f"Pipeline {pipeline_id} not found")
|
||||
|
||||
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
|
||||
|
||||
if workflow_execution_id is None:
|
||||
workflow_execution_id = str(uuid.uuid4())
|
||||
|
||||
# Create application generate entity from dict
|
||||
entity = RagPipelineGenerateEntity(**application_generate_entity)
|
||||
|
||||
# Create workflow repositories
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = (
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
)
|
||||
|
||||
# Set the user directly in g for preserve_flask_contexts
|
||||
g._login_user = account
|
||||
|
||||
# Copy context for passing to pipeline generator
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# Direct execution without creating another thread
|
||||
# Since we're already in a thread pool, no need for nested threading
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
|
||||
pipeline_generator = PipelineGenerator()
|
||||
# Using protected method intentionally for async execution
|
||||
pipeline_generator._generate( # type: ignore[attr-defined]
|
||||
flask_app=flask_app,
|
||||
context=context,
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow_id,
|
||||
user=account,
|
||||
application_generate_entity=entity,
|
||||
invoke_from=InvokeFrom.PUBLISHED,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Error in priority pipeline task")
|
||||
raise
|
||||
196
api/tasks/rag_pipeline/rag_pipeline_run_task.py
Normal file
196
api/tasks/rag_pipeline/rag_pipeline_run_task.py
Normal file
@@ -0,0 +1,196 @@
|
||||
import contextvars
|
||||
import json
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Any
|
||||
|
||||
import click
|
||||
from celery import shared_task # type: ignore
|
||||
from flask import current_app, g
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, RagPipelineGenerateEntity
|
||||
from core.app.entities.rag_pipeline_invoke_entities import RagPipelineInvokeEntity
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account, Tenant
|
||||
from models.dataset import Pipeline
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@shared_task(queue="pipeline")
|
||||
def rag_pipeline_run_task(
|
||||
rag_pipeline_invoke_entities_file_id: str,
|
||||
tenant_id: str,
|
||||
):
|
||||
"""
|
||||
Async Run rag pipeline
|
||||
:param rag_pipeline_invoke_entities: Rag pipeline invoke entities
|
||||
rag_pipeline_invoke_entities include:
|
||||
:param pipeline_id: Pipeline ID
|
||||
:param user_id: User ID
|
||||
:param tenant_id: Tenant ID
|
||||
:param workflow_id: Workflow ID
|
||||
:param invoke_from: Invoke source (debugger, published, etc.)
|
||||
:param streaming: Whether to stream results
|
||||
:param datasource_type: Type of datasource
|
||||
:param datasource_info: Datasource information dict
|
||||
:param batch: Batch identifier
|
||||
:param document_id: Document ID (optional)
|
||||
:param start_node_id: Starting node ID
|
||||
:param inputs: Input parameters dict
|
||||
:param workflow_execution_id: Workflow execution ID
|
||||
:param workflow_thread_pool_id: Thread pool ID for workflow execution
|
||||
"""
|
||||
# run with threading, thread pool size is 10
|
||||
|
||||
try:
|
||||
start_at = time.perf_counter()
|
||||
rag_pipeline_invoke_entities_content = FileService(db.engine).get_file_content(
|
||||
rag_pipeline_invoke_entities_file_id
|
||||
)
|
||||
rag_pipeline_invoke_entities = json.loads(rag_pipeline_invoke_entities_content)
|
||||
|
||||
# Get Flask app object for thread context
|
||||
flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
with ThreadPoolExecutor(max_workers=10) as executor:
|
||||
futures = []
|
||||
for rag_pipeline_invoke_entity in rag_pipeline_invoke_entities:
|
||||
# Submit task to thread pool with Flask app
|
||||
future = executor.submit(run_single_rag_pipeline_task, rag_pipeline_invoke_entity, flask_app)
|
||||
futures.append(future)
|
||||
|
||||
# Wait for all tasks to complete
|
||||
for future in futures:
|
||||
try:
|
||||
future.result() # This will raise any exceptions that occurred in the thread
|
||||
except Exception:
|
||||
logging.exception("Error in pipeline task")
|
||||
end_at = time.perf_counter()
|
||||
logging.info(
|
||||
click.style(
|
||||
f"tenant_id: {tenant_id} , Rag pipeline run completed. Latency: {end_at - start_at}s", fg="green"
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logging.exception(click.style(f"Error running rag pipeline, tenant_id: {tenant_id}", fg="red"))
|
||||
raise
|
||||
finally:
|
||||
tenant_self_pipeline_task_queue = f"tenant_self_pipeline_task_queue:{tenant_id}"
|
||||
tenant_pipeline_task_key = f"tenant_pipeline_task:{tenant_id}"
|
||||
|
||||
# Check if there are waiting tasks in the queue
|
||||
# Use rpop to get the next task from the queue (FIFO order)
|
||||
next_file_id = redis_client.rpop(tenant_self_pipeline_task_queue)
|
||||
|
||||
if next_file_id:
|
||||
# Process the next waiting task
|
||||
# Keep the flag set to indicate a task is running
|
||||
redis_client.setex(tenant_pipeline_task_key, 60 * 60, 1)
|
||||
rag_pipeline_run_task.delay( # type: ignore
|
||||
rag_pipeline_invoke_entities_file_id=next_file_id.decode("utf-8")
|
||||
if isinstance(next_file_id, bytes)
|
||||
else next_file_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
# No more waiting tasks, clear the flag
|
||||
redis_client.delete(tenant_pipeline_task_key)
|
||||
file_service = FileService(db.engine)
|
||||
file_service.delete_file(rag_pipeline_invoke_entities_file_id)
|
||||
db.session.close()
|
||||
|
||||
|
||||
def run_single_rag_pipeline_task(rag_pipeline_invoke_entity: Mapping[str, Any], flask_app):
|
||||
"""Run a single RAG pipeline task within Flask app context."""
|
||||
# Create Flask application context for this thread
|
||||
with flask_app.app_context():
|
||||
try:
|
||||
rag_pipeline_invoke_entity_model = RagPipelineInvokeEntity(**rag_pipeline_invoke_entity)
|
||||
user_id = rag_pipeline_invoke_entity_model.user_id
|
||||
tenant_id = rag_pipeline_invoke_entity_model.tenant_id
|
||||
pipeline_id = rag_pipeline_invoke_entity_model.pipeline_id
|
||||
workflow_id = rag_pipeline_invoke_entity_model.workflow_id
|
||||
streaming = rag_pipeline_invoke_entity_model.streaming
|
||||
workflow_execution_id = rag_pipeline_invoke_entity_model.workflow_execution_id
|
||||
workflow_thread_pool_id = rag_pipeline_invoke_entity_model.workflow_thread_pool_id
|
||||
application_generate_entity = rag_pipeline_invoke_entity_model.application_generate_entity
|
||||
|
||||
with Session(db.engine) as session:
|
||||
# Load required entities
|
||||
account = session.query(Account).where(Account.id == user_id).first()
|
||||
if not account:
|
||||
raise ValueError(f"Account {user_id} not found")
|
||||
|
||||
tenant = session.query(Tenant).where(Tenant.id == tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError(f"Tenant {tenant_id} not found")
|
||||
account.current_tenant = tenant
|
||||
|
||||
pipeline = session.query(Pipeline).where(Pipeline.id == pipeline_id).first()
|
||||
if not pipeline:
|
||||
raise ValueError(f"Pipeline {pipeline_id} not found")
|
||||
|
||||
workflow = session.query(Workflow).where(Workflow.id == pipeline.workflow_id).first()
|
||||
if not workflow:
|
||||
raise ValueError(f"Workflow {pipeline.workflow_id} not found")
|
||||
|
||||
if workflow_execution_id is None:
|
||||
workflow_execution_id = str(uuid.uuid4())
|
||||
|
||||
# Create application generate entity from dict
|
||||
entity = RagPipelineGenerateEntity(**application_generate_entity)
|
||||
|
||||
# Create workflow repositories
|
||||
session_factory = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
workflow_execution_repository = DifyCoreRepositoryFactory.create_workflow_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowRunTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
|
||||
workflow_node_execution_repository = (
|
||||
DifyCoreRepositoryFactory.create_workflow_node_execution_repository(
|
||||
session_factory=session_factory,
|
||||
user=account,
|
||||
app_id=entity.app_config.app_id,
|
||||
triggered_from=WorkflowNodeExecutionTriggeredFrom.RAG_PIPELINE_RUN,
|
||||
)
|
||||
)
|
||||
|
||||
# Set the user directly in g for preserve_flask_contexts
|
||||
g._login_user = account
|
||||
|
||||
# Copy context for passing to pipeline generator
|
||||
context = contextvars.copy_context()
|
||||
|
||||
# Direct execution without creating another thread
|
||||
# Since we're already in a thread pool, no need for nested threading
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
|
||||
pipeline_generator = PipelineGenerator()
|
||||
# Using protected method intentionally for async execution
|
||||
pipeline_generator._generate( # type: ignore[attr-defined]
|
||||
flask_app=flask_app,
|
||||
context=context,
|
||||
pipeline=pipeline,
|
||||
workflow_id=workflow_id,
|
||||
user=account,
|
||||
application_generate_entity=entity,
|
||||
invoke_from=InvokeFrom.PUBLISHED,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
workflow_thread_pool_id=workflow_thread_pool_id,
|
||||
)
|
||||
except Exception:
|
||||
logging.exception("Error in pipeline task")
|
||||
raise
|
||||
@@ -355,6 +355,11 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
|
||||
"""
|
||||
Delete draft variables for an app in batches.
|
||||
|
||||
This function now handles cleanup of associated Offload data including:
|
||||
- WorkflowDraftVariableFile records
|
||||
- UploadFile records
|
||||
- Object storage files
|
||||
|
||||
Args:
|
||||
app_id: The ID of the app whose draft variables should be deleted
|
||||
batch_size: Number of records to delete per batch
|
||||
@@ -366,22 +371,31 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
|
||||
raise ValueError("batch_size must be positive")
|
||||
|
||||
total_deleted = 0
|
||||
total_files_deleted = 0
|
||||
|
||||
while True:
|
||||
with db.engine.begin() as conn:
|
||||
# Get a batch of draft variable IDs
|
||||
# Get a batch of draft variable IDs along with their file_ids
|
||||
query_sql = """
|
||||
SELECT id FROM workflow_draft_variables
|
||||
SELECT id, file_id FROM workflow_draft_variables
|
||||
WHERE app_id = :app_id
|
||||
LIMIT :batch_size
|
||||
"""
|
||||
result = conn.execute(sa.text(query_sql), {"app_id": app_id, "batch_size": batch_size})
|
||||
|
||||
draft_var_ids = [row[0] for row in result]
|
||||
if not draft_var_ids:
|
||||
rows = list(result)
|
||||
if not rows:
|
||||
break
|
||||
|
||||
# Delete the batch
|
||||
draft_var_ids = [row[0] for row in rows]
|
||||
file_ids = [row[1] for row in rows if row[1] is not None]
|
||||
|
||||
# Clean up associated Offload data first
|
||||
if file_ids:
|
||||
files_deleted = _delete_draft_variable_offload_data(conn, file_ids)
|
||||
total_files_deleted += files_deleted
|
||||
|
||||
# Delete the draft variables
|
||||
delete_sql = """
|
||||
DELETE FROM workflow_draft_variables
|
||||
WHERE id IN :ids
|
||||
@@ -392,10 +406,85 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
|
||||
|
||||
logger.info(click.style(f"Deleted {batch_deleted} draft variables (batch) for app {app_id}", fg="green"))
|
||||
|
||||
logger.info(click.style(f"Deleted {total_deleted} total draft variables for app {app_id}", fg="green"))
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Deleted {total_deleted} total draft variables for app {app_id}. "
|
||||
f"Cleaned up {total_files_deleted} total associated files.",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
return total_deleted
|
||||
|
||||
|
||||
def _delete_draft_variable_offload_data(conn, file_ids: list[str]) -> int:
|
||||
"""
|
||||
Delete Offload data associated with WorkflowDraftVariable file_ids.
|
||||
|
||||
This function:
|
||||
1. Finds WorkflowDraftVariableFile records by file_ids
|
||||
2. Deletes associated files from object storage
|
||||
3. Deletes UploadFile records
|
||||
4. Deletes WorkflowDraftVariableFile records
|
||||
|
||||
Args:
|
||||
conn: Database connection
|
||||
file_ids: List of WorkflowDraftVariableFile IDs
|
||||
|
||||
Returns:
|
||||
Number of files cleaned up
|
||||
"""
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
if not file_ids:
|
||||
return 0
|
||||
|
||||
files_deleted = 0
|
||||
|
||||
try:
|
||||
# Get WorkflowDraftVariableFile records and their associated UploadFile keys
|
||||
query_sql = """
|
||||
SELECT wdvf.id, uf.key, uf.id as upload_file_id
|
||||
FROM workflow_draft_variable_files wdvf
|
||||
JOIN upload_files uf ON wdvf.upload_file_id = uf.id
|
||||
WHERE wdvf.id IN :file_ids
|
||||
"""
|
||||
result = conn.execute(sa.text(query_sql), {"file_ids": tuple(file_ids)})
|
||||
file_records = list(result)
|
||||
|
||||
# Delete from object storage and collect upload file IDs
|
||||
upload_file_ids = []
|
||||
for _, storage_key, upload_file_id in file_records:
|
||||
try:
|
||||
storage.delete(storage_key)
|
||||
upload_file_ids.append(upload_file_id)
|
||||
files_deleted += 1
|
||||
except Exception:
|
||||
logging.exception("Failed to delete storage object %s", storage_key)
|
||||
# Continue with database cleanup even if storage deletion fails
|
||||
upload_file_ids.append(upload_file_id)
|
||||
|
||||
# Delete UploadFile records
|
||||
if upload_file_ids:
|
||||
delete_upload_files_sql = """
|
||||
DELETE FROM upload_files
|
||||
WHERE id IN :upload_file_ids
|
||||
"""
|
||||
conn.execute(sa.text(delete_upload_files_sql), {"upload_file_ids": tuple(upload_file_ids)})
|
||||
|
||||
# Delete WorkflowDraftVariableFile records
|
||||
delete_variable_files_sql = """
|
||||
DELETE FROM workflow_draft_variable_files
|
||||
WHERE id IN :file_ids
|
||||
"""
|
||||
conn.execute(sa.text(delete_variable_files_sql), {"file_ids": tuple(file_ids)})
|
||||
|
||||
except Exception:
|
||||
logging.exception("Error deleting draft variable offload data:")
|
||||
# Don't raise, as we want to continue with the main deletion process
|
||||
|
||||
return files_deleted
|
||||
|
||||
|
||||
def _delete_records(query_sql: str, params: dict, delete_func: Callable, name: str) -> None:
|
||||
while True:
|
||||
with db.engine.begin() as conn:
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
@@ -45,7 +46,7 @@ def remove_document_from_index_task(document_id: str):
|
||||
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document.id).all()
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document.id)).all()
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
if index_node_ids:
|
||||
try:
|
||||
|
||||
@@ -3,26 +3,30 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.account import Account, Tenant
|
||||
from models.dataset import Dataset, Document, DocumentSegment
|
||||
from services.feature_service import FeatureService
|
||||
from services.rag_pipeline.rag_pipeline import RagPipelineService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="dataset")
|
||||
def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
||||
def retry_document_indexing_task(dataset_id: str, document_ids: list[str], user_id: str):
|
||||
"""
|
||||
Async process document
|
||||
:param dataset_id:
|
||||
:param document_ids:
|
||||
:param user_id:
|
||||
|
||||
Usage: retry_document_indexing_task.delay(dataset_id, document_ids)
|
||||
Usage: retry_document_indexing_task.delay(dataset_id, document_ids, user_id)
|
||||
"""
|
||||
start_at = time.perf_counter()
|
||||
try:
|
||||
@@ -30,11 +34,19 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
||||
if not dataset:
|
||||
logger.info(click.style(f"Dataset not found: {dataset_id}", fg="red"))
|
||||
return
|
||||
tenant_id = dataset.tenant_id
|
||||
user = db.session.query(Account).where(Account.id == user_id).first()
|
||||
if not user:
|
||||
logger.info(click.style(f"User not found: {user_id}", fg="red"))
|
||||
return
|
||||
tenant = db.session.query(Tenant).where(Tenant.id == dataset.tenant_id).first()
|
||||
if not tenant:
|
||||
raise ValueError("Tenant not found")
|
||||
user.current_tenant = tenant
|
||||
|
||||
for document_id in document_ids:
|
||||
retry_indexing_cache_key = f"document_{document_id}_is_retried"
|
||||
# check document limit
|
||||
features = FeatureService.get_features(tenant_id)
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
try:
|
||||
if features.billing.enabled:
|
||||
vector_space = features.vector_space
|
||||
@@ -69,7 +81,9 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
||||
# clean old data
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
|
||||
segments = db.session.scalars(
|
||||
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
|
||||
).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
# delete from vector index
|
||||
@@ -84,8 +98,12 @@ def retry_document_indexing_task(dataset_id: str, document_ids: list[str]):
|
||||
db.session.add(document)
|
||||
db.session.commit()
|
||||
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
if dataset.runtime_mode == "rag_pipeline":
|
||||
rag_pipeline_service = RagPipelineService()
|
||||
rag_pipeline_service.retry_error_document(dataset, document, user)
|
||||
else:
|
||||
indexing_runner = IndexingRunner()
|
||||
indexing_runner.run([document])
|
||||
redis_client.delete(retry_indexing_cache_key)
|
||||
except Exception as ex:
|
||||
document.indexing_status = "error"
|
||||
|
||||
@@ -3,6 +3,7 @@ import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.indexing_runner import IndexingRunner
|
||||
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
|
||||
@@ -63,7 +64,7 @@ def sync_website_document_indexing_task(dataset_id: str, document_id: str):
|
||||
# clean old data
|
||||
index_processor = IndexProcessorFactory(document.doc_form).init_index_processor()
|
||||
|
||||
segments = db.session.query(DocumentSegment).where(DocumentSegment.document_id == document_id).all()
|
||||
segments = db.session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
|
||||
if segments:
|
||||
index_node_ids = [segment.index_node_id for segment in segments]
|
||||
# delete from vector index
|
||||
|
||||
27
api/tasks/workflow_draft_var_tasks.py
Normal file
27
api/tasks/workflow_draft_var_tasks.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""
|
||||
Celery tasks for asynchronous workflow execution storage operations.
|
||||
|
||||
These tasks provide asynchronous storage capabilities for workflow execution data,
|
||||
improving performance by offloading storage operations to background workers.
|
||||
"""
|
||||
|
||||
import logging
|
||||
|
||||
from celery import shared_task # type: ignore[import-untyped]
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
|
||||
|
||||
|
||||
@shared_task(queue="workflow_draft_var", bind=True, max_retries=3, default_retry_delay=60)
|
||||
def save_workflow_execution_task(
|
||||
self,
|
||||
deletions: list[DraftVarFileDeletion],
|
||||
):
|
||||
with Session(bind=db.engine) as session, session.begin():
|
||||
srv = WorkflowDraftVariableService(session=session)
|
||||
srv.delete_workflow_draft_variable_file(deletions=deletions)
|
||||
@@ -120,7 +120,7 @@ def _create_workflow_run_from_execution(
|
||||
return workflow_run
|
||||
|
||||
|
||||
def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution) -> None:
|
||||
def _update_workflow_run_from_execution(workflow_run: WorkflowRun, execution: WorkflowExecution):
|
||||
"""
|
||||
Update a WorkflowRun database model from a WorkflowExecution domain entity.
|
||||
"""
|
||||
|
||||
@@ -140,9 +140,7 @@ def _create_node_execution_from_domain(
|
||||
return node_execution
|
||||
|
||||
|
||||
def _update_node_execution_from_domain(
|
||||
node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution
|
||||
) -> None:
|
||||
def _update_node_execution_from_domain(node_execution: WorkflowNodeExecutionModel, execution: WorkflowNodeExecution):
|
||||
"""
|
||||
Update a WorkflowNodeExecutionModel database model from a WorkflowNodeExecution domain entity.
|
||||
"""
|
||||
|
||||
Reference in New Issue
Block a user