feat: archive workflow run logs backend (#31310)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled

Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
This commit is contained in:
非法操作
2026-01-23 13:11:56 +08:00
committed by GitHub
parent 41428432cc
commit fa92548cf6
28 changed files with 2691 additions and 48 deletions

View File

@@ -0,0 +1 @@
"""Workflow run retention services."""

View File

@@ -0,0 +1,531 @@
"""
Archive Paid Plan Workflow Run Logs Service.
This service archives workflow run logs for paid plan users older than the configured
retention period (default: 90 days) to S3-compatible storage.
Archived tables:
- workflow_runs
- workflow_app_logs
- workflow_node_executions
- workflow_node_execution_offload
- workflow_pauses
- workflow_pause_reasons
- workflow_trigger_logs
"""
import datetime
import io
import json
import logging
import time
import zipfile
from collections.abc import Sequence
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass, field
from typing import Any
import click
from sqlalchemy import inspect
from sqlalchemy.orm import Session, sessionmaker
from configs import dify_config
from core.workflow.enums import WorkflowType
from enums.cloud_plan import CloudPlan
from extensions.ext_database import db
from libs.archive_storage import (
ArchiveStorage,
ArchiveStorageNotConfiguredError,
get_archive_storage,
)
from models.workflow import WorkflowAppLog, WorkflowRun
from repositories.api_workflow_node_execution_repository import DifyAPIWorkflowNodeExecutionRepository
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
from services.billing_service import BillingService
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME, ARCHIVE_SCHEMA_VERSION
logger = logging.getLogger(__name__)
@dataclass
class TableStats:
"""Statistics for a single archived table."""
table_name: str
row_count: int
checksum: str
size_bytes: int
@dataclass
class ArchiveResult:
"""Result of archiving a single workflow run."""
run_id: str
tenant_id: str
success: bool
tables: list[TableStats] = field(default_factory=list)
error: str | None = None
elapsed_time: float = 0.0
@dataclass
class ArchiveSummary:
"""Summary of the entire archive operation."""
total_runs_processed: int = 0
runs_archived: int = 0
runs_skipped: int = 0
runs_failed: int = 0
total_elapsed_time: float = 0.0
class WorkflowRunArchiver:
"""
Archive workflow run logs for paid plan users.
Storage Layout:
{tenant_id}/app_id={app_id}/year={YYYY}/month={MM}/workflow_run_id={run_id}/
└── archive.v1.0.zip
├── manifest.json
├── workflow_runs.jsonl
├── workflow_app_logs.jsonl
├── workflow_node_executions.jsonl
├── workflow_node_execution_offload.jsonl
├── workflow_pauses.jsonl
├── workflow_pause_reasons.jsonl
└── workflow_trigger_logs.jsonl
"""
ARCHIVED_TYPE = [
WorkflowType.WORKFLOW,
WorkflowType.RAG_PIPELINE,
]
ARCHIVED_TABLES = [
"workflow_runs",
"workflow_app_logs",
"workflow_node_executions",
"workflow_node_execution_offload",
"workflow_pauses",
"workflow_pause_reasons",
"workflow_trigger_logs",
]
start_from: datetime.datetime | None
end_before: datetime.datetime
def __init__(
self,
days: int = 90,
batch_size: int = 100,
start_from: datetime.datetime | None = None,
end_before: datetime.datetime | None = None,
workers: int = 1,
tenant_ids: Sequence[str] | None = None,
limit: int | None = None,
dry_run: bool = False,
delete_after_archive: bool = False,
workflow_run_repo: APIWorkflowRunRepository | None = None,
):
"""
Initialize the archiver.
Args:
days: Archive runs older than this many days
batch_size: Number of runs to process per batch
start_from: Optional start time (inclusive) for archiving
end_before: Optional end time (exclusive) for archiving
workers: Number of concurrent workflow runs to archive
tenant_ids: Optional tenant IDs for grayscale rollout
limit: Maximum number of runs to archive (None for unlimited)
dry_run: If True, only preview without making changes
delete_after_archive: If True, delete runs and related data after archiving
"""
self.days = days
self.batch_size = batch_size
if start_from or end_before:
if start_from is None or end_before is None:
raise ValueError("start_from and end_before must be provided together")
if start_from >= end_before:
raise ValueError("start_from must be earlier than end_before")
self.start_from = start_from.replace(tzinfo=datetime.UTC)
self.end_before = end_before.replace(tzinfo=datetime.UTC)
else:
self.start_from = None
self.end_before = datetime.datetime.now(datetime.UTC) - datetime.timedelta(days=days)
if workers < 1:
raise ValueError("workers must be at least 1")
self.workers = workers
self.tenant_ids = sorted(set(tenant_ids)) if tenant_ids else []
self.limit = limit
self.dry_run = dry_run
self.delete_after_archive = delete_after_archive
self.workflow_run_repo = workflow_run_repo
def run(self) -> ArchiveSummary:
"""
Main archiving loop.
Returns:
ArchiveSummary with statistics about the operation
"""
summary = ArchiveSummary()
start_time = time.time()
click.echo(
click.style(
self._build_start_message(),
fg="white",
)
)
# Initialize archive storage (will raise if not configured)
try:
if not self.dry_run:
storage = get_archive_storage()
else:
storage = None
except ArchiveStorageNotConfiguredError as e:
click.echo(click.style(f"Archive storage not configured: {e}", fg="red"))
return summary
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
repo = self._get_workflow_run_repo()
def _archive_with_session(run: WorkflowRun) -> ArchiveResult:
with session_maker() as session:
return self._archive_run(session, storage, run)
last_seen: tuple[datetime.datetime, str] | None = None
archived_count = 0
with ThreadPoolExecutor(max_workers=self.workers) as executor:
while True:
# Check limit
if self.limit and archived_count >= self.limit:
click.echo(click.style(f"Reached limit of {self.limit} runs", fg="yellow"))
break
# Fetch batch of runs
runs = self._get_runs_batch(last_seen)
if not runs:
break
run_ids = [run.id for run in runs]
with session_maker() as session:
archived_run_ids = repo.get_archived_run_ids(session, run_ids)
last_seen = (runs[-1].created_at, runs[-1].id)
# Filter to paid tenants only
tenant_ids = {run.tenant_id for run in runs}
paid_tenants = self._filter_paid_tenants(tenant_ids)
runs_to_process: list[WorkflowRun] = []
for run in runs:
summary.total_runs_processed += 1
# Skip non-paid tenants
if run.tenant_id not in paid_tenants:
summary.runs_skipped += 1
continue
# Skip already archived runs
if run.id in archived_run_ids:
summary.runs_skipped += 1
continue
# Check limit
if self.limit and archived_count + len(runs_to_process) >= self.limit:
break
runs_to_process.append(run)
if not runs_to_process:
continue
results = list(executor.map(_archive_with_session, runs_to_process))
for run, result in zip(runs_to_process, results):
if result.success:
summary.runs_archived += 1
archived_count += 1
click.echo(
click.style(
f"{'[DRY RUN] Would archive' if self.dry_run else 'Archived'} "
f"run {run.id} (tenant={run.tenant_id}, "
f"tables={len(result.tables)}, time={result.elapsed_time:.2f}s)",
fg="green",
)
)
else:
summary.runs_failed += 1
click.echo(
click.style(
f"Failed to archive run {run.id}: {result.error}",
fg="red",
)
)
summary.total_elapsed_time = time.time() - start_time
click.echo(
click.style(
f"{'[DRY RUN] ' if self.dry_run else ''}Archive complete: "
f"processed={summary.total_runs_processed}, archived={summary.runs_archived}, "
f"skipped={summary.runs_skipped}, failed={summary.runs_failed}, "
f"time={summary.total_elapsed_time:.2f}s",
fg="white",
)
)
return summary
def _get_runs_batch(
self,
last_seen: tuple[datetime.datetime, str] | None,
) -> Sequence[WorkflowRun]:
"""Fetch a batch of workflow runs to archive."""
repo = self._get_workflow_run_repo()
return repo.get_runs_batch_by_time_range(
start_from=self.start_from,
end_before=self.end_before,
last_seen=last_seen,
batch_size=self.batch_size,
run_types=self.ARCHIVED_TYPE,
tenant_ids=self.tenant_ids or None,
)
def _build_start_message(self) -> str:
range_desc = f"before {self.end_before.isoformat()}"
if self.start_from:
range_desc = f"between {self.start_from.isoformat()} and {self.end_before.isoformat()}"
return (
f"{'[DRY RUN] ' if self.dry_run else ''}Starting workflow run archiving "
f"for runs {range_desc} "
f"(batch_size={self.batch_size}, tenant_ids={','.join(self.tenant_ids) or 'all'})"
)
def _filter_paid_tenants(self, tenant_ids: set[str]) -> set[str]:
"""Filter tenant IDs to only include paid tenants."""
if not dify_config.BILLING_ENABLED:
# If billing is not enabled, treat all tenants as paid
return tenant_ids
if not tenant_ids:
return set()
try:
bulk_info = BillingService.get_plan_bulk_with_cache(list(tenant_ids))
except Exception:
logger.exception("Failed to fetch billing plans for tenants")
# On error, skip all tenants in this batch
return set()
# Filter to paid tenants (any plan except SANDBOX)
paid = set()
for tid, info in bulk_info.items():
if info and info.get("plan") in (CloudPlan.PROFESSIONAL, CloudPlan.TEAM):
paid.add(tid)
return paid
def _archive_run(
self,
session: Session,
storage: ArchiveStorage | None,
run: WorkflowRun,
) -> ArchiveResult:
"""Archive a single workflow run."""
start_time = time.time()
result = ArchiveResult(run_id=run.id, tenant_id=run.tenant_id, success=False)
try:
# Extract data from all tables
table_data, app_logs, trigger_metadata = self._extract_data(session, run)
if self.dry_run:
# In dry run, just report what would be archived
for table_name in self.ARCHIVED_TABLES:
records = table_data.get(table_name, [])
result.tables.append(
TableStats(
table_name=table_name,
row_count=len(records),
checksum="",
size_bytes=0,
)
)
result.success = True
else:
if storage is None:
raise ArchiveStorageNotConfiguredError("Archive storage not configured")
archive_key = self._get_archive_key(run)
# Serialize tables for the archive bundle
table_stats: list[TableStats] = []
table_payloads: dict[str, bytes] = {}
for table_name in self.ARCHIVED_TABLES:
records = table_data.get(table_name, [])
data = ArchiveStorage.serialize_to_jsonl(records)
table_payloads[table_name] = data
checksum = ArchiveStorage.compute_checksum(data)
table_stats.append(
TableStats(
table_name=table_name,
row_count=len(records),
checksum=checksum,
size_bytes=len(data),
)
)
# Generate and upload archive bundle
manifest = self._generate_manifest(run, table_stats)
manifest_data = json.dumps(manifest, indent=2, default=str).encode("utf-8")
archive_data = self._build_archive_bundle(manifest_data, table_payloads)
storage.put_object(archive_key, archive_data)
repo = self._get_workflow_run_repo()
archived_log_count = repo.create_archive_logs(session, run, app_logs, trigger_metadata)
session.commit()
deleted_counts = None
if self.delete_after_archive:
deleted_counts = repo.delete_runs_with_related(
[run],
delete_node_executions=self._delete_node_executions,
delete_trigger_logs=self._delete_trigger_logs,
)
logger.info(
"Archived workflow run %s: tables=%s, archived_logs=%s, deleted=%s",
run.id,
{s.table_name: s.row_count for s in table_stats},
archived_log_count,
deleted_counts,
)
result.tables = table_stats
result.success = True
except Exception as e:
logger.exception("Failed to archive workflow run %s", run.id)
result.error = str(e)
session.rollback()
result.elapsed_time = time.time() - start_time
return result
def _extract_data(
self,
session: Session,
run: WorkflowRun,
) -> tuple[dict[str, list[dict[str, Any]]], Sequence[WorkflowAppLog], str | None]:
table_data: dict[str, list[dict[str, Any]]] = {}
table_data["workflow_runs"] = [self._row_to_dict(run)]
repo = self._get_workflow_run_repo()
app_logs = repo.get_app_logs_by_run_id(session, run.id)
table_data["workflow_app_logs"] = [self._row_to_dict(row) for row in app_logs]
node_exec_repo = self._get_workflow_node_execution_repo(session)
node_exec_records = node_exec_repo.get_executions_by_workflow_run(
tenant_id=run.tenant_id,
app_id=run.app_id,
workflow_run_id=run.id,
)
node_exec_ids = [record.id for record in node_exec_records]
offload_records = node_exec_repo.get_offloads_by_execution_ids(session, node_exec_ids)
table_data["workflow_node_executions"] = [self._row_to_dict(row) for row in node_exec_records]
table_data["workflow_node_execution_offload"] = [self._row_to_dict(row) for row in offload_records]
repo = self._get_workflow_run_repo()
pause_records = repo.get_pause_records_by_run_id(session, run.id)
pause_ids = [pause.id for pause in pause_records]
pause_reason_records = repo.get_pause_reason_records_by_run_id(
session,
pause_ids,
)
table_data["workflow_pauses"] = [self._row_to_dict(row) for row in pause_records]
table_data["workflow_pause_reasons"] = [self._row_to_dict(row) for row in pause_reason_records]
trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
trigger_records = trigger_repo.list_by_run_id(run.id)
table_data["workflow_trigger_logs"] = [self._row_to_dict(row) for row in trigger_records]
trigger_metadata = trigger_records[0].trigger_metadata if trigger_records else None
return table_data, app_logs, trigger_metadata
@staticmethod
def _row_to_dict(row: Any) -> dict[str, Any]:
mapper = inspect(row).mapper
return {str(column.name): getattr(row, mapper.get_property_by_column(column).key) for column in mapper.columns}
def _get_archive_key(self, run: WorkflowRun) -> str:
"""Get the storage key for the archive bundle."""
created_at = run.created_at
prefix = (
f"{run.tenant_id}/app_id={run.app_id}/year={created_at.strftime('%Y')}/"
f"month={created_at.strftime('%m')}/workflow_run_id={run.id}"
)
return f"{prefix}/{ARCHIVE_BUNDLE_NAME}"
def _generate_manifest(
self,
run: WorkflowRun,
table_stats: list[TableStats],
) -> dict[str, Any]:
"""Generate a manifest for the archived workflow run."""
return {
"schema_version": ARCHIVE_SCHEMA_VERSION,
"workflow_run_id": run.id,
"tenant_id": run.tenant_id,
"app_id": run.app_id,
"workflow_id": run.workflow_id,
"created_at": run.created_at.isoformat(),
"archived_at": datetime.datetime.now(datetime.UTC).isoformat(),
"tables": {
stat.table_name: {
"row_count": stat.row_count,
"checksum": stat.checksum,
"size_bytes": stat.size_bytes,
}
for stat in table_stats
},
}
def _build_archive_bundle(self, manifest_data: bytes, table_payloads: dict[str, bytes]) -> bytes:
buffer = io.BytesIO()
with zipfile.ZipFile(buffer, mode="w", compression=zipfile.ZIP_DEFLATED) as archive:
archive.writestr("manifest.json", manifest_data)
for table_name in self.ARCHIVED_TABLES:
data = table_payloads.get(table_name)
if data is None:
raise ValueError(f"Missing archive payload for {table_name}")
archive.writestr(f"{table_name}.jsonl", data)
return buffer.getvalue()
def _delete_trigger_logs(self, session: Session, run_ids: Sequence[str]) -> int:
trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
return trigger_repo.delete_by_run_ids(run_ids)
def _delete_node_executions(self, session: Session, runs: Sequence[WorkflowRun]) -> tuple[int, int]:
run_ids = [run.id for run in runs]
return self._get_workflow_node_execution_repo(session).delete_by_runs(session, run_ids)
def _get_workflow_node_execution_repo(
self,
session: Session,
) -> DifyAPIWorkflowNodeExecutionRepository:
from repositories.factory import DifyAPIRepositoryFactory
session_maker = sessionmaker(bind=session.get_bind(), expire_on_commit=False)
return DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(session_maker)
def _get_workflow_run_repo(self) -> APIWorkflowRunRepository:
if self.workflow_run_repo is not None:
return self.workflow_run_repo
from repositories.factory import DifyAPIRepositoryFactory
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
return self.workflow_run_repo

View File

@@ -0,0 +1,2 @@
ARCHIVE_SCHEMA_VERSION = "1.0"
ARCHIVE_BUNDLE_NAME = f"archive.v{ARCHIVE_SCHEMA_VERSION}.zip"

View File

@@ -0,0 +1,134 @@
"""
Delete Archived Workflow Run Service.
This service deletes archived workflow run data from the database while keeping
archive logs intact.
"""
import time
from collections.abc import Sequence
from dataclasses import dataclass, field
from datetime import datetime
from sqlalchemy.orm import Session, sessionmaker
from extensions.ext_database import db
from models.workflow import WorkflowRun
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.sqlalchemy_workflow_trigger_log_repository import SQLAlchemyWorkflowTriggerLogRepository
@dataclass
class DeleteResult:
run_id: str
tenant_id: str
success: bool
deleted_counts: dict[str, int] = field(default_factory=dict)
error: str | None = None
elapsed_time: float = 0.0
class ArchivedWorkflowRunDeletion:
def __init__(self, dry_run: bool = False):
self.dry_run = dry_run
self.workflow_run_repo: APIWorkflowRunRepository | None = None
def delete_by_run_id(self, run_id: str) -> DeleteResult:
start_time = time.time()
result = DeleteResult(run_id=run_id, tenant_id="", success=False)
repo = self._get_workflow_run_repo()
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
with session_maker() as session:
run = session.get(WorkflowRun, run_id)
if not run:
result.error = f"Workflow run {run_id} not found"
result.elapsed_time = time.time() - start_time
return result
result.tenant_id = run.tenant_id
if not repo.get_archived_run_ids(session, [run.id]):
result.error = f"Workflow run {run_id} is not archived"
result.elapsed_time = time.time() - start_time
return result
result = self._delete_run(run)
result.elapsed_time = time.time() - start_time
return result
def delete_batch(
self,
tenant_ids: list[str] | None,
start_date: datetime,
end_date: datetime,
limit: int = 100,
) -> list[DeleteResult]:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
results: list[DeleteResult] = []
repo = self._get_workflow_run_repo()
with session_maker() as session:
runs = list(
repo.get_archived_runs_by_time_range(
session=session,
tenant_ids=tenant_ids,
start_date=start_date,
end_date=end_date,
limit=limit,
)
)
for run in runs:
results.append(self._delete_run(run))
return results
def _delete_run(self, run: WorkflowRun) -> DeleteResult:
start_time = time.time()
result = DeleteResult(run_id=run.id, tenant_id=run.tenant_id, success=False)
if self.dry_run:
result.success = True
result.elapsed_time = time.time() - start_time
return result
repo = self._get_workflow_run_repo()
try:
deleted_counts = repo.delete_runs_with_related(
[run],
delete_node_executions=self._delete_node_executions,
delete_trigger_logs=self._delete_trigger_logs,
)
result.deleted_counts = deleted_counts
result.success = True
except Exception as e:
result.error = str(e)
result.elapsed_time = time.time() - start_time
return result
@staticmethod
def _delete_trigger_logs(session: Session, run_ids: Sequence[str]) -> int:
trigger_repo = SQLAlchemyWorkflowTriggerLogRepository(session)
return trigger_repo.delete_by_run_ids(run_ids)
@staticmethod
def _delete_node_executions(
session: Session,
runs: Sequence[WorkflowRun],
) -> tuple[int, int]:
from repositories.factory import DifyAPIRepositoryFactory
run_ids = [run.id for run in runs]
repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker=sessionmaker(bind=session.get_bind(), expire_on_commit=False)
)
return repo.delete_by_runs(session, run_ids)
def _get_workflow_run_repo(self) -> APIWorkflowRunRepository:
if self.workflow_run_repo is not None:
return self.workflow_run_repo
from repositories.factory import DifyAPIRepositoryFactory
self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(
sessionmaker(bind=db.engine, expire_on_commit=False)
)
return self.workflow_run_repo

