mirror of
https://github.com/langgenius/dify.git
synced 2026-03-13 03:07:09 +00:00
Compare commits
22 Commits
3-6-type-c
...
copilot/su
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0e4f5cb38c | ||
|
|
c13d1872d4 | ||
|
|
c911de6a6c | ||
|
|
968bf10e1c | ||
|
|
3d77a5ec08 | ||
|
|
41af72449d | ||
|
|
de72bdef71 | ||
|
|
c925d17e8f | ||
|
|
dc2a53d834 | ||
|
|
05ab107e73 | ||
|
|
c016793efb | ||
|
|
a5bcbaebb7 | ||
|
|
f97ade7053 | ||
|
|
a0dcd04546 | ||
|
|
b0138316f0 | ||
|
|
099568f3da | ||
|
|
0623522d04 | ||
|
|
a25d48c5bd | ||
|
|
4f3a020670 | ||
|
|
d2e1177478 | ||
|
|
8a21fd88fd | ||
|
|
1c1bcc67da |
@@ -62,6 +62,22 @@ This is the default standard for backend code in this repo. Follow it for new co
|
||||
|
||||
- Code should usually include type annotations that match the repo’s current Python version (avoid untyped public APIs and “mystery” values).
|
||||
- Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless there’s a strong reason.
|
||||
- For dictionary-like data with known keys and value types, prefer `TypedDict` over `dict[...]` or `Mapping[...]`.
|
||||
- For optional keys in typed payloads, use `NotRequired[...]` (or `total=False` when most fields are optional).
|
||||
- Keep `dict[...]` / `Mapping[...]` for truly dynamic key spaces where the key set is unknown.
|
||||
|
||||
```python
|
||||
from datetime import datetime
|
||||
from typing import NotRequired, TypedDict
|
||||
|
||||
|
||||
class UserProfile(TypedDict):
|
||||
user_id: str
|
||||
email: str
|
||||
created_at: datetime
|
||||
nickname: NotRequired[str]
|
||||
```
|
||||
|
||||
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
|
||||
|
||||
```python
|
||||
|
||||
@@ -1,16 +1,38 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from flask import request
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from controllers.console.error import UnauthorizedAndForceLogout
|
||||
from core.logging.context import init_request_context
|
||||
from dify_app import DifyApp
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Console bootstrap APIs exempt from license check:
|
||||
# - system-features: license status for expiry UI (GlobalPublicStoreProvider)
|
||||
# - setup: install/setup status check (AppInitializer)
|
||||
# - init: init password validation for fresh install (InitPasswordPopup)
|
||||
# - login: auto-login after setup completion (InstallForm)
|
||||
# - version: version check (AppContextProvider)
|
||||
# - activate/check: invitation link validation (signin page)
|
||||
# Without these exemptions, the signin page triggers location.reload()
|
||||
# on unauthorized_and_force_logout, causing an infinite loop.
|
||||
_CONSOLE_EXEMPT_PREFIXES = (
|
||||
"/console/api/system-features",
|
||||
"/console/api/setup",
|
||||
"/console/api/init",
|
||||
"/console/api/login",
|
||||
"/console/api/version",
|
||||
"/console/api/activate/check",
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Application Factory Function
|
||||
@@ -31,6 +53,39 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
init_request_context()
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
# Enterprise license validation for API endpoints (both console and webapp)
|
||||
# When license expires, block all API access except bootstrap endpoints needed
|
||||
# for the frontend to load the license expiration page without infinite reloads.
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
is_console_api = request.path.startswith("/console/api/")
|
||||
is_webapp_api = request.path.startswith("/api/")
|
||||
|
||||
if is_console_api or is_webapp_api:
|
||||
if is_console_api:
|
||||
is_exempt = any(request.path.startswith(p) for p in _CONSOLE_EXEMPT_PREFIXES)
|
||||
else: # webapp API
|
||||
is_exempt = request.path.startswith("/api/system-features")
|
||||
|
||||
if not is_exempt:
|
||||
try:
|
||||
# Check license status (cached — see EnterpriseService for TTL details)
|
||||
license_status = EnterpriseService.get_cached_license_status()
|
||||
if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST):
|
||||
raise UnauthorizedAndForceLogout(
|
||||
f"Enterprise license is {license_status}. Please contact your administrator."
|
||||
)
|
||||
if license_status is None:
|
||||
raise UnauthorizedAndForceLogout(
|
||||
"Unable to verify enterprise license. Please contact your administrator."
|
||||
)
|
||||
except UnauthorizedAndForceLogout:
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Failed to check enterprise license status")
|
||||
raise UnauthorizedAndForceLogout(
|
||||
"Unable to verify enterprise license. Please contact your administrator."
|
||||
)
|
||||
|
||||
# add after request hook for injecting trace headers from OpenTelemetry span context
|
||||
# Only adds headers when OTEL is enabled and has valid context
|
||||
@dify_app.after_request
|
||||
|
||||
@@ -2668,3 +2668,77 @@ def clean_expired_messages(
|
||||
raise
|
||||
|
||||
click.echo(click.style("messages cleanup completed.", fg="green"))
|
||||
|
||||
|
||||
@click.command("export-app-messages", help="Export messages for an app to JSONL.GZ.")
|
||||
@click.option("--app-id", required=True, help="Application ID to export messages for.")
|
||||
@click.option(
|
||||
"--start-from",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
default=None,
|
||||
help="Optional lower bound (inclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--end-before",
|
||||
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
|
||||
required=True,
|
||||
help="Upper bound (exclusive) for created_at.",
|
||||
)
|
||||
@click.option(
|
||||
"--filename",
|
||||
required=True,
|
||||
help="Base filename (relative path). Do not include suffix like .jsonl.gz.",
|
||||
)
|
||||
@click.option("--use-cloud-storage", is_flag=True, default=False, help="Upload to cloud storage instead of local file.")
|
||||
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for cursor pagination.")
|
||||
@click.option("--dry-run", is_flag=True, default=False, help="Scan only, print stats without writing any file.")
|
||||
def export_app_messages(
|
||||
app_id: str,
|
||||
start_from: datetime.datetime | None,
|
||||
end_before: datetime.datetime,
|
||||
filename: str,
|
||||
use_cloud_storage: bool,
|
||||
batch_size: int,
|
||||
dry_run: bool,
|
||||
):
|
||||
if start_from and start_from >= end_before:
|
||||
raise click.UsageError("--start-from must be before --end-before.")
|
||||
|
||||
from services.retention.conversation.message_export_service import AppMessageExportService
|
||||
|
||||
try:
|
||||
validated_filename = AppMessageExportService.validate_export_filename(filename)
|
||||
except ValueError as e:
|
||||
raise click.BadParameter(str(e), param_hint="--filename") from e
|
||||
|
||||
click.echo(click.style(f"export_app_messages: starting export for app {app_id}.", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
service = AppMessageExportService(
|
||||
app_id=app_id,
|
||||
end_before=end_before,
|
||||
filename=validated_filename,
|
||||
start_from=start_from,
|
||||
batch_size=batch_size,
|
||||
use_cloud_storage=use_cloud_storage,
|
||||
dry_run=dry_run,
|
||||
)
|
||||
stats = service.run()
|
||||
|
||||
elapsed = time.perf_counter() - start_at
|
||||
click.echo(
|
||||
click.style(
|
||||
f"export_app_messages: completed in {elapsed:.2f}s\n"
|
||||
f" - Batches: {stats.batches}\n"
|
||||
f" - Total messages: {stats.total_messages}\n"
|
||||
f" - Messages with feedback: {stats.messages_with_feedback}\n"
|
||||
f" - Total feedbacks: {stats.total_feedbacks}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
elapsed = time.perf_counter() - start_at
|
||||
logger.exception("export_app_messages failed")
|
||||
click.echo(click.style(f"export_app_messages: failed after {elapsed:.2f}s - {e}", fg="red"))
|
||||
raise
|
||||
|
||||
@@ -44,14 +44,13 @@ from core.app.entities.task_entities import (
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
|
||||
from core.app.task_pipeline.message_file_utils import prepare_file_dict
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_manager import ModelInstance
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.tools.signature import sign_tool_file
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
|
||||
from dify_graph.model_runtime.entities.message_entities import (
|
||||
@@ -460,91 +459,40 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
"""
|
||||
self._task_state.metadata.usage = self._task_state.llm_result.usage
|
||||
metadata_dict = self._task_state.metadata.model_dump()
|
||||
|
||||
# Fetch files associated with this message
|
||||
files = None
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
|
||||
|
||||
if message_files:
|
||||
# Fetch all required UploadFile objects in a single query to avoid N+1 problem
|
||||
upload_file_ids = list(
|
||||
dict.fromkeys(
|
||||
mf.upload_file_id
|
||||
for mf in message_files
|
||||
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
|
||||
)
|
||||
)
|
||||
upload_files_map = {}
|
||||
if upload_file_ids:
|
||||
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
|
||||
upload_files_map = {uf.id: uf for uf in upload_files}
|
||||
|
||||
files_list = []
|
||||
for message_file in message_files:
|
||||
file_dict = prepare_file_dict(message_file, upload_files_map)
|
||||
files_list.append(file_dict)
|
||||
|
||||
files = files_list or None
|
||||
|
||||
return MessageEndStreamResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=self._message_id,
|
||||
metadata=metadata_dict,
|
||||
files=files,
|
||||
)
|
||||
|
||||
def _record_files(self):
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
|
||||
if not message_files:
|
||||
return None
|
||||
|
||||
files_list = []
|
||||
upload_file_ids = [
|
||||
mf.upload_file_id
|
||||
for mf in message_files
|
||||
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
|
||||
]
|
||||
upload_files_map = {}
|
||||
if upload_file_ids:
|
||||
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
|
||||
upload_files_map = {uf.id: uf for uf in upload_files}
|
||||
|
||||
for message_file in message_files:
|
||||
upload_file = None
|
||||
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
|
||||
upload_file = upload_files_map.get(message_file.upload_file_id)
|
||||
|
||||
url = None
|
||||
filename = "file"
|
||||
mime_type = "application/octet-stream"
|
||||
size = 0
|
||||
extension = ""
|
||||
|
||||
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
url = message_file.url
|
||||
if message_file.url:
|
||||
filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params
|
||||
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if upload_file:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
|
||||
filename = upload_file.name
|
||||
mime_type = upload_file.mime_type or "application/octet-stream"
|
||||
size = upload_file.size or 0
|
||||
extension = f".{upload_file.extension}" if upload_file.extension else ""
|
||||
elif message_file.upload_file_id:
|
||||
# Fallback: generate URL even if upload_file not found
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
|
||||
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
|
||||
# For tool files, use URL directly if it's HTTP, otherwise sign it
|
||||
if message_file.url.startswith("http"):
|
||||
url = message_file.url
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
else:
|
||||
# Extract tool file id and extension from URL
|
||||
url_parts = message_file.url.split("/")
|
||||
if url_parts:
|
||||
file_part = url_parts[-1].split("?")[0] # Remove query params first
|
||||
# Use rsplit to correctly handle filenames with multiple dots
|
||||
if "." in file_part:
|
||||
tool_file_id, ext = file_part.rsplit(".", 1)
|
||||
extension = f".{ext}"
|
||||
else:
|
||||
tool_file_id = file_part
|
||||
extension = ".bin"
|
||||
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
filename = file_part
|
||||
|
||||
transfer_method_value = message_file.transfer_method
|
||||
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
|
||||
file_dict = {
|
||||
"related_id": message_file.id,
|
||||
"extension": extension,
|
||||
"filename": filename,
|
||||
"size": size,
|
||||
"mime_type": mime_type,
|
||||
"transfer_method": transfer_method_value,
|
||||
"type": message_file.type,
|
||||
"url": url or "",
|
||||
"upload_file_id": message_file.upload_file_id or message_file.id,
|
||||
"remote_url": remote_url,
|
||||
}
|
||||
files_list.append(file_dict)
|
||||
return files_list or None
|
||||
|
||||
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
|
||||
"""
|
||||
Agent message to stream response.
|
||||
|
||||
76
api/core/app/task_pipeline/message_file_utils.py
Normal file
76
api/core/app/task_pipeline/message_file_utils.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from core.tools.signature import sign_tool_file
|
||||
from dify_graph.file import helpers as file_helpers
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from models.model import MessageFile, UploadFile
|
||||
|
||||
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
|
||||
|
||||
|
||||
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict:
|
||||
"""
|
||||
Prepare file dictionary for message end stream response.
|
||||
|
||||
:param message_file: MessageFile instance
|
||||
:param upload_files_map: Dictionary mapping upload_file_id to UploadFile
|
||||
:return: Dictionary containing file information
|
||||
"""
|
||||
upload_file = None
|
||||
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
|
||||
upload_file = upload_files_map.get(message_file.upload_file_id)
|
||||
|
||||
url = None
|
||||
filename = "file"
|
||||
mime_type = "application/octet-stream"
|
||||
size = 0
|
||||
extension = ""
|
||||
|
||||
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
url = message_file.url
|
||||
if message_file.url:
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
if "." in filename:
|
||||
extension = "." + filename.rsplit(".", 1)[1]
|
||||
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
if upload_file:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
|
||||
filename = upload_file.name
|
||||
mime_type = upload_file.mime_type or "application/octet-stream"
|
||||
size = upload_file.size or 0
|
||||
extension = f".{upload_file.extension}" if upload_file.extension else ""
|
||||
elif message_file.upload_file_id:
|
||||
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
|
||||
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
|
||||
if message_file.url.startswith(("http://", "https://")):
|
||||
url = message_file.url
|
||||
filename = message_file.url.split("/")[-1].split("?")[0]
|
||||
if "." in filename:
|
||||
extension = "." + filename.rsplit(".", 1)[1]
|
||||
else:
|
||||
url_parts = message_file.url.split("/")
|
||||
if url_parts:
|
||||
file_part = url_parts[-1].split("?")[0]
|
||||
if "." in file_part:
|
||||
tool_file_id, ext = file_part.rsplit(".", 1)
|
||||
extension = f".{ext}"
|
||||
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
|
||||
extension = ".bin"
|
||||
else:
|
||||
tool_file_id = file_part
|
||||
extension = ".bin"
|
||||
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
|
||||
filename = file_part
|
||||
|
||||
transfer_method_value = message_file.transfer_method.value
|
||||
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
|
||||
return {
|
||||
"related_id": message_file.id,
|
||||
"extension": extension,
|
||||
"filename": filename,
|
||||
"size": size,
|
||||
"mime_type": mime_type,
|
||||
"transfer_method": transfer_method_value,
|
||||
"type": message_file.type,
|
||||
"url": url or "",
|
||||
"upload_file_id": message_file.upload_file_id or message_file.id,
|
||||
"remote_url": remote_url,
|
||||
}
|
||||
@@ -13,6 +13,7 @@ def init_app(app: DifyApp):
|
||||
convert_to_agent_apps,
|
||||
create_tenant,
|
||||
delete_archived_workflow_runs,
|
||||
export_app_messages,
|
||||
extract_plugins,
|
||||
extract_unique_plugins,
|
||||
file_usage,
|
||||
@@ -66,6 +67,7 @@ def init_app(app: DifyApp):
|
||||
restore_workflow_runs,
|
||||
clean_workflow_runs,
|
||||
clean_expired_messages,
|
||||
export_app_messages,
|
||||
]
|
||||
for cmd in cmds_to_register:
|
||||
app.cli.add_command(cmd)
|
||||
|
||||
@@ -6,6 +6,13 @@ from typing import Any
|
||||
import httpx
|
||||
|
||||
from core.helper.trace_id_helper import generate_traceparent_header
|
||||
from services.errors.enterprise import (
|
||||
EnterpriseAPIBadRequestError,
|
||||
EnterpriseAPIError,
|
||||
EnterpriseAPIForbiddenError,
|
||||
EnterpriseAPINotFoundError,
|
||||
EnterpriseAPIUnauthorizedError,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -41,7 +48,6 @@ class BaseRequest:
|
||||
params: Mapping[str, Any] | None = None,
|
||||
*,
|
||||
timeout: float | httpx.Timeout | None = None,
|
||||
raise_for_status: bool = False,
|
||||
) -> Any:
|
||||
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
|
||||
url = f"{cls.base_url}{endpoint}"
|
||||
@@ -64,10 +70,51 @@ class BaseRequest:
|
||||
request_kwargs["timeout"] = timeout
|
||||
|
||||
response = client.request(method, url, **request_kwargs)
|
||||
if raise_for_status:
|
||||
response.raise_for_status()
|
||||
|
||||
# Validate HTTP status and raise domain-specific errors
|
||||
if not response.is_success:
|
||||
cls._handle_error_response(response)
|
||||
return response.json()
|
||||
|
||||
@classmethod
|
||||
def _handle_error_response(cls, response: httpx.Response) -> None:
|
||||
"""
|
||||
Handle non-2xx HTTP responses by raising appropriate domain errors.
|
||||
|
||||
Attempts to extract error message from JSON response body,
|
||||
falls back to status text if parsing fails.
|
||||
"""
|
||||
error_message = f"Enterprise API request failed: {response.status_code} {response.reason_phrase}"
|
||||
|
||||
# Try to extract error message from JSON response
|
||||
try:
|
||||
error_data = response.json()
|
||||
if isinstance(error_data, dict):
|
||||
# Common error response formats:
|
||||
# {"error": "...", "message": "..."}
|
||||
# {"message": "..."}
|
||||
# {"detail": "..."}
|
||||
error_message = (
|
||||
error_data.get("message") or error_data.get("error") or error_data.get("detail") or error_message
|
||||
)
|
||||
except Exception:
|
||||
# If JSON parsing fails, use the default message
|
||||
logger.debug(
|
||||
"Failed to parse error response from enterprise API (status=%s)", response.status_code, exc_info=True
|
||||
)
|
||||
|
||||
# Raise specific error based on status code
|
||||
if response.status_code == 400:
|
||||
raise EnterpriseAPIBadRequestError(error_message)
|
||||
elif response.status_code == 401:
|
||||
raise EnterpriseAPIUnauthorizedError(error_message)
|
||||
elif response.status_code == 403:
|
||||
raise EnterpriseAPIForbiddenError(error_message)
|
||||
elif response.status_code == 404:
|
||||
raise EnterpriseAPINotFoundError(error_message)
|
||||
else:
|
||||
raise EnterpriseAPIError(error_message, status_code=response.status_code)
|
||||
|
||||
|
||||
class EnterpriseRequest(BaseRequest):
|
||||
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")
|
||||
|
||||
@@ -1,15 +1,26 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_redis import redis_client
|
||||
from services.enterprise.base import EnterpriseRequest
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0
|
||||
# License status cache configuration
|
||||
LICENSE_STATUS_CACHE_KEY = "enterprise:license:status"
|
||||
VALID_LICENSE_CACHE_TTL = 600 # 10 minutes — valid licenses are stable
|
||||
INVALID_LICENSE_CACHE_TTL = 30 # 30 seconds — short so admin fixes are picked up quickly
|
||||
|
||||
|
||||
class WebAppSettings(BaseModel):
|
||||
@@ -52,7 +63,7 @@ class DefaultWorkspaceJoinResult(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult":
|
||||
def _check_workspace_id_when_joined(self) -> DefaultWorkspaceJoinResult:
|
||||
if self.joined and not self.workspace_id:
|
||||
raise ValueError("workspace_id must be non-empty when joined is True")
|
||||
return self
|
||||
@@ -115,7 +126,6 @@ class EnterpriseService:
|
||||
"/default-workspace/members",
|
||||
json={"account_id": account_id},
|
||||
timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS,
|
||||
raise_for_status=True,
|
||||
)
|
||||
if not isinstance(data, dict):
|
||||
raise ValueError("Invalid response format from enterprise default workspace API")
|
||||
@@ -223,3 +233,64 @@ class EnterpriseService:
|
||||
|
||||
params = {"appId": app_id}
|
||||
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
|
||||
|
||||
@classmethod
|
||||
def get_cached_license_status(cls) -> LicenseStatus | None:
|
||||
"""Get enterprise license status with Redis caching to reduce HTTP calls.
|
||||
|
||||
Caches valid statuses (active/expiring) for 10 minutes and invalid statuses
|
||||
(inactive/expired/lost) for 30 seconds. The shorter TTL for invalid statuses
|
||||
balances prompt license-fix detection against DoS mitigation — without
|
||||
caching, every request on an expired license would hit the enterprise API.
|
||||
|
||||
Returns:
|
||||
LicenseStatus enum value, or None if enterprise is disabled / unreachable.
|
||||
"""
|
||||
if not dify_config.ENTERPRISE_ENABLED:
|
||||
return None
|
||||
|
||||
cached = cls._read_cached_license_status()
|
||||
if cached is not None:
|
||||
return cached
|
||||
|
||||
return cls._fetch_and_cache_license_status()
|
||||
|
||||
@classmethod
|
||||
def _read_cached_license_status(cls) -> LicenseStatus | None:
|
||||
"""Read license status from Redis cache, returning None on miss or failure."""
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
try:
|
||||
raw = redis_client.get(LICENSE_STATUS_CACHE_KEY)
|
||||
if raw:
|
||||
value = raw.decode("utf-8") if isinstance(raw, bytes) else raw
|
||||
return LicenseStatus(value)
|
||||
except Exception:
|
||||
logger.warning("Failed to read license status from cache", exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _fetch_and_cache_license_status(cls) -> LicenseStatus | None:
|
||||
"""Fetch license status from enterprise API and cache the result."""
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
try:
|
||||
info = cls.get_info()
|
||||
license_info = info.get("License")
|
||||
if not license_info:
|
||||
return None
|
||||
|
||||
status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
|
||||
ttl = (
|
||||
VALID_LICENSE_CACHE_TTL
|
||||
if status in (LicenseStatus.ACTIVE, LicenseStatus.EXPIRING)
|
||||
else INVALID_LICENSE_CACHE_TTL
|
||||
)
|
||||
try:
|
||||
redis_client.setex(LICENSE_STATUS_CACHE_KEY, ttl, status)
|
||||
except Exception:
|
||||
logger.warning("Failed to cache license status", exc_info=True)
|
||||
return status
|
||||
except Exception:
|
||||
logger.exception("Failed to get enterprise license status")
|
||||
return None
|
||||
|
||||
@@ -7,6 +7,7 @@ from . import (
|
||||
conversation,
|
||||
dataset,
|
||||
document,
|
||||
enterprise,
|
||||
file,
|
||||
index,
|
||||
message,
|
||||
@@ -21,6 +22,7 @@ __all__ = [
|
||||
"conversation",
|
||||
"dataset",
|
||||
"document",
|
||||
"enterprise",
|
||||
"file",
|
||||
"index",
|
||||
"message",
|
||||
|
||||
45
api/services/errors/enterprise.py
Normal file
45
api/services/errors/enterprise.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Enterprise service errors."""
|
||||
|
||||
from services.errors.base import BaseServiceError
|
||||
|
||||
|
||||
class EnterpriseServiceError(BaseServiceError):
|
||||
"""Base exception for enterprise service errors."""
|
||||
|
||||
def __init__(self, description: str | None = None, status_code: int | None = None):
|
||||
super().__init__(description)
|
||||
self.status_code = status_code
|
||||
|
||||
|
||||
class EnterpriseAPIError(EnterpriseServiceError):
|
||||
"""Generic enterprise API error (non-2xx response)."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class EnterpriseAPINotFoundError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 404 Not Found."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=404)
|
||||
|
||||
|
||||
class EnterpriseAPIForbiddenError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 403 Forbidden."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=403)
|
||||
|
||||
|
||||
class EnterpriseAPIUnauthorizedError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 401 Unauthorized."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=401)
|
||||
|
||||
|
||||
class EnterpriseAPIBadRequestError(EnterpriseServiceError):
|
||||
"""Enterprise API returned 400 Bad Request."""
|
||||
|
||||
def __init__(self, description: str | None = None):
|
||||
super().__init__(description, status_code=400)
|
||||
@@ -379,14 +379,19 @@ class FeatureService:
|
||||
)
|
||||
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
|
||||
|
||||
if is_authenticated and (license_info := enterprise_info.get("License")):
|
||||
# SECURITY NOTE: Only license *status* is exposed to unauthenticated callers
|
||||
# so the login page can detect an expired/inactive license after force-logout.
|
||||
# All other license details (expiry date, workspace usage) remain auth-gated.
|
||||
# This behavior reflects prior internal review of information-leakage risks.
|
||||
if license_info := enterprise_info.get("License"):
|
||||
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
|
||||
features.license.expired_at = license_info.get("expiredAt", "")
|
||||
|
||||
if workspaces_info := license_info.get("workspaces"):
|
||||
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
|
||||
features.license.workspaces.limit = workspaces_info.get("limit", 0)
|
||||
features.license.workspaces.size = workspaces_info.get("used", 0)
|
||||
if is_authenticated:
|
||||
features.license.expired_at = license_info.get("expiredAt", "")
|
||||
if workspaces_info := license_info.get("workspaces"):
|
||||
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
|
||||
features.license.workspaces.limit = workspaces_info.get("limit", 0)
|
||||
features.license.workspaces.size = workspaces_info.get("used", 0)
|
||||
|
||||
if "PluginInstallationPermission" in enterprise_info:
|
||||
plugin_installation_info = enterprise_info["PluginInstallationPermission"]
|
||||
|
||||
@@ -63,7 +63,12 @@ class RagPipelineTransformService:
|
||||
):
|
||||
node = self._deal_file_extensions(node)
|
||||
if node.get("data", {}).get("type") == "knowledge-index":
|
||||
node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node)
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
|
||||
if dataset.tenant_id != current_user.current_tenant_id:
|
||||
raise ValueError("Unauthorized")
|
||||
node = self._deal_knowledge_index(
|
||||
knowledge_configuration, dataset, indexing_technique, retrieval_model, node
|
||||
)
|
||||
new_nodes.append(node)
|
||||
if new_nodes:
|
||||
graph["nodes"] = new_nodes
|
||||
@@ -155,14 +160,13 @@ class RagPipelineTransformService:
|
||||
|
||||
def _deal_knowledge_index(
|
||||
self,
|
||||
knowledge_configuration: KnowledgeConfiguration,
|
||||
dataset: Dataset,
|
||||
doc_form: str,
|
||||
indexing_technique: str | None,
|
||||
retrieval_model: RetrievalSetting | None,
|
||||
node: dict,
|
||||
):
|
||||
knowledge_configuration_dict = node.get("data", {})
|
||||
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict)
|
||||
|
||||
if indexing_technique == "high_quality":
|
||||
knowledge_configuration.embedding_model = dataset.embedding_model
|
||||
|
||||
304
api/services/retention/conversation/message_export_service.py
Normal file
304
api/services/retention/conversation/message_export_service.py
Normal file
@@ -0,0 +1,304 @@
|
||||
"""
|
||||
Export app messages to JSONL.GZ format.
|
||||
|
||||
Outputs: conversation_id, message_id, query, answer, inputs (raw JSON),
|
||||
retriever_resources (from message_metadata), feedback (user feedbacks array).
|
||||
|
||||
Uses (created_at, id) cursor pagination and batch-loads feedbacks to avoid N+1.
|
||||
Does NOT touch Message.inputs / Message.user_feedback properties.
|
||||
"""
|
||||
|
||||
import datetime
|
||||
import gzip
|
||||
import json
|
||||
import logging
|
||||
import tempfile
|
||||
from collections import defaultdict
|
||||
from collections.abc import Generator, Iterable
|
||||
from pathlib import Path, PurePosixPath
|
||||
from typing import Any, BinaryIO, cast
|
||||
|
||||
import orjson
|
||||
import sqlalchemy as sa
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
from sqlalchemy import select, tuple_
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import Message, MessageFeedback
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
MAX_FILENAME_BASE_LENGTH = 1024
|
||||
FORBIDDEN_FILENAME_SUFFIXES = (".jsonl.gz", ".jsonl", ".gz")
|
||||
|
||||
|
||||
class AppMessageExportFeedback(BaseModel):
|
||||
id: str
|
||||
app_id: str
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
rating: str
|
||||
content: str | None = None
|
||||
from_source: str
|
||||
from_end_user_id: str | None = None
|
||||
from_account_id: str | None = None
|
||||
created_at: str
|
||||
updated_at: str
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AppMessageExportRecord(BaseModel):
|
||||
conversation_id: str
|
||||
message_id: str
|
||||
query: str
|
||||
answer: str
|
||||
inputs: dict[str, Any]
|
||||
retriever_resources: list[Any] = Field(default_factory=list)
|
||||
feedback: list[AppMessageExportFeedback] = Field(default_factory=list)
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AppMessageExportStats(BaseModel):
|
||||
batches: int = 0
|
||||
total_messages: int = 0
|
||||
messages_with_feedback: int = 0
|
||||
total_feedbacks: int = 0
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
|
||||
class AppMessageExportService:
|
||||
@staticmethod
|
||||
def validate_export_filename(filename: str) -> str:
|
||||
normalized = filename.strip()
|
||||
if not normalized:
|
||||
raise ValueError("--filename must not be empty.")
|
||||
|
||||
normalized_lower = normalized.lower()
|
||||
if normalized_lower.endswith(FORBIDDEN_FILENAME_SUFFIXES):
|
||||
raise ValueError("--filename must not include .jsonl.gz/.jsonl/.gz suffix; pass base filename only.")
|
||||
|
||||
if normalized.startswith("/"):
|
||||
raise ValueError("--filename must be a relative path; absolute paths are not allowed.")
|
||||
|
||||
if "\\" in normalized:
|
||||
raise ValueError("--filename must use '/' as path separator; '\\' is not allowed.")
|
||||
|
||||
if "//" in normalized:
|
||||
raise ValueError("--filename must not contain empty path segments ('//').")
|
||||
|
||||
if len(normalized) > MAX_FILENAME_BASE_LENGTH:
|
||||
raise ValueError(f"--filename is too long; max length is {MAX_FILENAME_BASE_LENGTH}.")
|
||||
|
||||
for ch in normalized:
|
||||
if ch == "\x00" or ord(ch) < 32 or ord(ch) == 127:
|
||||
raise ValueError("--filename must not contain control characters or NUL.")
|
||||
|
||||
parts = PurePosixPath(normalized).parts
|
||||
if not parts:
|
||||
raise ValueError("--filename must include a file name.")
|
||||
|
||||
if any(part in (".", "..") for part in parts):
|
||||
raise ValueError("--filename must not contain '.' or '..' path segments.")
|
||||
|
||||
return normalized
|
||||
|
||||
@property
|
||||
def output_gz_name(self) -> str:
|
||||
return f"{self._filename_base}.jsonl.gz"
|
||||
|
||||
@property
|
||||
def output_jsonl_name(self) -> str:
|
||||
return f"{self._filename_base}.jsonl"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
end_before: datetime.datetime,
|
||||
filename: str,
|
||||
*,
|
||||
start_from: datetime.datetime | None = None,
|
||||
batch_size: int = 1000,
|
||||
use_cloud_storage: bool = False,
|
||||
dry_run: bool = False,
|
||||
) -> None:
|
||||
if start_from and start_from >= end_before:
|
||||
raise ValueError(f"start_from ({start_from}) must be before end_before ({end_before})")
|
||||
|
||||
self._app_id = app_id
|
||||
self._end_before = end_before
|
||||
self._start_from = start_from
|
||||
self._filename_base = self.validate_export_filename(filename)
|
||||
self._batch_size = batch_size
|
||||
self._use_cloud_storage = use_cloud_storage
|
||||
self._dry_run = dry_run
|
||||
|
||||
def run(self) -> AppMessageExportStats:
|
||||
stats = AppMessageExportStats()
|
||||
|
||||
logger.info(
|
||||
"export_app_messages: app_id=%s, start_from=%s, end_before=%s, dry_run=%s, cloud=%s, output_gz=%s",
|
||||
self._app_id,
|
||||
self._start_from,
|
||||
self._end_before,
|
||||
self._dry_run,
|
||||
self._use_cloud_storage,
|
||||
self.output_gz_name,
|
||||
)
|
||||
|
||||
if self._dry_run:
|
||||
for _ in self._iter_records_with_stats(stats):
|
||||
pass
|
||||
self._finalize_stats(stats)
|
||||
return stats
|
||||
|
||||
if self._use_cloud_storage:
|
||||
self._export_to_cloud(stats)
|
||||
else:
|
||||
self._export_to_local(stats)
|
||||
|
||||
self._finalize_stats(stats)
|
||||
return stats
|
||||
|
||||
def iter_records(self) -> Generator[AppMessageExportRecord, None, None]:
|
||||
for batch in self._iter_record_batches():
|
||||
yield from batch
|
||||
|
||||
@staticmethod
|
||||
def write_jsonl_gz(records: Iterable[AppMessageExportRecord], fileobj: BinaryIO) -> None:
|
||||
with gzip.GzipFile(fileobj=fileobj, mode="wb") as gz:
|
||||
for record in records:
|
||||
gz.write(orjson.dumps(record.model_dump(mode="json")) + b"\n")
|
||||
|
||||
def _export_to_local(self, stats: AppMessageExportStats) -> None:
|
||||
output_path = Path.cwd() / self.output_gz_name
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with output_path.open("wb") as output_file:
|
||||
self.write_jsonl_gz(self._iter_records_with_stats(stats), output_file)
|
||||
|
||||
def _export_to_cloud(self, stats: AppMessageExportStats) -> None:
|
||||
with tempfile.SpooledTemporaryFile(max_size=64 * 1024 * 1024) as tmp:
|
||||
self.write_jsonl_gz(self._iter_records_with_stats(stats), cast(BinaryIO, tmp))
|
||||
tmp.seek(0)
|
||||
data = tmp.read()
|
||||
|
||||
storage.save(self.output_gz_name, data)
|
||||
logger.info("export_app_messages: uploaded %d bytes to cloud key=%s", len(data), self.output_gz_name)
|
||||
|
||||
def _iter_records_with_stats(self, stats: AppMessageExportStats) -> Generator[AppMessageExportRecord, None, None]:
|
||||
for record in self.iter_records():
|
||||
self._update_stats(stats, record)
|
||||
yield record
|
||||
|
||||
@staticmethod
|
||||
def _update_stats(stats: AppMessageExportStats, record: AppMessageExportRecord) -> None:
|
||||
stats.total_messages += 1
|
||||
if record.feedback:
|
||||
stats.messages_with_feedback += 1
|
||||
stats.total_feedbacks += len(record.feedback)
|
||||
|
||||
def _finalize_stats(self, stats: AppMessageExportStats) -> None:
|
||||
if stats.total_messages == 0:
|
||||
stats.batches = 0
|
||||
return
|
||||
stats.batches = (stats.total_messages + self._batch_size - 1) // self._batch_size
|
||||
|
||||
def _iter_record_batches(self) -> Generator[list[AppMessageExportRecord], None, None]:
|
||||
cursor: tuple[datetime.datetime, str] | None = None
|
||||
while True:
|
||||
rows, cursor = self._fetch_batch(cursor)
|
||||
if not rows:
|
||||
break
|
||||
|
||||
message_ids = [str(row.id) for row in rows]
|
||||
feedbacks_map = self._fetch_feedbacks(message_ids)
|
||||
yield [self._build_record(row, feedbacks_map) for row in rows]
|
||||
|
||||
def _fetch_batch(
|
||||
self, cursor: tuple[datetime.datetime, str] | None
|
||||
) -> tuple[list[Any], tuple[datetime.datetime, str] | None]:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(
|
||||
Message.id,
|
||||
Message.conversation_id,
|
||||
Message.query,
|
||||
Message.answer,
|
||||
Message._inputs, # pyright: ignore[reportPrivateUsage]
|
||||
Message.message_metadata,
|
||||
Message.created_at,
|
||||
)
|
||||
.where(
|
||||
Message.app_id == self._app_id,
|
||||
Message.created_at < self._end_before,
|
||||
)
|
||||
.order_by(Message.created_at, Message.id)
|
||||
.limit(self._batch_size)
|
||||
)
|
||||
|
||||
if self._start_from:
|
||||
stmt = stmt.where(Message.created_at >= self._start_from)
|
||||
|
||||
if cursor:
|
||||
stmt = stmt.where(
|
||||
tuple_(Message.created_at, Message.id)
|
||||
> tuple_(
|
||||
sa.literal(cursor[0], type_=sa.DateTime()),
|
||||
sa.literal(cursor[1], type_=Message.id.type),
|
||||
)
|
||||
)
|
||||
|
||||
rows = list(session.execute(stmt).all())
|
||||
|
||||
if not rows:
|
||||
return [], cursor
|
||||
|
||||
last = rows[-1]
|
||||
return rows, (last.created_at, last.id)
|
||||
|
||||
def _fetch_feedbacks(self, message_ids: list[str]) -> dict[str, list[AppMessageExportFeedback]]:
|
||||
if not message_ids:
|
||||
return {}
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(MessageFeedback)
|
||||
.where(
|
||||
MessageFeedback.message_id.in_(message_ids),
|
||||
MessageFeedback.from_source == "user",
|
||||
)
|
||||
.order_by(MessageFeedback.message_id, MessageFeedback.created_at)
|
||||
)
|
||||
feedbacks = list(session.scalars(stmt).all())
|
||||
|
||||
result: dict[str, list[AppMessageExportFeedback]] = defaultdict(list)
|
||||
for feedback in feedbacks:
|
||||
result[str(feedback.message_id)].append(AppMessageExportFeedback.model_validate(feedback.to_dict()))
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _build_record(row: Any, feedbacks_map: dict[str, list[AppMessageExportFeedback]]) -> AppMessageExportRecord:
|
||||
retriever_resources: list[Any] = []
|
||||
if row.message_metadata:
|
||||
try:
|
||||
metadata = json.loads(row.message_metadata)
|
||||
value = metadata.get("retriever_resources", [])
|
||||
if isinstance(value, list):
|
||||
retriever_resources = value
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
message_id = str(row.id)
|
||||
return AppMessageExportRecord(
|
||||
conversation_id=str(row.conversation_id),
|
||||
message_id=message_id,
|
||||
query=row.query,
|
||||
answer=row.answer,
|
||||
inputs=row._inputs if isinstance(row._inputs, dict) else {},
|
||||
retriever_resources=retriever_resources,
|
||||
feedback=feedbacks_map.get(message_id, []),
|
||||
)
|
||||
@@ -358,10 +358,9 @@ class TestFeatureService:
|
||||
assert result is not None
|
||||
assert isinstance(result, SystemFeatureModel)
|
||||
|
||||
# --- 1. Verify Response Payload Optimization (Data Minimization) ---
|
||||
# Ensure only essential UI flags are returned to unauthenticated clients
|
||||
# to keep the payload lightweight and adhere to architectural boundaries.
|
||||
assert result.license.status == LicenseStatus.NONE
|
||||
# --- 1. Verify only license *status* is exposed to unauthenticated clients ---
|
||||
# Detailed license info (expiry, workspaces) remains auth-gated.
|
||||
assert result.license.status == LicenseStatus.ACTIVE
|
||||
assert result.license.expired_at == ""
|
||||
assert result.license.workspaces.enabled is False
|
||||
assert result.license.workspaces.limit == 0
|
||||
|
||||
@@ -0,0 +1,233 @@
|
||||
import datetime
|
||||
import json
|
||||
import uuid
|
||||
from decimal import Decimal
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
Conversation,
|
||||
DatasetRetrieverResource,
|
||||
Message,
|
||||
MessageAgentThought,
|
||||
MessageAnnotation,
|
||||
MessageChain,
|
||||
MessageFeedback,
|
||||
MessageFile,
|
||||
)
|
||||
from models.web import SavedMessage
|
||||
from services.retention.conversation.message_export_service import AppMessageExportService, AppMessageExportStats
|
||||
|
||||
|
||||
class TestAppMessageExportServiceIntegration:
|
||||
@pytest.fixture(autouse=True)
|
||||
def cleanup_database(self, db_session_with_containers: Session):
|
||||
yield
|
||||
db_session_with_containers.query(DatasetRetrieverResource).delete()
|
||||
db_session_with_containers.query(AppAnnotationHitHistory).delete()
|
||||
db_session_with_containers.query(SavedMessage).delete()
|
||||
db_session_with_containers.query(MessageFile).delete()
|
||||
db_session_with_containers.query(MessageAgentThought).delete()
|
||||
db_session_with_containers.query(MessageChain).delete()
|
||||
db_session_with_containers.query(MessageAnnotation).delete()
|
||||
db_session_with_containers.query(MessageFeedback).delete()
|
||||
db_session_with_containers.query(Message).delete()
|
||||
db_session_with_containers.query(Conversation).delete()
|
||||
db_session_with_containers.query(App).delete()
|
||||
db_session_with_containers.query(TenantAccountJoin).delete()
|
||||
db_session_with_containers.query(Tenant).delete()
|
||||
db_session_with_containers.query(Account).delete()
|
||||
db_session_with_containers.commit()
|
||||
|
||||
@staticmethod
|
||||
def _create_app_context(session: Session) -> tuple[App, Conversation]:
|
||||
account = Account(
|
||||
email=f"test-{uuid.uuid4()}@example.com",
|
||||
name="tester",
|
||||
interface_language="en-US",
|
||||
status="active",
|
||||
)
|
||||
session.add(account)
|
||||
session.flush()
|
||||
|
||||
tenant = Tenant(name=f"tenant-{uuid.uuid4()}", status="normal")
|
||||
session.add(tenant)
|
||||
session.flush()
|
||||
|
||||
join = TenantAccountJoin(
|
||||
tenant_id=tenant.id,
|
||||
account_id=account.id,
|
||||
role=TenantAccountRole.OWNER,
|
||||
current=True,
|
||||
)
|
||||
session.add(join)
|
||||
session.flush()
|
||||
|
||||
app = App(
|
||||
tenant_id=tenant.id,
|
||||
name="export-app",
|
||||
description="integration test app",
|
||||
mode="chat",
|
||||
enable_site=True,
|
||||
enable_api=True,
|
||||
api_rpm=60,
|
||||
api_rph=3600,
|
||||
is_demo=False,
|
||||
is_public=False,
|
||||
created_by=account.id,
|
||||
updated_by=account.id,
|
||||
)
|
||||
session.add(app)
|
||||
session.flush()
|
||||
|
||||
conversation = Conversation(
|
||||
app_id=app.id,
|
||||
app_model_config_id=str(uuid.uuid4()),
|
||||
model_provider="openai",
|
||||
model_id="gpt-4o-mini",
|
||||
mode="chat",
|
||||
name="conv",
|
||||
inputs={"seed": 1},
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
)
|
||||
session.add(conversation)
|
||||
session.commit()
|
||||
return app, conversation
|
||||
|
||||
@staticmethod
|
||||
def _create_message(
|
||||
session: Session,
|
||||
app: App,
|
||||
conversation: Conversation,
|
||||
created_at: datetime.datetime,
|
||||
*,
|
||||
query: str,
|
||||
answer: str,
|
||||
inputs: dict,
|
||||
message_metadata: str | None,
|
||||
) -> Message:
|
||||
message = Message(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
model_provider="openai",
|
||||
model_id="gpt-4o-mini",
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
answer=answer,
|
||||
message=[{"role": "assistant", "content": answer}],
|
||||
message_tokens=10,
|
||||
message_unit_price=Decimal("0.001"),
|
||||
answer_tokens=20,
|
||||
answer_unit_price=Decimal("0.002"),
|
||||
total_price=Decimal("0.003"),
|
||||
currency="USD",
|
||||
message_metadata=message_metadata,
|
||||
from_source="api",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
created_at=created_at,
|
||||
)
|
||||
session.add(message)
|
||||
session.flush()
|
||||
return message
|
||||
|
||||
def test_iter_records_with_stats(self, db_session_with_containers: Session):
|
||||
app, conversation = self._create_app_context(db_session_with_containers)
|
||||
|
||||
first_inputs = {
|
||||
"plain": "v1",
|
||||
"nested": {"a": 1, "b": [1, {"x": True}]},
|
||||
"list": ["x", 2, {"y": "z"}],
|
||||
}
|
||||
second_inputs = {"other": "value", "items": [1, 2, 3]}
|
||||
|
||||
base_time = datetime.datetime(2026, 2, 25, 10, 0, 0)
|
||||
first_message = self._create_message(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conversation,
|
||||
created_at=base_time,
|
||||
query="q1",
|
||||
answer="a1",
|
||||
inputs=first_inputs,
|
||||
message_metadata=json.dumps({"retriever_resources": [{"dataset_id": "ds-1"}]}),
|
||||
)
|
||||
second_message = self._create_message(
|
||||
db_session_with_containers,
|
||||
app,
|
||||
conversation,
|
||||
created_at=base_time + datetime.timedelta(minutes=1),
|
||||
query="q2",
|
||||
answer="a2",
|
||||
inputs=second_inputs,
|
||||
message_metadata=None,
|
||||
)
|
||||
|
||||
user_feedback_1 = MessageFeedback(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="like",
|
||||
from_source="user",
|
||||
content="first",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
)
|
||||
user_feedback_2 = MessageFeedback(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="dislike",
|
||||
from_source="user",
|
||||
content="second",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
)
|
||||
admin_feedback = MessageFeedback(
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="like",
|
||||
from_source="admin",
|
||||
content="should-be-filtered",
|
||||
from_account_id=str(uuid.uuid4()),
|
||||
)
|
||||
db_session_with_containers.add_all([user_feedback_1, user_feedback_2, admin_feedback])
|
||||
user_feedback_1.created_at = base_time + datetime.timedelta(minutes=2)
|
||||
user_feedback_2.created_at = base_time + datetime.timedelta(minutes=3)
|
||||
admin_feedback.created_at = base_time + datetime.timedelta(minutes=4)
|
||||
db_session_with_containers.commit()
|
||||
|
||||
service = AppMessageExportService(
|
||||
app_id=app.id,
|
||||
start_from=base_time - datetime.timedelta(minutes=1),
|
||||
end_before=base_time + datetime.timedelta(minutes=10),
|
||||
filename="unused",
|
||||
batch_size=1,
|
||||
dry_run=True,
|
||||
)
|
||||
stats = AppMessageExportStats()
|
||||
records = list(service._iter_records_with_stats(stats))
|
||||
service._finalize_stats(stats)
|
||||
|
||||
assert len(records) == 2
|
||||
assert records[0].message_id == first_message.id
|
||||
assert records[1].message_id == second_message.id
|
||||
|
||||
assert records[0].inputs == first_inputs
|
||||
assert records[1].inputs == second_inputs
|
||||
|
||||
assert records[0].retriever_resources == [{"dataset_id": "ds-1"}]
|
||||
assert records[1].retriever_resources == []
|
||||
|
||||
assert [feedback.rating for feedback in records[0].feedback] == ["like", "dislike"]
|
||||
assert [feedback.content for feedback in records[0].feedback] == ["first", "second"]
|
||||
assert records[1].feedback == []
|
||||
|
||||
assert stats.batches == 2
|
||||
assert stats.total_messages == 2
|
||||
assert stats.messages_with_feedback == 1
|
||||
assert stats.total_feedbacks == 2
|
||||
@@ -0,0 +1,425 @@
|
||||
"""
|
||||
Unit tests for EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response method.
|
||||
|
||||
This test suite ensures that the files array is correctly populated in the message_end
|
||||
SSE event, which is critical for vision/image chat responses to render correctly.
|
||||
|
||||
Test Coverage:
|
||||
- Files array populated when MessageFile records exist
|
||||
- Files array is None when no MessageFile records exist
|
||||
- Correct signed URL generation for LOCAL_FILE transfer method
|
||||
- Correct URL handling for REMOTE_URL transfer method
|
||||
- Correct URL handling for TOOL_FILE transfer method
|
||||
- Proper file metadata formatting (filename, mime_type, size, extension)
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.task_entities import MessageEndStreamResponse
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from models.model import MessageFile, UploadFile
|
||||
|
||||
|
||||
class TestMessageEndStreamResponseFiles:
|
||||
"""Test suite for files array population in message_end SSE event."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_pipeline(self):
|
||||
"""Create a mock EasyUIBasedGenerateTaskPipeline instance."""
|
||||
pipeline = Mock(spec=EasyUIBasedGenerateTaskPipeline)
|
||||
pipeline._message_id = str(uuid.uuid4())
|
||||
pipeline._task_state = Mock()
|
||||
pipeline._task_state.metadata = Mock()
|
||||
pipeline._task_state.metadata.model_dump = Mock(return_value={"test": "metadata"})
|
||||
pipeline._task_state.llm_result = Mock()
|
||||
pipeline._task_state.llm_result.usage = Mock()
|
||||
pipeline._application_generate_entity = Mock()
|
||||
pipeline._application_generate_entity.task_id = str(uuid.uuid4())
|
||||
return pipeline
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_file_local(self):
|
||||
"""Create a mock MessageFile with LOCAL_FILE transfer method."""
|
||||
message_file = Mock(spec=MessageFile)
|
||||
message_file.id = str(uuid.uuid4())
|
||||
message_file.message_id = str(uuid.uuid4())
|
||||
message_file.transfer_method = FileTransferMethod.LOCAL_FILE
|
||||
message_file.upload_file_id = str(uuid.uuid4())
|
||||
message_file.url = None
|
||||
message_file.type = "image"
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_file_remote(self):
|
||||
"""Create a mock MessageFile with REMOTE_URL transfer method."""
|
||||
message_file = Mock(spec=MessageFile)
|
||||
message_file.id = str(uuid.uuid4())
|
||||
message_file.message_id = str(uuid.uuid4())
|
||||
message_file.transfer_method = FileTransferMethod.REMOTE_URL
|
||||
message_file.upload_file_id = None
|
||||
message_file.url = "https://example.com/image.jpg"
|
||||
message_file.type = "image"
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_message_file_tool(self):
|
||||
"""Create a mock MessageFile with TOOL_FILE transfer method."""
|
||||
message_file = Mock(spec=MessageFile)
|
||||
message_file.id = str(uuid.uuid4())
|
||||
message_file.message_id = str(uuid.uuid4())
|
||||
message_file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||
message_file.upload_file_id = None
|
||||
message_file.url = "tool_file_123.png"
|
||||
message_file.type = "image"
|
||||
return message_file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_upload_file(self, mock_message_file_local):
|
||||
"""Create a mock UploadFile."""
|
||||
upload_file = Mock(spec=UploadFile)
|
||||
upload_file.id = mock_message_file_local.upload_file_id
|
||||
upload_file.name = "test_image.png"
|
||||
upload_file.mime_type = "image/png"
|
||||
upload_file.size = 1024
|
||||
upload_file.extension = "png"
|
||||
return upload_file
|
||||
|
||||
def test_message_end_with_no_files(self, mock_pipeline):
|
||||
"""Test that files array is None when no MessageFile records exist."""
|
||||
# Arrange
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_session.scalars.return_value.all.return_value = []
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is None
|
||||
assert result.id == mock_pipeline._message_id
|
||||
assert result.metadata == {"test": "metadata"}
|
||||
|
||||
def test_message_end_with_local_file(self, mock_pipeline, mock_message_file_local, mock_upload_file):
|
||||
"""Test that files array is populated correctly for LOCAL_FILE transfer method."""
|
||||
# Arrange
|
||||
mock_message_file_local.message_id = mock_pipeline._message_id
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
# First query: MessageFile
|
||||
mock_message_files_result = Mock()
|
||||
mock_message_files_result.all.return_value = [mock_message_file_local]
|
||||
|
||||
# Second query: UploadFile (batch query to avoid N+1)
|
||||
mock_upload_files_result = Mock()
|
||||
mock_upload_files_result.all.return_value = [mock_upload_file]
|
||||
|
||||
# Setup scalars to return different results for different queries
|
||||
call_count = [0] # Use list to allow modification in nested function
|
||||
|
||||
def scalars_side_effect(query):
|
||||
call_count[0] += 1
|
||||
# First call is for MessageFile, second call is for UploadFile
|
||||
if call_count[0] == 1:
|
||||
return mock_message_files_result
|
||||
else:
|
||||
return mock_upload_files_result
|
||||
|
||||
mock_session.scalars.side_effect = scalars_side_effect
|
||||
mock_get_url.return_value = "https://example.com/signed-url?signature=abc123"
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert file_dict["related_id"] == mock_message_file_local.id
|
||||
assert file_dict["filename"] == "test_image.png"
|
||||
assert file_dict["mime_type"] == "image/png"
|
||||
assert file_dict["size"] == 1024
|
||||
assert file_dict["extension"] == ".png"
|
||||
assert file_dict["type"] == "image"
|
||||
assert file_dict["transfer_method"] == FileTransferMethod.LOCAL_FILE.value
|
||||
assert "https://example.com/signed-url" in file_dict["url"]
|
||||
assert file_dict["upload_file_id"] == mock_message_file_local.upload_file_id
|
||||
assert file_dict["remote_url"] == ""
|
||||
|
||||
# Verify database queries
|
||||
# Should be called twice: once for MessageFile, once for UploadFile
|
||||
assert mock_session.scalars.call_count == 2
|
||||
mock_get_url.assert_called_once_with(upload_file_id=str(mock_upload_file.id))
|
||||
|
||||
def test_message_end_with_remote_url(self, mock_pipeline, mock_message_file_remote):
|
||||
"""Test that files array is populated correctly for REMOTE_URL transfer method."""
|
||||
# Arrange
|
||||
mock_message_file_remote.message_id = mock_pipeline._message_id
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
mock_scalars_result = Mock()
|
||||
mock_scalars_result.all.return_value = [mock_message_file_remote]
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert file_dict["related_id"] == mock_message_file_remote.id
|
||||
assert file_dict["filename"] == "image.jpg"
|
||||
assert file_dict["url"] == "https://example.com/image.jpg"
|
||||
assert file_dict["extension"] == ".jpg"
|
||||
assert file_dict["type"] == "image"
|
||||
assert file_dict["transfer_method"] == FileTransferMethod.REMOTE_URL.value
|
||||
assert file_dict["remote_url"] == "https://example.com/image.jpg"
|
||||
assert file_dict["upload_file_id"] == mock_message_file_remote.id
|
||||
|
||||
# Verify only one query for message_files is made
|
||||
mock_session.scalars.assert_called_once()
|
||||
|
||||
def test_message_end_with_tool_file_http(self, mock_pipeline, mock_message_file_tool):
|
||||
"""Test that files array is populated correctly for TOOL_FILE with HTTP URL."""
|
||||
# Arrange
|
||||
mock_message_file_tool.message_id = mock_pipeline._message_id
|
||||
mock_message_file_tool.url = "https://example.com/tool_file.png"
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
mock_scalars_result = Mock()
|
||||
mock_scalars_result.all.return_value = [mock_message_file_tool]
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert file_dict["url"] == "https://example.com/tool_file.png"
|
||||
assert file_dict["filename"] == "tool_file.png"
|
||||
assert file_dict["extension"] == ".png"
|
||||
assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value
|
||||
|
||||
def test_message_end_with_tool_file_local(self, mock_pipeline, mock_message_file_tool):
|
||||
"""Test that files array is populated correctly for TOOL_FILE with local path."""
|
||||
# Arrange
|
||||
mock_message_file_tool.message_id = mock_pipeline._message_id
|
||||
mock_message_file_tool.url = "tool_file_123.png"
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
mock_scalars_result = Mock()
|
||||
mock_scalars_result.all.return_value = [mock_message_file_tool]
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
|
||||
mock_sign_tool.return_value = "https://example.com/signed-tool-file.png?signature=xyz"
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert "https://example.com/signed-tool-file.png" in file_dict["url"]
|
||||
assert file_dict["filename"] == "tool_file_123.png"
|
||||
assert file_dict["extension"] == ".png"
|
||||
assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value
|
||||
|
||||
# Verify tool file signing was called
|
||||
mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_123", extension=".png")
|
||||
|
||||
def test_message_end_with_tool_file_long_extension(self, mock_pipeline, mock_message_file_tool):
|
||||
"""Test that TOOL_FILE extensions longer than MAX_TOOL_FILE_EXTENSION_LENGTH fall back to .bin."""
|
||||
mock_message_file_tool.message_id = mock_pipeline._message_id
|
||||
mock_message_file_tool.url = "tool_file_abc.verylongextension"
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
mock_scalars_result = Mock()
|
||||
mock_scalars_result.all.return_value = [mock_message_file_tool]
|
||||
mock_session.scalars.return_value = mock_scalars_result
|
||||
mock_sign_tool.return_value = "https://example.com/signed.bin"
|
||||
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
assert result.files is not None
|
||||
file_dict = result.files[0]
|
||||
assert file_dict["extension"] == ".bin"
|
||||
mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_abc", extension=".bin")
|
||||
|
||||
def test_message_end_with_multiple_files(
|
||||
self, mock_pipeline, mock_message_file_local, mock_message_file_remote, mock_upload_file
|
||||
):
|
||||
"""Test that files array contains all MessageFile records when multiple exist."""
|
||||
# Arrange
|
||||
mock_message_file_local.message_id = mock_pipeline._message_id
|
||||
mock_message_file_remote.message_id = mock_pipeline._message_id
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
# First query: MessageFile
|
||||
mock_message_files_result = Mock()
|
||||
mock_message_files_result.all.return_value = [mock_message_file_local, mock_message_file_remote]
|
||||
|
||||
# Second query: UploadFile (batch query to avoid N+1)
|
||||
mock_upload_files_result = Mock()
|
||||
mock_upload_files_result.all.return_value = [mock_upload_file]
|
||||
|
||||
# Setup scalars to return different results for different queries
|
||||
call_count = [0] # Use list to allow modification in nested function
|
||||
|
||||
def scalars_side_effect(query):
|
||||
call_count[0] += 1
|
||||
# First call is for MessageFile, second call is for UploadFile
|
||||
if call_count[0] == 1:
|
||||
return mock_message_files_result
|
||||
else:
|
||||
return mock_upload_files_result
|
||||
|
||||
mock_session.scalars.side_effect = scalars_side_effect
|
||||
mock_get_url.return_value = "https://example.com/signed-url?signature=abc123"
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 2
|
||||
|
||||
# Verify both files are present
|
||||
file_ids = [f["related_id"] for f in result.files]
|
||||
assert mock_message_file_local.id in file_ids
|
||||
assert mock_message_file_remote.id in file_ids
|
||||
|
||||
def test_message_end_with_local_file_no_upload_file(self, mock_pipeline, mock_message_file_local):
|
||||
"""Test fallback when UploadFile is not found for LOCAL_FILE."""
|
||||
# Arrange
|
||||
mock_message_file_local.message_id = mock_pipeline._message_id
|
||||
|
||||
with (
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
|
||||
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
|
||||
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
|
||||
):
|
||||
mock_engine = MagicMock()
|
||||
mock_db.engine = mock_engine
|
||||
|
||||
mock_session = MagicMock(spec=Session)
|
||||
mock_session_class.return_value.__enter__.return_value = mock_session
|
||||
|
||||
# Mock database queries
|
||||
# First query: MessageFile
|
||||
mock_message_files_result = Mock()
|
||||
mock_message_files_result.all.return_value = [mock_message_file_local]
|
||||
|
||||
# Second query: UploadFile (batch query) - returns empty list (not found)
|
||||
mock_upload_files_result = Mock()
|
||||
mock_upload_files_result.all.return_value = [] # UploadFile not found
|
||||
|
||||
# Setup scalars to return different results for different queries
|
||||
call_count = [0] # Use list to allow modification in nested function
|
||||
|
||||
def scalars_side_effect(query):
|
||||
call_count[0] += 1
|
||||
# First call is for MessageFile, second call is for UploadFile
|
||||
if call_count[0] == 1:
|
||||
return mock_message_files_result
|
||||
else:
|
||||
return mock_upload_files_result
|
||||
|
||||
mock_session.scalars.side_effect = scalars_side_effect
|
||||
mock_get_url.return_value = "https://example.com/fallback-url?signature=def456"
|
||||
|
||||
# Act
|
||||
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, MessageEndStreamResponse)
|
||||
assert result.files is not None
|
||||
assert len(result.files) == 1
|
||||
|
||||
file_dict = result.files[0]
|
||||
assert "https://example.com/fallback-url" in file_dict["url"]
|
||||
# Verify fallback URL was generated using upload_file_id from message_file
|
||||
mock_get_url.assert_called_with(upload_file_id=str(mock_message_file_local.upload_file_id))
|
||||
@@ -1,9 +1,8 @@
|
||||
"""Unit tests for enterprise service integrations.
|
||||
|
||||
This module covers the enterprise-only default workspace auto-join behavior:
|
||||
- Enterprise mode disabled: no external calls
|
||||
- Successful join / skipped join: no errors
|
||||
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
|
||||
Covers:
|
||||
- Default workspace auto-join behavior
|
||||
- License status caching (get_cached_license_status)
|
||||
"""
|
||||
|
||||
from unittest.mock import patch
|
||||
@@ -11,6 +10,9 @@ from unittest.mock import patch
|
||||
import pytest
|
||||
|
||||
from services.enterprise.enterprise_service import (
|
||||
INVALID_LICENSE_CACHE_TTL,
|
||||
LICENSE_STATUS_CACHE_KEY,
|
||||
VALID_LICENSE_CACHE_TTL,
|
||||
DefaultWorkspaceJoinResult,
|
||||
EnterpriseService,
|
||||
try_join_default_workspace,
|
||||
@@ -37,7 +39,6 @@ class TestJoinDefaultWorkspace:
|
||||
"/default-workspace/members",
|
||||
json={"account_id": account_id},
|
||||
timeout=1.0,
|
||||
raise_for_status=True,
|
||||
)
|
||||
|
||||
def test_join_default_workspace_invalid_response_format_raises(self):
|
||||
@@ -139,3 +140,134 @@ class TestTryJoinDefaultWorkspace:
|
||||
|
||||
# Should not raise even though UUID parsing fails inside join_default_workspace
|
||||
try_join_default_workspace("not-a-uuid")
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# get_cached_license_status
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
_EE_SVC = "services.enterprise.enterprise_service"
|
||||
|
||||
|
||||
class TestGetCachedLicenseStatus:
|
||||
"""Tests for EnterpriseService.get_cached_license_status."""
|
||||
|
||||
def test_returns_none_when_enterprise_disabled(self):
|
||||
with patch(f"{_EE_SVC}.dify_config") as mock_config:
|
||||
mock_config.ENTERPRISE_ENABLED = False
|
||||
|
||||
assert EnterpriseService.get_cached_license_status() is None
|
||||
|
||||
def test_cache_hit_returns_license_status_enum(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = b"active"
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.ACTIVE
|
||||
assert isinstance(result, LicenseStatus)
|
||||
mock_get_info.assert_not_called()
|
||||
|
||||
def test_cache_miss_fetches_api_and_caches_valid_status(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_get_info.return_value = {"License": {"status": "active"}}
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.ACTIVE
|
||||
mock_redis.setex.assert_called_once_with(
|
||||
LICENSE_STATUS_CACHE_KEY, VALID_LICENSE_CACHE_TTL, LicenseStatus.ACTIVE
|
||||
)
|
||||
|
||||
def test_cache_miss_fetches_api_and_caches_invalid_status_with_short_ttl(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_get_info.return_value = {"License": {"status": "expired"}}
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.EXPIRED
|
||||
mock_redis.setex.assert_called_once_with(
|
||||
LICENSE_STATUS_CACHE_KEY, INVALID_LICENSE_CACHE_TTL, LicenseStatus.EXPIRED
|
||||
)
|
||||
|
||||
def test_redis_read_failure_falls_through_to_api(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.side_effect = ConnectionError("redis down")
|
||||
mock_get_info.return_value = {"License": {"status": "active"}}
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.ACTIVE
|
||||
mock_get_info.assert_called_once()
|
||||
|
||||
def test_redis_write_failure_still_returns_status(self):
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.setex.side_effect = ConnectionError("redis down")
|
||||
mock_get_info.return_value = {"License": {"status": "expiring"}}
|
||||
|
||||
result = EnterpriseService.get_cached_license_status()
|
||||
|
||||
assert result == LicenseStatus.EXPIRING
|
||||
|
||||
def test_api_failure_returns_none(self):
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_get_info.side_effect = Exception("network failure")
|
||||
|
||||
assert EnterpriseService.get_cached_license_status() is None
|
||||
|
||||
def test_api_returns_no_license_info(self):
|
||||
with (
|
||||
patch(f"{_EE_SVC}.dify_config") as mock_config,
|
||||
patch(f"{_EE_SVC}.redis_client") as mock_redis,
|
||||
patch.object(EnterpriseService, "get_info") as mock_get_info,
|
||||
):
|
||||
mock_config.ENTERPRISE_ENABLED = True
|
||||
mock_redis.get.return_value = None
|
||||
mock_get_info.return_value = {} # no "License" key
|
||||
|
||||
assert EnterpriseService.get_cached_license_status() is None
|
||||
mock_redis.setex.assert_not_called()
|
||||
|
||||
43
api/tests/unit_tests/services/test_export_app_messages.py
Normal file
43
api/tests/unit_tests/services/test_export_app_messages.py
Normal file
@@ -0,0 +1,43 @@
|
||||
import datetime
|
||||
|
||||
import pytest
|
||||
|
||||
from services.retention.conversation.message_export_service import AppMessageExportService
|
||||
|
||||
|
||||
def test_validate_export_filename_accepts_relative_path():
|
||||
assert AppMessageExportService.validate_export_filename("exports/2026/test01") == "exports/2026/test01"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"filename",
|
||||
[
|
||||
"test01.jsonl.gz",
|
||||
"test01.jsonl",
|
||||
"test01.gz",
|
||||
"/tmp/test01",
|
||||
"exports/../test01",
|
||||
"bad\x00name",
|
||||
"bad\tname",
|
||||
"a" * 1025,
|
||||
],
|
||||
)
|
||||
def test_validate_export_filename_rejects_invalid_values(filename: str):
|
||||
with pytest.raises(ValueError):
|
||||
AppMessageExportService.validate_export_filename(filename)
|
||||
|
||||
|
||||
def test_service_derives_output_names_from_filename_base():
|
||||
service = AppMessageExportService(
|
||||
app_id="736b9b03-20f2-4697-91da-8d00f6325900",
|
||||
start_from=None,
|
||||
end_before=datetime.datetime(2026, 3, 1),
|
||||
filename="exports/2026/test01",
|
||||
batch_size=1000,
|
||||
use_cloud_storage=True,
|
||||
dry_run=True,
|
||||
)
|
||||
|
||||
assert service._filename_base == "exports/2026/test01"
|
||||
assert service.output_gz_name == "exports/2026/test01.jsonl.gz"
|
||||
assert service.output_jsonl_name == "exports/2026/test01.jsonl"
|
||||
@@ -1,5 +1,6 @@
|
||||
'use client'
|
||||
import type { ReactNode } from 'react'
|
||||
import type { IToastProps } from './context'
|
||||
import { noop } from 'es-toolkit/function'
|
||||
import * as React from 'react'
|
||||
import { useEffect, useState } from 'react'
|
||||
|
||||
Reference in New Issue
Block a user