mirror of
https://github.com/langgenius/dify.git
synced 2026-03-03 05:55:18 +00:00
Compare commits
6 Commits
yanli/pyda
...
deploy/dev
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6473c1419b | ||
|
|
d1a0b9695c | ||
|
|
3147e44a0b | ||
|
|
c243e91668 | ||
|
|
004fbbe52b | ||
|
|
63fb0ddde5 |
@@ -39,6 +39,7 @@ from . import (
|
|||||||
feature,
|
feature,
|
||||||
human_input_form,
|
human_input_form,
|
||||||
init_validate,
|
init_validate,
|
||||||
|
notification,
|
||||||
ping,
|
ping,
|
||||||
setup,
|
setup,
|
||||||
spec,
|
spec,
|
||||||
@@ -184,6 +185,7 @@ __all__ = [
|
|||||||
"model_config",
|
"model_config",
|
||||||
"model_providers",
|
"model_providers",
|
||||||
"models",
|
"models",
|
||||||
|
"notification",
|
||||||
"oauth",
|
"oauth",
|
||||||
"oauth_server",
|
"oauth_server",
|
||||||
"ops_trace",
|
"ops_trace",
|
||||||
|
|||||||
@@ -1,3 +1,5 @@
|
|||||||
|
import csv
|
||||||
|
import io
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from functools import wraps
|
from functools import wraps
|
||||||
from typing import ParamSpec, TypeVar
|
from typing import ParamSpec, TypeVar
|
||||||
@@ -6,7 +8,7 @@ from flask import request
|
|||||||
from flask_restx import Resource
|
from flask_restx import Resource
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
from werkzeug.exceptions import NotFound, Unauthorized
|
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||||
|
|
||||||
from configs import dify_config
|
from configs import dify_config
|
||||||
from constants.languages import supported_language
|
from constants.languages import supported_language
|
||||||
@@ -16,6 +18,7 @@ from core.db.session_factory import session_factory
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from libs.token import extract_access_token
|
from libs.token import extract_access_token
|
||||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||||
|
from services.billing_service import BillingService
|
||||||
|
|
||||||
P = ParamSpec("P")
|
P = ParamSpec("P")
|
||||||
R = TypeVar("R")
|
R = TypeVar("R")
|
||||||
@@ -277,3 +280,170 @@ class DeleteExploreBannerApi(Resource):
|
|||||||
db.session.commit()
|
db.session.commit()
|
||||||
|
|
||||||
return {"result": "success"}, 204
|
return {"result": "success"}, 204
|
||||||
|
|
||||||
|
|
||||||
|
class LangContentPayload(BaseModel):
|
||||||
|
lang: str = Field(..., description="Language tag: 'zh' | 'en' | 'jp'")
|
||||||
|
title: str = Field(...)
|
||||||
|
body: str = Field(...)
|
||||||
|
cta_label: str = Field(...)
|
||||||
|
cta_url: str = Field(...)
|
||||||
|
|
||||||
|
|
||||||
|
class UpsertNotificationPayload(BaseModel):
|
||||||
|
notification_id: str | None = Field(default=None, description="Omit to create; supply UUID to update")
|
||||||
|
contents: list[LangContentPayload] = Field(..., min_length=1)
|
||||||
|
start_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-01T00:00:00Z")
|
||||||
|
end_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-20T23:59:59Z")
|
||||||
|
frequency: str = Field(default="once", description="'once' | 'every_page_load'")
|
||||||
|
status: str = Field(default="active", description="'active' | 'inactive'")
|
||||||
|
|
||||||
|
|
||||||
|
class BatchAddNotificationAccountsPayload(BaseModel):
|
||||||
|
notification_id: str = Field(...)
|
||||||
|
user_email: list[str] = Field(..., description="List of account email addresses")
|
||||||
|
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
UpsertNotificationPayload.__name__,
|
||||||
|
UpsertNotificationPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
|
console_ns.schema_model(
|
||||||
|
BatchAddNotificationAccountsPayload.__name__,
|
||||||
|
BatchAddNotificationAccountsPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/admin/upsert_notification")
|
||||||
|
class UpsertNotificationApi(Resource):
|
||||||
|
@console_ns.doc("upsert_notification")
|
||||||
|
@console_ns.doc(
|
||||||
|
description=(
|
||||||
|
"Create or update an in-product notification. "
|
||||||
|
"Supply notification_id to update an existing one; omit it to create a new one. "
|
||||||
|
"Pass at least one language variant in contents (zh / en / jp)."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@console_ns.expect(console_ns.models[UpsertNotificationPayload.__name__])
|
||||||
|
@console_ns.response(200, "Notification upserted successfully")
|
||||||
|
@only_edition_cloud
|
||||||
|
@admin_required
|
||||||
|
def post(self):
|
||||||
|
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
|
||||||
|
result = BillingService.upsert_notification(
|
||||||
|
contents=[c.model_dump() for c in payload.contents],
|
||||||
|
frequency=payload.frequency,
|
||||||
|
status=payload.status,
|
||||||
|
notification_id=payload.notification_id,
|
||||||
|
start_time=payload.start_time,
|
||||||
|
end_time=payload.end_time,
|
||||||
|
)
|
||||||
|
return {"result": "success", "notification_id": result.get("notificationId")}, 200
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/admin/batch_add_notification_accounts")
|
||||||
|
class BatchAddNotificationAccountsApi(Resource):
|
||||||
|
@console_ns.doc("batch_add_notification_accounts")
|
||||||
|
@console_ns.doc(
|
||||||
|
description=(
|
||||||
|
"Register target accounts for a notification by email address. "
|
||||||
|
'JSON body: {"notification_id": "...", "user_email": ["a@example.com", ...]}. '
|
||||||
|
"File upload: multipart/form-data with a 'file' field (CSV or TXT, one email per line) "
|
||||||
|
"plus a 'notification_id' field. "
|
||||||
|
"Emails that do not match any account are silently skipped."
|
||||||
|
)
|
||||||
|
)
|
||||||
|
@console_ns.response(200, "Accounts added successfully")
|
||||||
|
@only_edition_cloud
|
||||||
|
@admin_required
|
||||||
|
def post(self):
|
||||||
|
from models.account import Account
|
||||||
|
|
||||||
|
if "file" in request.files:
|
||||||
|
notification_id = request.form.get("notification_id", "").strip()
|
||||||
|
if not notification_id:
|
||||||
|
raise BadRequest("notification_id is required.")
|
||||||
|
emails = self._parse_emails_from_file()
|
||||||
|
else:
|
||||||
|
payload = BatchAddNotificationAccountsPayload.model_validate(console_ns.payload)
|
||||||
|
notification_id = payload.notification_id
|
||||||
|
emails = payload.user_email
|
||||||
|
|
||||||
|
if not emails:
|
||||||
|
raise BadRequest("No valid email addresses provided.")
|
||||||
|
|
||||||
|
# Resolve emails → account IDs in chunks to avoid large IN-clause
|
||||||
|
account_ids: list[str] = []
|
||||||
|
chunk_size = 500
|
||||||
|
for i in range(0, len(emails), chunk_size):
|
||||||
|
chunk = emails[i : i + chunk_size]
|
||||||
|
rows = db.session.execute(
|
||||||
|
select(Account.id, Account.email).where(Account.email.in_(chunk))
|
||||||
|
).all()
|
||||||
|
account_ids.extend(str(row.id) for row in rows)
|
||||||
|
|
||||||
|
if not account_ids:
|
||||||
|
raise BadRequest("None of the provided emails matched an existing account.")
|
||||||
|
|
||||||
|
# Send to dify-saas in batches of 1000
|
||||||
|
total_count = 0
|
||||||
|
batch_size = 1000
|
||||||
|
for i in range(0, len(account_ids), batch_size):
|
||||||
|
batch = account_ids[i : i + batch_size]
|
||||||
|
result = BillingService.batch_add_notification_accounts(
|
||||||
|
notification_id=notification_id,
|
||||||
|
account_ids=batch,
|
||||||
|
)
|
||||||
|
total_count += result.get("count", 0)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"result": "success",
|
||||||
|
"emails_provided": len(emails),
|
||||||
|
"accounts_matched": len(account_ids),
|
||||||
|
"count": total_count,
|
||||||
|
}, 200
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _parse_emails_from_file() -> list[str]:
|
||||||
|
"""Parse email addresses from an uploaded CSV or TXT file."""
|
||||||
|
file = request.files["file"]
|
||||||
|
if not file.filename:
|
||||||
|
raise BadRequest("Uploaded file has no filename.")
|
||||||
|
|
||||||
|
filename_lower = file.filename.lower()
|
||||||
|
if not filename_lower.endswith((".csv", ".txt")):
|
||||||
|
raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.")
|
||||||
|
|
||||||
|
try:
|
||||||
|
content = file.read().decode("utf-8")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
try:
|
||||||
|
file.seek(0)
|
||||||
|
content = file.read().decode("gbk")
|
||||||
|
except UnicodeDecodeError:
|
||||||
|
raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.")
|
||||||
|
|
||||||
|
emails: list[str] = []
|
||||||
|
if filename_lower.endswith(".csv"):
|
||||||
|
reader = csv.reader(io.StringIO(content))
|
||||||
|
for row in reader:
|
||||||
|
for cell in row:
|
||||||
|
cell = cell.strip()
|
||||||
|
if cell:
|
||||||
|
emails.append(cell)
|
||||||
|
else:
|
||||||
|
for line in content.splitlines():
|
||||||
|
line = line.strip()
|
||||||
|
if line:
|
||||||
|
emails.append(line)
|
||||||
|
|
||||||
|
# Deduplicate while preserving order
|
||||||
|
seen: set[str] = set()
|
||||||
|
unique_emails: list[str] = []
|
||||||
|
for email in emails:
|
||||||
|
if email.lower() not in seen:
|
||||||
|
seen.add(email.lower())
|
||||||
|
unique_emails.append(email)
|
||||||
|
|
||||||
|
return unique_emails
|
||||||
|
|||||||
80
api/controllers/console/notification.py
Normal file
80
api/controllers/console/notification.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
from flask_restx import Resource
|
||||||
|
|
||||||
|
from controllers.console import console_ns
|
||||||
|
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||||
|
from libs.login import current_account_with_tenant, login_required
|
||||||
|
from services.billing_service import BillingService
|
||||||
|
|
||||||
|
# Notification content is stored under three lang tags.
|
||||||
|
_FALLBACK_LANG = "en"
|
||||||
|
|
||||||
|
# Maps dify interface_language prefixes to notification lang tags.
|
||||||
|
# Any unrecognised prefix falls back to _FALLBACK_LANG.
|
||||||
|
_LANG_MAP: dict[str, str] = {
|
||||||
|
"zh": "zh",
|
||||||
|
"ja": "jp",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_lang(interface_language: str | None) -> str:
|
||||||
|
"""Derive the notification lang tag from the user's interface_language.
|
||||||
|
|
||||||
|
e.g. "zh-Hans" → "zh", "ja-JP" → "jp", "en-US" / None → "en"
|
||||||
|
"""
|
||||||
|
if not interface_language:
|
||||||
|
return _FALLBACK_LANG
|
||||||
|
prefix = interface_language.split("-")[0].lower()
|
||||||
|
return _LANG_MAP.get(prefix, _FALLBACK_LANG)
|
||||||
|
|
||||||
|
|
||||||
|
def _pick_lang_content(contents: dict, lang: str) -> dict:
|
||||||
|
"""Return the single LangContent for *lang*, falling back to English."""
|
||||||
|
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
|
||||||
|
|
||||||
|
|
||||||
|
@console_ns.route("/notification")
|
||||||
|
class NotificationApi(Resource):
|
||||||
|
@console_ns.doc("get_notification")
|
||||||
|
@console_ns.doc(
|
||||||
|
description=(
|
||||||
|
"Return the active in-product notification for the current user "
|
||||||
|
"in their interface language (falls back to English if unavailable). "
|
||||||
|
"Calling this endpoint also marks the notification as seen; subsequent "
|
||||||
|
"calls return should_show=false when frequency='once'."
|
||||||
|
),
|
||||||
|
responses={
|
||||||
|
200: "Success — inspect should_show to decide whether to render the modal",
|
||||||
|
401: "Unauthorized",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
@setup_required
|
||||||
|
@login_required
|
||||||
|
@account_initialization_required
|
||||||
|
@only_edition_cloud
|
||||||
|
def get(self):
|
||||||
|
current_user, _ = current_account_with_tenant()
|
||||||
|
|
||||||
|
result = BillingService.get_account_notification(str(current_user.id))
|
||||||
|
|
||||||
|
# Proto JSON uses camelCase field names (Kratos default marshaling).
|
||||||
|
if not result.get("shouldShow"):
|
||||||
|
return {"should_show": False}, 200
|
||||||
|
|
||||||
|
notification = result.get("notification") or {}
|
||||||
|
contents: dict = notification.get("contents") or {}
|
||||||
|
|
||||||
|
lang = _resolve_lang(current_user.interface_language)
|
||||||
|
lang_content = _pick_lang_content(contents, lang)
|
||||||
|
|
||||||
|
return {
|
||||||
|
"should_show": True,
|
||||||
|
"notification": {
|
||||||
|
"notification_id": notification.get("notificationId"),
|
||||||
|
"frequency": notification.get("frequency"),
|
||||||
|
"lang": lang_content.get("lang", lang),
|
||||||
|
"title": lang_content.get("title", ""),
|
||||||
|
"body": lang_content.get("body", ""),
|
||||||
|
"cta_label": lang_content.get("ctaLabel", ""),
|
||||||
|
"cta_url": lang_content.get("ctaUrl", ""),
|
||||||
|
},
|
||||||
|
}, 200
|
||||||
@@ -393,3 +393,66 @@ class BillingService:
|
|||||||
for item in data:
|
for item in data:
|
||||||
tenant_whitelist.append(item["tenant_id"])
|
tenant_whitelist.append(item["tenant_id"])
|
||||||
return tenant_whitelist
|
return tenant_whitelist
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def get_account_notification(cls, account_id: str) -> dict:
|
||||||
|
"""Return the active in-product notification for account_id, if any.
|
||||||
|
|
||||||
|
Calling this endpoint also marks the notification as seen; subsequent
|
||||||
|
calls will return should_show=false when frequency='once'.
|
||||||
|
|
||||||
|
Response shape (mirrors GetAccountNotificationReply):
|
||||||
|
{
|
||||||
|
"should_show": bool,
|
||||||
|
"notification": { # present only when should_show=true
|
||||||
|
"notification_id": str,
|
||||||
|
"contents": { # lang -> LangContent
|
||||||
|
"en": {"lang": "en", "title": ..., "body": ..., "cta_label": ..., "cta_url": ...},
|
||||||
|
...
|
||||||
|
},
|
||||||
|
"frequency": "once" | "every_page_load"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
"""
|
||||||
|
return cls._send_request("GET", "/notifications/active", params={"account_id": account_id})
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def upsert_notification(
|
||||||
|
cls,
|
||||||
|
contents: list[dict],
|
||||||
|
frequency: str = "once",
|
||||||
|
status: str = "active",
|
||||||
|
notification_id: str | None = None,
|
||||||
|
start_time: str | None = None,
|
||||||
|
end_time: str | None = None,
|
||||||
|
) -> dict:
|
||||||
|
"""Create or update a notification.
|
||||||
|
|
||||||
|
contents: list of {"lang": str, "title": str, "body": str, "cta_label": str, "cta_url": str}
|
||||||
|
start_time / end_time: RFC3339 strings (e.g. "2026-03-01T00:00:00Z"), optional.
|
||||||
|
Returns {"notification_id": str}.
|
||||||
|
"""
|
||||||
|
payload: dict = {
|
||||||
|
"contents": contents,
|
||||||
|
"frequency": frequency,
|
||||||
|
"status": status,
|
||||||
|
}
|
||||||
|
if notification_id:
|
||||||
|
payload["notification_id"] = notification_id
|
||||||
|
if start_time:
|
||||||
|
payload["start_time"] = start_time
|
||||||
|
if end_time:
|
||||||
|
payload["end_time"] = end_time
|
||||||
|
return cls._send_request("POST", "/notifications", json=payload)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def batch_add_notification_accounts(cls, notification_id: str, account_ids: list[str]) -> dict:
|
||||||
|
"""Register target account IDs for a notification (max 1000 per call).
|
||||||
|
|
||||||
|
Returns {"count": int}.
|
||||||
|
"""
|
||||||
|
return cls._send_request(
|
||||||
|
"POST",
|
||||||
|
f"/notifications/{notification_id}/accounts",
|
||||||
|
json={"account_ids": account_ids},
|
||||||
|
)
|
||||||
|
|||||||
Reference in New Issue
Block a user