mirror of
https://github.com/langgenius/dify.git
synced 2026-03-02 13:35:10 +00:00
Compare commits
14 Commits
refactor/b
...
feat/notif
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9589bba713 | ||
|
|
6473c1419b | ||
|
|
d1a0b9695c | ||
|
|
3147e44a0b | ||
|
|
9ddbc1c0fb | ||
|
|
42a8d962a0 | ||
|
|
8af110a87e | ||
|
|
cc127f5b62 | ||
|
|
aca3d1900e | ||
|
|
ddc47c2f39 | ||
|
|
2dc9bc00d6 | ||
|
|
c243e91668 | ||
|
|
004fbbe52b | ||
|
|
63fb0ddde5 |
2
.github/workflows/pyrefly-diff.yml
vendored
2
.github/workflows/pyrefly-diff.yml
vendored
@@ -48,7 +48,7 @@ jobs:
|
||||
|
||||
- name: Compute diff
|
||||
run: |
|
||||
diff /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
|
||||
diff -u /tmp/pyrefly_base.txt /tmp/pyrefly_pr.txt > pyrefly_diff.txt || true
|
||||
|
||||
- name: Save PR number
|
||||
run: |
|
||||
|
||||
@@ -39,6 +39,7 @@ from . import (
|
||||
feature,
|
||||
human_input_form,
|
||||
init_validate,
|
||||
notification,
|
||||
ping,
|
||||
setup,
|
||||
spec,
|
||||
@@ -184,6 +185,7 @@ __all__ = [
|
||||
"model_config",
|
||||
"model_providers",
|
||||
"models",
|
||||
"notification",
|
||||
"oauth",
|
||||
"oauth_server",
|
||||
"ops_trace",
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import csv
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
@@ -6,7 +8,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
@@ -16,6 +18,7 @@ from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
from services.billing_service import BillingService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@@ -277,3 +280,168 @@ class DeleteExploreBannerApi(Resource):
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class LangContentPayload(BaseModel):
|
||||
lang: str = Field(..., description="Language tag: 'zh' | 'en' | 'jp'")
|
||||
title: str = Field(...)
|
||||
body: str = Field(...)
|
||||
cta_label: str = Field(...)
|
||||
cta_url: str = Field(...)
|
||||
|
||||
|
||||
class UpsertNotificationPayload(BaseModel):
|
||||
notification_id: str | None = Field(default=None, description="Omit to create; supply UUID to update")
|
||||
contents: list[LangContentPayload] = Field(..., min_length=1)
|
||||
start_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-01T00:00:00Z")
|
||||
end_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-20T23:59:59Z")
|
||||
frequency: str = Field(default="once", description="'once' | 'every_page_load'")
|
||||
status: str = Field(default="active", description="'active' | 'inactive'")
|
||||
|
||||
|
||||
class BatchAddNotificationAccountsPayload(BaseModel):
|
||||
notification_id: str = Field(...)
|
||||
user_email: list[str] = Field(..., description="List of account email addresses")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
UpsertNotificationPayload.__name__,
|
||||
UpsertNotificationPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
BatchAddNotificationAccountsPayload.__name__,
|
||||
BatchAddNotificationAccountsPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/admin/upsert_notification")
|
||||
class UpsertNotificationApi(Resource):
|
||||
@console_ns.doc("upsert_notification")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Create or update an in-product notification. "
|
||||
"Supply notification_id to update an existing one; omit it to create a new one. "
|
||||
"Pass at least one language variant in contents (zh / en / jp)."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[UpsertNotificationPayload.__name__])
|
||||
@console_ns.response(200, "Notification upserted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
|
||||
result = BillingService.upsert_notification(
|
||||
contents=[c.model_dump() for c in payload.contents],
|
||||
frequency=payload.frequency,
|
||||
status=payload.status,
|
||||
notification_id=payload.notification_id,
|
||||
start_time=payload.start_time,
|
||||
end_time=payload.end_time,
|
||||
)
|
||||
return {"result": "success", "notification_id": result.get("notificationId")}, 200
|
||||
|
||||
|
||||
@console_ns.route("/admin/batch_add_notification_accounts")
|
||||
class BatchAddNotificationAccountsApi(Resource):
|
||||
@console_ns.doc("batch_add_notification_accounts")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Register target accounts for a notification by email address. "
|
||||
'JSON body: {"notification_id": "...", "user_email": ["a@example.com", ...]}. '
|
||||
"File upload: multipart/form-data with a 'file' field (CSV or TXT, one email per line) "
|
||||
"plus a 'notification_id' field. "
|
||||
"Emails that do not match any account are silently skipped."
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Accounts added successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
from models.account import Account
|
||||
|
||||
if "file" in request.files:
|
||||
notification_id = request.form.get("notification_id", "").strip()
|
||||
if not notification_id:
|
||||
raise BadRequest("notification_id is required.")
|
||||
emails = self._parse_emails_from_file()
|
||||
else:
|
||||
payload = BatchAddNotificationAccountsPayload.model_validate(console_ns.payload)
|
||||
notification_id = payload.notification_id
|
||||
emails = payload.user_email
|
||||
|
||||
if not emails:
|
||||
raise BadRequest("No valid email addresses provided.")
|
||||
|
||||
# Resolve emails → account IDs in chunks to avoid large IN-clause
|
||||
account_ids: list[str] = []
|
||||
chunk_size = 500
|
||||
for i in range(0, len(emails), chunk_size):
|
||||
chunk = emails[i : i + chunk_size]
|
||||
rows = db.session.execute(select(Account.id, Account.email).where(Account.email.in_(chunk))).all()
|
||||
account_ids.extend(str(row.id) for row in rows)
|
||||
|
||||
if not account_ids:
|
||||
raise BadRequest("None of the provided emails matched an existing account.")
|
||||
|
||||
# Send to dify-saas in batches of 1000
|
||||
total_count = 0
|
||||
batch_size = 1000
|
||||
for i in range(0, len(account_ids), batch_size):
|
||||
batch = account_ids[i : i + batch_size]
|
||||
result = BillingService.batch_add_notification_accounts(
|
||||
notification_id=notification_id,
|
||||
account_ids=batch,
|
||||
)
|
||||
total_count += result.get("count", 0)
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"emails_provided": len(emails),
|
||||
"accounts_matched": len(account_ids),
|
||||
"count": total_count,
|
||||
}, 200
|
||||
|
||||
@staticmethod
|
||||
def _parse_emails_from_file() -> list[str]:
|
||||
"""Parse email addresses from an uploaded CSV or TXT file."""
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
raise BadRequest("Uploaded file has no filename.")
|
||||
|
||||
filename_lower = file.filename.lower()
|
||||
if not filename_lower.endswith((".csv", ".txt")):
|
||||
raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.")
|
||||
|
||||
try:
|
||||
content = file.read().decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
file.seek(0)
|
||||
content = file.read().decode("gbk")
|
||||
except UnicodeDecodeError:
|
||||
raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.")
|
||||
|
||||
emails: list[str] = []
|
||||
if filename_lower.endswith(".csv"):
|
||||
reader = csv.reader(io.StringIO(content))
|
||||
for row in reader:
|
||||
for cell in row:
|
||||
cell = cell.strip()
|
||||
if cell:
|
||||
emails.append(cell)
|
||||
else:
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
emails.append(line)
|
||||
|
||||
# Deduplicate while preserving order
|
||||
seen: set[str] = set()
|
||||
unique_emails: list[str] = []
|
||||
for email in emails:
|
||||
if email.lower() not in seen:
|
||||
seen.add(email.lower())
|
||||
unique_emails.append(email)
|
||||
|
||||
return unique_emails
|
||||
|
||||
80
api/controllers/console/notification.py
Normal file
80
api/controllers/console/notification.py
Normal file
@@ -0,0 +1,80 @@
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
# Notification content is stored under three lang tags.
|
||||
_FALLBACK_LANG = "en"
|
||||
|
||||
# Maps dify interface_language prefixes to notification lang tags.
|
||||
# Any unrecognised prefix falls back to _FALLBACK_LANG.
|
||||
_LANG_MAP: dict[str, str] = {
|
||||
"zh": "zh",
|
||||
"ja": "jp",
|
||||
}
|
||||
|
||||
|
||||
def _resolve_lang(interface_language: str | None) -> str:
|
||||
"""Derive the notification lang tag from the user's interface_language.
|
||||
|
||||
e.g. "zh-Hans" → "zh", "ja-JP" → "jp", "en-US" / None → "en"
|
||||
"""
|
||||
if not interface_language:
|
||||
return _FALLBACK_LANG
|
||||
prefix = interface_language.split("-")[0].lower()
|
||||
return _LANG_MAP.get(prefix, _FALLBACK_LANG)
|
||||
|
||||
|
||||
def _pick_lang_content(contents: dict, lang: str) -> dict:
|
||||
"""Return the single LangContent for *lang*, falling back to English."""
|
||||
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
|
||||
|
||||
|
||||
@console_ns.route("/notification")
|
||||
class NotificationApi(Resource):
|
||||
@console_ns.doc("get_notification")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Return the active in-product notification for the current user "
|
||||
"in their interface language (falls back to English if unavailable). "
|
||||
"Calling this endpoint also marks the notification as seen; subsequent "
|
||||
"calls return should_show=false when frequency='once'."
|
||||
),
|
||||
responses={
|
||||
200: "Success — inspect should_show to decide whether to render the modal",
|
||||
401: "Unauthorized",
|
||||
},
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
result = BillingService.get_account_notification(str(current_user.id))
|
||||
|
||||
# Proto JSON uses camelCase field names (Kratos default marshaling).
|
||||
if not result.get("shouldShow"):
|
||||
return {"should_show": False}, 200
|
||||
|
||||
notification = result.get("notification") or {}
|
||||
contents: dict = notification.get("contents") or {}
|
||||
|
||||
lang = _resolve_lang(current_user.interface_language)
|
||||
lang_content = _pick_lang_content(contents, lang)
|
||||
|
||||
return {
|
||||
"should_show": True,
|
||||
"notification": {
|
||||
"notification_id": notification.get("notificationId"),
|
||||
"frequency": notification.get("frequency"),
|
||||
"lang": lang_content.get("lang", lang),
|
||||
"title": lang_content.get("title", ""),
|
||||
"body": lang_content.get("body", ""),
|
||||
"cta_label": lang_content.get("ctaLabel", ""),
|
||||
"cta_url": lang_content.get("ctaUrl", ""),
|
||||
},
|
||||
}, 200
|
||||
@@ -157,7 +157,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
id=self._message_id,
|
||||
mode=self._conversation_mode,
|
||||
message_id=self._message_id,
|
||||
answer=cast(str, self._task_state.llm_result.message.content),
|
||||
answer=self._task_state.llm_result.message.get_text_content(),
|
||||
created_at=self._message_created_at,
|
||||
**extras,
|
||||
),
|
||||
@@ -170,7 +170,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
mode=self._conversation_mode,
|
||||
conversation_id=self._conversation_id,
|
||||
message_id=self._message_id,
|
||||
answer=cast(str, self._task_state.llm_result.message.content),
|
||||
answer=self._task_state.llm_result.message.get_text_content(),
|
||||
created_at=self._message_created_at,
|
||||
**extras,
|
||||
),
|
||||
@@ -283,7 +283,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
|
||||
# handle output moderation
|
||||
output_moderation_answer = self.handle_output_moderation_when_task_finished(
|
||||
cast(str, self._task_state.llm_result.message.content)
|
||||
self._task_state.llm_result.message.get_text_content()
|
||||
)
|
||||
if output_moderation_answer:
|
||||
self._task_state.llm_result.message.content = output_moderation_answer
|
||||
@@ -397,7 +397,7 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
message.message_unit_price = usage.prompt_unit_price
|
||||
message.message_price_unit = usage.prompt_price_unit
|
||||
message.answer = (
|
||||
PromptTemplateParser.remove_template_variables(cast(str, llm_result.message.content).strip())
|
||||
PromptTemplateParser.remove_template_variables(llm_result.message.get_text_content().strip())
|
||||
if llm_result.message.content
|
||||
else ""
|
||||
)
|
||||
|
||||
@@ -155,6 +155,26 @@ def wrap_span_metadata(metadata, **kwargs):
|
||||
return metadata
|
||||
|
||||
|
||||
# Mapping from NodeType string values to OpenInference span kinds.
|
||||
# NodeType values not listed here default to CHAIN.
|
||||
_NODE_TYPE_TO_SPAN_KIND: dict[str, OpenInferenceSpanKindValues] = {
|
||||
"llm": OpenInferenceSpanKindValues.LLM,
|
||||
"knowledge-retrieval": OpenInferenceSpanKindValues.RETRIEVER,
|
||||
"tool": OpenInferenceSpanKindValues.TOOL,
|
||||
"agent": OpenInferenceSpanKindValues.AGENT,
|
||||
}
|
||||
|
||||
|
||||
def _get_node_span_kind(node_type: str) -> OpenInferenceSpanKindValues:
|
||||
"""Return the OpenInference span kind for a given workflow node type.
|
||||
|
||||
Covers every ``NodeType`` enum value. Nodes that do not have a
|
||||
specialised span kind (e.g. ``start``, ``end``, ``if-else``,
|
||||
``code``, ``loop``, ``iteration``, etc.) are mapped to ``CHAIN``.
|
||||
"""
|
||||
return _NODE_TYPE_TO_SPAN_KIND.get(node_type, OpenInferenceSpanKindValues.CHAIN)
|
||||
|
||||
|
||||
class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -289,9 +309,8 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
)
|
||||
|
||||
# Determine the correct span kind based on node type
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN
|
||||
span_kind = _get_node_span_kind(node_execution.node_type)
|
||||
if node_execution.node_type == "llm":
|
||||
span_kind = OpenInferenceSpanKindValues.LLM
|
||||
provider = process_data.get("model_provider")
|
||||
model = process_data.get("model_name")
|
||||
if provider:
|
||||
@@ -306,12 +325,6 @@ class ArizePhoenixDataTrace(BaseTraceInstance):
|
||||
node_metadata["total_tokens"] = usage_data.get("total_tokens", 0)
|
||||
node_metadata["prompt_tokens"] = usage_data.get("prompt_tokens", 0)
|
||||
node_metadata["completion_tokens"] = usage_data.get("completion_tokens", 0)
|
||||
elif node_execution.node_type == "dataset_retrieval":
|
||||
span_kind = OpenInferenceSpanKindValues.RETRIEVER
|
||||
elif node_execution.node_type == "tool":
|
||||
span_kind = OpenInferenceSpanKindValues.TOOL
|
||||
else:
|
||||
span_kind = OpenInferenceSpanKindValues.CHAIN
|
||||
|
||||
workflow_span_context = set_span_in_context(workflow_span)
|
||||
node_span = self.tracer.start_span(
|
||||
|
||||
@@ -68,7 +68,7 @@ dependencies = [
|
||||
"pydantic~=2.12.5",
|
||||
"pydantic-extra-types~=2.10.3",
|
||||
"pydantic-settings~=2.12.0",
|
||||
"pyjwt~=2.10.1",
|
||||
"pyjwt~=2.11.0",
|
||||
"pypdfium2==5.2.0",
|
||||
"python-docx~=1.2.0",
|
||||
"python-dotenv==1.0.1",
|
||||
@@ -124,7 +124,7 @@ dev = [
|
||||
"pytest-env~=1.1.3",
|
||||
"pytest-mock~=3.14.0",
|
||||
"testcontainers~=4.13.2",
|
||||
"types-aiofiles~=24.1.0",
|
||||
"types-aiofiles~=25.1.0",
|
||||
"types-beautifulsoup4~=4.12.0",
|
||||
"types-cachetools~=5.5.0",
|
||||
"types-colorama~=0.4.15",
|
||||
|
||||
@@ -1,9 +1,7 @@
|
||||
project-includes = ["."]
|
||||
project-excludes = [
|
||||
"tests/",
|
||||
".venv",
|
||||
"migrations/",
|
||||
"core/rag",
|
||||
]
|
||||
python-platform = "linux"
|
||||
python-version = "3.11.0"
|
||||
|
||||
@@ -131,33 +131,54 @@ class AppGenerateService:
|
||||
elif app_model.mode == AppMode.ADVANCED_CHAT:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
with rate_limit_context(rate_limit, request_id):
|
||||
payload = AppExecutionParams.new(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
call_depth=0,
|
||||
)
|
||||
payload_json = payload.model_dump_json()
|
||||
|
||||
def on_subscribe():
|
||||
workflow_based_app_execution_task.delay(payload_json)
|
||||
if streaming:
|
||||
# Streaming mode: subscribe to SSE and enqueue the execution on first subscriber
|
||||
with rate_limit_context(rate_limit, request_id):
|
||||
payload = AppExecutionParams.new(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=True,
|
||||
call_depth=0,
|
||||
)
|
||||
payload_json = payload.model_dump_json()
|
||||
|
||||
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
|
||||
generator = AdvancedChatAppGenerator()
|
||||
return rate_limit.generate(
|
||||
generator.convert_to_event_stream(
|
||||
generator.retrieve_events(
|
||||
AppMode.ADVANCED_CHAT,
|
||||
payload.workflow_run_id,
|
||||
on_subscribe=on_subscribe,
|
||||
def on_subscribe():
|
||||
workflow_based_app_execution_task.delay(payload_json)
|
||||
|
||||
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
|
||||
generator = AdvancedChatAppGenerator()
|
||||
return rate_limit.generate(
|
||||
generator.convert_to_event_stream(
|
||||
generator.retrieve_events(
|
||||
AppMode.ADVANCED_CHAT,
|
||||
payload.workflow_run_id,
|
||||
on_subscribe=on_subscribe,
|
||||
),
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
request_id=request_id,
|
||||
)
|
||||
else:
|
||||
# Blocking mode: run synchronously and return JSON instead of SSE
|
||||
# Keep behaviour consistent with WORKFLOW blocking branch.
|
||||
advanced_generator = AdvancedChatAppGenerator()
|
||||
return rate_limit.generate(
|
||||
advanced_generator.convert_to_event_stream(
|
||||
advanced_generator.generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
streaming=False,
|
||||
)
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
elif app_model.mode == AppMode.WORKFLOW:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
|
||||
@@ -393,3 +393,66 @@ class BillingService:
|
||||
for item in data:
|
||||
tenant_whitelist.append(item["tenant_id"])
|
||||
return tenant_whitelist
|
||||
|
||||
@classmethod
|
||||
def get_account_notification(cls, account_id: str) -> dict:
|
||||
"""Return the active in-product notification for account_id, if any.
|
||||
|
||||
Calling this endpoint also marks the notification as seen; subsequent
|
||||
calls will return should_show=false when frequency='once'.
|
||||
|
||||
Response shape (mirrors GetAccountNotificationReply):
|
||||
{
|
||||
"should_show": bool,
|
||||
"notification": { # present only when should_show=true
|
||||
"notification_id": str,
|
||||
"contents": { # lang -> LangContent
|
||||
"en": {"lang": "en", "title": ..., "body": ..., "cta_label": ..., "cta_url": ...},
|
||||
...
|
||||
},
|
||||
"frequency": "once" | "every_page_load"
|
||||
}
|
||||
}
|
||||
"""
|
||||
return cls._send_request("GET", "/notifications/active", params={"account_id": account_id})
|
||||
|
||||
@classmethod
|
||||
def upsert_notification(
|
||||
cls,
|
||||
contents: list[dict],
|
||||
frequency: str = "once",
|
||||
status: str = "active",
|
||||
notification_id: str | None = None,
|
||||
start_time: str | None = None,
|
||||
end_time: str | None = None,
|
||||
) -> dict:
|
||||
"""Create or update a notification.
|
||||
|
||||
contents: list of {"lang": str, "title": str, "body": str, "cta_label": str, "cta_url": str}
|
||||
start_time / end_time: RFC3339 strings (e.g. "2026-03-01T00:00:00Z"), optional.
|
||||
Returns {"notification_id": str}.
|
||||
"""
|
||||
payload: dict = {
|
||||
"contents": contents,
|
||||
"frequency": frequency,
|
||||
"status": status,
|
||||
}
|
||||
if notification_id:
|
||||
payload["notification_id"] = notification_id
|
||||
if start_time:
|
||||
payload["start_time"] = start_time
|
||||
if end_time:
|
||||
payload["end_time"] = end_time
|
||||
return cls._send_request("POST", "/notifications", json=payload)
|
||||
|
||||
@classmethod
|
||||
def batch_add_notification_accounts(cls, notification_id: str, account_ids: list[str]) -> dict:
|
||||
"""Register target account IDs for a notification (max 1000 per call).
|
||||
|
||||
Returns {"count": int}.
|
||||
"""
|
||||
return cls._send_request(
|
||||
"POST",
|
||||
f"/notifications/{notification_id}/accounts",
|
||||
json={"account_ids": account_ids},
|
||||
)
|
||||
|
||||
36
api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py
Normal file
36
api/tests/unit_tests/core/ops/test_arize_phoenix_trace.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from openinference.semconv.trace import OpenInferenceSpanKindValues
|
||||
|
||||
from core.ops.arize_phoenix_trace.arize_phoenix_trace import _NODE_TYPE_TO_SPAN_KIND, _get_node_span_kind
|
||||
from core.workflow.enums import NodeType
|
||||
|
||||
|
||||
class TestGetNodeSpanKind:
|
||||
"""Tests for _get_node_span_kind helper."""
|
||||
|
||||
def test_all_node_types_are_mapped_correctly(self):
|
||||
"""Ensure every NodeType enum member is mapped to the correct span kind."""
|
||||
# Mappings for node types that have a specialised span kind.
|
||||
special_mappings = {
|
||||
NodeType.LLM: OpenInferenceSpanKindValues.LLM,
|
||||
NodeType.KNOWLEDGE_RETRIEVAL: OpenInferenceSpanKindValues.RETRIEVER,
|
||||
NodeType.TOOL: OpenInferenceSpanKindValues.TOOL,
|
||||
NodeType.AGENT: OpenInferenceSpanKindValues.AGENT,
|
||||
}
|
||||
|
||||
# Test that every NodeType enum member is mapped to the correct span kind.
|
||||
# Node types not in `special_mappings` should default to CHAIN.
|
||||
for node_type in NodeType:
|
||||
expected_span_kind = special_mappings.get(node_type, OpenInferenceSpanKindValues.CHAIN)
|
||||
actual_span_kind = _get_node_span_kind(node_type)
|
||||
assert actual_span_kind == expected_span_kind, (
|
||||
f"NodeType.{node_type.name} was mapped to {actual_span_kind}, but {expected_span_kind} was expected."
|
||||
)
|
||||
|
||||
def test_unknown_string_defaults_to_chain(self):
|
||||
"""An unrecognised node type string should still return CHAIN."""
|
||||
assert _get_node_span_kind("some-future-node-type") == OpenInferenceSpanKindValues.CHAIN
|
||||
|
||||
def test_stale_dataset_retrieval_not_in_mapping(self):
|
||||
"""The old 'dataset_retrieval' string was never a valid NodeType value;
|
||||
make sure it is not present in the mapping dictionary."""
|
||||
assert "dataset_retrieval" not in _NODE_TYPE_TO_SPAN_KIND
|
||||
@@ -63,3 +63,56 @@ def test_workflow_blocking_injects_pause_state_config(mocker, monkeypatch):
|
||||
pause_state_config = call_kwargs.get("pause_state_config")
|
||||
assert pause_state_config is not None
|
||||
assert pause_state_config.state_owner_user_id == "owner-id"
|
||||
|
||||
|
||||
def test_advanced_chat_blocking_returns_dict_and_does_not_use_event_retrieval(mocker, monkeypatch):
|
||||
"""
|
||||
Regression test: ADVANCED_CHAT in blocking mode should return a plain dict
|
||||
(non-streaming), and must not go through the async retrieve_events path.
|
||||
Keeps behavior consistent with WORKFLOW blocking branch.
|
||||
"""
|
||||
# Disable billing and stub RateLimit to a no-op that just passes values through
|
||||
monkeypatch.setattr(app_generate_service_module.dify_config, "BILLING_ENABLED", False)
|
||||
mocker.patch("services.app_generate_service.RateLimit", _DummyRateLimit)
|
||||
|
||||
# Arrange a fake workflow and wire AppGenerateService._get_workflow to return it
|
||||
workflow = MagicMock()
|
||||
workflow.id = "workflow-id"
|
||||
mocker.patch.object(AppGenerateService, "_get_workflow", return_value=workflow)
|
||||
|
||||
# Spy on the streaming retrieval path to ensure it's NOT called
|
||||
retrieve_spy = mocker.patch("services.app_generate_service.AdvancedChatAppGenerator.retrieve_events")
|
||||
|
||||
# Make AdvancedChatAppGenerator.generate return a plain dict when streaming=False
|
||||
generate_spy = mocker.patch(
|
||||
"services.app_generate_service.AdvancedChatAppGenerator.generate",
|
||||
return_value={"result": "ok"},
|
||||
)
|
||||
|
||||
# Minimal app model for ADVANCED_CHAT
|
||||
app_model = MagicMock()
|
||||
app_model.mode = AppMode.ADVANCED_CHAT
|
||||
app_model.id = "app-id"
|
||||
app_model.tenant_id = "tenant-id"
|
||||
app_model.max_active_requests = 0
|
||||
app_model.is_agent = False
|
||||
|
||||
user = MagicMock()
|
||||
user.id = "user-id"
|
||||
|
||||
# Must include query and inputs for AdvancedChatAppGenerator
|
||||
args = {"workflow_id": "wf-1", "query": "hello", "inputs": {}}
|
||||
|
||||
# Act: call service with streaming=False (blocking mode)
|
||||
result = AppGenerateService.generate(
|
||||
app_model=app_model,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=MagicMock(),
|
||||
streaming=False,
|
||||
)
|
||||
|
||||
# Assert: returns the dict from generate(), and did not call retrieve_events()
|
||||
assert result == {"result": "ok"}
|
||||
assert generate_spy.call_args.kwargs.get("streaming") is False
|
||||
retrieve_spy.assert_not_called()
|
||||
|
||||
16
api/uv.lock
generated
16
api/uv.lock
generated
@@ -1636,7 +1636,7 @@ requires-dist = [
|
||||
{ name = "pydantic", specifier = "~=2.12.5" },
|
||||
{ name = "pydantic-extra-types", specifier = "~=2.10.3" },
|
||||
{ name = "pydantic-settings", specifier = "~=2.12.0" },
|
||||
{ name = "pyjwt", specifier = "~=2.10.1" },
|
||||
{ name = "pyjwt", specifier = "~=2.11.0" },
|
||||
{ name = "pypdfium2", specifier = "==5.2.0" },
|
||||
{ name = "python-docx", specifier = "~=1.2.0" },
|
||||
{ name = "python-dotenv", specifier = "==1.0.1" },
|
||||
@@ -1683,7 +1683,7 @@ dev = [
|
||||
{ name = "scipy-stubs", specifier = ">=1.15.3.0" },
|
||||
{ name = "sseclient-py", specifier = ">=1.8.0" },
|
||||
{ name = "testcontainers", specifier = "~=4.13.2" },
|
||||
{ name = "types-aiofiles", specifier = "~=24.1.0" },
|
||||
{ name = "types-aiofiles", specifier = "~=25.1.0" },
|
||||
{ name = "types-beautifulsoup4", specifier = "~=4.12.0" },
|
||||
{ name = "types-cachetools", specifier = "~=5.5.0" },
|
||||
{ name = "types-cffi", specifier = ">=1.17.0" },
|
||||
@@ -4957,11 +4957,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pyjwt"
|
||||
version = "2.10.1"
|
||||
version = "2.11.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/e7/46/bd74733ff231675599650d3e47f361794b22ef3e3770998dda30d3b63726/pyjwt-2.10.1.tar.gz", hash = "sha256:3cc5772eb20009233caf06e9d8a0577824723b44e6648ee0a2aedb6cf9381953", size = 87785, upload-time = "2024-11-28T03:43:29.933Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/5c/5a/b46fa56bf322901eee5b0454a34343cdbdae202cd421775a8ee4e42fd519/pyjwt-2.11.0.tar.gz", hash = "sha256:35f95c1f0fbe5d5ba6e43f00271c275f7a1a4db1dab27bf708073b75318ea623", size = 98019, upload-time = "2026-01-30T19:59:55.694Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/61/ad/689f02752eeec26aed679477e80e632ef1b682313be70793d798c1d5fc8f/PyJWT-2.10.1-py3-none-any.whl", hash = "sha256:dcdd193e30abefd5debf142f9adfcdd2b58004e644f25406ffaebd50bd98dacb", size = 22997, upload-time = "2024-11-28T03:43:27.893Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/6f/01/c26ce75ba460d5cd503da9e13b21a33804d38c2165dec7b716d06b13010c/pyjwt-2.11.0-py3-none-any.whl", hash = "sha256:94a6bde30eb5c8e04fee991062b534071fd1439ef58d2adc9ccb823e7bcd0469", size = 28224, upload-time = "2026-01-30T19:59:54.539Z" },
|
||||
]
|
||||
|
||||
[package.optional-dependencies]
|
||||
@@ -6293,11 +6293,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "types-aiofiles"
|
||||
version = "24.1.0.20250822"
|
||||
version = "25.1.0.20251011"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/19/48/c64471adac9206cc844afb33ed311ac5a65d2f59df3d861e0f2d0cad7414/types_aiofiles-24.1.0.20250822.tar.gz", hash = "sha256:9ab90d8e0c307fe97a7cf09338301e3f01a163e39f3b529ace82466355c84a7b", size = 14484, upload-time = "2025-08-22T03:02:23.039Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/84/6c/6d23908a8217e36704aa9c79d99a620f2fdd388b66a4b7f72fbc6b6ff6c6/types_aiofiles-25.1.0.20251011.tar.gz", hash = "sha256:1c2b8ab260cb3cd40c15f9d10efdc05a6e1e6b02899304d80dfa0410e028d3ff", size = 14535, upload-time = "2025-10-11T02:44:51.237Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/bc/8e/5e6d2215e1d8f7c2a94c6e9d0059ae8109ce0f5681956d11bb0a228cef04/types_aiofiles-24.1.0.20250822-py3-none-any.whl", hash = "sha256:0ec8f8909e1a85a5a79aed0573af7901f53120dd2a29771dd0b3ef48e12328b0", size = 14322, upload-time = "2025-08-22T03:02:21.918Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/71/0f/76917bab27e270bb6c32addd5968d69e558e5b6f7fb4ac4cbfa282996a96/types_aiofiles-25.1.0.20251011-py3-none-any.whl", hash = "sha256:8ff8de7f9d42739d8f0dadcceeb781ce27cd8d8c4152d4a7c52f6b20edb8149c", size = 14338, upload-time = "2025-10-11T02:44:50.054Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -77,11 +77,11 @@ const Toast = ({
|
||||
</div>
|
||||
<div className={cn('flex grow flex-col items-start gap-1 py-1', size === 'md' ? 'px-1' : 'px-0.5')}>
|
||||
<div className="flex items-center gap-1">
|
||||
<div className="text-text-primary system-sm-semibold [word-break:break-word]">{message}</div>
|
||||
<div className="system-sm-semibold text-text-primary [word-break:break-word]">{message}</div>
|
||||
{customComponent}
|
||||
</div>
|
||||
{!!children && (
|
||||
<div className="text-text-secondary system-xs-regular">
|
||||
<div className="system-xs-regular text-text-secondary">
|
||||
{children}
|
||||
</div>
|
||||
)}
|
||||
@@ -149,26 +149,25 @@ Toast.notify = ({
|
||||
if (typeof window === 'object') {
|
||||
const holder = document.createElement('div')
|
||||
const root = createRoot(holder)
|
||||
let timerId: ReturnType<typeof setTimeout> | undefined
|
||||
|
||||
const unmountAndRemove = () => {
|
||||
if (timerId) {
|
||||
clearTimeout(timerId)
|
||||
timerId = undefined
|
||||
}
|
||||
if (typeof window !== 'undefined' && holder) {
|
||||
toastHandler.clear = () => {
|
||||
if (holder) {
|
||||
root.unmount()
|
||||
holder.remove()
|
||||
}
|
||||
onClose?.()
|
||||
}
|
||||
|
||||
toastHandler.clear = unmountAndRemove
|
||||
|
||||
root.render(
|
||||
<ToastContext.Provider value={{
|
||||
notify: noop,
|
||||
close: unmountAndRemove,
|
||||
close: () => {
|
||||
if (holder) {
|
||||
root.unmount()
|
||||
holder.remove()
|
||||
}
|
||||
onClose?.()
|
||||
},
|
||||
}}
|
||||
>
|
||||
<Toast type={type} size={size} message={message} duration={duration} className={className} customComponent={customComponent} />
|
||||
@@ -177,7 +176,7 @@ Toast.notify = ({
|
||||
document.body.appendChild(holder)
|
||||
const d = duration ?? defaultDuring
|
||||
if (d > 0)
|
||||
timerId = setTimeout(unmountAndRemove, d)
|
||||
setTimeout(toastHandler.clear, d)
|
||||
}
|
||||
|
||||
return toastHandler
|
||||
|
||||
@@ -1856,6 +1856,11 @@
|
||||
"count": 4
|
||||
}
|
||||
},
|
||||
"app/components/base/file-uploader/utils.spec.ts": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"app/components/base/file-uploader/utils.ts": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 3
|
||||
@@ -2028,6 +2033,11 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/base/input/index.spec.tsx": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/base/input/index.stories.tsx": {
|
||||
"no-console": {
|
||||
"count": 2
|
||||
@@ -2580,6 +2590,9 @@
|
||||
"app/components/base/toast/index.tsx": {
|
||||
"react-refresh/only-export-components": {
|
||||
"count": 2
|
||||
},
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"app/components/base/tooltip/index.tsx": {
|
||||
@@ -2605,6 +2618,11 @@
|
||||
"count": 4
|
||||
}
|
||||
},
|
||||
"app/components/base/with-input-validation/index.spec.tsx": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"app/components/base/with-input-validation/index.stories.tsx": {
|
||||
"no-console": {
|
||||
"count": 1
|
||||
|
||||
Reference in New Issue
Block a user