Compare commits

..

10 Commits

Author SHA1 Message Date
Yansong Zhang
6473c1419b Merge remote-tracking branch 'origin/main' into feat/notification
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
2026-03-02 14:58:00 +08:00
Yansong Zhang
d1a0b9695c fix linter 2026-03-02 14:55:24 +08:00
Yansong Zhang
3147e44a0b add notification use saas 2026-03-02 14:55:15 +08:00
Varun Chawla
9ddbc1c0fb fix: map all NodeType values to span kinds in Arize Phoenix tracing (#32059)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Crazywoola <100913391+crazywoola@users.noreply.github.com>
2026-03-02 14:54:26 +08:00
Stable Genius
42a8d962a0 refactor: remove tests and core/rag from pyrefly excludes (#32801)
Co-authored-by: Stable Genius <259448942+stablegenius49@users.noreply.github.com>
2026-03-02 15:31:29 +09:00
ふるい
8af110a87e refactor: use unified diff format in pyrefly-diff workflow (#32828) 2026-03-02 15:28:12 +09:00
wangxiaolei
cc127f5b62 fix: fix chat assistant response mode blocking is not work (#32394)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-02 14:05:04 +08:00
autofix-ci[bot]
c243e91668 [autofix.ci] apply automated fixes 2026-02-10 08:30:13 +00:00
Yansong Zhang
004fbbe52b add notification logic for backend 2026-02-10 16:13:06 +08:00
Yansong Zhang
63fb0ddde5 add notification logic for backend 2026-02-10 16:12:59 +08:00
12 changed files with 486 additions and 296 deletions

View File

@@ -2598,29 +2598,15 @@ def migrate_oss(
@click.option(
"--start-from",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=False,
default=None,
required=True,
help="Lower bound (inclusive) for created_at.",
)
@click.option(
"--end-before",
type=click.DateTime(formats=["%Y-%m-%d", "%Y-%m-%dT%H:%M:%S"]),
required=False,
default=None,
required=True,
help="Upper bound (exclusive) for created_at.",
)
@click.option(
"--from-days-ago",
type=int,
default=None,
help="Relative lower bound in days ago (inclusive). Must be used with --before-days.",
)
@click.option(
"--before-days",
type=int,
default=None,
help="Relative upper bound in days ago (exclusive). Required for relative mode.",
)
@click.option("--batch-size", default=1000, show_default=True, help="Batch size for selecting messages.")
@click.option(
"--graceful-period",
@@ -2632,10 +2618,8 @@ def migrate_oss(
def clean_expired_messages(
batch_size: int,
graceful_period: int,
start_from: datetime.datetime | None,
end_before: datetime.datetime | None,
from_days_ago: int | None,
before_days: int | None,
start_from: datetime.datetime,
end_before: datetime.datetime,
dry_run: bool,
):
"""
@@ -2646,64 +2630,18 @@ def clean_expired_messages(
start_at = time.perf_counter()
try:
abs_mode = start_from is not None and end_before is not None
rel_mode = before_days is not None
if abs_mode and rel_mode:
raise click.UsageError(
"Options are mutually exclusive: use either (--start-from,--end-before) "
"or (--from-days-ago,--before-days)."
)
if from_days_ago is not None and before_days is None:
raise click.UsageError("--from-days-ago must be used together with --before-days.")
if (start_from is None) ^ (end_before is None):
raise click.UsageError("Both --start-from and --end-before are required when using absolute time range.")
if not abs_mode and not rel_mode:
raise click.UsageError(
"You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])."
)
if rel_mode:
if before_days < 0:
raise click.UsageError("--before-days must be >= 0.")
if from_days_ago is not None:
if from_days_ago < 0:
raise click.UsageError("--from-days-ago must be >= 0.")
if from_days_ago <= before_days:
raise click.UsageError("--from-days-ago must be greater than --before-days.")
# Create policy based on billing configuration
# NOTE: graceful_period will be ignored when billing is disabled.
policy = create_message_clean_policy(graceful_period_days=graceful_period)
# Create and run the cleanup service
if abs_mode:
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
)
elif from_days_ago is None:
service = MessagesCleanService.from_days(
policy=policy,
days=before_days,
batch_size=batch_size,
dry_run=dry_run,
)
else:
now = datetime.datetime.now()
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=now - datetime.timedelta(days=from_days_ago),
end_before=now - datetime.timedelta(days=before_days),
batch_size=batch_size,
dry_run=dry_run,
)
service = MessagesCleanService.from_time_range(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=batch_size,
dry_run=dry_run,
)
stats = service.run()
end_at = time.perf_counter()

View File

@@ -39,6 +39,7 @@ from . import (
feature,
human_input_form,
init_validate,
notification,
ping,
setup,
spec,
@@ -184,6 +185,7 @@ __all__ = [
"model_config",
"model_providers",
"models",
"notification",
"oauth",
"oauth_server",
"ops_trace",

View File

@@ -1,3 +1,5 @@
import csv
import io
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
@@ -6,7 +8,7 @@ from flask import request
from flask_restx import Resource
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from werkzeug.exceptions import NotFound, Unauthorized
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
from configs import dify_config
from constants.languages import supported_language
@@ -16,6 +18,7 @@ from core.db.session_factory import session_factory
from extensions.ext_database import db
from libs.token import extract_access_token
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
from services.billing_service import BillingService
P = ParamSpec("P")
R = TypeVar("R")
@@ -277,3 +280,170 @@ class DeleteExploreBannerApi(Resource):
db.session.commit()
return {"result": "success"}, 204
class LangContentPayload(BaseModel):
lang: str = Field(..., description="Language tag: 'zh' | 'en' | 'jp'")
title: str = Field(...)
body: str = Field(...)
cta_label: str = Field(...)
cta_url: str = Field(...)
class UpsertNotificationPayload(BaseModel):
notification_id: str | None = Field(default=None, description="Omit to create; supply UUID to update")
contents: list[LangContentPayload] = Field(..., min_length=1)
start_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-01T00:00:00Z")
end_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-20T23:59:59Z")
frequency: str = Field(default="once", description="'once' | 'every_page_load'")
status: str = Field(default="active", description="'active' | 'inactive'")
class BatchAddNotificationAccountsPayload(BaseModel):
notification_id: str = Field(...)
user_email: list[str] = Field(..., description="List of account email addresses")
console_ns.schema_model(
UpsertNotificationPayload.__name__,
UpsertNotificationPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
console_ns.schema_model(
BatchAddNotificationAccountsPayload.__name__,
BatchAddNotificationAccountsPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
)
@console_ns.route("/admin/upsert_notification")
class UpsertNotificationApi(Resource):
@console_ns.doc("upsert_notification")
@console_ns.doc(
description=(
"Create or update an in-product notification. "
"Supply notification_id to update an existing one; omit it to create a new one. "
"Pass at least one language variant in contents (zh / en / jp)."
)
)
@console_ns.expect(console_ns.models[UpsertNotificationPayload.__name__])
@console_ns.response(200, "Notification upserted successfully")
@only_edition_cloud
@admin_required
def post(self):
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
result = BillingService.upsert_notification(
contents=[c.model_dump() for c in payload.contents],
frequency=payload.frequency,
status=payload.status,
notification_id=payload.notification_id,
start_time=payload.start_time,
end_time=payload.end_time,
)
return {"result": "success", "notification_id": result.get("notificationId")}, 200
@console_ns.route("/admin/batch_add_notification_accounts")
class BatchAddNotificationAccountsApi(Resource):
@console_ns.doc("batch_add_notification_accounts")
@console_ns.doc(
description=(
"Register target accounts for a notification by email address. "
'JSON body: {"notification_id": "...", "user_email": ["a@example.com", ...]}. '
"File upload: multipart/form-data with a 'file' field (CSV or TXT, one email per line) "
"plus a 'notification_id' field. "
"Emails that do not match any account are silently skipped."
)
)
@console_ns.response(200, "Accounts added successfully")
@only_edition_cloud
@admin_required
def post(self):
from models.account import Account
if "file" in request.files:
notification_id = request.form.get("notification_id", "").strip()
if not notification_id:
raise BadRequest("notification_id is required.")
emails = self._parse_emails_from_file()
else:
payload = BatchAddNotificationAccountsPayload.model_validate(console_ns.payload)
notification_id = payload.notification_id
emails = payload.user_email
if not emails:
raise BadRequest("No valid email addresses provided.")
# Resolve emails → account IDs in chunks to avoid large IN-clause
account_ids: list[str] = []
chunk_size = 500
for i in range(0, len(emails), chunk_size):
chunk = emails[i : i + chunk_size]
rows = db.session.execute(
select(Account.id, Account.email).where(Account.email.in_(chunk))
).all()
account_ids.extend(str(row.id) for row in rows)
if not account_ids:
raise BadRequest("None of the provided emails matched an existing account.")
# Send to dify-saas in batches of 1000
total_count = 0
batch_size = 1000
for i in range(0, len(account_ids), batch_size):
batch = account_ids[i : i + batch_size]
result = BillingService.batch_add_notification_accounts(
notification_id=notification_id,
account_ids=batch,
)
total_count += result.get("count", 0)
return {
"result": "success",
"emails_provided": len(emails),
"accounts_matched": len(account_ids),
"count": total_count,
}, 200
@staticmethod
def _parse_emails_from_file() -> list[str]:
"""Parse email addresses from an uploaded CSV or TXT file."""
file = request.files["file"]
if not file.filename:
raise BadRequest("Uploaded file has no filename.")
filename_lower = file.filename.lower()
if not filename_lower.endswith((".csv", ".txt")):
raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.")
try:
content = file.read().decode("utf-8")
except UnicodeDecodeError:
try:
file.seek(0)
content = file.read().decode("gbk")
except UnicodeDecodeError:
raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.")
emails: list[str] = []
if filename_lower.endswith(".csv"):
reader = csv.reader(io.StringIO(content))
for row in reader:
for cell in row:
cell = cell.strip()
if cell:
emails.append(cell)
else:
for line in content.splitlines():
line = line.strip()
if line:
emails.append(line)
# Deduplicate while preserving order
seen: set[str] = set()
unique_emails: list[str] = []
for email in emails:
if email.lower() not in seen:
seen.add(email.lower())
unique_emails.append(email)
return unique_emails

View File

@@ -0,0 +1,80 @@
from flask_restx import Resource
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
from libs.login import current_account_with_tenant, login_required
from services.billing_service import BillingService
# Notification content is stored under three lang tags.
_FALLBACK_LANG = "en"
# Maps dify interface_language prefixes to notification lang tags.
# Any unrecognised prefix falls back to _FALLBACK_LANG.
_LANG_MAP: dict[str, str] = {
"zh": "zh",
"ja": "jp",
}
def _resolve_lang(interface_language: str | None) -> str:
"""Derive the notification lang tag from the user's interface_language.
e.g. "zh-Hans""zh", "ja-JP""jp", "en-US" / None → "en"
"""
if not interface_language:
return _FALLBACK_LANG
prefix = interface_language.split("-")[0].lower()
return _LANG_MAP.get(prefix, _FALLBACK_LANG)
def _pick_lang_content(contents: dict, lang: str) -> dict:
"""Return the single LangContent for *lang*, falling back to English."""
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
@console_ns.route("/notification")
class NotificationApi(Resource):
@console_ns.doc("get_notification")
@console_ns.doc(
description=(
"Return the active in-product notification for the current user "
"in their interface language (falls back to English if unavailable). "
"Calling this endpoint also marks the notification as seen; subsequent "
"calls return should_show=false when frequency='once'."
),
responses={
200: "Success — inspect should_show to decide whether to render the modal",
401: "Unauthorized",
},
)
@setup_required
@login_required
@account_initialization_required
@only_edition_cloud
def get(self):
current_user, _ = current_account_with_tenant()
result = BillingService.get_account_notification(str(current_user.id))
# Proto JSON uses camelCase field names (Kratos default marshaling).
if not result.get("shouldShow"):
return {"should_show": False}, 200
notification = result.get("notification") or {}
contents: dict = notification.get("contents") or {}
lang = _resolve_lang(current_user.interface_language)
lang_content = _pick_lang_content(contents, lang)
return {
"should_show": True,
"notification": {
"notification_id": notification.get("notificationId"),
"frequency": notification.get("frequency"),
"lang": lang_content.get("lang", lang),
"title": lang_content.get("title", ""),
"body": lang_content.get("body", ""),
"cta_label": lang_content.get("ctaLabel", ""),
"cta_url": lang_content.get("ctaUrl", ""),
},
}, 200

View File

@@ -157,7 +157,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
id=self._message_id,
mode=self._conversation_mode,
message_id=self._message_id,
answer=cast(str, self._task_state.llm_result.message.content),
answer=self._task_state.llm_result.message.get_text_content(),
created_at=self._message_created_at,
**extras,
),
@@ -170,7 +170,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
mode=self._conversation_mode,
conversation_id=self._conversation_id,
message_id=self._message_id,
answer=cast(str, self._task_state.llm_result.message.content),
answer=self._task_state.llm_result.message.get_text_content(),
created_at=self._message_created_at,
**extras,
),
@@ -283,7 +283,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
# handle output moderation
output_moderation_answer = self.handle_output_moderation_when_task_finished(
cast(str, self._task_state.llm_result.message.content)
self._task_state.llm_result.message.get_text_content()
)
if output_moderation_answer:
self._task_state.llm_result.message.content = output_moderation_answer
@@ -397,7 +397,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.message_unit_price = usage.prompt_unit_price
message.message_price_unit = usage.prompt_price_unit
message.answer = (
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
PromptTemplateParser.remove_template_variables(llm_result.message.get_text_content().strip())
if llm_result.message.content
else ""
)

View File

@@ -155,6 +155,26 @@ def wrap_span_metadata(metadata, **kwargs):
return metadata
# Mapping from NodeType string values to OpenInference span kinds.
# NodeType values not listed here default to CHAIN.
_NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = {
"llm": OpenInferenceSpanKindValues.LLM,
"knowledge-retrieval": OpenInferenceSpanKindValues.RETRIEVER,
"tool": OpenInferenceSpanKindValues.TOOL,
"agent": OpenInferenceSpanKindValues.AGENT,
}
def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues:
"""Return the OpenInference span kind for a given workflow node type.
Covers every ``NodeType`` enum value. Nodes that do not have a
specialised span kind (e.g. ``start``, ``end``, ``if-else``,
``code``, ``loop``, ``iteration``, etc.) are mapped to ``CHAIN``.
"""
return _NODE_TYPE_TO_SPAN_KIND.get(node_type, OpenInferenceSpanKindValues.CHAIN)
class ArizePhoenixDataTrace(BaseTraceInstance):
def __init__(
self,
@@ -289,9 +309,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
)
# Determine the correct span kind based on node type
span_kind = OpenInferenceSpanKindValues.CHAIN
span_kind = _get_node_span_kind(node_execution.node_type)
if node_execution.node_type == "llm":
span_kind = OpenInferenceSpanKindValues.LLM
provider = process_data.get("model_provider")
model = process_data.get("model_name")
if provider:
@@ -306,12 +325,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
elif node_execution.node_type == "dataset_retrieval":
span_kind = OpenInferenceSpanKindValues.RETRIEVER
elif node_execution.node_type == "tool":
span_kind = OpenInferenceSpanKindValues.TOOL
else:
span_kind = OpenInferenceSpanKindValues.CHAIN
workflow_span_context = set_span_in_context(workflow_span)
node_span = self.tracer.start_span(

View File

@@ -1,9 +1,7 @@
project-includes = ["."]
project-excludes = [
"tests/",
".venv",
"migrations/",
"core/rag",
]
python-platform = "linux"
python-version = "3.11.0"

View File

@@ -131,33 +131,54 @@ class AppGenerateService:
elif app_model.mode == AppMode.ADVANCED_CHAT:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
with rate_limit_context(rate_limit, request_id):
payload = AppExecutionParams.new(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=streaming,
call_depth=0,
)
payload_json = payload.model_dump_json()
def on_subscribe():
workflow_based_app_execution_task.delay(payload_json)
if streaming:
# Streaming mode: subscribe to SSE and enqueue the execution on first subscriber
with rate_limit_context(rate_limit, request_id):
payload = AppExecutionParams.new(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=True,
call_depth=0,
)
payload_json = payload.model_dump_json()
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
generator = AdvancedChatAppGenerator()
return rate_limit.generate(
generator.convert_to_event_stream(
generator.retrieve_events(
AppMode.ADVANCED_CHAT,
payload.workflow_run_id,
on_subscribe=on_subscribe,
def on_subscribe():
workflow_based_app_execution_task.delay(payload_json)
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
generator = AdvancedChatAppGenerator()
return rate_limit.generate(
generator.convert_to_event_stream(
generator.retrieve_events(
AppMode.ADVANCED_CHAT,
payload.workflow_run_id,
on_subscribe=on_subscribe,
),
),
),
request_id=request_id,
)
request_id=request_id,
)
else:
# Blocking mode: run synchronously and return JSON instead of SSE
# Keep behaviour consistent with WORKFLOW blocking branch.
advanced_generator = AdvancedChatAppGenerator()
return rate_limit.generate(
advanced_generator.convert_to_event_stream(
advanced_generator.generate(
app_model=app_model,
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
workflow_run_id=str(uuid.uuid4()),
streaming=False,
)
),
request_id=request_id,
)
elif app_model.mode == AppMode.WORKFLOW:
workflow_id = args.get("workflow_id")
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)

View File

@@ -393,3 +393,66 @@ class BillingService:
for item in data:
tenant_whitelist.append(item["tenant_id"])
return tenant_whitelist
@classmethod
def get_account_notification(cls, account_id: str) -> dict:
"""Return the active in-product notification for account_id, if any.
Calling this endpoint also marks the notification as seen; subsequent
calls will return should_show=false when frequency='once'.
Response shape (mirrors GetAccountNotificationReply):
{
"should_show": bool,
"notification": { # present only when should_show=true
"notification_id": str,
"contents": { # lang -> LangContent
"en": {"lang": "en", "title": ..., "body": ..., "cta_label": ..., "cta_url": ...},
...
},
"frequency": "once" | "every_page_load"
}
}
"""
return cls._send_request("GET", "/notifications/active", params={"account_id": account_id})
@classmethod
def upsert_notification(
cls,
contents: list[dict],
frequency: str = "once",
status: str = "active",
notification_id: str | None = None,
start_time: str | None = None,
end_time: str | None = None,
) -> dict:
"""Create or update a notification.
contents: list of {"lang": str, "title": str, "body": str, "cta_label": str, "cta_url": str}
start_time / end_time: RFC3339 strings (e.g. "2026-03-01T00:00:00Z"), optional.
Returns {"notification_id": str}.
"""
payload: dict = {
"contents": contents,
"frequency": frequency,
"status": status,
}
if notification_id:
payload["notification_id"] = notification_id
if start_time:
payload["start_time"] = start_time
if end_time:
payload["end_time"] = end_time
return cls._send_request("POST", "/notifications", json=payload)
@classmethod
def batch_add_notification_accounts(cls, notification_id: str, account_ids: list[str]) -> dict:
"""Register target account IDs for a notification (max 1000 per call).
Returns {"count": int}.
"""
return cls._send_request(
"POST",
f"/notifications/{notification_id}/accounts",
json={"account_ids": account_ids},
)

View File

@@ -1,184 +0,0 @@
import datetime
import re
from unittest.mock import MagicMock, patch
import click
import pytest
from commands import clean_expired_messages
def _mock_service() -> MagicMock:
service = MagicMock()
service.run.return_value = {
"batches": 1,
"total_messages": 10,
"filtered_messages": 5,
"total_deleted": 5,
}
return service
def test_absolute_mode_calls_from_time_range():
policy = object()
service = _mock_service()
start_from = datetime.datetime(2024, 1, 1, 0, 0, 0)
end_before = datetime.datetime(2024, 2, 1, 0, 0, 0)
with (
patch("commands.create_message_clean_policy", return_value=policy),
patch("commands.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range,
patch("commands.MessagesCleanService.from_days") as mock_from_days,
):
clean_expired_messages.callback(
batch_size=200,
graceful_period=21,
start_from=start_from,
end_before=end_before,
from_days_ago=None,
before_days=None,
dry_run=True,
)
mock_from_time_range.assert_called_once_with(
policy=policy,
start_from=start_from,
end_before=end_before,
batch_size=200,
dry_run=True,
)
mock_from_days.assert_not_called()
def test_relative_mode_before_days_only_calls_from_days():
policy = object()
service = _mock_service()
with (
patch("commands.create_message_clean_policy", return_value=policy),
patch("commands.MessagesCleanService.from_days", return_value=service) as mock_from_days,
patch("commands.MessagesCleanService.from_time_range") as mock_from_time_range,
):
clean_expired_messages.callback(
batch_size=500,
graceful_period=14,
start_from=None,
end_before=None,
from_days_ago=None,
before_days=30,
dry_run=False,
)
mock_from_days.assert_called_once_with(
policy=policy,
days=30,
batch_size=500,
dry_run=False,
)
mock_from_time_range.assert_not_called()
def test_relative_mode_with_from_days_ago_calls_from_time_range():
policy = object()
service = _mock_service()
fixed_now = datetime.datetime(2024, 8, 20, 12, 0, 0)
with (
patch("commands.create_message_clean_policy", return_value=policy),
patch("commands.MessagesCleanService.from_time_range", return_value=service) as mock_from_time_range,
patch("commands.MessagesCleanService.from_days") as mock_from_days,
patch("commands.datetime", autospec=True) as mock_datetime,
):
mock_datetime.datetime.now.return_value = fixed_now
mock_datetime.timedelta = datetime.timedelta
clean_expired_messages.callback(
batch_size=1000,
graceful_period=21,
start_from=None,
end_before=None,
from_days_ago=60,
before_days=30,
dry_run=False,
)
mock_from_time_range.assert_called_once_with(
policy=policy,
start_from=fixed_now - datetime.timedelta(days=60),
end_before=fixed_now - datetime.timedelta(days=30),
batch_size=1000,
dry_run=False,
)
mock_from_days.assert_not_called()
@pytest.mark.parametrize(
("kwargs", "message"),
[
(
{
"start_from": datetime.datetime(2024, 1, 1),
"end_before": datetime.datetime(2024, 2, 1),
"from_days_ago": None,
"before_days": 30,
},
"mutually exclusive",
),
(
{
"start_from": datetime.datetime(2024, 1, 1),
"end_before": None,
"from_days_ago": None,
"before_days": None,
},
"Both --start-from and --end-before are required",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": 10,
"before_days": None,
},
"--from-days-ago must be used together with --before-days",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": None,
"before_days": -1,
},
"--before-days must be >= 0",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": 30,
"before_days": 30,
},
"--from-days-ago must be greater than --before-days",
),
(
{
"start_from": None,
"end_before": None,
"from_days_ago": None,
"before_days": None,
},
"You must provide either (--start-from,--end-before) or (--before-days [--from-days-ago])",
),
],
)
def test_invalid_inputs_raise_usage_error(kwargs: dict, message: str):
with pytest.raises(click.UsageError, match=re.escape(message)):
clean_expired_messages.callback(
batch_size=1000,
graceful_period=21,
start_from=kwargs["start_from"],
end_before=kwargs["end_before"],
from_days_ago=kwargs["from_days_ago"],
before_days=kwargs["before_days"],
dry_run=False,
)

View File

@@ -0,0 +1,36 @@
from openinference.semconv.trace import OpenInferenceSpanKindValues
from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind
from core.workflow.enums import NodeType
class TestGetNodeSpanKind:
"""Tests for _get_node_span_kind helper."""
def test_all_node_types_are_mapped_correctly(self):
"""Ensure every NodeType enum member is mapped to the correct span kind."""
# Mappings for node types that have a specialised span kind.
special_mappings = {
NodeType.LLM: OpenInferenceSpanKindValues.LLM,
NodeType.KNOWLEDGE_RETRIEVAL: OpenInferenceSpanKindValues.RETRIEVER,
NodeType.TOOL: OpenInferenceSpanKindValues.TOOL,
NodeType.AGENT: OpenInferenceSpanKindValues.AGENT,
}
# Test that every NodeType enum member is mapped to the correct span kind.
# Node types not in `special_mappings` should default to CHAIN.
for node_type in NodeType:
expected_span_kind = special_mappings.get(node_type, OpenInferenceSpanKindValues.CHAIN)
actual_span_kind = _get_node_span_kind(node_type)
assert actual_span_kind == expected_span_kind, (
f"NodeType.{node_type.name} was mapped to {actual_span_kind}, but {expected_span_kind} was expected."
)
def test_unknown_string_defaults_to_chain(self):
"""An unrecognised node type string should still return CHAIN."""
assert _get_node_span_kind("some-future-node-type") == OpenInferenceSpanKindValues.CHAIN
def test_stale_dataset_retrieval_not_in_mapping(self):
"""The old 'dataset_retrieval' string was never a valid NodeType value;
make sure it is not present in the mapping dictionary."""
assert "dataset_retrieval" not in _NODE_TYPE_TO_SPAN_KIND

View File

@@ -63,3 +63,56 @@ def test_workflow_blocking_injects_pause_state_config(mocker, monkeypatch):
pause_state_config = call_kwargs.get("pause_state_config")
assert pause_state_config is not None
assert pause_state_config.state_owner_user_id == "owner-id"
def test_advanced_chat_blocking_returns_dict_and_does_not_use_event_retrieval(mocker, monkeypatch):
"""
Regression test: ADVANCED_CHAT in blocking mode should return a plain dict
(non-streaming), and must not go through the async retrieve_events path.
Keeps behavior consistent with WORKFLOW blocking branch.
"""
# Disable billing and stub RateLimit to a no-op that just passes values through
monkeypatch.setattr(app_generate_service_module.dify_config, "BILLING_ENABLED", False)
mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit)
# Arrange a fake workflow and wire AppGenerateService._get_workflow to return it
workflow = MagicMock()
workflow.id = "workflow-id"
mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow)
# Spy on the streaming retrieval path to ensure it's NOT called
retrieve_spy = mocker.patch("services.app_generate_service.AdvancedChatAppGenerator.retrieve_events")
# Make AdvancedChatAppGenerator.generate return a plain dict when streaming=False
generate_spy = mocker.patch(
"services.app_generate_service.AdvancedChatAppGenerator.generate",
return_value={"result": "ok"},
)
# Minimal app model for ADVANCED_CHAT
app_model = MagicMock()
app_model.mode = AppMode.ADVANCED_CHAT
app_model.id = "app-id"
app_model.tenant_id = "tenant-id"
app_model.max_active_requests = 0
app_model.is_agent = False
user = MagicMock()
user.id = "user-id"
# Must include query and inputs for AdvancedChatAppGenerator
args = {"workflow_id": "wf-1", "query": "hello", "inputs": {}}
# Act: call service with streaming=False (blocking mode)
result = AppGenerateService.generate(
app_model=app_model,
user=user,
args=args,
invoke_from=MagicMock(),
streaming=False,
)
# Assert: returns the dict from generate(), and did not call retrieve_events()
assert result == {"result": "ok"}
assert generate_spy.call_args.kwargs.get("streaming") is False
retrieve_spy.assert_not_called()