Compare commits

...

22 Commits

Author SHA1 Message Date
copilot-swe-agent[bot]
0e4f5cb38c refactor: move console_exempt_prefixes to module level in app_factory.py
Co-authored-by: GareArc <52963600+GareArc@users.noreply.github.com>
2026-03-09 07:28:50 +00:00
copilot-swe-agent[bot]
c13d1872d4 Initial plan 2026-03-09 07:27:12 +00:00
GareArc
c911de6a6c fix: exempt setup flow endpoints from license check
Add /console/api/init and /console/api/login to the license exempt
list so that fresh installs can complete setup when the enterprise
license is inactive. Without these exemptions the init password
validation and post-setup auto-login are blocked, causing the setup
page to enter an infinite reload loop.
2026-03-08 23:46:26 -07:00
Xiyuan Chen
968bf10e1c Update api/services/enterprise/enterprise_service.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-08 17:35:50 -07:00
Xiyuan Chen
3d77a5ec08 Update api/services/feature_service.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-08 17:07:17 -07:00
GareArc
41af72449d fix: address PR review feedback on enterprise license enforcement
- Cache invalid license statuses with 30s TTL to prevent DoS amplification
- Return LicenseStatus enum (not raw str) from get_cached_license_status
- Flatten nested try/except into _read_cached_license_status / _fetch_and_cache_license_status helpers
- Escalate log levels from debug to warning with exc_info for cache failures
- Switch before_request license check from fail-open to fail-closed
- Remove dead raise_for_status parameter from BaseRequest.send_request
- Gate license expired_at behind is_authenticated; only expose status to unauthenticated callers (CVE-2025-63387)
- Remove redundant 'not is_console_api' guard in before_request
- Add 8 unit tests for get_cached_license_status
2026-03-08 17:00:12 -07:00
Xiyuan Chen
de72bdef71 Merge branch 'main' into fix/main-enterprise-api-error-handling 2026-03-08 16:28:01 -07:00
CoralGarden52
c925d17e8f chore: add TypedDict related prompt to api/AGENTS.md (#33116)
Some checks failed
Mark stale issues and pull requests / stale (push) Has been cancelled
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
2026-03-08 07:03:52 +09:00
Angel
dc2a53d834 feat: add files to message end pr32019 (#32242)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Co-authored-by: fatelei <fatelei@gmail.com>
Co-authored-by: angel.k <angel.kolev@solaredge.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-03-07 20:01:12 +08:00
hj24
05ab107e73 feat: add export app messages (#32990)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
2026-03-07 11:27:15 +08:00
pepsi
c016793efb refactor: pass KnowledgeConfiguration directly instead of dict (#32732)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
Co-authored-by: pepsi <pepsi@pepsidexuniji.local>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-06 22:15:32 +09:00
Coding On Star
a5bcbaebb7 feat(toast): add IToastProps type import to enhance type safety (#33096)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-03-06 19:22:55 +08:00
GareArc
f97ade7053 fix: use LicenseStatus enum instead of raw strings and tighten path prefix matching
Replace raw license status strings with LicenseStatus enum values in
app_factory.py and enterprise_service.py to prevent silent mismatches.
Use trailing-slash prefixes ('/console/api/', '/api/') to avoid false
matches on unrelated paths like /api-docs.
2026-03-05 01:17:49 -08:00
GareArc
a0dcd04546 fix: remove extra exempts 2026-03-05 01:10:23 -08:00
autofix-ci[bot]
b0138316f0 [autofix.ci] apply automated fixes 2026-03-05 09:02:35 +00:00
GareArc
099568f3da fix: expose license status to unauthenticated /system-features callers
After force-logout due to license expiry, the login page calls
/system-features without auth. The license block was gated behind
is_authenticated, so the frontend always saw status='none' instead
of the actual expiry status. Split the guard so license.status and
expired_at are always returned while workspace usage details remain
auth-gated.
2026-03-05 00:48:50 -08:00
GareArc
0623522d04 fix: exempt console bootstrap APIs from license check to prevent infinite reload loop 2026-03-04 22:13:52 -08:00
GareArc
a25d48c5bd feat: add Redis caching for enterprise license status
Cache license status for 10 minutes to reduce HTTP calls to enterprise API.
Only caches license status, not full system features.

Changes:
- Add EnterpriseService.get_cached_license_status() method
- Cache key: enterprise:license:status
- TTL: 600 seconds (10 minutes)
- Graceful degradation: falls back to API call if Redis fails

Performance improvement:
- Before: HTTP call (~50-200ms) on every API request
- After: Redis lookup (~1ms) on cached requests
- Reduces load on enterprise service by ~99%
2026-03-04 21:29:11 -08:00
GareArc
4f3a020670 feat: extend license enforcement to webapp API endpoints
Extend license middleware to also block webapp API (/api/*) when
enterprise license is expired/inactive/lost.

Changes:
- Check both /console/api and /api endpoints
- Add webapp-specific exempt paths:
  - /api/passport (webapp authentication)
  - /api/login, /api/logout, /api/oauth
  - /api/forgot-password
  - /api/system-features (webapp needs this to check license status)

This ensures both console users and webapp users are blocked when
license expires, maintaining consistent enforcement across all APIs.
2026-03-04 20:40:29 -08:00
GareArc
d2e1177478 fix: use UnauthorizedAndForceLogout to trigger frontend logout on license expiry
Change license check to raise UnauthorizedAndForceLogout exception instead
of returning generic JSON response. This ensures proper frontend handling:

Frontend behavior (service/base.ts line 588):
- Checks if code === 'unauthorized_and_force_logout'
- Executes globalThis.location.reload()
- Forces user logout and redirect to login page
- Login page displays license expiration UI (already exists)

Response format:
- HTTP 401 (not 403)
- code: "unauthorized_and_force_logout"
- Triggers frontend reload which clears auth state

This completes the license enforcement flow:
1. Backend blocks all business APIs when license expires
2. Backend returns proper error code to trigger logout
3. Frontend reloads and redirects to login
4. Login page shows license expiration message
2026-03-04 20:40:29 -08:00
GareArc
8a21fd88fd feat: add global license check middleware to block API access on expiry
Add before_request middleware that validates enterprise license status
for all /console/api endpoints when ENTERPRISE_ENABLED is true.

Behavior:
- Checks license status before each console API request
- Returns 403 with clear error message when license is expired/inactive/lost
- Exempts auth endpoints (login, oauth, forgot-password, etc.)
- Exempts /console/api/features so frontend can fetch license status
- Gracefully handles errors to avoid service disruption

This ensures all business APIs are blocked when license expires,
addressing the issue where APIs remained callable after expiry.
2026-03-04 20:40:29 -08:00
GareArc
1c1bcc67da fix: handle enterprise API errors properly to prevent KeyError crashes
When enterprise API returns 403/404, the response contains error JSON
instead of expected data structure. Code was accessing fields directly
causing KeyError → 500 Internal Server Error.

Changes:
- Add enterprise-specific error classes (EnterpriseAPIError, etc.)
- Implement centralized error validation in EnterpriseRequest.send_request()
- Extract error messages from API responses (message/error/detail fields)
- Raise domain-specific errors based on HTTP status codes
- Preserve backward compatibility with raise_for_status parameter

This prevents KeyError crashes and returns proper HTTP error codes
(403/404) instead of 500 errors.
2026-03-04 19:55:03 -08:00
19 changed files with 1586 additions and 104 deletions

View File

@@ -62,6 +62,22 @@ This is the default standard for backend code in this repo. Follow it for new co
- Code should usually include type annotations that match the repos current Python version (avoid untyped public APIs and “mystery” values).
- Prefer modern typing forms (e.g. `list[str]`, `dict[str, int]`) and avoid `Any` unless theres a strong reason.
- For dictionary-like data with known keys and value types, prefer `TypedDict` over `dict[...]` or `Mapping[...]`.
- For optional keys in typed payloads, use `NotRequired[...]` (or `total=False` when most fields are optional).
- Keep `dict[...]` / `Mapping[...]` for truly dynamic key spaces where the key set is unknown.
```python
from datetime import datetime
from typing import NotRequired, TypedDict
class UserProfile(TypedDict):
user_id: str
email: str
created_at: datetime
nickname: NotRequired[str]
```
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
```python

View File

@@ -1,16 +1,38 @@
import logging
import time
from flask import request
from opentelemetry.trace import get_current_span
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
from configs import dify_config
from contexts.wrapper import RecyclableContextVar
from controllers.console.error import UnauthorizedAndForceLogout
from core.logging.context import init_request_context
from dify_app import DifyApp
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import LicenseStatus
logger = logging.getLogger(__name__)
# Console bootstrap APIs exempt from license check:
# - system-features: license status for expiry UI (GlobalPublicStoreProvider)
# - setup: install/setup status check (AppInitializer)
# - init: init password validation for fresh install (InitPasswordPopup)
# - login: auto-login after setup completion (InstallForm)
# - version: version check (AppContextProvider)
# - activate/check: invitation link validation (signin page)
# Without these exemptions, the signin page triggers location.reload()
# on unauthorized_and_force_logout, causing an infinite loop.
_CONSOLE_EXEMPT_PREFIXES = (
"/console/api/system-features",
"/console/api/setup",
"/console/api/init",
"/console/api/login",
"/console/api/version",
"/console/api/activate/check",
)
# ----------------------------
# Application Factory Function
@@ -31,6 +53,39 @@ def create_flask_app_with_configs() -> DifyApp:
init_request_context()
RecyclableContextVar.increment_thread_recycles()
# Enterprise license validation for API endpoints (both console and webapp)
# When license expires, block all API access except bootstrap endpoints needed
# for the frontend to load the license expiration page without infinite reloads.
if dify_config.ENTERPRISE_ENABLED:
is_console_api = request.path.startswith("/console/api/")
is_webapp_api = request.path.startswith("/api/")
if is_console_api or is_webapp_api:
if is_console_api:
is_exempt = any(request.path.startswith(p) for p in _CONSOLE_EXEMPT_PREFIXES)
else: # webapp API
is_exempt = request.path.startswith("/api/system-features")
if not is_exempt:
try:
# Check license status (cached — see EnterpriseService for TTL details)
license_status = EnterpriseService.get_cached_license_status()
if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST):
raise UnauthorizedAndForceLogout(
f"Enterprise license is {license_status}. Please contact your administrator."
)
if license_status is None:
raise UnauthorizedAndForceLogout(
"Unable to verify enterprise license. Please contact your administrator."
)
except UnauthorizedAndForceLogout:
raise
except Exception:
logger.exception("Failed to check enterprise license status")
raise UnauthorizedAndForceLogout(
"Unable to verify enterprise license. Please contact your administrator."
)
# add after request hook for injecting trace headers from OpenTelemetry span context
# Only adds headers when OTEL is enabled and has valid context
@dify_app.after_request

View File

@@ -2668,3 +2668,77 @@ def clean_expired_messages(
raise
click.echo(click.style("messages cleanup completed.", fg="green"))
@click.command("export-app-messages", help="Export messages for an app to JSONL.GZ.")
@click.option("--app-id", required=True, help="Application ID to export messages for.")
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
default=None,
help="Optional lower bound (inclusive) for created_at.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=True,
help="Upper bound (exclusive) for created_at.",
)
@click.option(
"--filename",
required=True,
help="Base filename (relative path). Do not include suffix like .jsonl.gz.",
)
@click.option("--use-cloud-storage", is_flag=True, default=False, help="Upload to cloud storage instead of local file.")
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for cursor pagination.")
@click.option("--dry-run", is_flag=True, default=False, help="Scan only, print stats without writing any file.")
def export_app_messages(
app_id: str,
start_from: datetime.datetime | None,
end_before: datetime.datetime,
filename: str,
use_cloud_storage: bool,
batch_size: int,
dry_run: bool,
):
if start_from and start_from >= end_before:
raise click.UsageError("--start-from must be before --end-before.")
from services.retention.conversation.message_export_service import AppMessageExportService
try:
validated_filename = AppMessageExportService.validate_export_filename(filename)
except ValueError as e:
raise click.BadParameter(str(e), param_hint="--filename") from e
click.echo(click.style(f"export_app_messages: starting export for app {app_id}.", fg="green"))
start_at = time.perf_counter()
try:
service = AppMessageExportService(
app_id=app_id,
end_before=end_before,
filename=validated_filename,
start_from=start_from,
batch_size=batch_size,
use_cloud_storage=use_cloud_storage,
dry_run=dry_run,
)
stats = service.run()
elapsed = time.perf_counter() - start_at
click.echo(
click.style(
f"export_app_messages: completed in {elapsed:.2f}s\n"
f" - Batches: {stats.batches}\n"
f" - Total messages: {stats.total_messages}\n"
f" - Messages with feedback: {stats.messages_with_feedback}\n"
f" - Total feedbacks: {stats.total_feedbacks}",
fg="green",
)
)
except Exception as e:
elapsed = time.perf_counter() - start_at
logger.exception("export_app_messages failed")
click.echo(click.style(f"export_app_messages: failed after {elapsed:.2f}s - {e}", fg="red"))
raise

View File

@@ -44,14 +44,13 @@ from core.app.entities.task_entities import (
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.app.task_pipeline.message_file_utils import prepare_file_dict
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_manager import ModelInstance
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.tools.signature import sign_tool_file
from dify_graph.file import helpers as file_helpers
from dify_graph.file.enums import FileTransferMethod
from dify_graph.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from dify_graph.model_runtime.entities.message_entities import (
@@ -460,91 +459,40 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
"""
self._task_state.metadata.usage = self._task_state.llm_result.usage
metadata_dict = self._task_state.metadata.model_dump()
# Fetch files associated with this message
files = None
with Session(db.engine, expire_on_commit=False) as session:
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
if message_files:
# Fetch all required UploadFile objects in a single query to avoid N+1 problem
upload_file_ids = list(
dict.fromkeys(
mf.upload_file_id
for mf in message_files
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
)
)
upload_files_map = {}
if upload_file_ids:
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
upload_files_map = {uf.id: uf for uf in upload_files}
files_list = []
for message_file in message_files:
file_dict = prepare_file_dict(message_file, upload_files_map)
files_list.append(file_dict)
files = files_list or None
return MessageEndStreamResponse(
task_id=self._application_generate_entity.task_id,
id=self._message_id,
metadata=metadata_dict,
files=files,
)
def _record_files(self):
with Session(db.engine, expire_on_commit=False) as session:
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
if not message_files:
return None
files_list = []
upload_file_ids = [
mf.upload_file_id
for mf in message_files
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
]
upload_files_map = {}
if upload_file_ids:
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
upload_files_map = {uf.id: uf for uf in upload_files}
for message_file in message_files:
upload_file = None
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
upload_file = upload_files_map.get(message_file.upload_file_id)
url = None
filename = "file"
mime_type = "application/octet-stream"
size = 0
extension = ""
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
url = message_file.url
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
# Fallback: generate URL even if upload_file not found
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
# For tool files, use URL directly if it's HTTP, otherwise sign it
if message_file.url.startswith("http"):
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
else:
# Extract tool file id and extension from URL
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0] # Remove query params first
# Use rsplit to correctly handle filenames with multiple dots
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
else:
tool_file_id = file_part
extension = ".bin"
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
transfer_method_value = message_file.transfer_method
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
file_dict = {
"related_id": message_file.id,
"extension": extension,
"filename": filename,
"size": size,
"mime_type": mime_type,
"transfer_method": transfer_method_value,
"type": message_file.type,
"url": url or "",
"upload_file_id": message_file.upload_file_id or message_file.id,
"remote_url": remote_url,
}
files_list.append(file_dict)
return files_list or None
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
"""
Agent message to stream response.

View File

@@ -0,0 +1,76 @@
from core.tools.signature import sign_tool_file
from dify_graph.file import helpers as file_helpers
from dify_graph.file.enums import FileTransferMethod
from models.model import MessageFile, UploadFile
MAX_TOOL_FILE_EXTENSION_LENGTH = 10
def prepare_file_dict(message_file: MessageFile, upload_files_map: dict[str, UploadFile]) -> dict:
"""
Prepare file dictionary for message end stream response.
:param message_file: MessageFile instance
:param upload_files_map: Dictionary mapping upload_file_id to UploadFile
:return: Dictionary containing file information
"""
upload_file = None
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
upload_file = upload_files_map.get(message_file.upload_file_id)
url = None
filename = "file"
mime_type = "application/octet-stream"
size = 0
extension = ""
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
url = message_file.url
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
if message_file.url.startswith(("http://", "https://")):
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
if "." in filename:
extension = "." + filename.rsplit(".", 1)[1]
else:
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0]
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
if len(extension) > MAX_TOOL_FILE_EXTENSION_LENGTH:
extension = ".bin"
else:
tool_file_id = file_part
extension = ".bin"
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
transfer_method_value = message_file.transfer_method.value
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
return {
"related_id": message_file.id,
"extension": extension,
"filename": filename,
"size": size,
"mime_type": mime_type,
"transfer_method": transfer_method_value,
"type": message_file.type,
"url": url or "",
"upload_file_id": message_file.upload_file_id or message_file.id,
"remote_url": remote_url,
}

View File

@@ -13,6 +13,7 @@ def init_app(app: DifyApp):
convert_to_agent_apps,
create_tenant,
delete_archived_workflow_runs,
export_app_messages,
extract_plugins,
extract_unique_plugins,
file_usage,
@@ -66,6 +67,7 @@ def init_app(app: DifyApp):
restore_workflow_runs,
clean_workflow_runs,
clean_expired_messages,
export_app_messages,
]
for cmd in cmds_to_register:
app.cli.add_command(cmd)

View File

@@ -6,6 +6,13 @@ from typing import Any
import httpx
from core.helper.trace_id_helper import generate_traceparent_header
from services.errors.enterprise import (
EnterpriseAPIBadRequestError,
EnterpriseAPIError,
EnterpriseAPIForbiddenError,
EnterpriseAPINotFoundError,
EnterpriseAPIUnauthorizedError,
)
logger = logging.getLogger(__name__)
@@ -41,7 +48,6 @@ class BaseRequest:
params: Mapping[str, Any] | None = None,
*,
timeout: float | httpx.Timeout | None = None,
raise_for_status: bool = False,
) -> Any:
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
url = f"{cls.base_url}{endpoint}"
@@ -64,10 +70,51 @@ class BaseRequest:
request_kwargs["timeout"] = timeout
response = client.request(method, url, **request_kwargs)
if raise_for_status:
response.raise_for_status()
# Validate HTTP status and raise domain-specific errors
if not response.is_success:
cls._handle_error_response(response)
return response.json()
@classmethod
def _handle_error_response(cls, response: httpx.Response) -> None:
"""
Handle non-2xx HTTP responses by raising appropriate domain errors.
Attempts to extract error message from JSON response body,
falls back to status text if parsing fails.
"""
error_message = f"Enterprise API request failed: {response.status_code} {response.reason_phrase}"
# Try to extract error message from JSON response
try:
error_data = response.json()
if isinstance(error_data, dict):
# Common error response formats:
# {"error": "...", "message": "..."}
# {"message": "..."}
# {"detail": "..."}
error_message = (
error_data.get("message") or error_data.get("error") or error_data.get("detail") or error_message
)
except Exception:
# If JSON parsing fails, use the default message
logger.debug(
"Failed to parse error response from enterprise API (status=%s)", response.status_code, exc_info=True
)
# Raise specific error based on status code
if response.status_code == 400:
raise EnterpriseAPIBadRequestError(error_message)
elif response.status_code == 401:
raise EnterpriseAPIUnauthorizedError(error_message)
elif response.status_code == 403:
raise EnterpriseAPIForbiddenError(error_message)
elif response.status_code == 404:
raise EnterpriseAPINotFoundError(error_message)
else:
raise EnterpriseAPIError(error_message, status_code=response.status_code)
class EnterpriseRequest(BaseRequest):
base_url = os.environ.get("ENTERPRISE_API_URL", "ENTERPRISE_API_URL")

View File

@@ -1,15 +1,26 @@
from __future__ import annotations
import logging
import uuid
from datetime import datetime
from typing import TYPE_CHECKING
from pydantic import BaseModel, ConfigDict, Field, model_validator
from configs import dify_config
from extensions.ext_redis import redis_client
from services.enterprise.base import EnterpriseRequest
if TYPE_CHECKING:
from services.feature_service import LicenseStatus
logger = logging.getLogger(__name__)
DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0
# License status cache configuration
LICENSE_STATUS_CACHE_KEY = "enterprise:license:status"
VALID_LICENSE_CACHE_TTL = 600 # 10 minutes — valid licenses are stable
INVALID_LICENSE_CACHE_TTL = 30 # 30 seconds — short so admin fixes are picked up quickly
class WebAppSettings(BaseModel):
@@ -52,7 +63,7 @@ class DefaultWorkspaceJoinResult(BaseModel):
model_config = ConfigDict(extra="forbid", populate_by_name=True)
@model_validator(mode="after")
def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult":
def _check_workspace_id_when_joined(self) -> DefaultWorkspaceJoinResult:
if self.joined and not self.workspace_id:
raise ValueError("workspace_id must be non-empty when joined is True")
return self
@@ -115,7 +126,6 @@ class EnterpriseService:
"/default-workspace/members",
json={"account_id": account_id},
timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS,
raise_for_status=True,
)
if not isinstance(data, dict):
raise ValueError("Invalid response format from enterprise default workspace API")
@@ -223,3 +233,64 @@ class EnterpriseService:
params = {"appId": app_id}
EnterpriseRequest.send_request("DELETE", "/webapp/clean", params=params)
@classmethod
def get_cached_license_status(cls) -> LicenseStatus | None:
"""Get enterprise license status with Redis caching to reduce HTTP calls.
Caches valid statuses (active/expiring) for 10 minutes and invalid statuses
(inactive/expired/lost) for 30 seconds. The shorter TTL for invalid statuses
balances prompt license-fix detection against DoS mitigation — without
caching, every request on an expired license would hit the enterprise API.
Returns:
LicenseStatus enum value, or None if enterprise is disabled / unreachable.
"""
if not dify_config.ENTERPRISE_ENABLED:
return None
cached = cls._read_cached_license_status()
if cached is not None:
return cached
return cls._fetch_and_cache_license_status()
@classmethod
def _read_cached_license_status(cls) -> LicenseStatus | None:
"""Read license status from Redis cache, returning None on miss or failure."""
from services.feature_service import LicenseStatus
try:
raw = redis_client.get(LICENSE_STATUS_CACHE_KEY)
if raw:
value = raw.decode("utf-8") if isinstance(raw, bytes) else raw
return LicenseStatus(value)
except Exception:
logger.warning("Failed to read license status from cache", exc_info=True)
return None
@classmethod
def _fetch_and_cache_license_status(cls) -> LicenseStatus | None:
"""Fetch license status from enterprise API and cache the result."""
from services.feature_service import LicenseStatus
try:
info = cls.get_info()
license_info = info.get("License")
if not license_info:
return None
status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
ttl = (
VALID_LICENSE_CACHE_TTL
if status in (LicenseStatus.ACTIVE, LicenseStatus.EXPIRING)
else INVALID_LICENSE_CACHE_TTL
)
try:
redis_client.setex(LICENSE_STATUS_CACHE_KEY, ttl, status)
except Exception:
logger.warning("Failed to cache license status", exc_info=True)
return status
except Exception:
logger.exception("Failed to get enterprise license status")
return None

View File

@@ -7,6 +7,7 @@ from . import (
conversation,
dataset,
document,
enterprise,
file,
index,
message,
@@ -21,6 +22,7 @@ __all__ = [
"conversation",
"dataset",
"document",
"enterprise",
"file",
"index",
"message",

View File

@@ -0,0 +1,45 @@
"""Enterprise service errors."""
from services.errors.base import BaseServiceError
class EnterpriseServiceError(BaseServiceError):
"""Base exception for enterprise service errors."""
def __init__(self, description: str | None = None, status_code: int | None = None):
super().__init__(description)
self.status_code = status_code
class EnterpriseAPIError(EnterpriseServiceError):
"""Generic enterprise API error (non-2xx response)."""
pass
class EnterpriseAPINotFoundError(EnterpriseServiceError):
"""Enterprise API returned 404 Not Found."""
def __init__(self, description: str | None = None):
super().__init__(description, status_code=404)
class EnterpriseAPIForbiddenError(EnterpriseServiceError):
"""Enterprise API returned 403 Forbidden."""
def __init__(self, description: str | None = None):
super().__init__(description, status_code=403)
class EnterpriseAPIUnauthorizedError(EnterpriseServiceError):
"""Enterprise API returned 401 Unauthorized."""
def __init__(self, description: str | None = None):
super().__init__(description, status_code=401)
class EnterpriseAPIBadRequestError(EnterpriseServiceError):
"""Enterprise API returned 400 Bad Request."""
def __init__(self, description: str | None = None):
super().__init__(description, status_code=400)

View File

@@ -379,14 +379,19 @@ class FeatureService:
)
features.webapp_auth.sso_config.protocol = enterprise_info.get("SSOEnforcedForWebProtocol", "")
if is_authenticated and (license_info := enterprise_info.get("License")):
# SECURITY NOTE: Only license *status* is exposed to unauthenticated callers
# so the login page can detect an expired/inactive license after force-logout.
# All other license details (expiry date, workspace usage) remain auth-gated.
# This behavior reflects prior internal review of information-leakage risks.
if license_info := enterprise_info.get("License"):
features.license.status = LicenseStatus(license_info.get("status", LicenseStatus.INACTIVE))
features.license.expired_at = license_info.get("expiredAt", "")
if workspaces_info := license_info.get("workspaces"):
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
features.license.workspaces.limit = workspaces_info.get("limit", 0)
features.license.workspaces.size = workspaces_info.get("used", 0)
if is_authenticated:
features.license.expired_at = license_info.get("expiredAt", "")
if workspaces_info := license_info.get("workspaces"):
features.license.workspaces.enabled = workspaces_info.get("enabled", False)
features.license.workspaces.limit = workspaces_info.get("limit", 0)
features.license.workspaces.size = workspaces_info.get("used", 0)
if "PluginInstallationPermission" in enterprise_info:
plugin_installation_info = enterprise_info["PluginInstallationPermission"]

View File

@@ -63,7 +63,12 @@ class RagPipelineTransformService:
):
node = self._deal_file_extensions(node)
if node.get("data", {}).get("type") == "knowledge-index":
node = self._deal_knowledge_index(dataset, doc_form, indexing_technique, retrieval_model, node)
knowledge_configuration = KnowledgeConfiguration.model_validate(node.get("data", {}))
if dataset.tenant_id != current_user.current_tenant_id:
raise ValueError("Unauthorized")
node = self._deal_knowledge_index(
knowledge_configuration, dataset, indexing_technique, retrieval_model, node
)
new_nodes.append(node)
if new_nodes:
graph["nodes"] = new_nodes
@@ -155,14 +160,13 @@ class RagPipelineTransformService:
def _deal_knowledge_index(
self,
knowledge_configuration: KnowledgeConfiguration,
dataset: Dataset,
doc_form: str,
indexing_technique: str | None,
retrieval_model: RetrievalSetting | None,
node: dict,
):
knowledge_configuration_dict = node.get("data", {})
knowledge_configuration = KnowledgeConfiguration.model_validate(knowledge_configuration_dict)
if indexing_technique == "high_quality":
knowledge_configuration.embedding_model = dataset.embedding_model

View File

@@ -0,0 +1,304 @@
"""
Export app messages to JSONL.GZ format.
Outputs: conversation_id, message_id, query, answer, inputs (raw JSON),
retriever_resources (from message_metadata), feedback (user feedbacks array).
Uses (created_at, id) cursor pagination and batch-loads feedbacks to avoid N+1.
Does NOT touch Message.inputs / Message.user_feedback properties.
"""
import datetime
import gzip
import json
import logging
import tempfile
from collections import defaultdict
from collections.abc import Generator, Iterable
from pathlib import Path, PurePosixPath
from typing import Any, BinaryIO, cast
import orjson
import sqlalchemy as sa
from pydantic import BaseModel, ConfigDict, Field
from sqlalchemy import select, tuple_
from sqlalchemy.orm import Session
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.model import Message, MessageFeedback
logger = logging.getLogger(__name__)
MAX_FILENAME_BASE_LENGTH = 1024
FORBIDDEN_FILENAME_SUFFIXES = (".jsonl.gz", ".jsonl", ".gz")
class AppMessageExportFeedback(BaseModel):
id: str
app_id: str
conversation_id: str
message_id: str
rating: str
content: str | None = None
from_source: str
from_end_user_id: str | None = None
from_account_id: str | None = None
created_at: str
updated_at: str
model_config = ConfigDict(extra="forbid")
class AppMessageExportRecord(BaseModel):
conversation_id: str
message_id: str
query: str
answer: str
inputs: dict[str, Any]
retriever_resources: list[Any] = Field(default_factory=list)
feedback: list[AppMessageExportFeedback] = Field(default_factory=list)
model_config = ConfigDict(extra="forbid")
class AppMessageExportStats(BaseModel):
batches: int = 0
total_messages: int = 0
messages_with_feedback: int = 0
total_feedbacks: int = 0
model_config = ConfigDict(extra="forbid")
class AppMessageExportService:
@staticmethod
def validate_export_filename(filename: str) -> str:
normalized = filename.strip()
if not normalized:
raise ValueError("--filename must not be empty.")
normalized_lower = normalized.lower()
if normalized_lower.endswith(FORBIDDEN_FILENAME_SUFFIXES):
raise ValueError("--filename must not include .jsonl.gz/.jsonl/.gz suffix; pass base filename only.")
if normalized.startswith("/"):
raise ValueError("--filename must be a relative path; absolute paths are not allowed.")
if "\\" in normalized:
raise ValueError("--filename must use '/' as path separator; '\\' is not allowed.")
if "//" in normalized:
raise ValueError("--filename must not contain empty path segments ('//').")
if len(normalized) > MAX_FILENAME_BASE_LENGTH:
raise ValueError(f"--filename is too long; max length is {MAX_FILENAME_BASE_LENGTH}.")
for ch in normalized:
if ch == "\x00" or ord(ch) < 32 or ord(ch) == 127:
raise ValueError("--filename must not contain control characters or NUL.")
parts = PurePosixPath(normalized).parts
if not parts:
raise ValueError("--filename must include a file name.")
if any(part in (".", "..") for part in parts):
raise ValueError("--filename must not contain '.' or '..' path segments.")
return normalized
@property
def output_gz_name(self) -> str:
return f"{self._filename_base}.jsonl.gz"
@property
def output_jsonl_name(self) -> str:
return f"{self._filename_base}.jsonl"
def __init__(
self,
app_id: str,
end_before: datetime.datetime,
filename: str,
*,
start_from: datetime.datetime | None = None,
batch_size: int = 1000,
use_cloud_storage: bool = False,
dry_run: bool = False,
) -> None:
if start_from and start_from >= end_before:
raise ValueError(f"start_from ({start_from}) must be before end_before ({end_before})")
self._app_id = app_id
self._end_before = end_before
self._start_from = start_from
self._filename_base = self.validate_export_filename(filename)
self._batch_size = batch_size
self._use_cloud_storage = use_cloud_storage
self._dry_run = dry_run
def run(self) -> AppMessageExportStats:
stats = AppMessageExportStats()
logger.info(
"export_app_messages: app_id=%s, start_from=%s, end_before=%s, dry_run=%s, cloud=%s, output_gz=%s",
self._app_id,
self._start_from,
self._end_before,
self._dry_run,
self._use_cloud_storage,
self.output_gz_name,
)
if self._dry_run:
for _ in self._iter_records_with_stats(stats):
pass
self._finalize_stats(stats)
return stats
if self._use_cloud_storage:
self._export_to_cloud(stats)
else:
self._export_to_local(stats)
self._finalize_stats(stats)
return stats
def iter_records(self) -> Generator[AppMessageExportRecord, None, None]:
for batch in self._iter_record_batches():
yield from batch
@staticmethod
def write_jsonl_gz(records: Iterable[AppMessageExportRecord], fileobj: BinaryIO) -> None:
with gzip.GzipFile(fileobj=fileobj, mode="wb") as gz:
for record in records:
gz.write(orjson.dumps(record.model_dump(mode="json")) + b"\n")
def _export_to_local(self, stats: AppMessageExportStats) -> None:
output_path = Path.cwd() / self.output_gz_name
output_path.parent.mkdir(parents=True, exist_ok=True)
with output_path.open("wb") as output_file:
self.write_jsonl_gz(self._iter_records_with_stats(stats), output_file)
def _export_to_cloud(self, stats: AppMessageExportStats) -> None:
with tempfile.SpooledTemporaryFile(max_size=64 * 1024 * 1024) as tmp:
self.write_jsonl_gz(self._iter_records_with_stats(stats), cast(BinaryIO, tmp))
tmp.seek(0)
data = tmp.read()
storage.save(self.output_gz_name, data)
logger.info("export_app_messages: uploaded %d bytes to cloud key=%s", len(data), self.output_gz_name)
def _iter_records_with_stats(self, stats: AppMessageExportStats) -> Generator[AppMessageExportRecord, None, None]:
for record in self.iter_records():
self._update_stats(stats, record)
yield record
@staticmethod
def _update_stats(stats: AppMessageExportStats, record: AppMessageExportRecord) -> None:
stats.total_messages += 1
if record.feedback:
stats.messages_with_feedback += 1
stats.total_feedbacks += len(record.feedback)
def _finalize_stats(self, stats: AppMessageExportStats) -> None:
if stats.total_messages == 0:
stats.batches = 0
return
stats.batches = (stats.total_messages + self._batch_size - 1) // self._batch_size
def _iter_record_batches(self) -> Generator[list[AppMessageExportRecord], None, None]:
cursor: tuple[datetime.datetime, str] | None = None
while True:
rows, cursor = self._fetch_batch(cursor)
if not rows:
break
message_ids = [str(row.id) for row in rows]
feedbacks_map = self._fetch_feedbacks(message_ids)
yield [self._build_record(row, feedbacks_map) for row in rows]
def _fetch_batch(
self, cursor: tuple[datetime.datetime, str] | None
) -> tuple[list[Any], tuple[datetime.datetime, str] | None]:
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(
Message.id,
Message.conversation_id,
Message.query,
Message.answer,
Message._inputs, # pyright: ignore[reportPrivateUsage]
Message.message_metadata,
Message.created_at,
)
.where(
Message.app_id == self._app_id,
Message.created_at < self._end_before,
)
.order_by(Message.created_at, Message.id)
.limit(self._batch_size)
)
if self._start_from:
stmt = stmt.where(Message.created_at >= self._start_from)
if cursor:
stmt = stmt.where(
tuple_(Message.created_at, Message.id)
> tuple_(
sa.literal(cursor[0], type_=sa.DateTime()),
sa.literal(cursor[1], type_=Message.id.type),
)
)
rows = list(session.execute(stmt).all())
if not rows:
return [], cursor
last = rows[-1]
return rows, (last.created_at, last.id)
def _fetch_feedbacks(self, message_ids: list[str]) -> dict[str, list[AppMessageExportFeedback]]:
if not message_ids:
return {}
with Session(db.engine, expire_on_commit=False) as session:
stmt = (
select(MessageFeedback)
.where(
MessageFeedback.message_id.in_(message_ids),
MessageFeedback.from_source == "user",
)
.order_by(MessageFeedback.message_id, MessageFeedback.created_at)
)
feedbacks = list(session.scalars(stmt).all())
result: dict[str, list[AppMessageExportFeedback]] = defaultdict(list)
for feedback in feedbacks:
result[str(feedback.message_id)].append(AppMessageExportFeedback.model_validate(feedback.to_dict()))
return result
@staticmethod
def _build_record(row: Any, feedbacks_map: dict[str, list[AppMessageExportFeedback]]) -> AppMessageExportRecord:
retriever_resources: list[Any] = []
if row.message_metadata:
try:
metadata = json.loads(row.message_metadata)
value = metadata.get("retriever_resources", [])
if isinstance(value, list):
retriever_resources = value
except (json.JSONDecodeError, TypeError):
pass
message_id = str(row.id)
return AppMessageExportRecord(
conversation_id=str(row.conversation_id),
message_id=message_id,
query=row.query,
answer=row.answer,
inputs=row._inputs if isinstance(row._inputs, dict) else {},
retriever_resources=retriever_resources,
feedback=feedbacks_map.get(message_id, []),
)

View File

@@ -358,10 +358,9 @@ class TestFeatureService:
assert result is not None
assert isinstance(result, SystemFeatureModel)
# --- 1. Verify Response Payload Optimization (Data Minimization) ---
# Ensure only essential UI flags are returned to unauthenticated clients
# to keep the payload lightweight and adhere to architectural boundaries.
assert result.license.status == LicenseStatus.NONE
# --- 1. Verify only license *status* is exposed to unauthenticated clients ---
# Detailed license info (expiry, workspaces) remains auth-gated.
assert result.license.status == LicenseStatus.ACTIVE
assert result.license.expired_at == ""
assert result.license.workspaces.enabled is False
assert result.license.workspaces.limit == 0

View File

@@ -0,0 +1,233 @@
import datetime
import json
import uuid
from decimal import Decimal
import pytest
from sqlalchemy.orm import Session
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.model import (
App,
AppAnnotationHitHistory,
Conversation,
DatasetRetrieverResource,
Message,
MessageAgentThought,
MessageAnnotation,
MessageChain,
MessageFeedback,
MessageFile,
)
from models.web import SavedMessage
from services.retention.conversation.message_export_service import AppMessageExportService, AppMessageExportStats
class TestAppMessageExportServiceIntegration:
@pytest.fixture(autouse=True)
def cleanup_database(self, db_session_with_containers: Session):
yield
db_session_with_containers.query(DatasetRetrieverResource).delete()
db_session_with_containers.query(AppAnnotationHitHistory).delete()
db_session_with_containers.query(SavedMessage).delete()
db_session_with_containers.query(MessageFile).delete()
db_session_with_containers.query(MessageAgentThought).delete()
db_session_with_containers.query(MessageChain).delete()
db_session_with_containers.query(MessageAnnotation).delete()
db_session_with_containers.query(MessageFeedback).delete()
db_session_with_containers.query(Message).delete()
db_session_with_containers.query(Conversation).delete()
db_session_with_containers.query(App).delete()
db_session_with_containers.query(TenantAccountJoin).delete()
db_session_with_containers.query(Tenant).delete()
db_session_with_containers.query(Account).delete()
db_session_with_containers.commit()
@staticmethod
def _create_app_context(session: Session) -> tuple[App, Conversation]:
account = Account(
email=f"test-{uuid.uuid4()}@example.com",
name="tester",
interface_language="en-US",
status="active",
)
session.add(account)
session.flush()
tenant = Tenant(name=f"tenant-{uuid.uuid4()}", status="normal")
session.add(tenant)
session.flush()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
session.add(join)
session.flush()
app = App(
tenant_id=tenant.id,
name="export-app",
description="integration test app",
mode="chat",
enable_site=True,
enable_api=True,
api_rpm=60,
api_rph=3600,
is_demo=False,
is_public=False,
created_by=account.id,
updated_by=account.id,
)
session.add(app)
session.flush()
conversation = Conversation(
app_id=app.id,
app_model_config_id=str(uuid.uuid4()),
model_provider="openai",
model_id="gpt-4o-mini",
mode="chat",
name="conv",
inputs={"seed": 1},
status="normal",
from_source="api",
from_end_user_id=str(uuid.uuid4()),
)
session.add(conversation)
session.commit()
return app, conversation
@staticmethod
def _create_message(
session: Session,
app: App,
conversation: Conversation,
created_at: datetime.datetime,
*,
query: str,
answer: str,
inputs: dict,
message_metadata: str | None,
) -> Message:
message = Message(
app_id=app.id,
conversation_id=conversation.id,
model_provider="openai",
model_id="gpt-4o-mini",
inputs=inputs,
query=query,
answer=answer,
message=[{"role": "assistant", "content": answer}],
message_tokens=10,
message_unit_price=Decimal("0.001"),
answer_tokens=20,
answer_unit_price=Decimal("0.002"),
total_price=Decimal("0.003"),
currency="USD",
message_metadata=message_metadata,
from_source="api",
from_end_user_id=conversation.from_end_user_id,
created_at=created_at,
)
session.add(message)
session.flush()
return message
def test_iter_records_with_stats(self, db_session_with_containers: Session):
app, conversation = self._create_app_context(db_session_with_containers)
first_inputs = {
"plain": "v1",
"nested": {"a": 1, "b": [1, {"x": True}]},
"list": ["x", 2, {"y": "z"}],
}
second_inputs = {"other": "value", "items": [1, 2, 3]}
base_time = datetime.datetime(2026, 2, 25, 10, 0, 0)
first_message = self._create_message(
db_session_with_containers,
app,
conversation,
created_at=base_time,
query="q1",
answer="a1",
inputs=first_inputs,
message_metadata=json.dumps({"retriever_resources": [{"dataset_id": "ds-1"}]}),
)
second_message = self._create_message(
db_session_with_containers,
app,
conversation,
created_at=base_time + datetime.timedelta(minutes=1),
query="q2",
answer="a2",
inputs=second_inputs,
message_metadata=None,
)
user_feedback_1 = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="user",
content="first",
from_end_user_id=conversation.from_end_user_id,
)
user_feedback_2 = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="dislike",
from_source="user",
content="second",
from_end_user_id=conversation.from_end_user_id,
)
admin_feedback = MessageFeedback(
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="admin",
content="should-be-filtered",
from_account_id=str(uuid.uuid4()),
)
db_session_with_containers.add_all([user_feedback_1, user_feedback_2, admin_feedback])
user_feedback_1.created_at = base_time + datetime.timedelta(minutes=2)
user_feedback_2.created_at = base_time + datetime.timedelta(minutes=3)
admin_feedback.created_at = base_time + datetime.timedelta(minutes=4)
db_session_with_containers.commit()
service = AppMessageExportService(
app_id=app.id,
start_from=base_time - datetime.timedelta(minutes=1),
end_before=base_time + datetime.timedelta(minutes=10),
filename="unused",
batch_size=1,
dry_run=True,
)
stats = AppMessageExportStats()
records = list(service._iter_records_with_stats(stats))
service._finalize_stats(stats)
assert len(records) == 2
assert records[0].message_id == first_message.id
assert records[1].message_id == second_message.id
assert records[0].inputs == first_inputs
assert records[1].inputs == second_inputs
assert records[0].retriever_resources == [{"dataset_id": "ds-1"}]
assert records[1].retriever_resources == []
assert [feedback.rating for feedback in records[0].feedback] == ["like", "dislike"]
assert [feedback.content for feedback in records[0].feedback] == ["first", "second"]
assert records[1].feedback == []
assert stats.batches == 2
assert stats.total_messages == 2
assert stats.messages_with_feedback == 1
assert stats.total_feedbacks == 2

View File

@@ -0,0 +1,425 @@
"""
Unit tests for EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response method.
This test suite ensures that the files array is correctly populated in the message_end
SSE event, which is critical for vision/image chat responses to render correctly.
Test Coverage:
- Files array populated when MessageFile records exist
- Files array is None when no MessageFile records exist
- Correct signed URL generation for LOCAL_FILE transfer method
- Correct URL handling for REMOTE_URL transfer method
- Correct URL handling for TOOL_FILE transfer method
- Proper file metadata formatting (filename, mime_type, size, extension)
"""
import uuid
from unittest.mock import MagicMock, Mock, patch
import pytest
from sqlalchemy.orm import Session
from core.app.entities.task_entities import MessageEndStreamResponse
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from dify_graph.file.enums import FileTransferMethod
from models.model import MessageFile, UploadFile
class TestMessageEndStreamResponseFiles:
"""Test suite for files array population in message_end SSE event."""
@pytest.fixture
def mock_pipeline(self):
"""Create a mock EasyUIBasedGenerateTaskPipeline instance."""
pipeline = Mock(spec=EasyUIBasedGenerateTaskPipeline)
pipeline._message_id = str(uuid.uuid4())
pipeline._task_state = Mock()
pipeline._task_state.metadata = Mock()
pipeline._task_state.metadata.model_dump = Mock(return_value={"test": "metadata"})
pipeline._task_state.llm_result = Mock()
pipeline._task_state.llm_result.usage = Mock()
pipeline._application_generate_entity = Mock()
pipeline._application_generate_entity.task_id = str(uuid.uuid4())
return pipeline
@pytest.fixture
def mock_message_file_local(self):
"""Create a mock MessageFile with LOCAL_FILE transfer method."""
message_file = Mock(spec=MessageFile)
message_file.id = str(uuid.uuid4())
message_file.message_id = str(uuid.uuid4())
message_file.transfer_method = FileTransferMethod.LOCAL_FILE
message_file.upload_file_id = str(uuid.uuid4())
message_file.url = None
message_file.type = "image"
return message_file
@pytest.fixture
def mock_message_file_remote(self):
"""Create a mock MessageFile with REMOTE_URL transfer method."""
message_file = Mock(spec=MessageFile)
message_file.id = str(uuid.uuid4())
message_file.message_id = str(uuid.uuid4())
message_file.transfer_method = FileTransferMethod.REMOTE_URL
message_file.upload_file_id = None
message_file.url = "https://example.com/image.jpg"
message_file.type = "image"
return message_file
@pytest.fixture
def mock_message_file_tool(self):
"""Create a mock MessageFile with TOOL_FILE transfer method."""
message_file = Mock(spec=MessageFile)
message_file.id = str(uuid.uuid4())
message_file.message_id = str(uuid.uuid4())
message_file.transfer_method = FileTransferMethod.TOOL_FILE
message_file.upload_file_id = None
message_file.url = "tool_file_123.png"
message_file.type = "image"
return message_file
@pytest.fixture
def mock_upload_file(self, mock_message_file_local):
"""Create a mock UploadFile."""
upload_file = Mock(spec=UploadFile)
upload_file.id = mock_message_file_local.upload_file_id
upload_file.name = "test_image.png"
upload_file.mime_type = "image/png"
upload_file.size = 1024
upload_file.extension = "png"
return upload_file
def test_message_end_with_no_files(self, mock_pipeline):
"""Test that files array is None when no MessageFile records exist."""
# Arrange
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
mock_session.scalars.return_value.all.return_value = []
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is None
assert result.id == mock_pipeline._message_id
assert result.metadata == {"test": "metadata"}
def test_message_end_with_local_file(self, mock_pipeline, mock_message_file_local, mock_upload_file):
"""Test that files array is populated correctly for LOCAL_FILE transfer method."""
# Arrange
mock_message_file_local.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
# First query: MessageFile
mock_message_files_result = Mock()
mock_message_files_result.all.return_value = [mock_message_file_local]
# Second query: UploadFile (batch query to avoid N+1)
mock_upload_files_result = Mock()
mock_upload_files_result.all.return_value = [mock_upload_file]
# Setup scalars to return different results for different queries
call_count = [0] # Use list to allow modification in nested function
def scalars_side_effect(query):
call_count[0] += 1
# First call is for MessageFile, second call is for UploadFile
if call_count[0] == 1:
return mock_message_files_result
else:
return mock_upload_files_result
mock_session.scalars.side_effect = scalars_side_effect
mock_get_url.return_value = "https://example.com/signed-url?signature=abc123"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert file_dict["related_id"] == mock_message_file_local.id
assert file_dict["filename"] == "test_image.png"
assert file_dict["mime_type"] == "image/png"
assert file_dict["size"] == 1024
assert file_dict["extension"] == ".png"
assert file_dict["type"] == "image"
assert file_dict["transfer_method"] == FileTransferMethod.LOCAL_FILE.value
assert "https://example.com/signed-url" in file_dict["url"]
assert file_dict["upload_file_id"] == mock_message_file_local.upload_file_id
assert file_dict["remote_url"] == ""
# Verify database queries
# Should be called twice: once for MessageFile, once for UploadFile
assert mock_session.scalars.call_count == 2
mock_get_url.assert_called_once_with(upload_file_id=str(mock_upload_file.id))
def test_message_end_with_remote_url(self, mock_pipeline, mock_message_file_remote):
"""Test that files array is populated correctly for REMOTE_URL transfer method."""
# Arrange
mock_message_file_remote.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_remote]
mock_session.scalars.return_value = mock_scalars_result
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert file_dict["related_id"] == mock_message_file_remote.id
assert file_dict["filename"] == "image.jpg"
assert file_dict["url"] == "https://example.com/image.jpg"
assert file_dict["extension"] == ".jpg"
assert file_dict["type"] == "image"
assert file_dict["transfer_method"] == FileTransferMethod.REMOTE_URL.value
assert file_dict["remote_url"] == "https://example.com/image.jpg"
assert file_dict["upload_file_id"] == mock_message_file_remote.id
# Verify only one query for message_files is made
mock_session.scalars.assert_called_once()
def test_message_end_with_tool_file_http(self, mock_pipeline, mock_message_file_tool):
"""Test that files array is populated correctly for TOOL_FILE with HTTP URL."""
# Arrange
mock_message_file_tool.message_id = mock_pipeline._message_id
mock_message_file_tool.url = "https://example.com/tool_file.png"
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_tool]
mock_session.scalars.return_value = mock_scalars_result
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert file_dict["url"] == "https://example.com/tool_file.png"
assert file_dict["filename"] == "tool_file.png"
assert file_dict["extension"] == ".png"
assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value
def test_message_end_with_tool_file_local(self, mock_pipeline, mock_message_file_tool):
"""Test that files array is populated correctly for TOOL_FILE with local path."""
# Arrange
mock_message_file_tool.message_id = mock_pipeline._message_id
mock_message_file_tool.url = "tool_file_123.png"
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_tool]
mock_session.scalars.return_value = mock_scalars_result
mock_sign_tool.return_value = "https://example.com/signed-tool-file.png?signature=xyz"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert "https://example.com/signed-tool-file.png" in file_dict["url"]
assert file_dict["filename"] == "tool_file_123.png"
assert file_dict["extension"] == ".png"
assert file_dict["transfer_method"] == FileTransferMethod.TOOL_FILE.value
# Verify tool file signing was called
mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_123", extension=".png")
def test_message_end_with_tool_file_long_extension(self, mock_pipeline, mock_message_file_tool):
"""Test that TOOL_FILE extensions longer than MAX_TOOL_FILE_EXTENSION_LENGTH fall back to .bin."""
mock_message_file_tool.message_id = mock_pipeline._message_id
mock_message_file_tool.url = "tool_file_abc.verylongextension"
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.sign_tool_file") as mock_sign_tool,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
mock_scalars_result = Mock()
mock_scalars_result.all.return_value = [mock_message_file_tool]
mock_session.scalars.return_value = mock_scalars_result
mock_sign_tool.return_value = "https://example.com/signed.bin"
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
assert result.files is not None
file_dict = result.files[0]
assert file_dict["extension"] == ".bin"
mock_sign_tool.assert_called_once_with(tool_file_id="tool_file_abc", extension=".bin")
def test_message_end_with_multiple_files(
self, mock_pipeline, mock_message_file_local, mock_message_file_remote, mock_upload_file
):
"""Test that files array contains all MessageFile records when multiple exist."""
# Arrange
mock_message_file_local.message_id = mock_pipeline._message_id
mock_message_file_remote.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
# First query: MessageFile
mock_message_files_result = Mock()
mock_message_files_result.all.return_value = [mock_message_file_local, mock_message_file_remote]
# Second query: UploadFile (batch query to avoid N+1)
mock_upload_files_result = Mock()
mock_upload_files_result.all.return_value = [mock_upload_file]
# Setup scalars to return different results for different queries
call_count = [0] # Use list to allow modification in nested function
def scalars_side_effect(query):
call_count[0] += 1
# First call is for MessageFile, second call is for UploadFile
if call_count[0] == 1:
return mock_message_files_result
else:
return mock_upload_files_result
mock_session.scalars.side_effect = scalars_side_effect
mock_get_url.return_value = "https://example.com/signed-url?signature=abc123"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 2
# Verify both files are present
file_ids = [f["related_id"] for f in result.files]
assert mock_message_file_local.id in file_ids
assert mock_message_file_remote.id in file_ids
def test_message_end_with_local_file_no_upload_file(self, mock_pipeline, mock_message_file_local):
"""Test fallback when UploadFile is not found for LOCAL_FILE."""
# Arrange
mock_message_file_local.message_id = mock_pipeline._message_id
with (
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.db") as mock_db,
patch("core.app.task_pipeline.easy_ui_based_generate_task_pipeline.Session") as mock_session_class,
patch("core.app.task_pipeline.message_file_utils.file_helpers.get_signed_file_url") as mock_get_url,
):
mock_engine = MagicMock()
mock_db.engine = mock_engine
mock_session = MagicMock(spec=Session)
mock_session_class.return_value.__enter__.return_value = mock_session
# Mock database queries
# First query: MessageFile
mock_message_files_result = Mock()
mock_message_files_result.all.return_value = [mock_message_file_local]
# Second query: UploadFile (batch query) - returns empty list (not found)
mock_upload_files_result = Mock()
mock_upload_files_result.all.return_value = [] # UploadFile not found
# Setup scalars to return different results for different queries
call_count = [0] # Use list to allow modification in nested function
def scalars_side_effect(query):
call_count[0] += 1
# First call is for MessageFile, second call is for UploadFile
if call_count[0] == 1:
return mock_message_files_result
else:
return mock_upload_files_result
mock_session.scalars.side_effect = scalars_side_effect
mock_get_url.return_value = "https://example.com/fallback-url?signature=def456"
# Act
result = EasyUIBasedGenerateTaskPipeline._message_end_to_stream_response(mock_pipeline)
# Assert
assert isinstance(result, MessageEndStreamResponse)
assert result.files is not None
assert len(result.files) == 1
file_dict = result.files[0]
assert "https://example.com/fallback-url" in file_dict["url"]
# Verify fallback URL was generated using upload_file_id from message_file
mock_get_url.assert_called_with(upload_file_id=str(mock_message_file_local.upload_file_id))

View File

@@ -1,9 +1,8 @@
"""Unit tests for enterprise service integrations.
This module covers the enterprise-only default workspace auto-join behavior:
- Enterprise mode disabled: no external calls
- Successful join / skipped join: no errors
- Failures (network/invalid response/invalid UUID): soft-fail wrapper must not raise
Covers:
- Default workspace auto-join behavior
- License status caching (get_cached_license_status)
"""
from unittest.mock import patch
@@ -11,6 +10,9 @@ from unittest.mock import patch
import pytest
from services.enterprise.enterprise_service import (
INVALID_LICENSE_CACHE_TTL,
LICENSE_STATUS_CACHE_KEY,
VALID_LICENSE_CACHE_TTL,
DefaultWorkspaceJoinResult,
EnterpriseService,
try_join_default_workspace,
@@ -37,7 +39,6 @@ class TestJoinDefaultWorkspace:
"/default-workspace/members",
json={"account_id": account_id},
timeout=1.0,
raise_for_status=True,
)
def test_join_default_workspace_invalid_response_format_raises(self):
@@ -139,3 +140,134 @@ class TestTryJoinDefaultWorkspace:
# Should not raise even though UUID parsing fails inside join_default_workspace
try_join_default_workspace("not-a-uuid")
# ---------------------------------------------------------------------------
# get_cached_license_status
# ---------------------------------------------------------------------------
_EE_SVC = "services.enterprise.enterprise_service"
class TestGetCachedLicenseStatus:
"""Tests for EnterpriseService.get_cached_license_status."""
def test_returns_none_when_enterprise_disabled(self):
with patch(f"{_EE_SVC}.dify_config") as mock_config:
mock_config.ENTERPRISE_ENABLED = False
assert EnterpriseService.get_cached_license_status() is None
def test_cache_hit_returns_license_status_enum(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = b"active"
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.ACTIVE
assert isinstance(result, LicenseStatus)
mock_get_info.assert_not_called()
def test_cache_miss_fetches_api_and_caches_valid_status(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.return_value = {"License": {"status": "active"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.ACTIVE
mock_redis.setex.assert_called_once_with(
LICENSE_STATUS_CACHE_KEY, VALID_LICENSE_CACHE_TTL, LicenseStatus.ACTIVE
)
def test_cache_miss_fetches_api_and_caches_invalid_status_with_short_ttl(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.return_value = {"License": {"status": "expired"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.EXPIRED
mock_redis.setex.assert_called_once_with(
LICENSE_STATUS_CACHE_KEY, INVALID_LICENSE_CACHE_TTL, LicenseStatus.EXPIRED
)
def test_redis_read_failure_falls_through_to_api(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.side_effect = ConnectionError("redis down")
mock_get_info.return_value = {"License": {"status": "active"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.ACTIVE
mock_get_info.assert_called_once()
def test_redis_write_failure_still_returns_status(self):
from services.feature_service import LicenseStatus
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_redis.setex.side_effect = ConnectionError("redis down")
mock_get_info.return_value = {"License": {"status": "expiring"}}
result = EnterpriseService.get_cached_license_status()
assert result == LicenseStatus.EXPIRING
def test_api_failure_returns_none(self):
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.side_effect = Exception("network failure")
assert EnterpriseService.get_cached_license_status() is None
def test_api_returns_no_license_info(self):
with (
patch(f"{_EE_SVC}.dify_config") as mock_config,
patch(f"{_EE_SVC}.redis_client") as mock_redis,
patch.object(EnterpriseService, "get_info") as mock_get_info,
):
mock_config.ENTERPRISE_ENABLED = True
mock_redis.get.return_value = None
mock_get_info.return_value = {} # no "License" key
assert EnterpriseService.get_cached_license_status() is None
mock_redis.setex.assert_not_called()

View File

@@ -0,0 +1,43 @@
import datetime
import pytest
from services.retention.conversation.message_export_service import AppMessageExportService
def test_validate_export_filename_accepts_relative_path():
assert AppMessageExportService.validate_export_filename("exports/2026/test01") == "exports/2026/test01"
@pytest.mark.parametrize(
"filename",
[
"test01.jsonl.gz",
"test01.jsonl",
"test01.gz",
"/tmp/test01",
"exports/../test01",
"bad\x00name",
"bad\tname",
"a" * 1025,
],
)
def test_validate_export_filename_rejects_invalid_values(filename: str):
with pytest.raises(ValueError):
AppMessageExportService.validate_export_filename(filename)
def test_service_derives_output_names_from_filename_base():
service = AppMessageExportService(
app_id="736b9b03-20f2-4697-91da-8d00f6325900",
start_from=None,
end_before=datetime.datetime(2026, 3, 1),
filename="exports/2026/test01",
batch_size=1000,
use_cloud_storage=True,
dry_run=True,
)
assert service._filename_base == "exports/2026/test01"
assert service.output_gz_name == "exports/2026/test01.jsonl.gz"
assert service.output_jsonl_name == "exports/2026/test01.jsonl"

View File

@@ -1,5 +1,6 @@
'use client'
import type { ReactNode } from 'react'
import type { IToastProps } from './context'
import { noop } from 'es-toolkit/function'
import * as React from 'react'
import { useEffect, useState } from 'react'