Compare commits

..

1 Commits

Author SHA1 Message Date
Stephen Zhou
b72fc07006 修复 Toast 类型检查} (Wait need close?)}{ 2026-03-06 19:20:00 +08:00
11 changed files with 85 additions and 1210 deletions

View File

@@ -62,22 +62,6 @@ This is the default standard for backend code in this repo. Follow it for new co
- Code should usually include type annotations that match the repos current Python version (avoid untyped public APIs and “mystery” values).
- Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless theres a strong reason.
- For dictionary-like data with known keys and value types, prefer `TypedDict` over `dict[...]` or `Mapping[...]`.
- For optional keys in typed payloads, use `NotRequired[...]` (or `total=False` when most fields are optional).
- Keep `dict[...]` / `Mapping[...]` for truly dynamic key spaces where the key set is unknown.
```python
from datetime import datetime
from typing import NotRequired, TypedDict
class UserProfile(TypedDict):
user_id: str
email: str
created_at: datetime
nickname: NotRequired[str]
```
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
```python

View File

@@ -2668,77 +2668,3 @@ def clean_expired_messages(
raise
click.echo(click.style("messages cleanup completed.", fg="green"))
@click.command("export-app-messages", help="Export messages for an app to JSONL.GZ.")
@click.option("--app-id", required=True, help="Application ID to export messages for.")
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional lower bound (inclusive) for created_at.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
help="Upper bound (exclusive) for created_at.",
)
@click.option(
"--filename",
required=True,
help="Base filename (relative path). Do not include suffix like .jsonl.gz.",
)
@click.option("--use-cloud-storage", is_flag=True, default=False, help="Upload to cloud storage instead of local file.")
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for cursor pagination.")
@click.option("--dry-run", is_flag=True, default=False, help="Scan only, print stats without writing any file.")
def export_app_messages(
app_id: str,
start_from: datetime.datetime | None,
end_before: datetime.datetime,
filename: str,
use_cloud_storage: bool,
batch_size: int,
dry_run: bool,
):
if start_from and start_from >= end_before:
raise click.UsageError("--start-from must be before --end-before.")
from services.retention.conversation.message_export_service import AppMessageExportService
try:
validated_filename = AppMessageExportService.validate_export_filename(filename)
except ValueError as e:
raise click.BadParameter(str(e), param_hint="--filename") from e
click.echo(click.style(f"export_app_messages: starting export for app {app_id}.", fg="green"))
start_at = time.perf_counter()
try:
service = AppMessageExportService(
app_id=app_id,
end_before=end_before,
filename=validated_filename,
start_from=start_from,
batch_size=batch_size,
use_cloud_storage=use_cloud_storage,
dry_run=dry_run,
)
stats = service.run()
elapsed = time.perf_counter() - start_at
click.echo(
click.style(
f"export_app_messages: completed in {elapsed:.2f}s\n"
f" - Batches: {stats.batches}\n"
f" - Total messages: {stats.total_messages}\n"
f" - Messages with feedback: {stats.messages_with_feedback}\n"
f" - Total feedbacks: {stats.total_feedbacks}",
fg="green",
)
)
except Exception as e:
elapsed = time.perf_counter() - start_at
logger.exception("export_app_messages failed")
click.echo(click.style(f"export_app_messages: failed after {elapsed:.2f}s - {e}", fg="red"))
raise

View File

@@ -44,13 +44,14 @@ from core.app.entities.task_entities import (
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.app.task_pipeline.message_file_utils import prepare_file_dict
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_manager import ModelInstance
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.tools.signature import sign_tool_file
from dify_graph.file import helpers as file_helpers
from dify_graph.file.enums import FileTransferMethod
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from dify_graph.model_runtime.entities.message_entities import (
@@ -459,40 +460,91 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
"""
self._task_state.metadata.usage = self._task_state.llm_result.usage
metadata_dict = self._task_state.metadata.model_dump()
# Fetch files associated with this message
files = None
with Session(db.engine, expire_on_commit=False) as session:
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
if message_files:
# Fetch all required UploadFile objects in a single query to avoid N+1 problem
upload_file_ids = list(
dict.fromkeys(
mf.upload_file_id
for mf in message_files
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
)
)
upload_files_map = {}
if upload_file_ids:
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
upload_files_map = {uf.id: uf for uf in upload_files}
files_list = []
for message_file in message_files:
file_dict = prepare_file_dict(message_file, upload_files_map)
files_list.append(file_dict)
files = files_list or None
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message_id,
metadata=metadata_dict,
files=files,
)
def _record_files(self):
with Session(db.engine, expire_on_commit=False) as session:
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
if not message_files:
return None
files_list = []
upload_file_ids = [
mf.upload_file_id
for mf in message_files
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
]
upload_files_map = {}
if upload_file_ids:
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
upload_files_map = {uf.id: uf for uf in upload_files}
for message_file in message_files:
upload_file = None
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
upload_file = upload_files_map.get(message_file.upload_file_id)
url = None
filename = "file"
mime_type = "application/octet-stream"
size = 0
extension = ""
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
url = message_file.url
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
# Fallback: generate URL even if upload_file not found
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
# For tool files, use URL directly if it's HTTP, otherwise sign it
if message_file.url.startswith("http"):
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
else:
# Extract tool file id and extension from URL
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0] # Remove query params first
# Use rsplit to correctly handle filenames with multiple dots
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
else:
tool_file_id = file_part
extension = ".bin"
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
transfer_method_value = message_file.transfer_method
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
file_dict = {
"related_id": message_file.id,
"extension": extension,
"filename": filename,
"size": size,
"mime_type": mime_type,
"transfer_method": transfer_method_value,
"type": message_file.type,
"url": url or "",
"upload_file_id": message_file.upload_file_id or message_file.id,
"remote_url": remote_url,
}
files_list.append(file_dict)
return files_list or None
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
"""
Agent message to stream response.

View File

@@ -1,76 +0,0 @@
from core.tools.signature import sign_tool_file
from dify_graph.file import helpers as file_helpers
from dify_graph.file.enums import FileTransferMethod
from models.model import MessageFile, UploadFile
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict:
"""
Prepare file dictionary for message end stream response.
:param message_file: MessageFile instance
:param upload_files_map: Dictionary mapping upload_file_id to UploadFile
:return: Dictionary containing file information
"""
upload_file = None
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
upload_file = upload_files_map.get(message_file.upload_file_id)
url = None
filename = "file"
mime_type = "application/octet-stream"
size = 0
extension = ""
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
url = message_file.url
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
if message_file.url.startswith(("http://", "https://")):
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
else:
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0]
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
extension = ".bin"
else:
tool_file_id = file_part
extension = ".bin"
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
transfer_method_value = message_file.transfer_method.value
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
return {
"related_id": message_file.id,
"extension": extension,
"filename": filename,
"size": size,
"mime_type": mime_type,
"transfer_method": transfer_method_value,
"type": message_file.type,
"url": url or "",
"upload_file_id": message_file.upload_file_id or message_file.id,
"remote_url": remote_url,
}

View File

@@ -13,7 +13,6 @@ def init_app(app: DifyApp):
convert_to_agent_apps,
create_tenant,
delete_archived_workflow_runs,
export_app_messages,
extract_plugins,
extract_unique_plugins,
file_usage,
@@ -67,7 +66,6 @@ def init_app(app: DifyApp):
restore_workflow_runs,
clean_workflow_runs,
clean_expired_messages,
export_app_messages,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@@ -63,12 +63,7 @@ class RagPipelineTransformService:
):
node = self._deal_file_extensions(node)
if node.get("data", {}).get("type") == "knowledge-index":
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
if dataset.tenant_id != current_user.current_tenant_id:
raise ValueError("Unauthorized")
node = self._deal_knowledge_index(
knowledge_configuration, dataset, indexing_technique, retrieval_model, node
)
node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node)
new_nodes.append(node)
if new_nodes:
graph["nodes"] = new_nodes
@@ -160,13 +155,14 @@ class RagPipelineTransformService:
def _deal_knowledge_index(
self,
knowledge_configuration: KnowledgeConfiguration,
dataset: Dataset,
doc_form: str,
indexing_technique: str | None,
retrieval_model: RetrievalSetting | None,
node: dict,
):
knowledge_configuration_dict = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict)
if indexing_technique == "high_quality":
knowledge_configuration.embedding_model = dataset.embedding_model

View File

@@ -1,304 +0,0 @@
"""
Export app messages to JSONL.GZ format.
Outputs: conversation_id, message_id, query, answer, inputs (raw JSON),
retriever_resources (from message_metadata), feedback (user feedbacks array).
Uses (created_at, id) cursor pagination and batch-loads feedbacks to avoid N+1.
Does NOT touch Message.inputs / Message.user_feedback properties.
"""
import datetime
import gzip
import json
import logging
import tempfile
from collections import defaultdict
from collections.abc import Generator, Iterable
from pathlib import Path, PurePosixPath
from typing import Any, BinaryIO, cast
import orjson
import sqlalchemy as sa
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import select, tuple_
from sqlalchemy.orm import Session
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import Message, MessageFeedback
logger = logging.getLogger(__name__)
MAX_FILENAME_BASE_LENGTH = 1024
FORBIDDEN_FILENAME_SUFFIXES = (".jsonl.gz", ".jsonl", ".gz")
class AppMessageExportFeedback(BaseModel):
id: str
app_id: str
conversation_id: str
message_id: str
rating: str
content: str | None = None
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
created_at: str
updated_at: str
model_config = ConfigDict(extra="forbid")
class AppMessageExportRecord(BaseModel):
conversation_id: str
message_id: str
query: str
answer: str
inputs: dict[str, Any]
retriever_resources: list[Any] = Field(default_factory=list)
feedback: list[AppMessageExportFeedback] = Field(default_factory=list)
model_config = ConfigDict(extra="forbid")
class AppMessageExportStats(BaseModel):
batches: int = 0
total_messages: int = 0
messages_with_feedback: int = 0
total_feedbacks: int = 0
model_config = ConfigDict(extra="forbid")
class AppMessageExportService:
@staticmethod
def validate_export_filename(filename: str) -> str:
normalized = filename.strip()
if not normalized:
raise ValueError("--filename must not be empty.")
normalized_lower = normalized.lower()
if normalized_lower.endswith(FORBIDDEN_FILENAME_SUFFIXES):
raise ValueError("--filename must not include .jsonl.gz/.jsonl/.gz suffix; pass base filename only.")
if normalized.startswith("/"):
raise ValueError("--filename must be a relative path; absolute paths are not allowed.")
if "\\" in normalized:
raise ValueError("--filename must use '/' as path separator; '\\' is not allowed.")
if "//" in normalized:
raise ValueError("--filename must not contain empty path segments ('//').")
if len(normalized) > MAX_FILENAME_BASE_LENGTH:
raise ValueError(f"--filename is too long; max length is {MAX_FILENAME_BASE_LENGTH}.")
for ch in normalized:
if ch == "\x00" or ord(ch) < 32 or ord(ch) == 127:
raise ValueError("--filename must not contain control characters or NUL.")
parts = PurePosixPath(normalized).parts
if not parts:
raise ValueError("--filename must include a file name.")
if any(part in (".", "..") for part in parts):
raise ValueError("--filename must not contain '.' or '..' path segments.")
return normalized
@property
def output_gz_name(self) -> str:
return f"{self._filename_base}.jsonl.gz"
@property
def output_jsonl_name(self) -> str:
return f"{self._filename_base}.jsonl"
def __init__(
self,
app_id: str,
end_before: datetime.datetime,
filename: str,
*,
start_from: datetime.datetime | None = None,
batch_size: int = 1000,
use_cloud_storage: bool = False,
dry_run: bool = False,
) -> None:
if start_from and start_from >= end_before:
raise ValueError(f"start_from ({start_from}) must be before end_before ({end_before})")
self._app_id = app_id
self._end_before = end_before
self._start_from = start_from
self._filename_base = self.validate_export_filename(filename)
self._batch_size = batch_size
self._use_cloud_storage = use_cloud_storage
self._dry_run = dry_run
def run(self) -> AppMessageExportStats:
stats = AppMessageExportStats()
logger.info(
"export_app_messages: app_id=%s, start_from=%s, end_before=%s, dry_run=%s, cloud=%s, output_gz=%s",
self._app_id,
self._start_from,
self._end_before,
self._dry_run,
self._use_cloud_storage,
self.output_gz_name,
)
if self._dry_run:
for _ in self._iter_records_with_stats(stats):
pass
self._finalize_stats(stats)
return stats
if self._use_cloud_storage:
self._export_to_cloud(stats)
else:
self._export_to_local(stats)
self._finalize_stats(stats)
return stats
def iter_records(self) -> Generator[AppMessageExportRecord, None, None]:
for batch in self._iter_record_batches():
yield from batch
@staticmethod
def write_jsonl_gz(records: Iterable[AppMessageExportRecord], fileobj: BinaryIO) -> None:
with gzip.GzipFile(fileobj=fileobj, mode="wb") as gz:
for record in records:
gz.write(orjson.dumps(record.model_dump(mode="json")) + b"\n")
def _export_to_local(self, stats: AppMessageExportStats) -> None:
output_path = Path.cwd() / self.output_gz_name
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("wb") as output_file:
self.write_jsonl_gz(self._iter_records_with_stats(stats), output_file)
def _export_to_cloud(self, stats: AppMessageExportStats) -> None:
with tempfile.SpooledTemporaryFile(max_size=64 * 1024 * 1024) as tmp:
self.write_jsonl_gz(self._iter_records_with_stats(stats), cast(BinaryIO, tmp))
tmp.seek(0)
data = tmp.read()
storage.save(self.output_gz_name, data)
logger.info("export_app_messages: uploaded %d bytes to cloud key=%s", len(data), self.output_gz_name)
def _iter_records_with_stats(self, stats: AppMessageExportStats) -> Generator[AppMessageExportRecord, None, None]:
for record in self.iter_records():
self._update_stats(stats, record)
yield record
@staticmethod
def _update_stats(stats: AppMessageExportStats, record: AppMessageExportRecord) -> None:
stats.total_messages += 1
if record.feedback:
stats.messages_with_feedback += 1
stats.total_feedbacks += len(record.feedback)
def _finalize_stats(self, stats: AppMessageExportStats) -> None:
if stats.total_messages == 0:
stats.batches = 0
return
stats.batches = (stats.total_messages + self._batch_size - 1) // self._batch_size
def _iter_record_batches(self) -> Generator[list[AppMessageExportRecord], None, None]:
cursor: tuple[datetime.datetime, str] | None = None
while True:
rows, cursor = self._fetch_batch(cursor)
if not rows:
break
message_ids = [str(row.id) for row in rows]
feedbacks_map = self._fetch_feedbacks(message_ids)
yield [self._build_record(row, feedbacks_map) for row in rows]
def _fetch_batch(
self, cursor: tuple[datetime.datetime, str] | None
) -> tuple[list[Any], tuple[datetime.datetime, str] | None]:
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(
Message.id,
Message.conversation_id,
Message.query,
Message.answer,
Message._inputs, # pyright: ignore[reportPrivateUsage]
Message.message_metadata,
Message.created_at,
)
.where(
Message.app_id == self._app_id,
Message.created_at < self._end_before,
)
.order_by(Message.created_at, Message.id)
.limit(self._batch_size)
)
if self._start_from:
stmt = stmt.where(Message.created_at >= self._start_from)
if cursor:
stmt = stmt.where(
tuple_(Message.created_at, Message.id)
> tuple_(
sa.literal(cursor[0], type_=sa.DateTime()),
sa.literal(cursor[1], type_=Message.id.type),
)
)
rows = list(session.execute(stmt).all())
if not rows:
return [], cursor
last = rows[-1]
return rows, (last.created_at, last.id)
def _fetch_feedbacks(self, message_ids: list[str]) -> dict[str, list[AppMessageExportFeedback]]:
if not message_ids:
return {}
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(MessageFeedback)
.where(
MessageFeedback.message_id.in_(message_ids),
MessageFeedback.from_source == "user",
)
.order_by(MessageFeedback.message_id, MessageFeedback.created_at)
)
feedbacks = list(session.scalars(stmt).all())
result: dict[str, list[AppMessageExportFeedback]] = defaultdict(list)
for feedback in feedbacks:
result[str(feedback.message_id)].append(AppMessageExportFeedback.model_validate(feedback.to_dict()))
return result
@staticmethod
def _build_record(row: Any, feedbacks_map: dict[str, list[AppMessageExportFeedback]]) -> AppMessageExportRecord:
retriever_resources: list[Any] = []
if row.message_metadata:
try:
metadata = json.loads(row.message_metadata)
value = metadata.get("retriever_resources", [])
if isinstance(value, list):
retriever_resources = value
except (json.JSONDecodeError, TypeError):
pass
message_id = str(row.id)
return AppMessageExportRecord(
conversation_id=str(row.conversation_id),
message_id=message_id,
query=row.query,
answer=row.answer,
inputs=row._inputs if isinstance(row._inputs, dict) else {},
retriever_resources=retriever_resources,
feedback=feedbacks_map.get(message_id, []),
)

View File

@@ -1,233 +0,0 @@
import datetime
import json
import uuid
from decimal import Decimal
import pytest
from sqlalchemy.orm import Session
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.model import (
App,
AppAnnotationHitHistory,
Conversation,
DatasetRetrieverResource,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.retention.conversation.message_export_service import AppMessageExportService, AppMessageExportStats
class TestAppMessageExportServiceIntegration:
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers: Session):
yield
db_session_with_containers.query(DatasetRetrieverResource).delete()
db_session_with_containers.query(AppAnnotationHitHistory).delete()
db_session_with_containers.query(SavedMessage).delete()
db_session_with_containers.query(MessageFile).delete()
db_session_with_containers.query(MessageAgentThought).delete()
db_session_with_containers.query(MessageChain).delete()
db_session_with_containers.query(MessageAnnotation).delete()
db_session_with_containers.query(MessageFeedback).delete()
db_session_with_containers.query(Message).delete()
db_session_with_containers.query(Conversation).delete()
db_session_with_containers.query(App).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.commit()
@staticmethod
def _create_app_context(session: Session) -> tuple[App, Conversation]:
account = Account(
email=f"test-{uuid.uuid4()}@example.com",
name="tester",
interface_language="en-US",
status="active",
)
session.add(account)
session.flush()
tenant = Tenant(name=f"tenant-{uuid.uuid4()}", status="normal")
session.add(tenant)
session.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
session.add(join)
session.flush()
app = App(
tenant_id=tenant.id,
name="export-app",
description="integration test app",
mode="chat",
enable_site=True,
enable_api=True,
api_rpm=60,
api_rph=3600,
is_demo=False,
is_public=False,
created_by=account.id,
updated_by=account.id,
)
session.add(app)
session.flush()
conversation = Conversation(
app_id=app.id,
app_model_config_id=str(uuid.uuid4()),
model_provider="openai",
model_id="gpt-4o-mini",
mode="chat",
name="conv",
inputs={"seed": 1},
status="normal",
from_source="api",
from_end_user_id=str(uuid.uuid4()),
)
session.add(conversation)
session.commit()
return app, conversation
@staticmethod
def _create_message(
session: Session,
app: App,
conversation: Conversation,
created_at: datetime.datetime,
*,
query: str,
answer: str,
inputs: dict,
message_metadata: str | None,
) -> Message:
message = Message(
app_id=app.id,
conversation_id=conversation.id,
model_provider="openai",
model_id="gpt-4o-mini",
inputs=inputs,
query=query,
answer=answer,
message=[{"role": "assistant", "content": answer}],
message_tokens=10,
message_unit_price=Decimal("0.001"),
answer_tokens=20,
answer_unit_price=Decimal("0.002"),
total_price=Decimal("0.003"),
currency="USD",
message_metadata=message_metadata,
from_source="api",
from_end_user_id=conversation.from_end_user_id,
created_at=created_at,
)
session.add(message)
session.flush()
return message
def test_iter_records_with_stats(self, db_session_with_containers: Session):
app, conversation = self._create_app_context(db_session_with_containers)
first_inputs = {
"plain": "v1",
"nested": {"a": 1, "b": [1, {"x": True}]},
"list": ["x", 2, {"y": "z"}],
}
second_inputs = {"other": "value", "items": [1, 2, 3]}
base_time = datetime.datetime(2026, 2, 25, 10, 0, 0)
first_message = self._create_message(
db_session_with_containers,
app,
conversation,
created_at=base_time,
query="q1",
answer="a1",
inputs=first_inputs,
message_metadata=json.dumps({"retriever_resources": [{"dataset_id": "ds-1"}]}),
)
second_message = self._create_message(
db_session_with_containers,
app,
conversation,
created_at=base_time + datetime.timedelta(minutes=1),
query="q2",
answer="a2",
inputs=second_inputs,
message_metadata=None,
)
user_feedback_1 = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="user",
content="first",
from_end_user_id=conversation.from_end_user_id,
)
user_feedback_2 = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="dislike",
from_source="user",
content="second",
from_end_user_id=conversation.from_end_user_id,
)
admin_feedback = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="admin",
content="should-be-filtered",
from_account_id=str(uuid.uuid4()),
)
db_session_with_containers.add_all([user_feedback_1, user_feedback_2, admin_feedback])
user_feedback_1.created_at = base_time + datetime.timedelta(minutes=2)
user_feedback_2.created_at = base_time + datetime.timedelta(minutes=3)
admin_feedback.created_at = base_time + datetime.timedelta(minutes=4)
db_session_with_containers.commit()
service = AppMessageExportService(
app_id=app.id,
start_from=base_time - datetime.timedelta(minutes=1),
end_before=base_time + datetime.timedelta(minutes=10),
filename="unused",
batch_size=1,
dry_run=True,
)
stats = AppMessageExportStats()
records = list(service._iter_records_with_stats(stats))
service._finalize_stats(stats)
assert len(records) == 2
assert records[0].message_id == first_message.id
assert records[1].message_id == second_message.id
assert records[0].inputs == first_inputs
assert records[1].inputs == second_inputs
assert records[0].retriever_resources == [{"dataset_id": "ds-1"}]
assert records[1].retriever_resources == []
assert [feedback.rating for feedback in records[0].feedback] == ["like", "dislike"]
assert [feedback.content for feedback in records[0].feedback] == ["first", "second"]
assert records[1].feedback == []
assert stats.batches == 2
assert stats.total_messages == 2
assert stats.messages_with_feedback == 1
assert stats.total_feedbacks == 2

View File

@@ -1,425 +0,0 @@
"""
Unit tests for EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response method.
This test suite ensures that the files array is correctly populated in the message_end
SSE event, which is critical for vision/image chat responses to render correctly.
Test Coverage:
- Files array populated when MessageFile records exist
- Files array is None when no MessageFile records exist
- Correct signed URL generation for LOCAL_FILE transfer method
- Correct URL handling for REMOTE_URL transfer method
- Correct URL handling for TOOL_FILE transfer method
- Proper file metadata formatting (filename, mime_type, size, extension)
"""
import uuid
from unittest.mock import MagicMock, Mock, patch
import pytest
from sqlalchemy.orm import Session
from core.app.entities.task_entities import MessageEndStreamResponse
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from dify_graph.file.enums import FileTransferMethod
from models.model import MessageFile, UploadFile
class TestMessageEndStreamResponseFiles:
"""Test suite for files array population in message_end SSE event."""
@pytest.fixture
def mock_pipeline(self):
"""Create a mock EasyUIBasedGenerateTaskPipeline instance."""
pipeline = Mock(spec=EasyUIBasedGenerateTaskPipeline)
pipeline._message_id = str(uuid.uuid4())
pipeline._task_state = Mock()
pipeline._task_state.metadata = Mock()
pipeline._task_state.metadata.model_dump = Mock(return_value={"test": "metadata"})
pipeline._task_state.llm_result = Mock()
pipeline._task_state.llm_result.usage = Mock()
pipeline._application_generate_entity = Mock()
pipeline._application_generate_entity.task_id = str(uuid.uuid4())
return pipeline
@pytest.fixture
def mock_message_file_local(self):
"""Create a mock MessageFile with LOCAL_FILE transfer method."""
message_file = Mock(spec=MessageFile)
message_file.id = str(uuid.uuid4())
message_file.message_id = str(uuid.uuid4())
message_file.transfer_method = FileTransferMethod.LOCAL_FILE
message_file.upload_file_id = str(uuid.uuid4())
message_file.url = None
message_file.type = "image"
return message_file
@pytest.fixture
def mock_message_file_remote(self):
"""Create a mock MessageFile with REMOTE_URL transfer method."""
message_file = Mock(spec=MessageFile)
message_file.id = str(uuid.uuid4())
message_file.message_id = str(uuid.uuid4())
message_file.transfer_method = FileTransferMethod.REMOTE_URL
message_file.upload_file_id = None
message_file.url = "https://example.com/image.jpg"
message_file.type = "image"
return message_file
@pytest.fixture
def mock_message_file_tool(self):
"""Create a mock MessageFile with TOOL_FILE transfer method."""
message_file = Mock(spec=MessageFile)
message_file.id = str(uuid.uuid4())
message_file.message_id = str(uuid.uuid4())
message_file.transfer_method = FileTransferMethod.TOOL_FILE
message_file.upload_file_id = None
message_file.url = "tool_file_123.png"
message_file.type = "image"
return message_file
@pytest.fixture
def mock_upload_file(self, mock_message_file_local):
"""Create a mock UploadFile."""
upload_file = Mock(spec=UploadFile)
upload_file.id = mock_message_file_local.upload_file_id
upload_file.name = "test_image.png"
upload_file.mime_type = "image/png"
upload_file.size = 1024
upload_file.extension = "png"
return upload_file
def test_message_end_with_no_files(self, mock_pipeline):
"""Test that files array is None when no MessageFile records exist."""
# Arrange
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.scalars.return_value.all.return_value = []
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is None
assert result.id == mock_pipeline._message_id
assert result.metadata == {"test": "metadata"}
def test_message_end_with_local_file(self, mock_pipeline, mock_message_file_local, mock_upload_file):
"""Test that files array is populated correctly for LOCAL_FILE transfer method."""
# Arrange
mock_message_file_local.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
# First query: MessageFile
mock_message_files_result = Mock()
mock_message_files_result.all.return_value = [mock_message_file_local]
# Second query: UploadFile (batch query to avoid N+1)
mock_upload_files_result = Mock()
mock_upload_files_result.all.return_value = [mock_upload_file]
# Setup scalars to return different results for different queries
call_count = [0] # Use list to allow modification in nested function
def scalars_side_effect(query):
call_count[0] += 1
# First call is for MessageFile, second call is for UploadFile
if call_count[0] == 1:
return mock_message_files_result
else:
return mock_upload_files_result
mock_session.scalars.side_effect = scalars_side_effect
mock_get_url.return_value = "https://example.com/signed-url?signature=abc123"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert file_dict["related_id"] == mock_message_file_local.id
assert file_dict["filename"] == "test_image.png"
assert file_dict["mime_type"] == "image/png"
assert file_dict["size"] == 1024
assert file_dict["extension"] == ".png"
assert file_dict["type"] == "image"
assert file_dict["transfer_method"] == FileTransferMethod.LOCAL_FILE.value
assert "https://example.com/signed-url" in file_dict["url"]
assert file_dict["upload_file_id"] == mock_message_file_local.upload_file_id
assert file_dict["remote_url"] == ""
# Verify database queries
# Should be called twice: once for MessageFile, once for UploadFile
assert mock_session.scalars.call_count == 2
mock_get_url.assert_called_once_with(upload_file_id=str(mock_upload_file.id))
def test_message_end_with_remote_url(self, mock_pipeline, mock_message_file_remote):
"""Test that files array is populated correctly for REMOTE_URL transfer method."""
# Arrange
mock_message_file_remote.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_remote]
mock_session.scalars.return_value = mock_scalars_result
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert file_dict["related_id"] == mock_message_file_remote.id
assert file_dict["filename"] == "image.jpg"
assert file_dict["url"] == "https://example.com/image.jpg"
assert file_dict["extension"] == ".jpg"
assert file_dict["type"] == "image"
assert file_dict["transfer_method"] == FileTransferMethod.REMOTE_URL.value
assert file_dict["remote_url"] == "https://example.com/image.jpg"
assert file_dict["upload_file_id"] == mock_message_file_remote.id
# Verify only one query for message_files is made
mock_session.scalars.assert_called_once()
def test_message_end_with_tool_file_http(self, mock_pipeline, mock_message_file_tool):
"""Test that files array is populated correctly for TOOL_FILE with HTTP URL."""
# Arrange
mock_message_file_tool.message_id = mock_pipeline._message_id
mock_message_file_tool.url = "https://example.com/tool_file.png"
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_tool]
mock_session.scalars.return_value = mock_scalars_result
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert file_dict["url"] == "https://example.com/tool_file.png"
assert file_dict["filename"] == "tool_file.png"
assert file_dict["extension"] == ".png"
assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value
def test_message_end_with_tool_file_local(self, mock_pipeline, mock_message_file_tool):
"""Test that files array is populated correctly for TOOL_FILE with local path."""
# Arrange
mock_message_file_tool.message_id = mock_pipeline._message_id
mock_message_file_tool.url = "tool_file_123.png"
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_tool]
mock_session.scalars.return_value = mock_scalars_result
mock_sign_tool.return_value = "https://example.com/signed-tool-file.png?signature=xyz"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert "https://example.com/signed-tool-file.png" in file_dict["url"]
assert file_dict["filename"] == "tool_file_123.png"
assert file_dict["extension"] == ".png"
assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value
# Verify tool file signing was called
mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_123", extension=".png")
def test_message_end_with_tool_file_long_extension(self, mock_pipeline, mock_message_file_tool):
"""Test that TOOL_FILE extensions longer than MAX_TOOL_FILE_EXTENSION_LENGTH fall back to .bin."""
mock_message_file_tool.message_id = mock_pipeline._message_id
mock_message_file_tool.url = "tool_file_abc.verylongextension"
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_tool]
mock_session.scalars.return_value = mock_scalars_result
mock_sign_tool.return_value = "https://example.com/signed.bin"
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
assert result.files is not None
file_dict = result.files[0]
assert file_dict["extension"] == ".bin"
mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_abc", extension=".bin")
def test_message_end_with_multiple_files(
self, mock_pipeline, mock_message_file_local, mock_message_file_remote, mock_upload_file
):
"""Test that files array contains all MessageFile records when multiple exist."""
# Arrange
mock_message_file_local.message_id = mock_pipeline._message_id
mock_message_file_remote.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
# First query: MessageFile
mock_message_files_result = Mock()
mock_message_files_result.all.return_value = [mock_message_file_local, mock_message_file_remote]
# Second query: UploadFile (batch query to avoid N+1)
mock_upload_files_result = Mock()
mock_upload_files_result.all.return_value = [mock_upload_file]
# Setup scalars to return different results for different queries
call_count = [0] # Use list to allow modification in nested function
def scalars_side_effect(query):
call_count[0] += 1
# First call is for MessageFile, second call is for UploadFile
if call_count[0] == 1:
return mock_message_files_result
else:
return mock_upload_files_result
mock_session.scalars.side_effect = scalars_side_effect
mock_get_url.return_value = "https://example.com/signed-url?signature=abc123"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 2
# Verify both files are present
file_ids = [f["related_id"] for f in result.files]
assert mock_message_file_local.id in file_ids
assert mock_message_file_remote.id in file_ids
def test_message_end_with_local_file_no_upload_file(self, mock_pipeline, mock_message_file_local):
"""Test fallback when UploadFile is not found for LOCAL_FILE."""
# Arrange
mock_message_file_local.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
# First query: MessageFile
mock_message_files_result = Mock()
mock_message_files_result.all.return_value = [mock_message_file_local]
# Second query: UploadFile (batch query) - returns empty list (not found)
mock_upload_files_result = Mock()
mock_upload_files_result.all.return_value = [] # UploadFile not found
# Setup scalars to return different results for different queries
call_count = [0] # Use list to allow modification in nested function
def scalars_side_effect(query):
call_count[0] += 1
# First call is for MessageFile, second call is for UploadFile
if call_count[0] == 1:
return mock_message_files_result
else:
return mock_upload_files_result
mock_session.scalars.side_effect = scalars_side_effect
mock_get_url.return_value = "https://example.com/fallback-url?signature=def456"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert "https://example.com/fallback-url" in file_dict["url"]
# Verify fallback URL was generated using upload_file_id from message_file
mock_get_url.assert_called_with(upload_file_id=str(mock_message_file_local.upload_file_id))

View File

@@ -1,43 +0,0 @@
import datetime
import pytest
from services.retention.conversation.message_export_service import AppMessageExportService
def test_validate_export_filename_accepts_relative_path():
assert AppMessageExportService.validate_export_filename("exports/2026/test01") == "exports/2026/test01"
@pytest.mark.parametrize(
"filename",
[
"test01.jsonl.gz",
"test01.jsonl",
"test01.gz",
"/tmp/test01",
"exports/../test01",
"bad\x00name",
"bad\tname",
"a" * 1025,
],
)
def test_validate_export_filename_rejects_invalid_values(filename: str):
with pytest.raises(ValueError):
AppMessageExportService.validate_export_filename(filename)
def test_service_derives_output_names_from_filename_base():
service = AppMessageExportService(
app_id="736b9b03-20f2-4697-91da-8d00f6325900",
start_from=None,
end_before=datetime.datetime(2026, 3, 1),
filename="exports/2026/test01",
batch_size=1000,
use_cloud_storage=True,
dry_run=True,
)
assert service._filename_base == "exports/2026/test01"
assert service.output_gz_name == "exports/2026/test01.jsonl.gz"
assert service.output_jsonl_name == "exports/2026/test01.jsonl"

View File

@@ -1,12 +1,12 @@
'use client'
import type { ReactNode } from 'react'
import type { IToastProps } from './context'
import { noop } from 'es-toolkit/function'
import * as React from 'react'
import { useEffect, useState } from 'react'
import { createRoot } from 'react-dom/client'
import ActionButton from '@/app/components/base/action-button'
import { cn } from '@/utils/classnames'
import type { IToastProps } from './context'
import { ToastContext, useToastContext } from './context'
export type ToastHandle = {