View File

@@ -0,0 +1,481 @@
"""
Restore Archived Workflow Run Service.
This service restores archived workflow run data from S3-compatible storage
back to the database.
"""
import io
import json
import logging
import time
import zipfile
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from dataclasses import dataclass
from datetime import datetime
from typing import Any, cast
import click
from sqlalchemy.dialects.postgresql import insert as pg_insert
from sqlalchemy.engine import CursorResult
from sqlalchemy.orm import DeclarativeBase, Session, sessionmaker
from extensions.ext_database import db
from libs.archive_storage import (
ArchiveStorage,
ArchiveStorageNotConfiguredError,
get_archive_storage,
)
from models.trigger import WorkflowTriggerLog
from models.workflow import (
WorkflowAppLog,
WorkflowArchiveLog,
WorkflowNodeExecutionModel,
WorkflowNodeExecutionOffload,
WorkflowPause,
WorkflowPauseReason,
WorkflowRun,
)
from repositories.api_workflow_run_repository import APIWorkflowRunRepository
from repositories.factory import DifyAPIRepositoryFactory
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
logger = logging.getLogger(__name__)
# Mapping of table names to SQLAlchemy models
TABLE_MODELS = {
"workflow_runs": WorkflowRun,
"workflow_app_logs": WorkflowAppLog,
"workflow_node_executions": WorkflowNodeExecutionModel,
"workflow_node_execution_offload": WorkflowNodeExecutionOffload,
"workflow_pauses": WorkflowPause,
"workflow_pause_reasons": WorkflowPauseReason,
"workflow_trigger_logs": WorkflowTriggerLog,
}
SchemaMapper = Callable[[dict[str, Any]], dict[str, Any]]
SCHEMA_MAPPERS: dict[str, dict[str, SchemaMapper]] = {
"1.0": {},
}
@dataclass
class RestoreResult:
"""Result of restoring a single workflow run."""
run_id: str
tenant_id: str
success: bool
restored_counts: dict[str, int]
error: str | None = None
elapsed_time: float = 0.0
class WorkflowRunRestore:
"""
Restore archived workflow run data from storage to database.
This service reads archived data from storage and restores it to the
database tables. It handles idempotency by skipping records that already
exist in the database.
"""
def __init__(self, dry_run: bool = False, workers: int = 1):
"""
Initialize the restore service.
Args:
dry_run: If True, only preview without making changes
workers: Number of concurrent workflow runs to restore
"""
self.dry_run = dry_run
if workers < 1:
raise ValueError("workers must be at least 1")
self.workers = workers
self.workflow_run_repo: APIWorkflowRunRepository | None = None
def _restore_from_run(
self,
run: WorkflowRun | WorkflowArchiveLog,
*,
session_maker: sessionmaker,
) -> RestoreResult:
start_time = time.time()
run_id = run.workflow_run_id if isinstance(run, WorkflowArchiveLog) else run.id
created_at = run.run_created_at if isinstance(run, WorkflowArchiveLog) else run.created_at
result = RestoreResult(
run_id=run_id,
tenant_id=run.tenant_id,
success=False,
restored_counts={},
)
if not self.dry_run:
click.echo(
click.style(
f"Starting restore for workflow run {run_id} (tenant={run.tenant_id})",
fg="white",
)
)
try:
storage = get_archive_storage()
except ArchiveStorageNotConfiguredError as e:
result.error = str(e)
click.echo(click.style(f"Archive storage not configured: {e}", fg="red"))
result.elapsed_time = time.time() - start_time
return result
prefix = (
f"{run.tenant_id}/app_id={run.app_id}/year={created_at.strftime('%Y')}/"
f"month={created_at.strftime('%m')}/workflow_run_id={run_id}"
)
archive_key = f"{prefix}/{ARCHIVE_BUNDLE_NAME}"
try:
archive_data = storage.get_object(archive_key)
except FileNotFoundError:
result.error = f"Archive bundle not found: {archive_key}"
click.echo(click.style(result.error, fg="red"))
result.elapsed_time = time.time() - start_time
return result
with session_maker() as session:
try:
with zipfile.ZipFile(io.BytesIO(archive_data), mode="r") as archive:
try:
manifest = self._load_manifest_from_zip(archive)
except ValueError as e:
result.error = f"Archive bundle invalid: {e}"
click.echo(click.style(result.error, fg="red"))
return result
tables = manifest.get("tables", {})
schema_version = self._get_schema_version(manifest)
for table_name, info in tables.items():
row_count = info.get("row_count", 0)
if row_count == 0:
result.restored_counts[table_name] = 0
continue
if self.dry_run:
result.restored_counts[table_name] = row_count
continue
member_path = f"{table_name}.jsonl"
try:
data = archive.read(member_path)
except KeyError:
click.echo(
click.style(
f" Warning: Table data not found in archive: {member_path}",
fg="yellow",
)
)
result.restored_counts[table_name] = 0
continue
records = ArchiveStorage.deserialize_from_jsonl(data)
restored = self._restore_table_records(
session,
table_name,
records,
schema_version=schema_version,
)
result.restored_counts[table_name] = restored
if not self.dry_run:
click.echo(
click.style(
f" Restored {restored}/{len(records)} records to {table_name}",
fg="white",
)
)
# Verify row counts match manifest
manifest_total = sum(info.get("row_count", 0) for info in tables.values())
restored_total = sum(result.restored_counts.values())
if not self.dry_run:
# Note: restored count might be less than manifest count if records already exist
logger.info(
"Restore verification: manifest_total=%d, restored_total=%d",
manifest_total,
restored_total,
)
# Delete the archive log record after successful restore
repo = self._get_workflow_run_repo()
repo.delete_archive_log_by_run_id(session, run_id)
session.commit()
result.success = True
if not self.dry_run:
click.echo(
click.style(
f"Completed restore for workflow run {run_id}: restored={result.restored_counts}",
fg="green",
)
)
except Exception as e:
logger.exception("Failed to restore workflow run %s", run_id)
result.error = str(e)
session.rollback()
click.echo(click.style(f"Restore failed: {e}", fg="red"))
result.elapsed_time = time.time() - start_time
return result
def _get_workflow_run_repo(self) -> APIWorkflowRunRepository:
if self.workflow_run_repo is not None:
return self.workflow_run_repo
self.workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(
sessionmaker(bind=db.engine, expire_on_commit=False)
)
return self.workflow_run_repo
@staticmethod
def _load_manifest_from_zip(archive: zipfile.ZipFile) -> dict[str, Any]:
try:
data = archive.read("manifest.json")
except KeyError as e:
raise ValueError("manifest.json missing from archive bundle") from e
return json.loads(data.decode("utf-8"))
def _restore_table_records(
self,
session: Session,
table_name: str,
records: list[dict[str, Any]],
*,
schema_version: str,
) -> int:
"""
Restore records to a table.
Uses INSERT ... ON CONFLICT DO NOTHING for idempotency.
Args:
session: Database session
table_name: Name of the table
records: List of record dictionaries
schema_version: Archived schema version from manifest
Returns:
Number of records actually inserted
"""
if not records:
return 0
model = TABLE_MODELS.get(table_name)
if not model:
logger.warning("Unknown table: %s", table_name)
return 0
column_names, required_columns, non_nullable_with_default = self._get_model_column_info(model)
unknown_fields: set[str] = set()
# Apply schema mapping, filter to current columns, then convert datetimes
converted_records = []
for record in records:
mapped = self._apply_schema_mapping(table_name, schema_version, record)
unknown_fields.update(set(mapped.keys()) - column_names)
filtered = {key: value for key, value in mapped.items() if key in column_names}
for key in non_nullable_with_default:
if key in filtered and filtered[key] is None:
filtered.pop(key)
missing_required = [key for key in required_columns if key not in filtered or filtered.get(key) is None]
if missing_required:
missing_cols = ", ".join(sorted(missing_required))
raise ValueError(
f"Missing required columns for {table_name} (schema_version={schema_version}): {missing_cols}"
)
converted = self._convert_datetime_fields(filtered, model)
converted_records.append(converted)
if unknown_fields:
logger.warning(
"Dropped unknown columns for %s (schema_version=%s): %s",
table_name,
schema_version,
", ".join(sorted(unknown_fields)),
)
# Use INSERT ... ON CONFLICT DO NOTHING for idempotency
stmt = pg_insert(model).values(converted_records)
stmt = stmt.on_conflict_do_nothing(index_elements=["id"])
result = session.execute(stmt)
return cast(CursorResult, result).rowcount or 0
def _convert_datetime_fields(
self,
record: dict[str, Any],
model: type[DeclarativeBase] | Any,
) -> dict[str, Any]:
"""Convert ISO datetime strings to datetime objects."""
from sqlalchemy import DateTime
result = dict(record)
for column in model.__table__.columns:
if isinstance(column.type, DateTime):
value = result.get(column.key)
if isinstance(value, str):
try:
result[column.key] = datetime.fromisoformat(value)
except ValueError:
pass
return result
def _get_schema_version(self, manifest: dict[str, Any]) -> str:
schema_version = manifest.get("schema_version")
if not schema_version:
logger.warning("Manifest missing schema_version; defaulting to 1.0")
schema_version = "1.0"
schema_version = str(schema_version)
if schema_version not in SCHEMA_MAPPERS:
raise ValueError(f"Unsupported schema_version {schema_version}. Add a mapping before restoring.")
return schema_version
def _apply_schema_mapping(
self,
table_name: str,
schema_version: str,
record: dict[str, Any],
) -> dict[str, Any]:
# Keep hook for forward/backward compatibility when schema evolves.
mapper = SCHEMA_MAPPERS.get(schema_version, {}).get(table_name)
if mapper is None:
return dict(record)
return mapper(record)
def _get_model_column_info(
self,
model: type[DeclarativeBase] | Any,
) -> tuple[set[str], set[str], set[str]]:
columns = list(model.__table__.columns)
column_names = {column.key for column in columns}
required_columns = {
column.key
for column in columns
if not column.nullable
and column.default is None
and column.server_default is None
and not column.autoincrement
}
non_nullable_with_default = {
column.key
for column in columns
if not column.nullable
and (column.default is not None or column.server_default is not None or column.autoincrement)
}
return column_names, required_columns, non_nullable_with_default
def restore_batch(
self,
tenant_ids: list[str] | None,
start_date: datetime,
end_date: datetime,
limit: int = 100,
) -> list[RestoreResult]:
"""
Restore multiple workflow runs by time range.
Args:
tenant_ids: Optional tenant IDs
start_date: Start date filter
end_date: End date filter
limit: Maximum number of runs to restore (default: 100)
Returns:
List of RestoreResult objects
"""
results: list[RestoreResult] = []
if tenant_ids is not None and not tenant_ids:
return results
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
repo = self._get_workflow_run_repo()
with session_maker() as session:
archive_logs = repo.get_archived_logs_by_time_range(
session=session,
tenant_ids=tenant_ids,
start_date=start_date,
end_date=end_date,
limit=limit,
)
click.echo(
click.style(
f"Found {len(archive_logs)} archived workflow runs to restore",
fg="white",
)
)
def _restore_with_session(archive_log: WorkflowArchiveLog) -> RestoreResult:
return self._restore_from_run(
archive_log,
session_maker=session_maker,
)
with ThreadPoolExecutor(max_workers=self.workers) as executor:
results = list(executor.map(_restore_with_session, archive_logs))
total_counts: dict[str, int] = {}
for result in results:
for table_name, count in result.restored_counts.items():
total_counts[table_name] = total_counts.get(table_name, 0) + count
success_count = sum(1 for result in results if result.success)
if self.dry_run:
click.echo(
click.style(
f"[DRY RUN] Would restore {len(results)} workflow runs: totals={total_counts}",
fg="yellow",
)
)
else:
click.echo(
click.style(
f"Restored {success_count}/{len(results)} workflow runs: totals={total_counts}",
fg="green",
)
)
return results
def restore_by_run_id(
self,
run_id: str,
) -> RestoreResult:
"""
Restore a single workflow run by run ID.
"""
repo = self._get_workflow_run_repo()
archive_log = repo.get_archived_log_by_run_id(run_id)
if not archive_log:
click.echo(click.style(f"Workflow run archive {run_id} not found", fg="red"))
return RestoreResult(
run_id=run_id,
tenant_id="",
success=False,
restored_counts={},
error=f"Workflow run archive {run_id} not found",
)
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
result = self._restore_from_run(archive_log, session_maker=session_maker)
if self.dry_run and result.success:
click.echo(
click.style(
f"[DRY RUN] Would restore workflow run {run_id}: totals={result.restored_counts}",
fg="yellow",
)
)
return result

View File

@@ -7,7 +7,7 @@ from sqlalchemy import and_, func, or_, select
from sqlalchemy.orm import Session
from core.workflow.enums import WorkflowExecutionStatus
from models import Account, App, EndUser, WorkflowAppLog, WorkflowRun
from models import Account, App, EndUser, WorkflowAppLog, WorkflowArchiveLog, WorkflowRun
from models.enums import AppTriggerType, CreatorUserRole
from models.trigger import WorkflowTriggerLog
from services.plugin.plugin_service import PluginService
@@ -173,7 +173,80 @@ class WorkflowAppService:
"data": items,
}
def handle_trigger_metadata(self, tenant_id: str, meta_val: str) -> dict[str, Any]:
def get_paginate_workflow_archive_logs(
self,
*,
session: Session,
app_model: App,
page: int = 1,
limit: int = 20,
):
"""
Get paginate workflow archive logs using SQLAlchemy 2.0 style.
"""
stmt = select(WorkflowArchiveLog).where(
WorkflowArchiveLog.tenant_id == app_model.tenant_id,
WorkflowArchiveLog.app_id == app_model.id,
WorkflowArchiveLog.log_id.isnot(None),
)
stmt = stmt.order_by(WorkflowArchiveLog.run_created_at.desc())
count_stmt = select(func.count()).select_from(stmt.subquery())
total = session.scalar(count_stmt) or 0
offset_stmt = stmt.offset((page - 1) * limit).limit(limit)
logs = list(session.scalars(offset_stmt).all())
account_ids = {log.created_by for log in logs if log.created_by_role == CreatorUserRole.ACCOUNT}
end_user_ids = {log.created_by for log in logs if log.created_by_role == CreatorUserRole.END_USER}
accounts_by_id = {}
if account_ids:
accounts_by_id = {
account.id: account
for account in session.scalars(select(Account).where(Account.id.in_(account_ids))).all()
}
end_users_by_id = {}
if end_user_ids:
end_users_by_id = {
end_user.id: end_user
for end_user in session.scalars(select(EndUser).where(EndUser.id.in_(end_user_ids))).all()
}
items = []
for log in logs:
if log.created_by_role == CreatorUserRole.ACCOUNT:
created_by_account = accounts_by_id.get(log.created_by)
created_by_end_user = None
elif log.created_by_role == CreatorUserRole.END_USER:
created_by_account = None
created_by_end_user = end_users_by_id.get(log.created_by)
else:
created_by_account = None
created_by_end_user = None
items.append(
{
"id": log.id,
"workflow_run": log.workflow_run_summary,
"trigger_metadata": self.handle_trigger_metadata(app_model.tenant_id, log.trigger_metadata),
"created_by_account": created_by_account,
"created_by_end_user": created_by_end_user,
"created_at": log.log_created_at,
}
)
return {
"page": page,
"limit": limit,
"total": total,
"has_more": total > page * limit,
"data": items,
}
def handle_trigger_metadata(self, tenant_id: str, meta_val: str | None) -> dict[str, Any]:
metadata: dict[str, Any] | None = self._safe_json_loads(meta_val)
if not metadata:
return {}