mirror of
https://github.com/langgenius/dify.git
synced 2026-03-11 10:07:05 +00:00
Compare commits
1 Commits
deploy/dev
...
fix/vinext
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a0e778258a |
35
.github/dependabot.yml
vendored
35
.github/dependabot.yml
vendored
@@ -19,3 +19,38 @@ updates:
|
||||
uv-dependencies:
|
||||
patterns:
|
||||
- "*"
|
||||
- package-ecosystem: "npm"
|
||||
directory: "/web"
|
||||
schedule:
|
||||
interval: "weekly"
|
||||
open-pull-requests-limit: 2
|
||||
ignore:
|
||||
- dependency-name: "tailwind-merge"
|
||||
update-types: ["version-update:semver-major"]
|
||||
- dependency-name: "tailwindcss"
|
||||
update-types: ["version-update:semver-major"]
|
||||
- dependency-name: "react-syntax-highlighter"
|
||||
update-types: ["version-update:semver-major"]
|
||||
- dependency-name: "react-window"
|
||||
update-types: ["version-update:semver-major"]
|
||||
groups:
|
||||
lexical:
|
||||
patterns:
|
||||
- "lexical"
|
||||
- "@lexical/*"
|
||||
storybook:
|
||||
patterns:
|
||||
- "storybook"
|
||||
- "@storybook/*"
|
||||
eslint-group:
|
||||
patterns:
|
||||
- "*eslint*"
|
||||
npm-dependencies:
|
||||
patterns:
|
||||
- "*"
|
||||
exclude-patterns:
|
||||
- "lexical"
|
||||
- "@lexical/*"
|
||||
- "storybook"
|
||||
- "@storybook/*"
|
||||
- "*eslint*"
|
||||
|
||||
2
.github/workflows/anti-slop.yml
vendored
2
.github/workflows/anti-slop.yml
vendored
@@ -15,5 +15,3 @@ jobs:
|
||||
- uses: peakoss/anti-slop@v0
|
||||
with:
|
||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
||||
close-pr: false
|
||||
failure-add-pr-labels: "needs-revision"
|
||||
|
||||
@@ -39,7 +39,6 @@ from . import (
|
||||
feature,
|
||||
human_input_form,
|
||||
init_validate,
|
||||
notification,
|
||||
ping,
|
||||
setup,
|
||||
spec,
|
||||
@@ -185,7 +184,6 @@ __all__ = [
|
||||
"model_config",
|
||||
"model_providers",
|
||||
"models",
|
||||
"notification",
|
||||
"oauth",
|
||||
"oauth_server",
|
||||
"ops_trace",
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import csv
|
||||
import io
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
@@ -8,7 +6,7 @@ from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from werkzeug.exceptions import BadRequest, NotFound, Unauthorized
|
||||
from werkzeug.exceptions import NotFound, Unauthorized
|
||||
|
||||
from configs import dify_config
|
||||
from constants.languages import supported_language
|
||||
@@ -18,7 +16,6 @@ from core.db.session_factory import session_factory
|
||||
from extensions.ext_database import db
|
||||
from libs.token import extract_access_token
|
||||
from models.model import App, ExporleBanner, InstalledApp, RecommendedApp, TrialApp
|
||||
from services.billing_service import BillingService
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
@@ -280,168 +277,3 @@ class DeleteExploreBannerApi(Resource):
|
||||
db.session.commit()
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
class LangContentPayload(BaseModel):
|
||||
lang: str = Field(..., description="Language tag: 'zh' | 'en' | 'jp'")
|
||||
title: str = Field(...)
|
||||
subtitle: str | None = Field(default=None)
|
||||
body: str = Field(...)
|
||||
title_pic_url: str | None = Field(default=None)
|
||||
|
||||
|
||||
class UpsertNotificationPayload(BaseModel):
|
||||
notification_id: str | None = Field(default=None, description="Omit to create; supply UUID to update")
|
||||
contents: list[LangContentPayload] = Field(..., min_length=1)
|
||||
start_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-01T00:00:00Z")
|
||||
end_time: str | None = Field(default=None, description="RFC3339, e.g. 2026-03-20T23:59:59Z")
|
||||
frequency: str = Field(default="once", description="'once' | 'every_page_load'")
|
||||
status: str = Field(default="active", description="'active' | 'inactive'")
|
||||
|
||||
|
||||
class BatchAddNotificationAccountsPayload(BaseModel):
|
||||
notification_id: str = Field(...)
|
||||
user_email: list[str] = Field(..., description="List of account email addresses")
|
||||
|
||||
|
||||
console_ns.schema_model(
|
||||
UpsertNotificationPayload.__name__,
|
||||
UpsertNotificationPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
console_ns.schema_model(
|
||||
BatchAddNotificationAccountsPayload.__name__,
|
||||
BatchAddNotificationAccountsPayload.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0),
|
||||
)
|
||||
|
||||
|
||||
@console_ns.route("/admin/upsert_notification")
|
||||
class UpsertNotificationApi(Resource):
|
||||
@console_ns.doc("upsert_notification")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Create or update an in-product notification. "
|
||||
"Supply notification_id to update an existing one; omit it to create a new one. "
|
||||
"Pass at least one language variant in contents (zh / en / jp)."
|
||||
)
|
||||
)
|
||||
@console_ns.expect(console_ns.models[UpsertNotificationPayload.__name__])
|
||||
@console_ns.response(200, "Notification upserted successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
payload = UpsertNotificationPayload.model_validate(console_ns.payload)
|
||||
result = BillingService.upsert_notification(
|
||||
contents=[c.model_dump() for c in payload.contents],
|
||||
frequency=payload.frequency,
|
||||
status=payload.status,
|
||||
notification_id=payload.notification_id,
|
||||
start_time=payload.start_time,
|
||||
end_time=payload.end_time,
|
||||
)
|
||||
return {"result": "success", "notification_id": result.get("notificationId")}, 200
|
||||
|
||||
|
||||
@console_ns.route("/admin/batch_add_notification_accounts")
|
||||
class BatchAddNotificationAccountsApi(Resource):
|
||||
@console_ns.doc("batch_add_notification_accounts")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Register target accounts for a notification by email address. "
|
||||
'JSON body: {"notification_id": "...", "user_email": ["a@example.com", ...]}. '
|
||||
"File upload: multipart/form-data with a 'file' field (CSV or TXT, one email per line) "
|
||||
"plus a 'notification_id' field. "
|
||||
"Emails that do not match any account are silently skipped."
|
||||
)
|
||||
)
|
||||
@console_ns.response(200, "Accounts added successfully")
|
||||
@only_edition_cloud
|
||||
@admin_required
|
||||
def post(self):
|
||||
from models.account import Account
|
||||
|
||||
if "file" in request.files:
|
||||
notification_id = request.form.get("notification_id", "").strip()
|
||||
if not notification_id:
|
||||
raise BadRequest("notification_id is required.")
|
||||
emails = self._parse_emails_from_file()
|
||||
else:
|
||||
payload = BatchAddNotificationAccountsPayload.model_validate(console_ns.payload)
|
||||
notification_id = payload.notification_id
|
||||
emails = payload.user_email
|
||||
|
||||
if not emails:
|
||||
raise BadRequest("No valid email addresses provided.")
|
||||
|
||||
# Resolve emails → account IDs in chunks to avoid large IN-clause
|
||||
account_ids: list[str] = []
|
||||
chunk_size = 500
|
||||
for i in range(0, len(emails), chunk_size):
|
||||
chunk = emails[i : i + chunk_size]
|
||||
rows = db.session.execute(select(Account.id, Account.email).where(Account.email.in_(chunk))).all()
|
||||
account_ids.extend(str(row.id) for row in rows)
|
||||
|
||||
if not account_ids:
|
||||
raise BadRequest("None of the provided emails matched an existing account.")
|
||||
|
||||
# Send to dify-saas in batches of 1000
|
||||
total_count = 0
|
||||
batch_size = 1000
|
||||
for i in range(0, len(account_ids), batch_size):
|
||||
batch = account_ids[i : i + batch_size]
|
||||
result = BillingService.batch_add_notification_accounts(
|
||||
notification_id=notification_id,
|
||||
account_ids=batch,
|
||||
)
|
||||
total_count += result.get("count", 0)
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"emails_provided": len(emails),
|
||||
"accounts_matched": len(account_ids),
|
||||
"count": total_count,
|
||||
}, 200
|
||||
|
||||
@staticmethod
|
||||
def _parse_emails_from_file() -> list[str]:
|
||||
"""Parse email addresses from an uploaded CSV or TXT file."""
|
||||
file = request.files["file"]
|
||||
if not file.filename:
|
||||
raise BadRequest("Uploaded file has no filename.")
|
||||
|
||||
filename_lower = file.filename.lower()
|
||||
if not filename_lower.endswith((".csv", ".txt")):
|
||||
raise BadRequest("Invalid file type. Only CSV (.csv) and TXT (.txt) files are allowed.")
|
||||
|
||||
try:
|
||||
content = file.read().decode("utf-8")
|
||||
except UnicodeDecodeError:
|
||||
try:
|
||||
file.seek(0)
|
||||
content = file.read().decode("gbk")
|
||||
except UnicodeDecodeError:
|
||||
raise BadRequest("Unable to decode the file. Please use UTF-8 or GBK encoding.")
|
||||
|
||||
emails: list[str] = []
|
||||
if filename_lower.endswith(".csv"):
|
||||
reader = csv.reader(io.StringIO(content))
|
||||
for row in reader:
|
||||
for cell in row:
|
||||
cell = cell.strip()
|
||||
if cell:
|
||||
emails.append(cell)
|
||||
else:
|
||||
for line in content.splitlines():
|
||||
line = line.strip()
|
||||
if line:
|
||||
emails.append(line)
|
||||
|
||||
# Deduplicate while preserving order
|
||||
seen: set[str] = set()
|
||||
unique_emails: list[str] = []
|
||||
for email in emails:
|
||||
if email.lower() not in seen:
|
||||
seen.add(email.lower())
|
||||
unique_emails.append(email)
|
||||
|
||||
return unique_emails
|
||||
|
||||
@@ -1,90 +0,0 @@
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, only_edition_cloud, setup_required
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from services.billing_service import BillingService
|
||||
|
||||
# Notification content is stored under three lang tags.
|
||||
_FALLBACK_LANG = "en-US"
|
||||
|
||||
|
||||
def _pick_lang_content(contents: dict, lang: str) -> dict:
|
||||
"""Return the single LangContent for *lang*, falling back to English."""
|
||||
return contents.get(lang) or contents.get(_FALLBACK_LANG) or next(iter(contents.values()), {})
|
||||
|
||||
|
||||
class DismissNotificationPayload(BaseModel):
|
||||
notification_id: str = Field(...)
|
||||
|
||||
|
||||
@console_ns.route("/notification")
|
||||
class NotificationApi(Resource):
|
||||
@console_ns.doc("get_notification")
|
||||
@console_ns.doc(
|
||||
description=(
|
||||
"Return the active in-product notification for the current user "
|
||||
"in their interface language (falls back to English if unavailable). "
|
||||
"The notification is NOT marked as seen here; call POST /notification/dismiss "
|
||||
"when the user explicitly closes the modal."
|
||||
),
|
||||
responses={
|
||||
200: "Success — inspect should_show to decide whether to render the modal",
|
||||
401: "Unauthorized",
|
||||
},
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def get(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
result = BillingService.get_account_notification(str(current_user.id))
|
||||
|
||||
# Proto JSON uses camelCase field names (Kratos default marshaling).
|
||||
if not result.get("shouldShow"):
|
||||
return {"should_show": False, "notifications": []}, 200
|
||||
|
||||
lang = current_user.interface_language or _FALLBACK_LANG
|
||||
|
||||
notifications = []
|
||||
for notification in result.get("notifications") or []:
|
||||
contents: dict = notification.get("contents") or {}
|
||||
lang_content = _pick_lang_content(contents, lang)
|
||||
notifications.append(
|
||||
{
|
||||
"notification_id": notification.get("notificationId"),
|
||||
"frequency": notification.get("frequency"),
|
||||
"lang": lang_content.get("lang", lang),
|
||||
"title": lang_content.get("title", ""),
|
||||
"subtitle": lang_content.get("subtitle", ""),
|
||||
"body": lang_content.get("body", ""),
|
||||
"title_pic_url": lang_content.get("titlePicUrl", ""),
|
||||
}
|
||||
)
|
||||
|
||||
return {"should_show": bool(notifications), "notifications": notifications}, 200
|
||||
|
||||
|
||||
@console_ns.route("/notification/dismiss")
|
||||
class NotificationDismissApi(Resource):
|
||||
@console_ns.doc("dismiss_notification")
|
||||
@console_ns.doc(
|
||||
description="Mark a notification as dismissed for the current user.",
|
||||
responses={200: "Success", 401: "Unauthorized"},
|
||||
)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@only_edition_cloud
|
||||
def post(self):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = DismissNotificationPayload.model_validate(request.get_json())
|
||||
BillingService.dismiss_notification(
|
||||
notification_id=payload.notification_id,
|
||||
account_id=str(current_user.id),
|
||||
)
|
||||
return {"result": "success"}, 200
|
||||
@@ -138,25 +138,20 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
query = self.application_generate_entity.query
|
||||
|
||||
# moderation
|
||||
stop, new_inputs, new_query = self.handle_input_moderation(
|
||||
if self.handle_input_moderation(
|
||||
app_record=self._app,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
inputs=inputs,
|
||||
query=query,
|
||||
message_id=self.message.id,
|
||||
)
|
||||
if stop:
|
||||
):
|
||||
return
|
||||
|
||||
self.application_generate_entity.inputs = new_inputs
|
||||
self.application_generate_entity.query = new_query
|
||||
system_inputs.query = new_query
|
||||
|
||||
# annotation reply
|
||||
if self.handle_annotation_reply(
|
||||
app_record=self._app,
|
||||
message=self.message,
|
||||
query=new_query,
|
||||
query=query,
|
||||
app_generate_entity=self.application_generate_entity,
|
||||
):
|
||||
return
|
||||
@@ -168,7 +163,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
# init variable pool
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=new_inputs,
|
||||
user_inputs=inputs,
|
||||
environment_variables=self._workflow.environment_variables,
|
||||
# Based on the definition of `Variable`,
|
||||
# `VariableBase` instances can be safely used as `Variable` since they are compatible.
|
||||
@@ -245,10 +240,10 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
inputs: Mapping[str, Any],
|
||||
query: str,
|
||||
message_id: str,
|
||||
) -> tuple[bool, Mapping[str, Any], str]:
|
||||
) -> bool:
|
||||
try:
|
||||
# process sensitive_word_avoidance
|
||||
_, new_inputs, new_query = self.moderation_for_inputs(
|
||||
_, inputs, query = self.moderation_for_inputs(
|
||||
app_id=app_record.id,
|
||||
tenant_id=app_generate_entity.app_config.tenant_id,
|
||||
app_generate_entity=app_generate_entity,
|
||||
@@ -258,9 +253,9 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
except ModerationError as e:
|
||||
self._complete_with_stream_output(text=str(e), stopped_by=QueueStopEvent.StopBy.INPUT_MODERATION)
|
||||
return True, inputs, query
|
||||
return True
|
||||
|
||||
return False, new_inputs, new_query
|
||||
return False
|
||||
|
||||
def handle_annotation_reply(
|
||||
self, app_record: App, message: Message, query: str, app_generate_entity: AdvancedChatAppGenerateEntity
|
||||
|
||||
@@ -59,6 +59,8 @@ class DatasourcePluginProviderController(ABC):
|
||||
:param credentials: the credentials of the tool
|
||||
"""
|
||||
credentials_schema = dict[str, ProviderConfig]()
|
||||
if credentials_schema is None:
|
||||
return
|
||||
|
||||
for credential in self.entity.credentials_schema:
|
||||
credentials_schema[credential.name] = credential
|
||||
|
||||
@@ -157,7 +157,6 @@ class PluginInstallTaskPluginStatus(BaseModel):
|
||||
message: str = Field(description="The message of the install task.")
|
||||
icon: str = Field(description="The icon of the plugin.")
|
||||
labels: I18nObject = Field(description="The labels of the plugin.")
|
||||
source: str | None = Field(default=None, description="The installation source of the plugin")
|
||||
|
||||
|
||||
class PluginInstallTask(BasePluginEntity):
|
||||
|
||||
@@ -74,8 +74,7 @@ class ExtractProcessor:
|
||||
else:
|
||||
suffix = ""
|
||||
# https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521
|
||||
# Generate a temporary filename under the created temp_dir and ensure the directory exists
|
||||
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
|
||||
file_path = f"{temp_dir}/{tempfile.gettempdir()}{suffix}"
|
||||
Path(file_path).write_bytes(response.content)
|
||||
extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE, document_model="text_model")
|
||||
if return_text:
|
||||
|
||||
@@ -204,61 +204,26 @@ class WordExtractor(BaseExtractor):
|
||||
return " ".join(unique_content)
|
||||
|
||||
def _parse_cell_paragraph(self, paragraph, image_map):
|
||||
paragraph_content: list[str] = []
|
||||
|
||||
for child in paragraph._element:
|
||||
tag = child.tag
|
||||
if tag == qn("w:hyperlink"):
|
||||
# Note: w:hyperlink elements may also use w:anchor for internal bookmarks.
|
||||
# This extractor intentionally only converts external links (HTTP/mailto, etc.)
|
||||
# that are backed by a relationship id (r:id) with rel.is_external == True.
|
||||
# Hyperlinks without such an external rel (including anchor-only bookmarks)
|
||||
# are left as plain text link_text.
|
||||
r_id = child.get(qn("r:id"))
|
||||
link_text_parts: list[str] = []
|
||||
for run_elem in child.findall(qn("w:r")):
|
||||
run = Run(run_elem, paragraph)
|
||||
if run.text:
|
||||
link_text_parts.append(run.text)
|
||||
link_text = "".join(link_text_parts).strip()
|
||||
if r_id:
|
||||
try:
|
||||
rel = paragraph.part.rels.get(r_id)
|
||||
if rel:
|
||||
target_ref = getattr(rel, "target_ref", None)
|
||||
if target_ref:
|
||||
parsed_target = urlparse(str(target_ref))
|
||||
if rel.is_external or parsed_target.scheme in ("http", "https", "mailto"):
|
||||
display_text = link_text or str(target_ref)
|
||||
link_text = f"[{display_text}]({target_ref})"
|
||||
except Exception:
|
||||
logger.exception("Failed to resolve URL for hyperlink with r:id: %s", r_id)
|
||||
if link_text:
|
||||
paragraph_content.append(link_text)
|
||||
|
||||
elif tag == qn("w:r"):
|
||||
run = Run(child, paragraph)
|
||||
if run.element.xpath(".//a:blip"):
|
||||
for blip in run.element.xpath(".//a:blip"):
|
||||
image_id = blip.get(
|
||||
"{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed"
|
||||
)
|
||||
if not image_id:
|
||||
continue
|
||||
rel = paragraph.part.rels.get(image_id)
|
||||
if rel is None:
|
||||
continue
|
||||
if rel.is_external:
|
||||
if image_id in image_map:
|
||||
paragraph_content.append(image_map[image_id])
|
||||
else:
|
||||
image_part = rel.target_part
|
||||
if image_part in image_map:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
else:
|
||||
if run.text:
|
||||
paragraph_content.append(run.text)
|
||||
|
||||
paragraph_content = []
|
||||
for run in paragraph.runs:
|
||||
if run.element.xpath(".//a:blip"):
|
||||
for blip in run.element.xpath(".//a:blip"):
|
||||
image_id = blip.get("{http://schemas.openxmlformats.org/officeDocument/2006/relationships}embed")
|
||||
if not image_id:
|
||||
continue
|
||||
rel = paragraph.part.rels.get(image_id)
|
||||
if rel is None:
|
||||
continue
|
||||
# For external images, use image_id as key; for internal, use target_part
|
||||
if rel.is_external:
|
||||
if image_id in image_map:
|
||||
paragraph_content.append(image_map[image_id])
|
||||
else:
|
||||
image_part = rel.target_part
|
||||
if image_part in image_map:
|
||||
paragraph_content.append(image_map[image_part])
|
||||
else:
|
||||
paragraph_content.append(run.text)
|
||||
return "".join(paragraph_content).strip()
|
||||
|
||||
def parse_docx(self, docx_path):
|
||||
|
||||
@@ -393,78 +393,3 @@ class BillingService:
|
||||
for item in data:
|
||||
tenant_whitelist.append(item["tenant_id"])
|
||||
return tenant_whitelist
|
||||
|
||||
@classmethod
|
||||
def get_account_notification(cls, account_id: str) -> dict:
|
||||
"""Return the active in-product notification for account_id, if any.
|
||||
|
||||
Calling this endpoint also marks the notification as seen; subsequent
|
||||
calls will return should_show=false when frequency='once'.
|
||||
|
||||
Response shape (mirrors GetAccountNotificationReply):
|
||||
{
|
||||
"should_show": bool,
|
||||
"notification": { # present only when should_show=true
|
||||
"notification_id": str,
|
||||
"contents": { # lang -> LangContent
|
||||
"en": {"lang": "en", "title": ..., "subtitle": ..., "body": ..., "title_pic_url": ...},
|
||||
...
|
||||
},
|
||||
"frequency": "once" | "every_page_load"
|
||||
}
|
||||
}
|
||||
"""
|
||||
return cls._send_request("GET", "/notifications/active", params={"account_id": account_id})
|
||||
|
||||
@classmethod
|
||||
def upsert_notification(
|
||||
cls,
|
||||
contents: list[dict],
|
||||
frequency: str = "once",
|
||||
status: str = "active",
|
||||
notification_id: str | None = None,
|
||||
start_time: str | None = None,
|
||||
end_time: str | None = None,
|
||||
) -> dict:
|
||||
"""Create or update a notification.
|
||||
|
||||
contents: list of {"lang": str, "title": str, "subtitle": str, "body": str, "title_pic_url": str}
|
||||
start_time / end_time: RFC3339 strings (e.g. "2026-03-01T00:00:00Z"), optional.
|
||||
Returns {"notification_id": str}.
|
||||
"""
|
||||
payload: dict = {
|
||||
"contents": contents,
|
||||
"frequency": frequency,
|
||||
"status": status,
|
||||
}
|
||||
if notification_id:
|
||||
payload["notification_id"] = notification_id
|
||||
if start_time:
|
||||
payload["start_time"] = start_time
|
||||
if end_time:
|
||||
payload["end_time"] = end_time
|
||||
return cls._send_request("POST", "/notifications", json=payload)
|
||||
|
||||
@classmethod
|
||||
def batch_add_notification_accounts(cls, notification_id: str, account_ids: list[str]) -> dict:
|
||||
"""Register target account IDs for a notification (max 1000 per call).
|
||||
|
||||
Returns {"count": int}.
|
||||
"""
|
||||
return cls._send_request(
|
||||
"POST",
|
||||
f"/notifications/{notification_id}/accounts",
|
||||
json={"account_ids": account_ids},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def dismiss_notification(cls, notification_id: str, account_id: str) -> dict:
|
||||
"""Mark a notification as dismissed for an account.
|
||||
|
||||
Returns {"success": bool}.
|
||||
"""
|
||||
return cls._send_request(
|
||||
"POST",
|
||||
f"/notifications/{notification_id}/dismiss",
|
||||
json={"account_id": account_id},
|
||||
)
|
||||
|
||||
@@ -245,6 +245,5 @@ class EmailDeliveryTestHandler:
|
||||
)
|
||||
if token:
|
||||
substitutions["form_token"] = token
|
||||
link = _build_form_link(token)
|
||||
substitutions["form_link"] = link if link is not None else f"/form/{token}"
|
||||
substitutions["form_link"] = _build_form_link(token) or ""
|
||||
return substitutions
|
||||
|
||||
@@ -33,8 +33,6 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ToolTransformService:
|
||||
_MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH = 10
|
||||
|
||||
@classmethod
|
||||
def get_tool_provider_icon_url(
|
||||
cls, provider_type: str, provider_name: str, icon: str | Mapping[str, str]
|
||||
@@ -437,46 +435,6 @@ class ToolTransformService:
|
||||
:return: list of ToolParameter instances
|
||||
"""
|
||||
|
||||
def resolve_property_type(prop: dict[str, Any], depth: int = 0) -> str:
|
||||
"""
|
||||
Resolve a JSON schema property type while guarding against cyclic or deeply nested unions.
|
||||
"""
|
||||
if depth >= ToolTransformService._MCP_SCHEMA_TYPE_RESOLUTION_MAX_DEPTH:
|
||||
return "string"
|
||||
prop_type = prop.get("type")
|
||||
if isinstance(prop_type, list):
|
||||
non_null_types = [type_name for type_name in prop_type if type_name != "null"]
|
||||
if non_null_types:
|
||||
return non_null_types[0]
|
||||
if prop_type:
|
||||
return "string"
|
||||
elif isinstance(prop_type, str):
|
||||
if prop_type == "null":
|
||||
return "string"
|
||||
return prop_type
|
||||
|
||||
for union_key in ("anyOf", "oneOf"):
|
||||
union_schemas = prop.get(union_key)
|
||||
if not isinstance(union_schemas, list):
|
||||
continue
|
||||
|
||||
for union_schema in union_schemas:
|
||||
if not isinstance(union_schema, dict):
|
||||
continue
|
||||
union_type = resolve_property_type(union_schema, depth + 1)
|
||||
if union_type != "null":
|
||||
return union_type
|
||||
|
||||
all_of_schemas = prop.get("allOf")
|
||||
if isinstance(all_of_schemas, list):
|
||||
for all_of_schema in all_of_schemas:
|
||||
if not isinstance(all_of_schema, dict):
|
||||
continue
|
||||
all_of_type = resolve_property_type(all_of_schema, depth + 1)
|
||||
if all_of_type != "null":
|
||||
return all_of_type
|
||||
return "string"
|
||||
|
||||
def create_parameter(
|
||||
name: str, description: str, param_type: str, required: bool, input_schema: dict[str, Any] | None = None
|
||||
) -> ToolParameter:
|
||||
@@ -503,7 +461,10 @@ class ToolTransformService:
|
||||
parameters = []
|
||||
for name, prop in props.items():
|
||||
current_description = prop.get("description", "")
|
||||
prop_type = resolve_property_type(prop)
|
||||
prop_type = prop.get("type", "string")
|
||||
|
||||
if isinstance(prop_type, list):
|
||||
prop_type = prop_type[0]
|
||||
if prop_type in TYPE_MAPPING:
|
||||
prop_type = TYPE_MAPPING[prop_type]
|
||||
input_schema = prop if prop_type in COMPLEX_TYPES else None
|
||||
|
||||
@@ -125,11 +125,7 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(
|
||||
runner,
|
||||
"handle_input_moderation",
|
||||
return_value=(False, mock_app_generate_entity.inputs, mock_app_generate_entity.query),
|
||||
),
|
||||
patch.object(runner, "handle_input_moderation", return_value=False),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class,
|
||||
@@ -269,11 +265,7 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(
|
||||
runner,
|
||||
"handle_input_moderation",
|
||||
return_value=(False, mock_app_generate_entity.inputs, mock_app_generate_entity.query),
|
||||
),
|
||||
patch.object(runner, "handle_input_moderation", return_value=False),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class,
|
||||
@@ -420,11 +412,7 @@ class TestAdvancedChatAppRunnerConversationVariables:
|
||||
patch("core.app.apps.advanced_chat.app_runner.select") as mock_select,
|
||||
patch("core.app.apps.advanced_chat.app_runner.db") as mock_db,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(
|
||||
runner,
|
||||
"handle_input_moderation",
|
||||
return_value=(False, mock_app_generate_entity.inputs, mock_app_generate_entity.query),
|
||||
),
|
||||
patch.object(runner, "handle_input_moderation", return_value=False),
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False),
|
||||
patch("core.app.apps.advanced_chat.app_runner.WorkflowEntry") as mock_workflow_entry_class,
|
||||
patch("core.app.apps.advanced_chat.app_runner.GraphRuntimeState") as mock_graph_runtime_state_class,
|
||||
|
||||
@@ -1,170 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.apps.advanced_chat.app_runner import AdvancedChatAppRunner
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.queue_entities import QueueStopEvent
|
||||
from core.moderation.base import ModerationError
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def build_runner():
|
||||
"""Construct a minimal AdvancedChatAppRunner with heavy dependencies mocked."""
|
||||
app_id = str(uuid4())
|
||||
workflow_id = str(uuid4())
|
||||
|
||||
# Mocks for constructor args
|
||||
mock_queue_manager = MagicMock()
|
||||
|
||||
mock_conversation = MagicMock()
|
||||
mock_conversation.id = str(uuid4())
|
||||
mock_conversation.app_id = app_id
|
||||
|
||||
mock_message = MagicMock()
|
||||
mock_message.id = str(uuid4())
|
||||
|
||||
mock_workflow = MagicMock()
|
||||
mock_workflow.id = workflow_id
|
||||
mock_workflow.tenant_id = str(uuid4())
|
||||
mock_workflow.app_id = app_id
|
||||
mock_workflow.type = "chat"
|
||||
mock_workflow.graph_dict = {}
|
||||
mock_workflow.environment_variables = []
|
||||
|
||||
mock_app_config = MagicMock()
|
||||
mock_app_config.app_id = app_id
|
||||
mock_app_config.workflow_id = workflow_id
|
||||
mock_app_config.tenant_id = str(uuid4())
|
||||
|
||||
gen = MagicMock(spec=AdvancedChatAppGenerateEntity)
|
||||
gen.app_config = mock_app_config
|
||||
gen.inputs = {"q": "raw"}
|
||||
gen.query = "raw-query"
|
||||
gen.files = []
|
||||
gen.user_id = str(uuid4())
|
||||
gen.invoke_from = InvokeFrom.SERVICE_API
|
||||
gen.workflow_run_id = str(uuid4())
|
||||
gen.task_id = str(uuid4())
|
||||
gen.call_depth = 0
|
||||
gen.single_iteration_run = None
|
||||
gen.single_loop_run = None
|
||||
gen.trace_manager = None
|
||||
|
||||
runner = AdvancedChatAppRunner(
|
||||
application_generate_entity=gen,
|
||||
queue_manager=mock_queue_manager,
|
||||
conversation=mock_conversation,
|
||||
message=mock_message,
|
||||
dialogue_count=1,
|
||||
variable_loader=MagicMock(),
|
||||
workflow=mock_workflow,
|
||||
system_user_id=str(uuid4()),
|
||||
app=MagicMock(),
|
||||
workflow_execution_repository=MagicMock(),
|
||||
workflow_node_execution_repository=MagicMock(),
|
||||
)
|
||||
|
||||
return runner
|
||||
|
||||
|
||||
def _patch_common_run_deps(runner: AdvancedChatAppRunner):
|
||||
"""Context manager that patches common heavy deps used by run()."""
|
||||
return patch.multiple(
|
||||
"core.app.apps.advanced_chat.app_runner",
|
||||
Session=MagicMock(
|
||||
return_value=MagicMock(
|
||||
__enter__=lambda s: s,
|
||||
__exit__=lambda *a, **k: False,
|
||||
scalar=lambda *a, **k: MagicMock(),
|
||||
),
|
||||
),
|
||||
select=MagicMock(),
|
||||
db=MagicMock(engine=MagicMock()),
|
||||
RedisChannel=MagicMock(),
|
||||
redis_client=MagicMock(),
|
||||
WorkflowEntry=MagicMock(**{"return_value.run.return_value": iter([])}),
|
||||
GraphRuntimeState=MagicMock(),
|
||||
)
|
||||
|
||||
|
||||
def test_handle_input_moderation_stops_on_moderation_error(build_runner):
|
||||
runner = build_runner
|
||||
|
||||
# moderation_for_inputs raises ModerationError -> should stop and emit stop event
|
||||
with (
|
||||
patch.object(runner, "moderation_for_inputs", side_effect=ModerationError("blocked")),
|
||||
patch.object(runner, "_complete_with_stream_output") as mock_complete,
|
||||
):
|
||||
stop, new_inputs, new_query = runner.handle_input_moderation(
|
||||
app_record=MagicMock(),
|
||||
app_generate_entity=runner.application_generate_entity,
|
||||
inputs={"k": "v"},
|
||||
query="hello",
|
||||
message_id="mid",
|
||||
)
|
||||
|
||||
assert stop is True
|
||||
# inputs/query should be unchanged on error path
|
||||
assert new_inputs == {"k": "v"}
|
||||
assert new_query == "hello"
|
||||
# ensure stopped_by reason is INPUT_MODERATION
|
||||
assert mock_complete.called
|
||||
args, kwargs = mock_complete.call_args
|
||||
assert kwargs.get("stopped_by") == QueueStopEvent.StopBy.INPUT_MODERATION
|
||||
|
||||
|
||||
def test_run_applies_overridden_inputs_and_query_from_moderation(build_runner):
|
||||
runner = build_runner
|
||||
|
||||
overridden_inputs = {"q": "sanitized"}
|
||||
overridden_query = "sanitized-query"
|
||||
|
||||
with (
|
||||
_patch_common_run_deps(runner),
|
||||
patch.object(
|
||||
runner,
|
||||
"moderation_for_inputs",
|
||||
return_value=(True, overridden_inputs, overridden_query),
|
||||
) as mock_moderate,
|
||||
patch.object(runner, "handle_annotation_reply", return_value=False) as mock_anno,
|
||||
patch.object(runner, "_init_graph", return_value=MagicMock()) as mock_init_graph,
|
||||
):
|
||||
runner.run()
|
||||
|
||||
# moderation called with original values
|
||||
mock_moderate.assert_called_once()
|
||||
|
||||
# application_generate_entity should be updated to overridden values
|
||||
assert runner.application_generate_entity.inputs == overridden_inputs
|
||||
assert runner.application_generate_entity.query == overridden_query
|
||||
|
||||
# annotation reply should use the new query
|
||||
mock_anno.assert_called()
|
||||
assert mock_anno.call_args.kwargs.get("query") == overridden_query
|
||||
|
||||
# since not stopped, graph initialization should proceed
|
||||
assert mock_init_graph.called
|
||||
|
||||
|
||||
def test_run_returns_early_when_direct_output_via_handle_input_moderation(build_runner):
|
||||
runner = build_runner
|
||||
|
||||
with (
|
||||
_patch_common_run_deps(runner),
|
||||
# Simulate handle_input_moderation signalling to stop
|
||||
patch.object(
|
||||
runner,
|
||||
"handle_input_moderation",
|
||||
return_value=(True, runner.application_generate_entity.inputs, runner.application_generate_entity.query),
|
||||
) as mock_handle,
|
||||
patch.object(runner, "_init_graph") as mock_init_graph,
|
||||
patch.object(runner, "handle_annotation_reply") as mock_anno,
|
||||
):
|
||||
runner.run()
|
||||
|
||||
mock_handle.assert_called_once()
|
||||
# Ensure no further steps executed
|
||||
mock_anno.assert_not_called()
|
||||
mock_init_graph.assert_not_called()
|
||||
@@ -1,90 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from configs import dify_config
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceEntity, DatasourceProviderType
|
||||
|
||||
|
||||
class ConcreteDatasourcePlugin(DatasourcePlugin):
|
||||
"""
|
||||
Concrete implementation of DatasourcePlugin for testing purposes.
|
||||
Since DatasourcePlugin is an ABC, we need a concrete class to instantiate it.
|
||||
"""
|
||||
|
||||
def datasource_provider_type(self) -> str:
|
||||
return DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
|
||||
class TestDatasourcePlugin:
|
||||
def test_init(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
icon = "test-icon.png"
|
||||
|
||||
# Act
|
||||
plugin = ConcreteDatasourcePlugin(entity=entity, runtime=runtime, icon=icon)
|
||||
|
||||
# Assert
|
||||
assert plugin.entity == entity
|
||||
assert plugin.runtime == runtime
|
||||
assert plugin.icon == icon
|
||||
|
||||
def test_datasource_provider_type(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
icon = "test-icon.png"
|
||||
plugin = ConcreteDatasourcePlugin(entity=entity, runtime=runtime, icon=icon)
|
||||
|
||||
# Act
|
||||
provider_type = plugin.datasource_provider_type()
|
||||
# Call the base class method to ensure it's covered
|
||||
base_provider_type = DatasourcePlugin.datasource_provider_type(plugin)
|
||||
|
||||
# Assert
|
||||
assert provider_type == DatasourceProviderType.LOCAL_FILE
|
||||
assert base_provider_type == DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def test_fork_datasource_runtime(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceEntity)
|
||||
mock_entity_copy = MagicMock(spec=DatasourceEntity)
|
||||
mock_entity.model_copy.return_value = mock_entity_copy
|
||||
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
new_runtime = MagicMock(spec=DatasourceRuntime)
|
||||
icon = "test-icon.png"
|
||||
|
||||
plugin = ConcreteDatasourcePlugin(entity=mock_entity, runtime=runtime, icon=icon)
|
||||
|
||||
# Act
|
||||
new_plugin = plugin.fork_datasource_runtime(new_runtime)
|
||||
|
||||
# Assert
|
||||
assert isinstance(new_plugin, ConcreteDatasourcePlugin)
|
||||
assert new_plugin.entity == mock_entity_copy
|
||||
assert new_plugin.runtime == new_runtime
|
||||
assert new_plugin.icon == icon
|
||||
mock_entity.model_copy.assert_called_once()
|
||||
|
||||
def test_get_icon_url(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
icon = "test-icon.png"
|
||||
tenant_id = "test-tenant-id"
|
||||
|
||||
plugin = ConcreteDatasourcePlugin(entity=entity, runtime=runtime, icon=icon)
|
||||
|
||||
# Mocking dify_config.CONSOLE_API_URL
|
||||
with patch.object(dify_config, "CONSOLE_API_URL", "https://api.dify.ai"):
|
||||
# Act
|
||||
icon_url = plugin.get_icon_url(tenant_id)
|
||||
|
||||
# Assert
|
||||
expected_url = (
|
||||
f"https://api.dify.ai/console/api/workspaces/current/plugin/icon?tenant_id={tenant_id}&filename={icon}"
|
||||
)
|
||||
assert icon_url == expected_url
|
||||
@@ -1,265 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.datasource.__base.datasource_plugin import DatasourcePlugin
|
||||
from core.datasource.__base.datasource_provider import DatasourcePluginProviderController
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceProviderEntityWithPlugin,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
from core.entities.provider_entities import ProviderConfig
|
||||
from core.tools.errors import ToolProviderCredentialValidationError
|
||||
|
||||
|
||||
class ConcreteDatasourcePluginProviderController(DatasourcePluginProviderController):
|
||||
"""
|
||||
Concrete implementation of DatasourcePluginProviderController for testing purposes.
|
||||
"""
|
||||
|
||||
def get_datasource(self, datasource_name: str) -> DatasourcePlugin:
|
||||
return MagicMock(spec=DatasourcePlugin)
|
||||
|
||||
|
||||
class TestDatasourcePluginProviderController:
|
||||
def test_init(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
tenant_id = "test-tenant-id"
|
||||
|
||||
# Act
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id=tenant_id)
|
||||
|
||||
# Assert
|
||||
assert controller.entity == mock_entity
|
||||
assert controller.tenant_id == tenant_id
|
||||
|
||||
def test_need_credentials(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
tenant_id = "test-tenant-id"
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id=tenant_id)
|
||||
|
||||
# Case 1: credentials_schema is None
|
||||
mock_entity.credentials_schema = None
|
||||
assert controller.need_credentials is False
|
||||
|
||||
# Case 2: credentials_schema is empty
|
||||
mock_entity.credentials_schema = []
|
||||
assert controller.need_credentials is False
|
||||
|
||||
# Case 3: credentials_schema has items
|
||||
mock_entity.credentials_schema = [MagicMock()]
|
||||
assert controller.need_credentials is True
|
||||
|
||||
@patch("core.datasource.__base.datasource_provider.PluginToolManager")
|
||||
def test_validate_credentials(self, mock_manager_class):
|
||||
# Arrange
|
||||
mock_manager = mock_manager_class.return_value
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.identity = MagicMock()
|
||||
mock_entity.identity.name = "test-provider"
|
||||
tenant_id = "test-tenant-id"
|
||||
user_id = "test-user-id"
|
||||
credentials = {"api_key": "secret"}
|
||||
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id=tenant_id)
|
||||
|
||||
# Act: Successful validation
|
||||
mock_manager.validate_datasource_credentials.return_value = True
|
||||
controller._validate_credentials(user_id, credentials)
|
||||
|
||||
mock_manager.validate_datasource_credentials.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
provider="test-provider",
|
||||
credentials=credentials,
|
||||
)
|
||||
|
||||
# Act: Failed validation
|
||||
mock_manager.validate_datasource_credentials.return_value = False
|
||||
with pytest.raises(ToolProviderCredentialValidationError, match="Invalid credentials"):
|
||||
controller._validate_credentials(user_id, credentials)
|
||||
|
||||
def test_provider_type(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
|
||||
# Act & Assert
|
||||
assert controller.provider_type == DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def test_validate_credentials_format_empty_schema(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.credentials_schema = []
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
credentials = {}
|
||||
|
||||
# Act & Assert (Should not raise anything)
|
||||
controller.validate_credentials_format(credentials)
|
||||
|
||||
def test_validate_credentials_format_unknown_credential(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.identity = MagicMock()
|
||||
mock_entity.identity.name = "test-provider"
|
||||
mock_entity.credentials_schema = []
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
credentials = {"unknown": "value"}
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(
|
||||
ToolProviderCredentialValidationError, match="credential unknown not found in provider test-provider"
|
||||
):
|
||||
controller.validate_credentials_format(credentials)
|
||||
|
||||
def test_validate_credentials_format_required_missing(self):
|
||||
# Arrange
|
||||
mock_config = MagicMock(spec=ProviderConfig)
|
||||
mock_config.name = "api_key"
|
||||
mock_config.required = True
|
||||
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.credentials_schema = [mock_config]
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ToolProviderCredentialValidationError, match="credential api_key is required"):
|
||||
controller.validate_credentials_format({})
|
||||
|
||||
def test_validate_credentials_format_not_required_null(self):
|
||||
# Arrange
|
||||
mock_config = MagicMock(spec=ProviderConfig)
|
||||
mock_config.name = "optional"
|
||||
mock_config.required = False
|
||||
mock_config.default = None
|
||||
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.credentials_schema = [mock_config]
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
|
||||
# Act & Assert
|
||||
credentials = {"optional": None}
|
||||
controller.validate_credentials_format(credentials)
|
||||
assert credentials["optional"] is None
|
||||
|
||||
def test_validate_credentials_format_type_mismatch_text(self):
|
||||
# Arrange
|
||||
mock_config = MagicMock(spec=ProviderConfig)
|
||||
mock_config.name = "text_field"
|
||||
mock_config.required = True
|
||||
mock_config.type = ProviderConfig.Type.TEXT_INPUT
|
||||
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.credentials_schema = [mock_config]
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ToolProviderCredentialValidationError, match="credential text_field should be string"):
|
||||
controller.validate_credentials_format({"text_field": 123})
|
||||
|
||||
def test_validate_credentials_format_select_validation(self):
|
||||
# Arrange
|
||||
mock_option = MagicMock()
|
||||
mock_option.value = "opt1"
|
||||
|
||||
mock_config = MagicMock(spec=ProviderConfig)
|
||||
mock_config.name = "select_field"
|
||||
mock_config.required = True
|
||||
mock_config.type = ProviderConfig.Type.SELECT
|
||||
mock_config.options = [mock_option]
|
||||
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.credentials_schema = [mock_config]
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
|
||||
# Case 1: Value not string
|
||||
with pytest.raises(ToolProviderCredentialValidationError, match="credential select_field should be string"):
|
||||
controller.validate_credentials_format({"select_field": 123})
|
||||
|
||||
# Case 2: Options not list
|
||||
mock_config.options = "invalid"
|
||||
with pytest.raises(
|
||||
ToolProviderCredentialValidationError, match="credential select_field options should be list"
|
||||
):
|
||||
controller.validate_credentials_format({"select_field": "opt1"})
|
||||
|
||||
# Case 3: Value not in options
|
||||
mock_config.options = [mock_option]
|
||||
with pytest.raises(ToolProviderCredentialValidationError, match="credential select_field should be one of"):
|
||||
controller.validate_credentials_format({"select_field": "invalid_opt"})
|
||||
|
||||
def test_get_datasource_base(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
|
||||
# Act
|
||||
result = DatasourcePluginProviderController.get_datasource(controller, "test")
|
||||
|
||||
# Assert
|
||||
assert result is None
|
||||
|
||||
def test_validate_credentials_format_hits_pop(self):
|
||||
# Arrange
|
||||
mock_config = MagicMock(spec=ProviderConfig)
|
||||
mock_config.name = "valid_field"
|
||||
mock_config.required = True
|
||||
mock_config.type = ProviderConfig.Type.TEXT_INPUT
|
||||
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.credentials_schema = [mock_config]
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
|
||||
# Act
|
||||
credentials = {"valid_field": "valid_value"}
|
||||
controller.validate_credentials_format(credentials)
|
||||
|
||||
# Assert
|
||||
assert "valid_field" in credentials
|
||||
assert credentials["valid_field"] == "valid_value"
|
||||
|
||||
def test_validate_credentials_format_hits_continue(self):
|
||||
# Arrange
|
||||
mock_config = MagicMock(spec=ProviderConfig)
|
||||
mock_config.name = "optional_field"
|
||||
mock_config.required = False
|
||||
mock_config.default = None
|
||||
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.credentials_schema = [mock_config]
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
|
||||
# Act
|
||||
credentials = {"optional_field": None}
|
||||
controller.validate_credentials_format(credentials)
|
||||
|
||||
# Assert
|
||||
assert credentials["optional_field"] is None
|
||||
|
||||
def test_validate_credentials_format_default_values(self):
|
||||
# Arrange
|
||||
mock_config_text = MagicMock(spec=ProviderConfig)
|
||||
mock_config_text.name = "text_def"
|
||||
mock_config_text.required = False
|
||||
mock_config_text.type = ProviderConfig.Type.TEXT_INPUT
|
||||
mock_config_text.default = 123 # Int default, should be converted to str
|
||||
|
||||
mock_config_other = MagicMock(spec=ProviderConfig)
|
||||
mock_config_other.name = "other_def"
|
||||
mock_config_other.required = False
|
||||
mock_config_other.type = "OTHER"
|
||||
mock_config_other.default = "fallback"
|
||||
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.credentials_schema = [mock_config_text, mock_config_other]
|
||||
controller = ConcreteDatasourcePluginProviderController(entity=mock_entity, tenant_id="test")
|
||||
|
||||
# Act
|
||||
credentials = {}
|
||||
controller.validate_credentials_format(credentials)
|
||||
|
||||
# Assert
|
||||
assert credentials["text_def"] == "123"
|
||||
assert credentials["other_def"] == "fallback"
|
||||
@@ -1,26 +0,0 @@
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime, FakeDatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import DatasourceInvokeFrom
|
||||
|
||||
|
||||
class TestDatasourceRuntime:
|
||||
def test_init(self):
|
||||
runtime = DatasourceRuntime(
|
||||
tenant_id="test-tenant",
|
||||
datasource_id="test-ds",
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
datasource_invoke_from=DatasourceInvokeFrom.RAG_PIPELINE,
|
||||
credentials={"key": "val"},
|
||||
runtime_parameters={"p": "v"},
|
||||
)
|
||||
assert runtime.tenant_id == "test-tenant"
|
||||
assert runtime.datasource_id == "test-ds"
|
||||
assert runtime.credentials["key"] == "val"
|
||||
|
||||
def test_fake_datasource_runtime(self):
|
||||
# This covers the FakeDatasourceRuntime class and its __init__
|
||||
runtime = FakeDatasourceRuntime()
|
||||
assert runtime.tenant_id == "fake_tenant_id"
|
||||
assert runtime.datasource_id == "fake_datasource_id"
|
||||
assert runtime.invoke_from == InvokeFrom.DEBUGGER
|
||||
assert runtime.datasource_invoke_from == DatasourceInvokeFrom.RAG_PIPELINE
|
||||
@@ -1,150 +0,0 @@
|
||||
from core.datasource.entities.api_entities import DatasourceApiEntity, DatasourceProviderApiEntity
|
||||
from core.datasource.entities.datasource_entities import DatasourceParameter
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
def test_datasource_api_entity():
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
|
||||
entity = DatasourceApiEntity(
|
||||
author="author", name="name", label=label, description=description, labels=["l1", "l2"]
|
||||
)
|
||||
|
||||
assert entity.author == "author"
|
||||
assert entity.name == "name"
|
||||
assert entity.label == label
|
||||
assert entity.description == description
|
||||
assert entity.labels == ["l1", "l2"]
|
||||
assert entity.parameters is None
|
||||
assert entity.output_schema is None
|
||||
|
||||
|
||||
def test_datasource_provider_api_entity_defaults():
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
|
||||
entity = DatasourceProviderApiEntity(
|
||||
id="id", author="author", name="name", description=description, icon="icon", label=label, type="type"
|
||||
)
|
||||
|
||||
assert entity.id == "id"
|
||||
assert entity.datasources == []
|
||||
assert entity.is_team_authorization is False
|
||||
assert entity.allow_delete is True
|
||||
assert entity.plugin_id == ""
|
||||
assert entity.plugin_unique_identifier == ""
|
||||
assert entity.labels == []
|
||||
|
||||
|
||||
def test_datasource_provider_api_entity_convert_none_to_empty_list():
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
|
||||
# Implicitly testing the field_validator "convert_none_to_empty_list"
|
||||
entity = DatasourceProviderApiEntity(
|
||||
id="id",
|
||||
author="author",
|
||||
name="name",
|
||||
description=description,
|
||||
icon="icon",
|
||||
label=label,
|
||||
type="type",
|
||||
datasources=None, # type: ignore
|
||||
)
|
||||
|
||||
assert entity.datasources == []
|
||||
|
||||
|
||||
def test_datasource_provider_api_entity_to_dict():
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
|
||||
# Create a parameter that should be converted
|
||||
param = DatasourceParameter.get_simple_instance(
|
||||
name="test_param", typ=DatasourceParameter.DatasourceParameterType.SYSTEM_FILES, required=True
|
||||
)
|
||||
|
||||
ds_entity = DatasourceApiEntity(
|
||||
author="author", name="ds_name", label=label, description=description, parameters=[param]
|
||||
)
|
||||
|
||||
provider_entity = DatasourceProviderApiEntity(
|
||||
id="id",
|
||||
author="author",
|
||||
name="name",
|
||||
description=description,
|
||||
icon="icon",
|
||||
label=label,
|
||||
type="type",
|
||||
masked_credentials={"key": "masked"},
|
||||
datasources=[ds_entity],
|
||||
labels=["l1"],
|
||||
)
|
||||
|
||||
result = provider_entity.to_dict()
|
||||
|
||||
assert result["id"] == "id"
|
||||
assert result["author"] == "author"
|
||||
assert result["name"] == "name"
|
||||
assert result["description"] == description.to_dict()
|
||||
assert result["icon"] == "icon"
|
||||
assert result["label"] == label.to_dict()
|
||||
assert result["type"] == "type"
|
||||
assert result["team_credentials"] == {"key": "masked"}
|
||||
assert result["is_team_authorization"] is False
|
||||
assert result["allow_delete"] is True
|
||||
assert result["labels"] == ["l1"]
|
||||
|
||||
# Check if parameter type was converted from SYSTEM_FILES to files
|
||||
assert result["datasources"][0]["parameters"][0]["type"] == "files"
|
||||
|
||||
|
||||
def test_datasource_provider_api_entity_to_dict_no_params():
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
|
||||
ds_entity = DatasourceApiEntity(
|
||||
author="author", name="ds_name", label=label, description=description, parameters=None
|
||||
)
|
||||
|
||||
provider_entity = DatasourceProviderApiEntity(
|
||||
id="id",
|
||||
author="author",
|
||||
name="name",
|
||||
description=description,
|
||||
icon="icon",
|
||||
label=label,
|
||||
type="type",
|
||||
datasources=[ds_entity],
|
||||
)
|
||||
|
||||
result = provider_entity.to_dict()
|
||||
assert result["datasources"][0]["parameters"] is None
|
||||
|
||||
|
||||
def test_datasource_provider_api_entity_to_dict_other_param_type():
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
|
||||
param = DatasourceParameter.get_simple_instance(
|
||||
name="test_param", typ=DatasourceParameter.DatasourceParameterType.STRING, required=True
|
||||
)
|
||||
|
||||
ds_entity = DatasourceApiEntity(
|
||||
author="author", name="ds_name", label=label, description=description, parameters=[param]
|
||||
)
|
||||
|
||||
provider_entity = DatasourceProviderApiEntity(
|
||||
id="id",
|
||||
author="author",
|
||||
name="name",
|
||||
description=description,
|
||||
icon="icon",
|
||||
label=label,
|
||||
type="type",
|
||||
datasources=[ds_entity],
|
||||
)
|
||||
|
||||
result = provider_entity.to_dict()
|
||||
assert result["datasources"][0]["parameters"][0]["type"] == "string"
|
||||
@@ -1,31 +0,0 @@
|
||||
from core.datasource.entities.common_entities import I18nObject
|
||||
|
||||
|
||||
def test_i18n_object_fallback():
|
||||
# Only en_US provided
|
||||
obj = I18nObject(en_US="Hello")
|
||||
assert obj.en_US == "Hello"
|
||||
assert obj.zh_Hans == "Hello"
|
||||
assert obj.pt_BR == "Hello"
|
||||
assert obj.ja_JP == "Hello"
|
||||
|
||||
# Some fields provided
|
||||
obj = I18nObject(en_US="Hello", zh_Hans="你好")
|
||||
assert obj.en_US == "Hello"
|
||||
assert obj.zh_Hans == "你好"
|
||||
assert obj.pt_BR == "Hello"
|
||||
assert obj.ja_JP == "Hello"
|
||||
|
||||
|
||||
def test_i18n_object_all_fields():
|
||||
obj = I18nObject(en_US="Hello", zh_Hans="你好", pt_BR="Olá", ja_JP="こんにちは")
|
||||
assert obj.en_US == "Hello"
|
||||
assert obj.zh_Hans == "你好"
|
||||
assert obj.pt_BR == "Olá"
|
||||
assert obj.ja_JP == "こんにちは"
|
||||
|
||||
|
||||
def test_i18n_object_to_dict():
|
||||
obj = I18nObject(en_US="Hello", zh_Hans="你好", pt_BR="Olá", ja_JP="こんにちは")
|
||||
expected_dict = {"en_US": "Hello", "zh_Hans": "你好", "pt_BR": "Olá", "ja_JP": "こんにちは"}
|
||||
assert obj.to_dict() == expected_dict
|
||||
@@ -1,275 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceIdentity,
|
||||
DatasourceInvokeMeta,
|
||||
DatasourceLabel,
|
||||
DatasourceMessage,
|
||||
DatasourceParameter,
|
||||
DatasourceProviderEntity,
|
||||
DatasourceProviderEntityWithPlugin,
|
||||
DatasourceProviderIdentity,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
GetOnlineDocumentPageContentResponse,
|
||||
GetWebsiteCrawlRequest,
|
||||
OnlineDocumentInfo,
|
||||
OnlineDocumentPage,
|
||||
OnlineDocumentPageContent,
|
||||
OnlineDocumentPagesMessage,
|
||||
OnlineDriveBrowseFilesRequest,
|
||||
OnlineDriveBrowseFilesResponse,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
OnlineDriveFile,
|
||||
OnlineDriveFileBucket,
|
||||
WebsiteCrawlMessage,
|
||||
WebSiteInfo,
|
||||
WebSiteInfoDetail,
|
||||
)
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolLabelEnum
|
||||
|
||||
|
||||
def test_datasource_provider_type():
|
||||
assert DatasourceProviderType.value_of("online_document") == DatasourceProviderType.ONLINE_DOCUMENT
|
||||
assert DatasourceProviderType.value_of("local_file") == DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
with pytest.raises(ValueError, match="invalid mode value invalid"):
|
||||
DatasourceProviderType.value_of("invalid")
|
||||
|
||||
|
||||
def test_datasource_parameter_type():
|
||||
param_type = DatasourceParameter.DatasourceParameterType.STRING
|
||||
assert param_type.as_normal_type() == "string"
|
||||
assert param_type.cast_value("test") == "test"
|
||||
|
||||
param_type = DatasourceParameter.DatasourceParameterType.NUMBER
|
||||
assert param_type.cast_value("123") == 123
|
||||
|
||||
|
||||
def test_datasource_parameter():
|
||||
param = DatasourceParameter.get_simple_instance(
|
||||
name="test_param",
|
||||
typ=DatasourceParameter.DatasourceParameterType.STRING,
|
||||
required=True,
|
||||
options=["opt1", "opt2"],
|
||||
)
|
||||
assert param.name == "test_param"
|
||||
assert param.type == DatasourceParameter.DatasourceParameterType.STRING
|
||||
assert param.required is True
|
||||
assert len(param.options) == 2
|
||||
assert param.options[0].value == "opt1"
|
||||
|
||||
param_no_options = DatasourceParameter.get_simple_instance(
|
||||
name="test_param_2", typ=DatasourceParameter.DatasourceParameterType.NUMBER, required=False
|
||||
)
|
||||
assert param_no_options.options == []
|
||||
|
||||
# Test init_frontend_parameter
|
||||
# For STRING, it should just return the value as is (or cast to str)
|
||||
frontend_param = param.init_frontend_parameter("val")
|
||||
assert frontend_param == "val"
|
||||
|
||||
# Test parameter type methods
|
||||
assert DatasourceParameter.DatasourceParameterType.STRING.as_normal_type() == "string"
|
||||
assert DatasourceParameter.DatasourceParameterType.NUMBER.as_normal_type() == "number"
|
||||
assert DatasourceParameter.DatasourceParameterType.SECRET_INPUT.as_normal_type() == "string"
|
||||
|
||||
assert DatasourceParameter.DatasourceParameterType.NUMBER.cast_value("10.5") == 10.5
|
||||
assert DatasourceParameter.DatasourceParameterType.BOOLEAN.cast_value("true") is True
|
||||
assert DatasourceParameter.DatasourceParameterType.FILES.cast_value(["f1", "f2"]) == ["f1", "f2"]
|
||||
|
||||
|
||||
def test_datasource_identity():
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
identity = DatasourceIdentity(author="author", name="name", label=label, provider="provider", icon="icon")
|
||||
assert identity.author == "author"
|
||||
assert identity.name == "name"
|
||||
assert identity.label == label
|
||||
assert identity.provider == "provider"
|
||||
assert identity.icon == "icon"
|
||||
|
||||
|
||||
def test_datasource_entity():
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
identity = DatasourceIdentity(author="author", name="name", label=label, provider="provider")
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
|
||||
entity = DatasourceEntity(
|
||||
identity=identity,
|
||||
description=description,
|
||||
parameters=None, # Should be handled by validator
|
||||
)
|
||||
assert entity.parameters == []
|
||||
|
||||
param = DatasourceParameter.get_simple_instance("p1", DatasourceParameter.DatasourceParameterType.STRING, True)
|
||||
entity_with_params = DatasourceEntity(identity=identity, description=description, parameters=[param])
|
||||
assert entity_with_params.parameters == [param]
|
||||
|
||||
|
||||
def test_datasource_provider_identity():
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
identity = DatasourceProviderIdentity(
|
||||
author="author", name="name", description=description, icon="icon.png", label=label, tags=[ToolLabelEnum.SEARCH]
|
||||
)
|
||||
|
||||
assert identity.author == "author"
|
||||
assert identity.name == "name"
|
||||
assert identity.description == description
|
||||
assert identity.icon == "icon.png"
|
||||
assert identity.label == label
|
||||
assert identity.tags == [ToolLabelEnum.SEARCH]
|
||||
|
||||
# Test generate_datasource_icon_url
|
||||
with patch("core.datasource.entities.datasource_entities.dify_config") as mock_config:
|
||||
mock_config.CONSOLE_API_URL = "http://api.example.com"
|
||||
url = identity.generate_datasource_icon_url("tenant123")
|
||||
assert "http://api.example.com/console/api/workspaces/current/plugin/icon" in url
|
||||
assert "tenant_id=tenant123" in url
|
||||
assert "filename=icon.png" in url
|
||||
|
||||
# Test hardcoded icon
|
||||
identity.icon = "https://assets.dify.ai/images/File%20Upload.svg"
|
||||
assert identity.generate_datasource_icon_url("tenant123") == identity.icon
|
||||
|
||||
# Test with empty CONSOLE_API_URL
|
||||
identity.icon = "test.png"
|
||||
with patch("core.datasource.entities.datasource_entities.dify_config") as mock_config:
|
||||
mock_config.CONSOLE_API_URL = None
|
||||
url = identity.generate_datasource_icon_url("tenant123")
|
||||
assert url.startswith("/console/api/workspaces/current/plugin/icon")
|
||||
|
||||
|
||||
def test_datasource_provider_entity():
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
identity = DatasourceProviderIdentity(
|
||||
author="author", name="name", description=description, icon="icon", label=label
|
||||
)
|
||||
|
||||
entity = DatasourceProviderEntity(
|
||||
identity=identity,
|
||||
provider_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
credentials_schema=[],
|
||||
oauth_schema=None,
|
||||
)
|
||||
assert entity.identity == identity
|
||||
assert entity.provider_type == DatasourceProviderType.ONLINE_DOCUMENT
|
||||
assert entity.credentials_schema == []
|
||||
|
||||
|
||||
def test_datasource_provider_entity_with_plugin():
|
||||
label = I18nObject(en_US="label", zh_Hans="标签")
|
||||
description = I18nObject(en_US="desc", zh_Hans="描述")
|
||||
identity = DatasourceProviderIdentity(
|
||||
author="author", name="name", description=description, icon="icon", label=label
|
||||
)
|
||||
|
||||
entity = DatasourceProviderEntityWithPlugin(
|
||||
identity=identity, provider_type=DatasourceProviderType.ONLINE_DOCUMENT, datasources=[]
|
||||
)
|
||||
assert entity.datasources == []
|
||||
|
||||
|
||||
def test_datasource_invoke_meta():
|
||||
meta = DatasourceInvokeMeta(time_cost=1.5, error="some error", tool_config={"k": "v"})
|
||||
assert meta.time_cost == 1.5
|
||||
assert meta.error == "some error"
|
||||
assert meta.tool_config == {"k": "v"}
|
||||
|
||||
d = meta.to_dict()
|
||||
assert d == {"time_cost": 1.5, "error": "some error", "tool_config": {"k": "v"}}
|
||||
|
||||
empty_meta = DatasourceInvokeMeta.empty()
|
||||
assert empty_meta.time_cost == 0.0
|
||||
assert empty_meta.error is None
|
||||
assert empty_meta.tool_config == {}
|
||||
|
||||
error_meta = DatasourceInvokeMeta.error_instance("fatal error")
|
||||
assert error_meta.time_cost == 0.0
|
||||
assert error_meta.error == "fatal error"
|
||||
assert error_meta.tool_config == {}
|
||||
|
||||
|
||||
def test_datasource_label():
|
||||
label_obj = I18nObject(en_US="label", zh_Hans="标签")
|
||||
ds_label = DatasourceLabel(name="name", label=label_obj, icon="icon")
|
||||
assert ds_label.name == "name"
|
||||
assert ds_label.label == label_obj
|
||||
assert ds_label.icon == "icon"
|
||||
|
||||
|
||||
def test_online_document_models():
|
||||
page = OnlineDocumentPage(
|
||||
page_id="p1",
|
||||
page_name="name",
|
||||
page_icon={"type": "emoji"},
|
||||
type="page",
|
||||
last_edited_time="2023-01-01",
|
||||
parent_id=None,
|
||||
)
|
||||
assert page.page_id == "p1"
|
||||
|
||||
info = OnlineDocumentInfo(workspace_id="w1", workspace_name="name", workspace_icon="icon", total=1, pages=[page])
|
||||
assert info.total == 1
|
||||
|
||||
msg = OnlineDocumentPagesMessage(result=[info])
|
||||
assert msg.result == [info]
|
||||
|
||||
req = GetOnlineDocumentPageContentRequest(workspace_id="w1", page_id="p1", type="page")
|
||||
assert req.workspace_id == "w1"
|
||||
|
||||
content = OnlineDocumentPageContent(workspace_id="w1", page_id="p1", content="hello")
|
||||
assert content.content == "hello"
|
||||
|
||||
resp = GetOnlineDocumentPageContentResponse(result=content)
|
||||
assert resp.result == content
|
||||
|
||||
|
||||
def test_website_crawl_models():
|
||||
req = GetWebsiteCrawlRequest(crawl_parameters={"url": "http://test.com"})
|
||||
assert req.crawl_parameters == {"url": "http://test.com"}
|
||||
|
||||
detail = WebSiteInfoDetail(source_url="http://test.com", content="content", title="title", description="desc")
|
||||
assert detail.title == "title"
|
||||
|
||||
info = WebSiteInfo(status="completed", web_info_list=[detail], total=1, completed=1)
|
||||
assert info.status == "completed"
|
||||
|
||||
msg = WebsiteCrawlMessage(result=info)
|
||||
assert msg.result == info
|
||||
|
||||
# Test default values
|
||||
msg_default = WebsiteCrawlMessage()
|
||||
assert msg_default.result.status == ""
|
||||
assert msg_default.result.web_info_list == []
|
||||
|
||||
|
||||
def test_online_drive_models():
|
||||
file = OnlineDriveFile(id="f1", name="file.txt", size=100, type="file")
|
||||
assert file.name == "file.txt"
|
||||
|
||||
bucket = OnlineDriveFileBucket(bucket="b1", files=[file], is_truncated=False, next_page_parameters=None)
|
||||
assert bucket.bucket == "b1"
|
||||
|
||||
req = OnlineDriveBrowseFilesRequest(bucket="b1", prefix="folder1", max_keys=10, next_page_parameters=None)
|
||||
assert req.prefix == "folder1"
|
||||
|
||||
resp = OnlineDriveBrowseFilesResponse(result=[bucket])
|
||||
assert resp.result == [bucket]
|
||||
|
||||
dl_req = OnlineDriveDownloadFileRequest(id="f1", bucket="b1")
|
||||
assert dl_req.id == "f1"
|
||||
|
||||
|
||||
def test_datasource_message():
|
||||
# Use proper dict for message to avoid Pydantic Union validation ambiguity/crashes
|
||||
msg = DatasourceMessage(type="text", message={"text": "hello"})
|
||||
assert msg.message.text == "hello"
|
||||
|
||||
msg_json = DatasourceMessage(type="json", message={"json_object": {"k": "v"}})
|
||||
assert msg_json.message.json_object == {"k": "v"}
|
||||
@@ -1,57 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin
|
||||
|
||||
|
||||
class TestLocalFileDatasourcePlugin:
|
||||
def test_init(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceEntity)
|
||||
mock_runtime = MagicMock(spec=DatasourceRuntime)
|
||||
tenant_id = "test-tenant-id"
|
||||
icon = "test-icon"
|
||||
plugin_unique_identifier = "test-plugin-id"
|
||||
|
||||
# Act
|
||||
plugin = LocalFileDatasourcePlugin(
|
||||
entity=mock_entity,
|
||||
runtime=mock_runtime,
|
||||
tenant_id=tenant_id,
|
||||
icon=icon,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert plugin.tenant_id == tenant_id
|
||||
assert plugin.plugin_unique_identifier == plugin_unique_identifier
|
||||
assert plugin.entity == mock_entity
|
||||
assert plugin.runtime == mock_runtime
|
||||
assert plugin.icon == icon
|
||||
|
||||
def test_datasource_provider_type(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceEntity)
|
||||
mock_runtime = MagicMock(spec=DatasourceRuntime)
|
||||
plugin = LocalFileDatasourcePlugin(
|
||||
entity=mock_entity, runtime=mock_runtime, tenant_id="test", icon="test", plugin_unique_identifier="test"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
assert plugin.datasource_provider_type() == DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def test_get_icon_url(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceEntity)
|
||||
mock_runtime = MagicMock(spec=DatasourceRuntime)
|
||||
icon = "test-icon"
|
||||
plugin = LocalFileDatasourcePlugin(
|
||||
entity=mock_entity, runtime=mock_runtime, tenant_id="test", icon=icon, plugin_unique_identifier="test"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
assert plugin.get_icon_url("any-tenant-id") == icon
|
||||
@@ -1,96 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceProviderEntityWithPlugin,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
from core.datasource.local_file.local_file_plugin import LocalFileDatasourcePlugin
|
||||
from core.datasource.local_file.local_file_provider import LocalFileDatasourcePluginProviderController
|
||||
|
||||
|
||||
class TestLocalFileDatasourcePluginProviderController:
|
||||
def test_init(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
plugin_id = "test_plugin_id"
|
||||
plugin_unique_identifier = "test_plugin_unique_identifier"
|
||||
tenant_id = "test_tenant_id"
|
||||
|
||||
# Act
|
||||
controller = LocalFileDatasourcePluginProviderController(
|
||||
entity=mock_entity,
|
||||
plugin_id=plugin_id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert controller.entity == mock_entity
|
||||
assert controller.plugin_id == plugin_id
|
||||
assert controller.plugin_unique_identifier == plugin_unique_identifier
|
||||
assert controller.tenant_id == tenant_id
|
||||
|
||||
def test_provider_type(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
controller = LocalFileDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
assert controller.provider_type == DatasourceProviderType.LOCAL_FILE
|
||||
|
||||
def test_validate_credentials(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
controller = LocalFileDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
# Should not raise any exception
|
||||
controller._validate_credentials("user_id", {"key": "value"})
|
||||
|
||||
def test_get_datasource_success(self):
|
||||
# Arrange
|
||||
mock_datasource_entity = MagicMock()
|
||||
mock_datasource_entity.identity.name = "test_datasource"
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.datasources = [mock_datasource_entity]
|
||||
mock_entity.identity.icon = "test_icon"
|
||||
|
||||
plugin_unique_identifier = "test_plugin_unique_identifier"
|
||||
tenant_id = "test_tenant_id"
|
||||
|
||||
controller = LocalFileDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="id", plugin_unique_identifier=plugin_unique_identifier, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Act
|
||||
datasource = controller.get_datasource("test_datasource")
|
||||
|
||||
# Assert
|
||||
assert isinstance(datasource, LocalFileDatasourcePlugin)
|
||||
assert datasource.entity == mock_datasource_entity
|
||||
assert datasource.tenant_id == tenant_id
|
||||
assert datasource.icon == "test_icon"
|
||||
assert datasource.plugin_unique_identifier == plugin_unique_identifier
|
||||
|
||||
def test_get_datasource_not_found(self):
|
||||
# Arrange
|
||||
mock_datasource_entity = MagicMock()
|
||||
mock_datasource_entity.identity.name = "other_datasource"
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.datasources = [mock_datasource_entity]
|
||||
|
||||
controller = LocalFileDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Datasource with name test_datasource not found"):
|
||||
controller.get_datasource("test_datasource")
|
||||
@@ -1,151 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceIdentity,
|
||||
DatasourceProviderType,
|
||||
GetOnlineDocumentPageContentRequest,
|
||||
)
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
|
||||
|
||||
class TestOnlineDocumentDatasourcePlugin:
|
||||
def test_init(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
tenant_id = "test_tenant"
|
||||
icon = "test_icon"
|
||||
plugin_unique_identifier = "test_plugin_id"
|
||||
|
||||
# Act
|
||||
plugin = OnlineDocumentDatasourcePlugin(
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
tenant_id=tenant_id,
|
||||
icon=icon,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert plugin.entity == entity
|
||||
assert plugin.runtime == runtime
|
||||
assert plugin.tenant_id == tenant_id
|
||||
assert plugin.icon == icon
|
||||
assert plugin.plugin_unique_identifier == plugin_unique_identifier
|
||||
|
||||
def test_get_online_document_pages(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
identity = MagicMock(spec=DatasourceIdentity)
|
||||
entity.identity = identity
|
||||
identity.provider = "test_provider"
|
||||
identity.name = "test_name"
|
||||
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
runtime.credentials = {"api_key": "test_key"}
|
||||
|
||||
tenant_id = "test_tenant"
|
||||
icon = "test_icon"
|
||||
plugin_unique_identifier = "test_plugin_id"
|
||||
|
||||
plugin = OnlineDocumentDatasourcePlugin(
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
tenant_id=tenant_id,
|
||||
icon=icon,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
)
|
||||
|
||||
user_id = "test_user"
|
||||
datasource_parameters = {"param": "value"}
|
||||
provider_type = "test_type"
|
||||
|
||||
mock_generator = MagicMock()
|
||||
|
||||
# Patch PluginDatasourceManager to isolate plugin behavior from external dependencies
|
||||
with patch("core.datasource.online_document.online_document_plugin.PluginDatasourceManager") as MockManager:
|
||||
mock_manager_instance = MockManager.return_value
|
||||
mock_manager_instance.get_online_document_pages.return_value = mock_generator
|
||||
|
||||
# Act
|
||||
result = plugin.get_online_document_pages(
|
||||
user_id=user_id, datasource_parameters=datasource_parameters, provider_type=provider_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == mock_generator
|
||||
mock_manager_instance.get_online_document_pages.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider="test_provider",
|
||||
datasource_name="test_name",
|
||||
credentials=runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def test_get_online_document_page_content(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
identity = MagicMock(spec=DatasourceIdentity)
|
||||
entity.identity = identity
|
||||
identity.provider = "test_provider"
|
||||
identity.name = "test_name"
|
||||
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
runtime.credentials = {"api_key": "test_key"}
|
||||
|
||||
tenant_id = "test_tenant"
|
||||
icon = "test_icon"
|
||||
plugin_unique_identifier = "test_plugin_id"
|
||||
|
||||
plugin = OnlineDocumentDatasourcePlugin(
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
tenant_id=tenant_id,
|
||||
icon=icon,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
)
|
||||
|
||||
user_id = "test_user"
|
||||
datasource_parameters = MagicMock(spec=GetOnlineDocumentPageContentRequest)
|
||||
provider_type = "test_type"
|
||||
|
||||
mock_generator = MagicMock()
|
||||
|
||||
with patch("core.datasource.online_document.online_document_plugin.PluginDatasourceManager") as MockManager:
|
||||
mock_manager_instance = MockManager.return_value
|
||||
mock_manager_instance.get_online_document_page_content.return_value = mock_generator
|
||||
|
||||
# Act
|
||||
result = plugin.get_online_document_page_content(
|
||||
user_id=user_id, datasource_parameters=datasource_parameters, provider_type=provider_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == mock_generator
|
||||
mock_manager_instance.get_online_document_page_content.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider="test_provider",
|
||||
datasource_name="test_name",
|
||||
credentials=runtime.credentials,
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def test_datasource_provider_type(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
plugin = OnlineDocumentDatasourcePlugin(
|
||||
entity=entity, runtime=runtime, tenant_id="test", icon="test", plugin_unique_identifier="test"
|
||||
)
|
||||
|
||||
# Act
|
||||
result = plugin.datasource_provider_type()
|
||||
|
||||
# Assert
|
||||
assert result == DatasourceProviderType.ONLINE_DOCUMENT
|
||||
@@ -1,100 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderEntityWithPlugin,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
from core.datasource.online_document.online_document_plugin import OnlineDocumentDatasourcePlugin
|
||||
from core.datasource.online_document.online_document_provider import OnlineDocumentDatasourcePluginProviderController
|
||||
|
||||
|
||||
class TestOnlineDocumentDatasourcePluginProviderController:
|
||||
def test_init(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
plugin_id = "test_plugin_id"
|
||||
plugin_unique_identifier = "test_plugin_uid"
|
||||
tenant_id = "test_tenant_id"
|
||||
|
||||
# Act
|
||||
controller = OnlineDocumentDatasourcePluginProviderController(
|
||||
entity=mock_entity,
|
||||
plugin_id=plugin_id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert controller.entity == mock_entity
|
||||
assert controller.plugin_id == plugin_id
|
||||
assert controller.plugin_unique_identifier == plugin_unique_identifier
|
||||
assert controller.tenant_id == tenant_id
|
||||
|
||||
def test_provider_type(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
controller = OnlineDocumentDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="test", plugin_unique_identifier="test", tenant_id="test"
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert controller.provider_type == DatasourceProviderType.ONLINE_DOCUMENT
|
||||
|
||||
def test_get_datasource_success(self):
|
||||
# Arrange
|
||||
from core.datasource.entities.datasource_entities import DatasourceIdentity
|
||||
|
||||
mock_datasource_entity = MagicMock(spec=DatasourceEntity)
|
||||
mock_datasource_entity.identity = MagicMock(spec=DatasourceIdentity)
|
||||
mock_datasource_entity.identity.name = "target_datasource"
|
||||
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.datasources = [mock_datasource_entity]
|
||||
mock_entity.identity = MagicMock()
|
||||
mock_entity.identity.icon = "test_icon"
|
||||
|
||||
plugin_unique_identifier = "test_plugin_uid"
|
||||
tenant_id = "test_tenant_id"
|
||||
|
||||
controller = OnlineDocumentDatasourcePluginProviderController(
|
||||
entity=mock_entity,
|
||||
plugin_id="test_plugin_id",
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Act
|
||||
result = controller.get_datasource("target_datasource")
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, OnlineDocumentDatasourcePlugin)
|
||||
assert result.entity == mock_datasource_entity
|
||||
assert result.tenant_id == tenant_id
|
||||
assert result.icon == "test_icon"
|
||||
assert result.plugin_unique_identifier == plugin_unique_identifier
|
||||
assert result.runtime.tenant_id == tenant_id
|
||||
|
||||
def test_get_datasource_not_found(self):
|
||||
# Arrange
|
||||
from core.datasource.entities.datasource_entities import DatasourceIdentity
|
||||
|
||||
mock_datasource_entity = MagicMock(spec=DatasourceEntity)
|
||||
mock_datasource_entity.identity = MagicMock(spec=DatasourceIdentity)
|
||||
mock_datasource_entity.identity.name = "other_datasource"
|
||||
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
mock_entity.datasources = [mock_datasource_entity]
|
||||
|
||||
controller = OnlineDocumentDatasourcePluginProviderController(
|
||||
entity=mock_entity,
|
||||
plugin_id="test_plugin_id",
|
||||
plugin_unique_identifier="test_plugin_uid",
|
||||
tenant_id="test_tenant_id",
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Datasource with name missing_datasource not found"):
|
||||
controller.get_datasource("missing_datasource")
|
||||
@@ -1,147 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceIdentity,
|
||||
DatasourceProviderType,
|
||||
OnlineDriveBrowseFilesRequest,
|
||||
OnlineDriveDownloadFileRequest,
|
||||
)
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
|
||||
|
||||
class TestOnlineDriveDatasourcePlugin:
|
||||
def test_init(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
tenant_id = "test_tenant"
|
||||
icon = "test_icon"
|
||||
plugin_unique_identifier = "test_plugin_id"
|
||||
|
||||
# Act
|
||||
plugin = OnlineDriveDatasourcePlugin(
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
tenant_id=tenant_id,
|
||||
icon=icon,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert plugin.entity == entity
|
||||
assert plugin.runtime == runtime
|
||||
assert plugin.tenant_id == tenant_id
|
||||
assert plugin.icon == icon
|
||||
assert plugin.plugin_unique_identifier == plugin_unique_identifier
|
||||
|
||||
def test_online_drive_browse_files(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
identity = MagicMock(spec=DatasourceIdentity)
|
||||
entity.identity = identity
|
||||
identity.provider = "test_provider"
|
||||
identity.name = "test_name"
|
||||
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
runtime.credentials = {"token": "test_token"}
|
||||
|
||||
tenant_id = "test_tenant"
|
||||
icon = "test_icon"
|
||||
plugin_unique_identifier = "test_plugin_id"
|
||||
|
||||
plugin = OnlineDriveDatasourcePlugin(
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
tenant_id=tenant_id,
|
||||
icon=icon,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
)
|
||||
|
||||
user_id = "test_user"
|
||||
request = MagicMock(spec=OnlineDriveBrowseFilesRequest)
|
||||
provider_type = "test_type"
|
||||
|
||||
mock_generator = MagicMock()
|
||||
|
||||
with patch("core.datasource.online_drive.online_drive_plugin.PluginDatasourceManager") as MockManager:
|
||||
mock_manager_instance = MockManager.return_value
|
||||
mock_manager_instance.online_drive_browse_files.return_value = mock_generator
|
||||
|
||||
# Act
|
||||
result = plugin.online_drive_browse_files(user_id=user_id, request=request, provider_type=provider_type)
|
||||
|
||||
# Assert
|
||||
assert result == mock_generator
|
||||
mock_manager_instance.online_drive_browse_files.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider="test_provider",
|
||||
datasource_name="test_name",
|
||||
credentials=runtime.credentials,
|
||||
request=request,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def test_online_drive_download_file(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
identity = MagicMock(spec=DatasourceIdentity)
|
||||
entity.identity = identity
|
||||
identity.provider = "test_provider"
|
||||
identity.name = "test_name"
|
||||
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
runtime.credentials = {"token": "test_token"}
|
||||
|
||||
tenant_id = "test_tenant"
|
||||
icon = "test_icon"
|
||||
plugin_unique_identifier = "test_plugin_id"
|
||||
|
||||
plugin = OnlineDriveDatasourcePlugin(
|
||||
entity=entity,
|
||||
runtime=runtime,
|
||||
tenant_id=tenant_id,
|
||||
icon=icon,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
)
|
||||
|
||||
user_id = "test_user"
|
||||
request = MagicMock(spec=OnlineDriveDownloadFileRequest)
|
||||
provider_type = "test_type"
|
||||
|
||||
mock_generator = MagicMock()
|
||||
|
||||
with patch("core.datasource.online_drive.online_drive_plugin.PluginDatasourceManager") as MockManager:
|
||||
mock_manager_instance = MockManager.return_value
|
||||
mock_manager_instance.online_drive_download_file.return_value = mock_generator
|
||||
|
||||
# Act
|
||||
result = plugin.online_drive_download_file(user_id=user_id, request=request, provider_type=provider_type)
|
||||
|
||||
# Assert
|
||||
assert result == mock_generator
|
||||
mock_manager_instance.online_drive_download_file.assert_called_once_with(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
datasource_provider="test_provider",
|
||||
datasource_name="test_name",
|
||||
credentials=runtime.credentials,
|
||||
request=request,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
|
||||
def test_datasource_provider_type(self):
|
||||
# Arrange
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
plugin = OnlineDriveDatasourcePlugin(
|
||||
entity=entity, runtime=runtime, tenant_id="test", icon="test", plugin_unique_identifier="test"
|
||||
)
|
||||
|
||||
# Act
|
||||
result = plugin.datasource_provider_type()
|
||||
|
||||
# Assert
|
||||
assert result == DatasourceProviderType.ONLINE_DRIVE
|
||||
@@ -1,83 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceProviderEntityWithPlugin, DatasourceProviderType
|
||||
from core.datasource.online_drive.online_drive_plugin import OnlineDriveDatasourcePlugin
|
||||
from core.datasource.online_drive.online_drive_provider import OnlineDriveDatasourcePluginProviderController
|
||||
|
||||
|
||||
class TestOnlineDriveDatasourcePluginProviderController:
|
||||
def test_init(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
plugin_id = "test_plugin_id"
|
||||
plugin_unique_identifier = "test_plugin_unique_identifier"
|
||||
tenant_id = "test_tenant_id"
|
||||
|
||||
# Act
|
||||
controller = OnlineDriveDatasourcePluginProviderController(
|
||||
entity=mock_entity,
|
||||
plugin_id=plugin_id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert controller.entity == mock_entity
|
||||
assert controller.plugin_id == plugin_id
|
||||
assert controller.plugin_unique_identifier == plugin_unique_identifier
|
||||
assert controller.tenant_id == tenant_id
|
||||
|
||||
def test_provider_type(self):
|
||||
# Arrange
|
||||
mock_entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
controller = OnlineDriveDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
assert controller.provider_type == DatasourceProviderType.ONLINE_DRIVE
|
||||
|
||||
def test_get_datasource_success(self):
|
||||
# Arrange
|
||||
mock_datasource_entity = MagicMock()
|
||||
mock_datasource_entity.identity.name = "test_datasource"
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.datasources = [mock_datasource_entity]
|
||||
mock_entity.identity.icon = "test_icon"
|
||||
|
||||
plugin_unique_identifier = "test_plugin_unique_identifier"
|
||||
tenant_id = "test_tenant_id"
|
||||
|
||||
controller = OnlineDriveDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="id", plugin_unique_identifier=plugin_unique_identifier, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Act
|
||||
datasource = controller.get_datasource("test_datasource")
|
||||
|
||||
# Assert
|
||||
assert isinstance(datasource, OnlineDriveDatasourcePlugin)
|
||||
assert datasource.entity == mock_datasource_entity
|
||||
assert datasource.tenant_id == tenant_id
|
||||
assert datasource.icon == "test_icon"
|
||||
assert datasource.plugin_unique_identifier == plugin_unique_identifier
|
||||
assert datasource.runtime.tenant_id == tenant_id
|
||||
|
||||
def test_get_datasource_not_found(self):
|
||||
# Arrange
|
||||
mock_datasource_entity = MagicMock()
|
||||
mock_datasource_entity.identity.name = "other_datasource"
|
||||
|
||||
mock_entity = MagicMock()
|
||||
mock_entity.datasources = [mock_datasource_entity]
|
||||
|
||||
controller = OnlineDriveDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="id", plugin_unique_identifier="unique_id", tenant_id="tenant"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match="Datasource with name test_datasource not found"):
|
||||
controller.get_datasource("test_datasource")
|
||||
@@ -1,409 +0,0 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import hmac
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from core.datasource.datasource_file_manager import DatasourceFileManager
|
||||
from models.model import MessageFile, UploadFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
|
||||
class TestDatasourceFileManager:
|
||||
@patch("core.datasource.datasource_file_manager.time.time")
|
||||
@patch("core.datasource.datasource_file_manager.os.urandom")
|
||||
@patch("core.datasource.datasource_file_manager.dify_config")
|
||||
def test_sign_file(self, mock_config, mock_urandom, mock_time):
|
||||
# Setup
|
||||
mock_config.FILES_URL = "http://localhost:5001"
|
||||
mock_config.SECRET_KEY = "test_secret"
|
||||
mock_time.return_value = 1700000000
|
||||
mock_urandom.return_value = b"1234567890abcdef" # 16 bytes
|
||||
|
||||
datasource_file_id = "file_id_123"
|
||||
extension = ".png"
|
||||
|
||||
# Execute
|
||||
signed_url = DatasourceFileManager.sign_file(datasource_file_id, extension)
|
||||
|
||||
# Verify
|
||||
assert signed_url.startswith("http://localhost:5001/files/datasources/file_id_123.png?")
|
||||
assert "timestamp=1700000000" in signed_url
|
||||
assert f"nonce={mock_urandom.return_value.hex()}" in signed_url
|
||||
assert "sign=" in signed_url
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.time.time")
|
||||
@patch("core.datasource.datasource_file_manager.os.urandom")
|
||||
@patch("core.datasource.datasource_file_manager.dify_config")
|
||||
def test_sign_file_empty_secret(self, mock_config, mock_urandom, mock_time):
|
||||
# Setup
|
||||
mock_config.FILES_URL = "http://localhost:5001"
|
||||
mock_config.SECRET_KEY = None # Empty secret
|
||||
mock_time.return_value = 1700000000
|
||||
mock_urandom.return_value = b"1234567890abcdef"
|
||||
|
||||
# Execute
|
||||
signed_url = DatasourceFileManager.sign_file("file_id", ".png")
|
||||
assert "sign=" in signed_url
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.time.time")
|
||||
@patch("core.datasource.datasource_file_manager.dify_config")
|
||||
def test_verify_file(self, mock_config, mock_time):
|
||||
# Setup
|
||||
mock_config.SECRET_KEY = "test_secret"
|
||||
mock_config.FILES_ACCESS_TIMEOUT = 300
|
||||
mock_time.return_value = 1700000000
|
||||
|
||||
datasource_file_id = "file_id_123"
|
||||
timestamp = "1699999800" # 200 seconds ago
|
||||
nonce = "some_nonce"
|
||||
|
||||
# Manually calculate sign
|
||||
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||
secret_key = b"test_secret"
|
||||
sign = hmac.new(secret_key, data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
# Execute & Verify Success
|
||||
assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is True
|
||||
|
||||
# Verify Failure - Wrong Sign
|
||||
assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, "wrong_sign") is False
|
||||
|
||||
# Verify Failure - Timeout
|
||||
mock_time.return_value = 1700000500 # 700 seconds after timestamp (300 is timeout)
|
||||
assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is False
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.time.time")
|
||||
@patch("core.datasource.datasource_file_manager.dify_config")
|
||||
def test_verify_file_empty_secret(self, mock_config, mock_time):
|
||||
# Setup
|
||||
mock_config.SECRET_KEY = "" # Empty string secret
|
||||
mock_config.FILES_ACCESS_TIMEOUT = 300
|
||||
mock_time.return_value = 1700000000
|
||||
|
||||
datasource_file_id = "file_id_123"
|
||||
timestamp = "1699999800"
|
||||
nonce = "some_nonce"
|
||||
|
||||
# Calculate with empty secret
|
||||
data_to_sign = f"file-preview|{datasource_file_id}|{timestamp}|{nonce}"
|
||||
sign = hmac.new(b"", data_to_sign.encode(), hashlib.sha256).digest()
|
||||
encoded_sign = base64.urlsafe_b64encode(sign).decode()
|
||||
|
||||
assert DatasourceFileManager.verify_file(datasource_file_id, timestamp, nonce, encoded_sign) is True
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
@patch("core.datasource.datasource_file_manager.uuid4")
|
||||
@patch("core.datasource.datasource_file_manager.dify_config")
|
||||
def test_create_file_by_raw(self, mock_config, mock_uuid, mock_storage, mock_db):
|
||||
# Setup
|
||||
mock_uuid.return_value = MagicMock(hex="unique_hex")
|
||||
mock_config.STORAGE_TYPE = "local"
|
||||
|
||||
user_id = "user_123"
|
||||
tenant_id = "tenant_456"
|
||||
file_binary = b"fake binary data"
|
||||
mimetype = "image/png"
|
||||
|
||||
# Execute
|
||||
upload_file = DatasourceFileManager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=file_binary,
|
||||
mimetype=mimetype,
|
||||
filename="test.png",
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert upload_file.tenant_id == tenant_id
|
||||
assert upload_file.name == "test.png"
|
||||
assert upload_file.size == len(file_binary)
|
||||
assert upload_file.mime_type == mimetype
|
||||
assert upload_file.key == f"datasources/{tenant_id}/unique_hex.png"
|
||||
|
||||
mock_storage.save.assert_called_once_with(upload_file.key, file_binary)
|
||||
mock_db.session.add.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
@patch("core.datasource.datasource_file_manager.uuid4")
|
||||
@patch("core.datasource.datasource_file_manager.dify_config")
|
||||
def test_create_file_by_raw_filename_no_extension(self, mock_config, mock_uuid, mock_storage, mock_db):
|
||||
# Setup
|
||||
mock_uuid.return_value = MagicMock(hex="unique_hex")
|
||||
mock_config.STORAGE_TYPE = "local"
|
||||
|
||||
user_id = "user_123"
|
||||
tenant_id = "tenant_456"
|
||||
file_binary = b"fake binary data"
|
||||
mimetype = "image/png"
|
||||
|
||||
# Execute
|
||||
upload_file = DatasourceFileManager.create_file_by_raw(
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
file_binary=file_binary,
|
||||
mimetype=mimetype,
|
||||
filename="test", # No extension
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert upload_file.name == "test.png" # Should append extension
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
@patch("core.datasource.datasource_file_manager.uuid4")
|
||||
@patch("core.datasource.datasource_file_manager.dify_config")
|
||||
@patch("core.datasource.datasource_file_manager.guess_extension")
|
||||
def test_create_file_by_raw_unknown_extension(self, mock_guess_ext, mock_config, mock_uuid, mock_storage, mock_db):
|
||||
# Setup
|
||||
mock_guess_ext.return_value = None # Cannot guess
|
||||
mock_uuid.return_value = MagicMock(hex="unique_hex")
|
||||
|
||||
# Execute
|
||||
upload_file = DatasourceFileManager.create_file_by_raw(
|
||||
user_id="user",
|
||||
tenant_id="tenant",
|
||||
conversation_id=None,
|
||||
file_binary=b"data",
|
||||
mimetype="application/x-unknown",
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert upload_file.extension == ".bin"
|
||||
assert upload_file.name == "unique_hex.bin"
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
@patch("core.datasource.datasource_file_manager.uuid4")
|
||||
@patch("core.datasource.datasource_file_manager.dify_config")
|
||||
def test_create_file_by_raw_no_filename(self, mock_config, mock_uuid, mock_storage, mock_db):
|
||||
# Setup
|
||||
mock_uuid.return_value = MagicMock(hex="unique_hex")
|
||||
mock_config.STORAGE_TYPE = "local"
|
||||
|
||||
# Execute
|
||||
upload_file = DatasourceFileManager.create_file_by_raw(
|
||||
user_id="user_123",
|
||||
tenant_id="tenant_456",
|
||||
conversation_id=None,
|
||||
file_binary=b"data",
|
||||
mimetype="application/pdf",
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert upload_file.name == "unique_hex.pdf"
|
||||
assert upload_file.extension == ".pdf"
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
@patch("core.datasource.datasource_file_manager.uuid4")
|
||||
def test_create_file_by_url_mimetype_from_guess(self, mock_uuid, mock_storage, mock_db, mock_ssrf):
|
||||
# Setup
|
||||
mock_uuid.return_value = MagicMock(hex="unique_hex")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"bits"
|
||||
mock_response.headers = {} # No content-type in headers
|
||||
mock_ssrf.get.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
tool_file = DatasourceFileManager.create_file_by_url(
|
||||
user_id="user_123", tenant_id="tenant_456", file_url="https://example.com/photo.png"
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert tool_file.mimetype == "image/png" # Guessed from .png in URL
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
@patch("core.datasource.datasource_file_manager.uuid4")
|
||||
def test_create_file_by_url_mimetype_default(self, mock_uuid, mock_storage, mock_db, mock_ssrf):
|
||||
# Setup
|
||||
mock_uuid.return_value = MagicMock(hex="unique_hex")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"bits"
|
||||
mock_response.headers = {}
|
||||
mock_ssrf.get.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
tool_file = DatasourceFileManager.create_file_by_url(
|
||||
user_id="user_123",
|
||||
tenant_id="tenant_456",
|
||||
file_url="https://example.com/unknown", # No extension, no headers
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert tool_file.mimetype == "application/octet-stream"
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
@patch("core.datasource.datasource_file_manager.uuid4")
|
||||
def test_create_file_by_url_success(self, mock_uuid, mock_storage, mock_db, mock_ssrf):
|
||||
# Setup
|
||||
mock_uuid.return_value = MagicMock(hex="unique_hex")
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = b"downloaded bits"
|
||||
mock_response.headers = {"Content-Type": "image/jpeg"}
|
||||
mock_ssrf.get.return_value = mock_response
|
||||
|
||||
# Execute
|
||||
tool_file = DatasourceFileManager.create_file_by_url(
|
||||
user_id="user_123", tenant_id="tenant_456", file_url="https://example.com/photo.jpg"
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert tool_file.mimetype == "image/jpeg"
|
||||
assert tool_file.size == len(b"downloaded bits")
|
||||
assert tool_file.file_key == "tools/tenant_456/unique_hex.jpg"
|
||||
mock_storage.save.assert_called_once()
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.ssrf_proxy")
|
||||
def test_create_file_by_url_timeout(self, mock_ssrf):
|
||||
# Setup
|
||||
mock_ssrf.get.side_effect = httpx.TimeoutException("Timeout")
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match="timeout when downloading file"):
|
||||
DatasourceFileManager.create_file_by_url(
|
||||
user_id="user_123", tenant_id="tenant_456", file_url="https://example.com/large.file"
|
||||
)
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
def test_get_file_binary(self, mock_storage, mock_db):
|
||||
# Setup
|
||||
mock_upload_file = MagicMock(spec=UploadFile)
|
||||
mock_upload_file.key = "some_key"
|
||||
mock_upload_file.mime_type = "image/png"
|
||||
|
||||
mock_query = mock_db.session.query.return_value
|
||||
mock_where = mock_query.where.return_value
|
||||
mock_where.first.return_value = mock_upload_file
|
||||
|
||||
mock_storage.load_once.return_value = b"file content"
|
||||
|
||||
# Execute
|
||||
result = DatasourceFileManager.get_file_binary("file_id")
|
||||
|
||||
# Verify
|
||||
assert result == (b"file content", "image/png")
|
||||
|
||||
# Case: Not found
|
||||
mock_where.first.return_value = None
|
||||
assert DatasourceFileManager.get_file_binary("unknown") is None
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
def test_get_file_binary_by_message_file_id(self, mock_storage, mock_db):
|
||||
# Setup
|
||||
mock_message_file = MagicMock(spec=MessageFile)
|
||||
mock_message_file.url = "http://localhost/files/tools/tool_id.png"
|
||||
|
||||
mock_tool_file = MagicMock(spec=ToolFile)
|
||||
mock_tool_file.file_key = "tool_key"
|
||||
mock_tool_file.mimetype = "image/png"
|
||||
|
||||
# Mock query sequence
|
||||
def mock_query(model):
|
||||
m = MagicMock()
|
||||
if model == MessageFile:
|
||||
m.where.return_value.first.return_value = mock_message_file
|
||||
elif model == ToolFile:
|
||||
m.where.return_value.first.return_value = mock_tool_file
|
||||
return m
|
||||
|
||||
mock_db.session.query.side_effect = mock_query
|
||||
mock_storage.load_once.return_value = b"tool content"
|
||||
|
||||
# Execute
|
||||
result = DatasourceFileManager.get_file_binary_by_message_file_id("msg_file_id")
|
||||
|
||||
# Verify
|
||||
assert result == (b"tool content", "image/png")
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
def test_get_file_binary_by_message_file_id_with_extension(self, mock_storage, mock_db):
|
||||
# Test that it correctly parses tool_id even with extension in URL
|
||||
mock_message_file = MagicMock(spec=MessageFile)
|
||||
mock_message_file.url = "http://localhost/files/tools/abcdef.png"
|
||||
|
||||
mock_tool_file = MagicMock(spec=ToolFile)
|
||||
mock_tool_file.id = "abcdef"
|
||||
mock_tool_file.file_key = "tk"
|
||||
mock_tool_file.mimetype = "image/png"
|
||||
|
||||
def mock_query(model):
|
||||
m = MagicMock()
|
||||
if model == MessageFile:
|
||||
m.where.return_value.first.return_value = mock_message_file
|
||||
else:
|
||||
m.where.return_value.first.return_value = mock_tool_file
|
||||
return m
|
||||
|
||||
mock_db.session.query.side_effect = mock_query
|
||||
mock_storage.load_once.return_value = b"bits"
|
||||
|
||||
result = DatasourceFileManager.get_file_binary_by_message_file_id("m")
|
||||
assert result == (b"bits", "image/png")
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
def test_get_file_binary_by_message_file_id_failures(self, mock_storage, mock_db):
|
||||
# Setup common mock
|
||||
mock_query_obj = MagicMock()
|
||||
mock_db.session.query.return_value = mock_query_obj
|
||||
mock_query_obj.where.return_value.first.return_value = None
|
||||
|
||||
# Case 1: Message file not found
|
||||
assert DatasourceFileManager.get_file_binary_by_message_file_id("none") is None
|
||||
|
||||
# Case 2: Message file found but tool file not found
|
||||
mock_message_file = MagicMock(spec=MessageFile)
|
||||
mock_message_file.url = None
|
||||
|
||||
def mock_query_v2(model):
|
||||
m = MagicMock()
|
||||
if model == MessageFile:
|
||||
m.where.return_value.first.return_value = mock_message_file
|
||||
else:
|
||||
m.where.return_value.first.return_value = None
|
||||
return m
|
||||
|
||||
mock_db.session.query.side_effect = mock_query_v2
|
||||
assert DatasourceFileManager.get_file_binary_by_message_file_id("msg_id") is None
|
||||
|
||||
@patch("core.datasource.datasource_file_manager.db")
|
||||
@patch("core.datasource.datasource_file_manager.storage")
|
||||
def test_get_file_generator_by_upload_file_id(self, mock_storage, mock_db):
|
||||
# Setup
|
||||
mock_upload_file = MagicMock(spec=UploadFile)
|
||||
mock_upload_file.key = "upload_key"
|
||||
mock_upload_file.mime_type = "text/plain"
|
||||
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = mock_upload_file
|
||||
|
||||
mock_storage.load_stream.return_value = iter([b"chunk1", b"chunk2"])
|
||||
|
||||
# Execute
|
||||
stream, mimetype = DatasourceFileManager.get_file_generator_by_upload_file_id("upload_id")
|
||||
|
||||
# Verify
|
||||
assert mimetype == "text/plain"
|
||||
assert list(stream) == [b"chunk1", b"chunk2"]
|
||||
|
||||
# Case: Not found
|
||||
mock_db.session.query.return_value.where.return_value.first.return_value = None
|
||||
stream, mimetype = DatasourceFileManager.get_file_generator_by_upload_file_id("none")
|
||||
assert stream is None
|
||||
assert mimetype is None
|
||||
@@ -1,15 +1,9 @@
|
||||
import types
|
||||
from collections.abc import Generator
|
||||
|
||||
import pytest
|
||||
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from core.datasource.datasource_manager import DatasourceManager
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage, DatasourceProviderType
|
||||
from core.datasource.errors import DatasourceProviderNotFoundError
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from dify_graph.entities.workflow_node_execution import WorkflowNodeExecutionStatus
|
||||
from dify_graph.file import File
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from dify_graph.node_events import StreamChunkEvent, StreamCompletedEvent
|
||||
|
||||
|
||||
@@ -21,22 +15,6 @@ def _gen_messages_text_only(text: str) -> Generator[DatasourceMessage, None, Non
|
||||
)
|
||||
|
||||
|
||||
def _drain_generator(gen: Generator[DatasourceMessage, None, object]) -> tuple[list[DatasourceMessage], object | None]:
|
||||
messages: list[DatasourceMessage] = []
|
||||
try:
|
||||
while True:
|
||||
messages.append(next(gen))
|
||||
except StopIteration as e:
|
||||
return messages, e.value
|
||||
|
||||
|
||||
def _invalidate_recyclable_contextvars() -> None:
|
||||
"""
|
||||
Ensure RecyclableContextVar.get() raises LookupError until reset by code under test.
|
||||
"""
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
|
||||
def test_get_icon_url_calls_runtime(mocker):
|
||||
fake_runtime = mocker.Mock()
|
||||
fake_runtime.get_icon_url.return_value = "https://icon"
|
||||
@@ -52,119 +30,6 @@ def test_get_icon_url_calls_runtime(mocker):
|
||||
DatasourceManager.get_datasource_runtime.assert_called_once()
|
||||
|
||||
|
||||
def test_get_datasource_runtime_delegates_to_provider_controller(mocker):
|
||||
provider_controller = mocker.Mock()
|
||||
provider_controller.get_datasource.return_value = object()
|
||||
mocker.patch.object(DatasourceManager, "get_datasource_plugin_provider", return_value=provider_controller)
|
||||
|
||||
runtime = DatasourceManager.get_datasource_runtime(
|
||||
provider_id="prov/x",
|
||||
datasource_name="ds",
|
||||
tenant_id="t1",
|
||||
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
)
|
||||
assert runtime is provider_controller.get_datasource.return_value
|
||||
provider_controller.get_datasource.assert_called_once_with("ds")
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("datasource_type", "controller_path"),
|
||||
[
|
||||
(
|
||||
DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
"core.datasource.datasource_manager.OnlineDocumentDatasourcePluginProviderController",
|
||||
),
|
||||
(
|
||||
DatasourceProviderType.ONLINE_DRIVE,
|
||||
"core.datasource.datasource_manager.OnlineDriveDatasourcePluginProviderController",
|
||||
),
|
||||
(
|
||||
DatasourceProviderType.WEBSITE_CRAWL,
|
||||
"core.datasource.datasource_manager.WebsiteCrawlDatasourcePluginProviderController",
|
||||
),
|
||||
(
|
||||
DatasourceProviderType.LOCAL_FILE,
|
||||
"core.datasource.datasource_manager.LocalFileDatasourcePluginProviderController",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_datasource_plugin_provider_creates_controller_and_caches(mocker, datasource_type, controller_path):
|
||||
_invalidate_recyclable_contextvars()
|
||||
|
||||
provider_entity = types.SimpleNamespace(declaration=object(), plugin_id="plugin", plugin_unique_identifier="uniq")
|
||||
fetch = mocker.patch(
|
||||
"core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider",
|
||||
return_value=provider_entity,
|
||||
)
|
||||
ctrl_cls = mocker.patch(controller_path)
|
||||
|
||||
first = DatasourceManager.get_datasource_plugin_provider(
|
||||
provider_id=f"prov/{datasource_type.value}",
|
||||
tenant_id="t1",
|
||||
datasource_type=datasource_type,
|
||||
)
|
||||
second = DatasourceManager.get_datasource_plugin_provider(
|
||||
provider_id=f"prov/{datasource_type.value}",
|
||||
tenant_id="t1",
|
||||
datasource_type=datasource_type,
|
||||
)
|
||||
|
||||
assert first is second
|
||||
assert fetch.call_count == 1
|
||||
assert ctrl_cls.call_count == 1
|
||||
|
||||
|
||||
def test_get_datasource_plugin_provider_raises_when_provider_entity_missing(mocker):
|
||||
_invalidate_recyclable_contextvars()
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
with pytest.raises(DatasourceProviderNotFoundError, match="plugin provider prov/notfound not found"):
|
||||
DatasourceManager.get_datasource_plugin_provider(
|
||||
provider_id="prov/notfound",
|
||||
tenant_id="t1",
|
||||
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
)
|
||||
|
||||
|
||||
def test_get_datasource_plugin_provider_raises_for_unsupported_type(mocker):
|
||||
_invalidate_recyclable_contextvars()
|
||||
provider_entity = types.SimpleNamespace(declaration=object(), plugin_id="plugin", plugin_unique_identifier="uniq")
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider",
|
||||
return_value=provider_entity,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported datasource type"):
|
||||
DatasourceManager.get_datasource_plugin_provider(
|
||||
provider_id="prov/x",
|
||||
tenant_id="t1",
|
||||
datasource_type=types.SimpleNamespace(), # not a DatasourceProviderType at runtime
|
||||
)
|
||||
|
||||
|
||||
def test_get_datasource_plugin_provider_raises_when_controller_none(mocker):
|
||||
_invalidate_recyclable_contextvars()
|
||||
provider_entity = types.SimpleNamespace(declaration=object(), plugin_id="plugin", plugin_unique_identifier="uniq")
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.PluginDatasourceManager.fetch_datasource_provider",
|
||||
return_value=provider_entity,
|
||||
)
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.OnlineDocumentDatasourcePluginProviderController",
|
||||
return_value=None,
|
||||
)
|
||||
|
||||
with pytest.raises(DatasourceProviderNotFoundError, match="Datasource provider prov/x not found"):
|
||||
DatasourceManager.get_datasource_plugin_provider(
|
||||
provider_id="prov/x",
|
||||
tenant_id="t1",
|
||||
datasource_type=DatasourceProviderType.ONLINE_DOCUMENT,
|
||||
)
|
||||
|
||||
|
||||
def test_stream_online_results_yields_messages_online_document(mocker):
|
||||
# stub runtime to yield a text message
|
||||
def _doc_messages(**_):
|
||||
@@ -195,148 +60,6 @@ def test_stream_online_results_yields_messages_online_document(mocker):
|
||||
assert msgs[0].message.text == "hello"
|
||||
|
||||
|
||||
def test_stream_online_results_sets_credentials_and_returns_empty_dict_online_document(mocker):
|
||||
class _Runtime:
|
||||
def __init__(self) -> None:
|
||||
self.runtime = types.SimpleNamespace(credentials=None)
|
||||
|
||||
def get_online_document_page_content(self, **_kwargs):
|
||||
yield from _gen_messages_text_only("hello")
|
||||
|
||||
runtime = _Runtime()
|
||||
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=runtime)
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={"token": "t"},
|
||||
)
|
||||
|
||||
gen = DatasourceManager.stream_online_results(
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_document",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="cred",
|
||||
datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"),
|
||||
online_drive_request=None,
|
||||
)
|
||||
messages, final_value = _drain_generator(gen)
|
||||
|
||||
assert runtime.runtime.credentials == {"token": "t"}
|
||||
assert [m.message.text for m in messages] == ["hello"]
|
||||
assert final_value == {}
|
||||
|
||||
|
||||
def test_stream_online_results_raises_when_missing_params(mocker):
|
||||
class _Runtime:
|
||||
def __init__(self) -> None:
|
||||
self.runtime = types.SimpleNamespace(credentials=None)
|
||||
|
||||
def get_online_document_page_content(self, **_kwargs):
|
||||
yield from _gen_messages_text_only("never")
|
||||
|
||||
def online_drive_download_file(self, **_kwargs):
|
||||
yield from _gen_messages_text_only("never")
|
||||
|
||||
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=_Runtime())
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="datasource_param is required for ONLINE_DOCUMENT streaming"):
|
||||
list(
|
||||
DatasourceManager.stream_online_results(
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_document",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
datasource_param=None,
|
||||
online_drive_request=None,
|
||||
)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="online_drive_request is required for ONLINE_DRIVE streaming"):
|
||||
list(
|
||||
DatasourceManager.stream_online_results(
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_drive",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
datasource_param=None,
|
||||
online_drive_request=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_stream_online_results_yields_messages_and_returns_empty_dict_online_drive(mocker):
|
||||
class _Runtime:
|
||||
def __init__(self) -> None:
|
||||
self.runtime = types.SimpleNamespace(credentials=None)
|
||||
|
||||
def online_drive_download_file(self, **_kwargs):
|
||||
yield from _gen_messages_text_only("drive")
|
||||
|
||||
runtime = _Runtime()
|
||||
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=runtime)
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={"token": "t"},
|
||||
)
|
||||
|
||||
gen = DatasourceManager.stream_online_results(
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_drive",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="cred",
|
||||
datasource_param=None,
|
||||
online_drive_request=types.SimpleNamespace(id="fid", bucket="b"),
|
||||
)
|
||||
messages, final_value = _drain_generator(gen)
|
||||
|
||||
assert runtime.runtime.credentials == {"token": "t"}
|
||||
assert [m.message.text for m in messages] == ["drive"]
|
||||
assert final_value == {}
|
||||
|
||||
|
||||
def test_stream_online_results_raises_for_unsupported_stream_type(mocker):
|
||||
mocker.patch.object(DatasourceManager, "get_datasource_runtime", return_value=mocker.Mock())
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceProviderService.get_datasource_credentials",
|
||||
return_value={},
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported datasource type for streaming"):
|
||||
list(
|
||||
DatasourceManager.stream_online_results(
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="website_crawl",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
datasource_param=None,
|
||||
online_drive_request=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_stream_node_events_emits_events_online_document(mocker):
|
||||
# make manager's low-level stream produce TEXT only
|
||||
mocker.patch.object(
|
||||
@@ -370,260 +93,6 @@ def test_stream_node_events_emits_events_online_document(mocker):
|
||||
assert events[-1].node_run_result.status == WorkflowNodeExecutionStatus.SUCCEEDED
|
||||
|
||||
|
||||
def test_stream_node_events_builds_file_and_variables_from_messages(mocker):
|
||||
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
|
||||
|
||||
def _transformed(**_kwargs):
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text="/files/datasources/tool_file_1.png"),
|
||||
meta={},
|
||||
)
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.TEXT,
|
||||
message=DatasourceMessage.TextMessage(text="hello"),
|
||||
meta=None,
|
||||
)
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.LINK,
|
||||
message=DatasourceMessage.TextMessage(text="http://example.com"),
|
||||
meta=None,
|
||||
)
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.VARIABLE,
|
||||
message=DatasourceMessage.VariableMessage(variable_name="v", variable_value="a", stream=True),
|
||||
meta=None,
|
||||
)
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.VARIABLE,
|
||||
message=DatasourceMessage.VariableMessage(variable_name="v", variable_value="b", stream=True),
|
||||
meta=None,
|
||||
)
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.VARIABLE,
|
||||
message=DatasourceMessage.VariableMessage(variable_name="x", variable_value=1, stream=False),
|
||||
meta=None,
|
||||
)
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.JSON,
|
||||
message=DatasourceMessage.JsonMessage(json_object={"k": "v"}),
|
||||
meta=None,
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages",
|
||||
side_effect=_transformed,
|
||||
)
|
||||
|
||||
fake_tool_file = types.SimpleNamespace(mimetype="image/png")
|
||||
|
||||
class _Session:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
def scalar(self, _stmt):
|
||||
return fake_tool_file
|
||||
|
||||
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session())
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.file_factory.get_file_type_by_mime_type", return_value=FileType.IMAGE
|
||||
)
|
||||
built = File(
|
||||
tenant_id="t1",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id="tool_file_1",
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
storage_key="k",
|
||||
)
|
||||
build_from_mapping = mocker.patch(
|
||||
"core.datasource.datasource_manager.file_factory.build_from_mapping",
|
||||
return_value=built,
|
||||
)
|
||||
|
||||
variable_pool = mocker.Mock()
|
||||
|
||||
events = list(
|
||||
DatasourceManager.stream_node_events(
|
||||
node_id="nodeA",
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_document",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
parameters_for_log={"k": "v"},
|
||||
datasource_info={"info": "x"},
|
||||
variable_pool=variable_pool,
|
||||
datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"),
|
||||
online_drive_request=None,
|
||||
)
|
||||
)
|
||||
|
||||
build_from_mapping.assert_called_once()
|
||||
variable_pool.add.assert_not_called()
|
||||
|
||||
assert any(isinstance(e, StreamChunkEvent) and e.chunk == "hello" for e in events)
|
||||
assert any(isinstance(e, StreamChunkEvent) and e.chunk.startswith("Link: http") for e in events)
|
||||
assert any(isinstance(e, StreamChunkEvent) and e.selector == ["nodeA", "v"] and e.chunk == "a" for e in events)
|
||||
assert any(isinstance(e, StreamChunkEvent) and e.selector == ["nodeA", "v"] and e.chunk == "b" for e in events)
|
||||
assert isinstance(events[-2], StreamChunkEvent)
|
||||
assert events[-2].is_final is True
|
||||
|
||||
assert isinstance(events[-1], StreamCompletedEvent)
|
||||
assert events[-1].node_run_result.outputs["v"] == "ab"
|
||||
assert events[-1].node_run_result.outputs["x"] == 1
|
||||
|
||||
|
||||
def test_stream_node_events_raises_when_toolfile_missing(mocker):
|
||||
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
|
||||
|
||||
def _transformed(**_kwargs):
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text="/files/datasources/missing.png"),
|
||||
meta={},
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages",
|
||||
side_effect=_transformed,
|
||||
)
|
||||
|
||||
class _Session:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
def scalar(self, _stmt):
|
||||
return None
|
||||
|
||||
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_Session())
|
||||
|
||||
with pytest.raises(ValueError, match="ToolFile not found for file_id=missing, tenant_id=t1"):
|
||||
list(
|
||||
DatasourceManager.stream_node_events(
|
||||
node_id="nodeA",
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_document",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
parameters_for_log={},
|
||||
datasource_info={},
|
||||
variable_pool=mocker.Mock(),
|
||||
datasource_param=types.SimpleNamespace(workspace_id="w", page_id="pg", type="t"),
|
||||
online_drive_request=None,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def test_stream_node_events_online_drive_sets_variable_pool_file_and_outputs(mocker):
|
||||
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
|
||||
|
||||
file_in = File(
|
||||
tenant_id="t1",
|
||||
type=FileType.DOCUMENT,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id="tf",
|
||||
extension=".pdf",
|
||||
mime_type="application/pdf",
|
||||
storage_key="k",
|
||||
)
|
||||
|
||||
def _transformed(**_kwargs):
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.FILE,
|
||||
message=DatasourceMessage.FileMessage(file_marker="file_marker"),
|
||||
meta={"file": file_in},
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages",
|
||||
side_effect=_transformed,
|
||||
)
|
||||
|
||||
variable_pool = mocker.Mock()
|
||||
events = list(
|
||||
DatasourceManager.stream_node_events(
|
||||
node_id="nodeA",
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="online_drive",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
parameters_for_log={},
|
||||
datasource_info={"k": "v"},
|
||||
variable_pool=variable_pool,
|
||||
datasource_param=None,
|
||||
online_drive_request=types.SimpleNamespace(id="id", bucket="b"),
|
||||
)
|
||||
)
|
||||
|
||||
variable_pool.add.assert_called_once()
|
||||
assert variable_pool.add.call_args[0][0] == ["nodeA", "file"]
|
||||
assert variable_pool.add.call_args[0][1] == file_in
|
||||
|
||||
completed = events[-1]
|
||||
assert isinstance(completed, StreamCompletedEvent)
|
||||
assert completed.node_run_result.outputs["file"] == file_in
|
||||
assert completed.node_run_result.outputs["datasource_type"] == DatasourceProviderType.ONLINE_DRIVE
|
||||
|
||||
|
||||
def test_stream_node_events_skips_file_build_for_non_online_types(mocker):
|
||||
mocker.patch.object(DatasourceManager, "stream_online_results", return_value=_gen_messages_text_only("ignored"))
|
||||
|
||||
def _transformed(**_kwargs):
|
||||
yield DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE_LINK,
|
||||
message=DatasourceMessage.TextMessage(text="/files/datasources/tool_file_1.png"),
|
||||
meta={},
|
||||
)
|
||||
|
||||
mocker.patch(
|
||||
"core.datasource.datasource_manager.DatasourceFileMessageTransformer.transform_datasource_invoke_messages",
|
||||
side_effect=_transformed,
|
||||
)
|
||||
build_from_mapping = mocker.patch("core.datasource.datasource_manager.file_factory.build_from_mapping")
|
||||
|
||||
events = list(
|
||||
DatasourceManager.stream_node_events(
|
||||
node_id="nodeA",
|
||||
user_id="u1",
|
||||
datasource_name="ds",
|
||||
datasource_type="website_crawl",
|
||||
provider_id="p/x",
|
||||
tenant_id="t1",
|
||||
provider="prov",
|
||||
plugin_id="plug",
|
||||
credential_id="",
|
||||
parameters_for_log={},
|
||||
datasource_info={},
|
||||
variable_pool=mocker.Mock(),
|
||||
datasource_param=None,
|
||||
online_drive_request=None,
|
||||
)
|
||||
)
|
||||
|
||||
build_from_mapping.assert_not_called()
|
||||
assert isinstance(events[-1], StreamCompletedEvent)
|
||||
assert events[-1].node_run_result.outputs["file"] is None
|
||||
|
||||
|
||||
def test_get_upload_file_by_id_builds_file(mocker):
|
||||
# fake UploadFile row
|
||||
fake_row = types.SimpleNamespace(
|
||||
@@ -664,27 +133,3 @@ def test_get_upload_file_by_id_builds_file(mocker):
|
||||
f = DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1")
|
||||
assert f.related_id == "fid"
|
||||
assert f.extension == ".txt"
|
||||
|
||||
|
||||
def test_get_upload_file_by_id_raises_when_missing(mocker):
|
||||
class _Q:
|
||||
def where(self, *_args, **_kwargs):
|
||||
return self
|
||||
|
||||
def first(self):
|
||||
return None
|
||||
|
||||
class _S:
|
||||
def __enter__(self):
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
return False
|
||||
|
||||
def query(self, *_):
|
||||
return _Q()
|
||||
|
||||
mocker.patch("core.datasource.datasource_manager.session_factory.create_session", return_value=_S())
|
||||
|
||||
with pytest.raises(ValueError, match="UploadFile not found for file_id=fid, tenant_id=t1"):
|
||||
DatasourceManager.get_upload_file_by_id(file_id="fid", tenant_id="t1")
|
||||
|
||||
@@ -1,64 +0,0 @@
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceInvokeMeta
|
||||
from core.datasource.errors import (
|
||||
DatasourceApiSchemaError,
|
||||
DatasourceEngineInvokeError,
|
||||
DatasourceInvokeError,
|
||||
DatasourceNotFoundError,
|
||||
DatasourceNotSupportedError,
|
||||
DatasourceParameterValidationError,
|
||||
DatasourceProviderCredentialValidationError,
|
||||
DatasourceProviderNotFoundError,
|
||||
)
|
||||
|
||||
|
||||
class TestErrors:
|
||||
def test_datasource_provider_not_found_error(self):
|
||||
error = DatasourceProviderNotFoundError("Provider not found")
|
||||
assert str(error) == "Provider not found"
|
||||
assert isinstance(error, ValueError)
|
||||
|
||||
def test_datasource_not_found_error(self):
|
||||
error = DatasourceNotFoundError("Datasource not found")
|
||||
assert str(error) == "Datasource not found"
|
||||
assert isinstance(error, ValueError)
|
||||
|
||||
def test_datasource_parameter_validation_error(self):
|
||||
error = DatasourceParameterValidationError("Validation failed")
|
||||
assert str(error) == "Validation failed"
|
||||
assert isinstance(error, ValueError)
|
||||
|
||||
def test_datasource_provider_credential_validation_error(self):
|
||||
error = DatasourceProviderCredentialValidationError("Credential validation failed")
|
||||
assert str(error) == "Credential validation failed"
|
||||
assert isinstance(error, ValueError)
|
||||
|
||||
def test_datasource_not_supported_error(self):
|
||||
error = DatasourceNotSupportedError("Not supported")
|
||||
assert str(error) == "Not supported"
|
||||
assert isinstance(error, ValueError)
|
||||
|
||||
def test_datasource_invoke_error(self):
|
||||
error = DatasourceInvokeError("Invoke error")
|
||||
assert str(error) == "Invoke error"
|
||||
assert isinstance(error, ValueError)
|
||||
|
||||
def test_datasource_api_schema_error(self):
|
||||
error = DatasourceApiSchemaError("API schema error")
|
||||
assert str(error) == "API schema error"
|
||||
assert isinstance(error, ValueError)
|
||||
|
||||
def test_datasource_engine_invoke_error(self):
|
||||
mock_meta = MagicMock(spec=DatasourceInvokeMeta)
|
||||
error = DatasourceEngineInvokeError(meta=mock_meta)
|
||||
assert error.meta == mock_meta
|
||||
assert isinstance(error, Exception)
|
||||
|
||||
def test_datasource_engine_invoke_error_init(self):
|
||||
# Test initialization with meta
|
||||
meta = DatasourceInvokeMeta(time_cost=1.5, error="Engine failed")
|
||||
error = DatasourceEngineInvokeError(meta=meta)
|
||||
assert error.meta == meta
|
||||
assert error.meta.time_cost == 1.5
|
||||
assert error.meta.error == "Engine failed"
|
||||
@@ -1,337 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.datasource.entities.datasource_entities import DatasourceMessage
|
||||
from core.datasource.utils.message_transformer import DatasourceFileMessageTransformer
|
||||
from dify_graph.file import File
|
||||
from dify_graph.file.enums import FileTransferMethod, FileType
|
||||
from models.tools import ToolFile
|
||||
|
||||
|
||||
class TestDatasourceFileMessageTransformer:
|
||||
def test_transform_text_and_link_messages(self):
|
||||
# Setup
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.TEXT, message=DatasourceMessage.TextMessage(text="hello")
|
||||
),
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.LINK,
|
||||
message=DatasourceMessage.TextMessage(text="https://example.com"),
|
||||
),
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 2
|
||||
assert result[0].type == DatasourceMessage.MessageType.TEXT
|
||||
assert result[0].message.text == "hello"
|
||||
assert result[1].type == DatasourceMessage.MessageType.LINK
|
||||
assert result[1].message.text == "https://example.com"
|
||||
|
||||
@patch("core.datasource.utils.message_transformer.ToolFileManager")
|
||||
@patch("core.datasource.utils.message_transformer.guess_extension")
|
||||
def test_transform_image_message_success(self, mock_guess_ext, mock_tool_file_manager_cls):
|
||||
# Setup
|
||||
mock_manager = mock_tool_file_manager_cls.return_value
|
||||
mock_tool_file = MagicMock(spec=ToolFile)
|
||||
mock_tool_file.id = "file_id_123"
|
||||
mock_tool_file.mimetype = "image/png"
|
||||
mock_manager.create_file_by_url.return_value = mock_tool_file
|
||||
mock_guess_ext.return_value = ".png"
|
||||
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE,
|
||||
message=DatasourceMessage.TextMessage(text="https://example.com/image.png"),
|
||||
meta={"some": "meta"},
|
||||
)
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1", conversation_id="conv1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0].type == DatasourceMessage.MessageType.IMAGE_LINK
|
||||
assert result[0].message.text == "/files/datasources/file_id_123.png"
|
||||
assert result[0].meta == {"some": "meta"}
|
||||
mock_manager.create_file_by_url.assert_called_once_with(
|
||||
user_id="user1", tenant_id="tenant1", file_url="https://example.com/image.png", conversation_id="conv1"
|
||||
)
|
||||
|
||||
@patch("core.datasource.utils.message_transformer.ToolFileManager")
|
||||
def test_transform_image_message_failure(self, mock_tool_file_manager_cls):
|
||||
# Setup
|
||||
mock_manager = mock_tool_file_manager_cls.return_value
|
||||
mock_manager.create_file_by_url.side_effect = Exception("Download failed")
|
||||
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE,
|
||||
message=DatasourceMessage.TextMessage(text="https://example.com/image.png"),
|
||||
)
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0].type == DatasourceMessage.MessageType.TEXT
|
||||
assert "Failed to download image" in result[0].message.text
|
||||
assert "Download failed" in result[0].message.text
|
||||
|
||||
@patch("core.datasource.utils.message_transformer.ToolFileManager")
|
||||
@patch("core.datasource.utils.message_transformer.guess_extension")
|
||||
def test_transform_blob_message_image(self, mock_guess_ext, mock_tool_file_manager_cls):
|
||||
# Setup
|
||||
mock_manager = mock_tool_file_manager_cls.return_value
|
||||
mock_tool_file = MagicMock(spec=ToolFile)
|
||||
mock_tool_file.id = "blob_id_456"
|
||||
mock_tool_file.mimetype = "image/jpeg"
|
||||
mock_manager.create_file_by_raw.return_value = mock_tool_file
|
||||
mock_guess_ext.return_value = ".jpg"
|
||||
|
||||
blob_data = b"fake-image-bits"
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.BLOB,
|
||||
message=DatasourceMessage.BlobMessage(blob=blob_data),
|
||||
meta={"mime_type": "image/jpeg", "file_name": "test.jpg"},
|
||||
)
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0].type == DatasourceMessage.MessageType.IMAGE_LINK
|
||||
assert result[0].message.text == "/files/datasources/blob_id_456.jpg"
|
||||
mock_manager.create_file_by_raw.assert_called_once()
|
||||
|
||||
@patch("core.datasource.utils.message_transformer.ToolFileManager")
|
||||
@patch("core.datasource.utils.message_transformer.guess_extension")
|
||||
@patch("core.datasource.utils.message_transformer.guess_type")
|
||||
def test_transform_blob_message_binary_guess_mimetype(
|
||||
self, mock_guess_type, mock_guess_ext, mock_tool_file_manager_cls
|
||||
):
|
||||
# Setup
|
||||
mock_manager = mock_tool_file_manager_cls.return_value
|
||||
mock_tool_file = MagicMock(spec=ToolFile)
|
||||
mock_tool_file.id = "blob_id_789"
|
||||
mock_tool_file.mimetype = "application/pdf"
|
||||
mock_manager.create_file_by_raw.return_value = mock_tool_file
|
||||
mock_guess_type.return_value = ("application/pdf", None)
|
||||
mock_guess_ext.return_value = ".pdf"
|
||||
|
||||
blob_data = b"fake-pdf-bits"
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.BLOB,
|
||||
message=DatasourceMessage.BlobMessage(blob=blob_data),
|
||||
meta={"file_name": "test.pdf"},
|
||||
)
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0].type == DatasourceMessage.MessageType.BINARY_LINK
|
||||
assert result[0].message.text == "/files/datasources/blob_id_789.pdf"
|
||||
|
||||
def test_transform_blob_message_invalid_type(self):
|
||||
# Setup
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.BLOB, message=DatasourceMessage.TextMessage(text="not a blob")
|
||||
)
|
||||
]
|
||||
|
||||
# Execute & Verify
|
||||
with pytest.raises(ValueError, match="unexpected message type"):
|
||||
list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
def test_transform_file_tool_file_image(self):
|
||||
# Setup
|
||||
mock_file = MagicMock(spec=File)
|
||||
mock_file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||
mock_file.related_id = "related_123"
|
||||
mock_file.extension = ".png"
|
||||
mock_file.type = FileType.IMAGE
|
||||
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.FILE,
|
||||
message=DatasourceMessage.TextMessage(text="ignored"),
|
||||
meta={"file": mock_file},
|
||||
)
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0].type == DatasourceMessage.MessageType.IMAGE_LINK
|
||||
assert result[0].message.text == "/files/datasources/related_123.png"
|
||||
|
||||
def test_transform_file_tool_file_binary(self):
|
||||
# Setup
|
||||
mock_file = MagicMock(spec=File)
|
||||
mock_file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||
mock_file.related_id = "related_456"
|
||||
mock_file.extension = ".txt"
|
||||
mock_file.type = FileType.DOCUMENT
|
||||
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.FILE,
|
||||
message=DatasourceMessage.TextMessage(text="ignored"),
|
||||
meta={"file": mock_file},
|
||||
)
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0].type == DatasourceMessage.MessageType.LINK
|
||||
assert result[0].message.text == "/files/datasources/related_456.txt"
|
||||
|
||||
def test_transform_file_other_transfer_method(self):
|
||||
# Setup
|
||||
mock_file = MagicMock(spec=File)
|
||||
mock_file.transfer_method = FileTransferMethod.REMOTE_URL
|
||||
|
||||
msg = DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.FILE,
|
||||
message=DatasourceMessage.TextMessage(text="remote image"),
|
||||
meta={"file": mock_file},
|
||||
)
|
||||
messages = [msg]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0] == msg
|
||||
|
||||
def test_transform_other_message_type(self):
|
||||
# JSON type is yielded by the default 'else' block or the 'yield message' at the end
|
||||
msg = DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.JSON, message=DatasourceMessage.JsonMessage(json_object={"k": "v"})
|
||||
)
|
||||
messages = [msg]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify
|
||||
assert len(result) == 1
|
||||
assert result[0] == msg
|
||||
|
||||
def test_get_datasource_file_url(self):
|
||||
# Test with extension
|
||||
url = DatasourceFileMessageTransformer.get_datasource_file_url("file1", ".jpg")
|
||||
assert url == "/files/datasources/file1.jpg"
|
||||
|
||||
# Test without extension
|
||||
url = DatasourceFileMessageTransformer.get_datasource_file_url("file2", None)
|
||||
assert url == "/files/datasources/file2.bin"
|
||||
|
||||
def test_transform_blob_message_no_meta_filename(self):
|
||||
# This tests line 70 where filename might be None
|
||||
with patch("core.datasource.utils.message_transformer.ToolFileManager") as mock_tool_file_manager_cls:
|
||||
mock_manager = mock_tool_file_manager_cls.return_value
|
||||
mock_tool_file = MagicMock(spec=ToolFile)
|
||||
mock_tool_file.id = "blob_id_no_name"
|
||||
mock_tool_file.mimetype = "application/octet-stream"
|
||||
mock_manager.create_file_by_raw.return_value = mock_tool_file
|
||||
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.BLOB,
|
||||
message=DatasourceMessage.BlobMessage(blob=b"data"),
|
||||
meta={}, # No mime_type, no file_name
|
||||
)
|
||||
]
|
||||
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].type == DatasourceMessage.MessageType.BINARY_LINK
|
||||
assert result[0].message.text == "/files/datasources/blob_id_no_name.bin"
|
||||
|
||||
@patch("core.datasource.utils.message_transformer.ToolFileManager")
|
||||
def test_transform_image_message_not_text_message(self, mock_tool_file_manager_cls):
|
||||
# This tests line 24-26 where it checks if message is instance of TextMessage
|
||||
messages = [
|
||||
DatasourceMessage(
|
||||
type=DatasourceMessage.MessageType.IMAGE, message=DatasourceMessage.BlobMessage(blob=b"not-text")
|
||||
)
|
||||
]
|
||||
|
||||
# Execute
|
||||
result = list(
|
||||
DatasourceFileMessageTransformer.transform_datasource_invoke_messages(
|
||||
messages=iter(messages), user_id="user1", tenant_id="tenant1"
|
||||
)
|
||||
)
|
||||
|
||||
# Verify - should yield unchanged if it's not a TextMessage
|
||||
assert len(result) == 1
|
||||
assert result[0].type == DatasourceMessage.MessageType.IMAGE
|
||||
assert isinstance(result[0].message, DatasourceMessage.BlobMessage)
|
||||
@@ -1,101 +0,0 @@
|
||||
from collections.abc import Generator
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceEntity,
|
||||
DatasourceProviderType,
|
||||
WebsiteCrawlMessage,
|
||||
)
|
||||
from core.datasource.website_crawl.website_crawl_plugin import WebsiteCrawlDatasourcePlugin
|
||||
|
||||
|
||||
class TestWebsiteCrawlDatasourcePlugin:
|
||||
@pytest.fixture
|
||||
def mock_entity(self):
|
||||
entity = MagicMock(spec=DatasourceEntity)
|
||||
entity.identity = MagicMock()
|
||||
entity.identity.provider = "test-provider"
|
||||
entity.identity.name = "test-name"
|
||||
return entity
|
||||
|
||||
@pytest.fixture
|
||||
def mock_runtime(self):
|
||||
runtime = MagicMock(spec=DatasourceRuntime)
|
||||
runtime.credentials = {"api_key": "test-key"}
|
||||
return runtime
|
||||
|
||||
def test_init(self, mock_entity, mock_runtime):
|
||||
# Arrange
|
||||
tenant_id = "test-tenant-id"
|
||||
icon = "test-icon"
|
||||
plugin_unique_identifier = "test-plugin-id"
|
||||
|
||||
# Act
|
||||
plugin = WebsiteCrawlDatasourcePlugin(
|
||||
entity=mock_entity,
|
||||
runtime=mock_runtime,
|
||||
tenant_id=tenant_id,
|
||||
icon=icon,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert plugin.tenant_id == tenant_id
|
||||
assert plugin.plugin_unique_identifier == plugin_unique_identifier
|
||||
assert plugin.entity == mock_entity
|
||||
assert plugin.runtime == mock_runtime
|
||||
assert plugin.icon == icon
|
||||
|
||||
def test_datasource_provider_type(self, mock_entity, mock_runtime):
|
||||
# Arrange
|
||||
plugin = WebsiteCrawlDatasourcePlugin(
|
||||
entity=mock_entity, runtime=mock_runtime, tenant_id="test", icon="test", plugin_unique_identifier="test"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
assert plugin.datasource_provider_type() == DatasourceProviderType.WEBSITE_CRAWL
|
||||
|
||||
def test_get_website_crawl(self, mock_entity, mock_runtime):
|
||||
# Arrange
|
||||
plugin = WebsiteCrawlDatasourcePlugin(
|
||||
entity=mock_entity,
|
||||
runtime=mock_runtime,
|
||||
tenant_id="test-tenant-id",
|
||||
icon="test-icon",
|
||||
plugin_unique_identifier="test-plugin-id",
|
||||
)
|
||||
|
||||
user_id = "test-user-id"
|
||||
datasource_parameters = {"url": "https://example.com"}
|
||||
provider_type = "firecrawl"
|
||||
|
||||
mock_message = MagicMock(spec=WebsiteCrawlMessage)
|
||||
|
||||
# Mock PluginDatasourceManager
|
||||
with patch("core.datasource.website_crawl.website_crawl_plugin.PluginDatasourceManager") as mock_manager_class:
|
||||
mock_manager = mock_manager_class.return_value
|
||||
mock_manager.get_website_crawl.return_value = (msg for msg in [mock_message])
|
||||
|
||||
# Act
|
||||
result = plugin.get_website_crawl(
|
||||
user_id=user_id, datasource_parameters=datasource_parameters, provider_type=provider_type
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, Generator)
|
||||
messages = list(result)
|
||||
assert len(messages) == 1
|
||||
assert messages[0] == mock_message
|
||||
|
||||
mock_manager.get_website_crawl.assert_called_once_with(
|
||||
tenant_id="test-tenant-id",
|
||||
user_id=user_id,
|
||||
datasource_provider="test-provider",
|
||||
datasource_name="test-name",
|
||||
credentials={"api_key": "test-key"},
|
||||
datasource_parameters=datasource_parameters,
|
||||
provider_type=provider_type,
|
||||
)
|
||||
@@ -1,95 +0,0 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.datasource.__base.datasource_runtime import DatasourceRuntime
|
||||
from core.datasource.entities.datasource_entities import (
|
||||
DatasourceProviderEntityWithPlugin,
|
||||
DatasourceProviderType,
|
||||
)
|
||||
from core.datasource.website_crawl.website_crawl_provider import WebsiteCrawlDatasourcePluginProviderController
|
||||
|
||||
|
||||
class TestWebsiteCrawlDatasourcePluginProviderController:
|
||||
@pytest.fixture
|
||||
def mock_entity(self):
|
||||
entity = MagicMock(spec=DatasourceProviderEntityWithPlugin)
|
||||
entity.datasources = []
|
||||
entity.identity = MagicMock()
|
||||
entity.identity.icon = "test-icon"
|
||||
return entity
|
||||
|
||||
def test_init(self, mock_entity):
|
||||
# Arrange
|
||||
plugin_id = "test-plugin-id"
|
||||
plugin_unique_identifier = "test-unique-id"
|
||||
tenant_id = "test-tenant-id"
|
||||
|
||||
# Act
|
||||
controller = WebsiteCrawlDatasourcePluginProviderController(
|
||||
entity=mock_entity,
|
||||
plugin_id=plugin_id,
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert controller.entity == mock_entity
|
||||
assert controller.plugin_id == plugin_id
|
||||
assert controller.plugin_unique_identifier == plugin_unique_identifier
|
||||
assert controller.tenant_id == tenant_id
|
||||
|
||||
def test_provider_type(self, mock_entity):
|
||||
# Arrange
|
||||
controller = WebsiteCrawlDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="test", plugin_unique_identifier="test", tenant_id="test"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
assert controller.provider_type == DatasourceProviderType.WEBSITE_CRAWL
|
||||
|
||||
def test_get_datasource_success(self, mock_entity):
|
||||
# Arrange
|
||||
datasource_name = "test-datasource"
|
||||
tenant_id = "test-tenant-id"
|
||||
plugin_unique_identifier = "test-unique-id"
|
||||
|
||||
mock_datasource_entity = MagicMock()
|
||||
mock_datasource_entity.identity = MagicMock()
|
||||
mock_datasource_entity.identity.name = datasource_name
|
||||
mock_entity.datasources = [mock_datasource_entity]
|
||||
|
||||
controller = WebsiteCrawlDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="test", plugin_unique_identifier=plugin_unique_identifier, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
# Act
|
||||
with patch(
|
||||
"core.datasource.website_crawl.website_crawl_provider.WebsiteCrawlDatasourcePlugin"
|
||||
) as mock_plugin_class:
|
||||
mock_plugin_instance = mock_plugin_class.return_value
|
||||
result = controller.get_datasource(datasource_name)
|
||||
|
||||
# Assert
|
||||
assert result == mock_plugin_instance
|
||||
mock_plugin_class.assert_called_once()
|
||||
args, kwargs = mock_plugin_class.call_args
|
||||
assert kwargs["entity"] == mock_datasource_entity
|
||||
assert isinstance(kwargs["runtime"], DatasourceRuntime)
|
||||
assert kwargs["runtime"].tenant_id == tenant_id
|
||||
assert kwargs["tenant_id"] == tenant_id
|
||||
assert kwargs["icon"] == "test-icon"
|
||||
assert kwargs["plugin_unique_identifier"] == plugin_unique_identifier
|
||||
|
||||
def test_get_datasource_not_found(self, mock_entity):
|
||||
# Arrange
|
||||
datasource_name = "non-existent"
|
||||
mock_entity.datasources = []
|
||||
|
||||
controller = WebsiteCrawlDatasourcePluginProviderController(
|
||||
entity=mock_entity, plugin_id="test", plugin_unique_identifier="test", tenant_id="test"
|
||||
)
|
||||
|
||||
# Act & Assert
|
||||
with pytest.raises(ValueError, match=f"Datasource with name {datasource_name} not found"):
|
||||
controller.get_datasource(datasource_name)
|
||||
@@ -423,6 +423,15 @@ def test_table_to_markdown_and_parse_helpers(monkeypatch):
|
||||
markdown = extractor._table_to_markdown(table, {})
|
||||
assert markdown == "| H1 | H2 |\n| --- | --- |\n| A | B |"
|
||||
|
||||
class FakeRunElement:
|
||||
def __init__(self, blips):
|
||||
self._blips = blips
|
||||
|
||||
def xpath(self, pattern):
|
||||
if pattern == ".//a:blip":
|
||||
return self._blips
|
||||
return []
|
||||
|
||||
class FakeBlip:
|
||||
def __init__(self, image_id):
|
||||
self.image_id = image_id
|
||||
@@ -430,31 +439,11 @@ def test_table_to_markdown_and_parse_helpers(monkeypatch):
|
||||
def get(self, key):
|
||||
return self.image_id
|
||||
|
||||
class FakeRunChild:
|
||||
def __init__(self, blips, text=""):
|
||||
self._blips = blips
|
||||
self.text = text
|
||||
self.tag = qn("w:r")
|
||||
|
||||
def xpath(self, pattern):
|
||||
if pattern == ".//a:blip":
|
||||
return self._blips
|
||||
return []
|
||||
|
||||
class FakeRun:
|
||||
def __init__(self, element, paragraph):
|
||||
# Mirror the subset used by _parse_cell_paragraph
|
||||
self.element = element
|
||||
self.text = getattr(element, "text", "")
|
||||
|
||||
# Patch we.Run so our lightweight child objects work with the extractor
|
||||
monkeypatch.setattr(we, "Run", FakeRun)
|
||||
|
||||
image_part = object()
|
||||
paragraph = SimpleNamespace(
|
||||
_element=[
|
||||
FakeRunChild([FakeBlip(None), FakeBlip("ext"), FakeBlip("int")], text=""),
|
||||
FakeRunChild([], text="plain"),
|
||||
runs=[
|
||||
SimpleNamespace(element=FakeRunElement([FakeBlip(None), FakeBlip("ext"), FakeBlip("int")]), text=""),
|
||||
SimpleNamespace(element=FakeRunElement([]), text="plain"),
|
||||
],
|
||||
part=SimpleNamespace(
|
||||
rels={
|
||||
@@ -463,7 +452,6 @@ def test_table_to_markdown_and_parse_helpers(monkeypatch):
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
image_map = {"ext": "EXT-IMG", image_part: "INT-IMG"}
|
||||
assert extractor._parse_cell_paragraph(paragraph, image_map) == "EXT-IMGINT-IMGplain"
|
||||
|
||||
@@ -637,83 +625,3 @@ def test_parse_docx_covers_drawing_shapes_hyperlink_error_and_table_branch(monke
|
||||
assert "BrokenLink" in content
|
||||
assert "TABLE-MARKDOWN" in content
|
||||
logger_exception.assert_called_once()
|
||||
|
||||
|
||||
def test_parse_cell_paragraph_hyperlink_in_table_cell_http():
|
||||
doc = Document()
|
||||
table = doc.add_table(rows=1, cols=1)
|
||||
cell = table.cell(0, 0)
|
||||
p = cell.paragraphs[0]
|
||||
|
||||
# Build modern hyperlink inside table cell
|
||||
r_id = "rIdHttp1"
|
||||
hyperlink = OxmlElement("w:hyperlink")
|
||||
hyperlink.set(qn("r:id"), r_id)
|
||||
|
||||
run_elem = OxmlElement("w:r")
|
||||
t = OxmlElement("w:t")
|
||||
t.text = "Dify"
|
||||
run_elem.append(t)
|
||||
hyperlink.append(run_elem)
|
||||
p._p.append(hyperlink)
|
||||
|
||||
# Relationship for external http link
|
||||
doc.part.rels.add_relationship(
|
||||
"http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink",
|
||||
"https://dify.ai",
|
||||
r_id,
|
||||
is_external=True,
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp:
|
||||
doc.save(tmp.name)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
reopened = Document(tmp_path)
|
||||
para = reopened.tables[0].cell(0, 0).paragraphs[0]
|
||||
extractor = object.__new__(WordExtractor)
|
||||
out = extractor._parse_cell_paragraph(para, {})
|
||||
assert out == "[Dify](https://dify.ai)"
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
|
||||
|
||||
def test_parse_cell_paragraph_hyperlink_in_table_cell_mailto():
|
||||
doc = Document()
|
||||
table = doc.add_table(rows=1, cols=1)
|
||||
cell = table.cell(0, 0)
|
||||
p = cell.paragraphs[0]
|
||||
|
||||
r_id = "rIdMail1"
|
||||
hyperlink = OxmlElement("w:hyperlink")
|
||||
hyperlink.set(qn("r:id"), r_id)
|
||||
|
||||
run_elem = OxmlElement("w:r")
|
||||
t = OxmlElement("w:t")
|
||||
t.text = "john@test.com"
|
||||
run_elem.append(t)
|
||||
hyperlink.append(run_elem)
|
||||
p._p.append(hyperlink)
|
||||
|
||||
doc.part.rels.add_relationship(
|
||||
"http://schemas.openxmlformats.org/officeDocument/2006/relationships/hyperlink",
|
||||
"mailto:john@test.com",
|
||||
r_id,
|
||||
is_external=True,
|
||||
)
|
||||
|
||||
with tempfile.NamedTemporaryFile(suffix=".docx", delete=False) as tmp:
|
||||
doc.save(tmp.name)
|
||||
tmp_path = tmp.name
|
||||
|
||||
try:
|
||||
reopened = Document(tmp_path)
|
||||
para = reopened.tables[0].cell(0, 0).paragraphs[0]
|
||||
extractor = object.__new__(WordExtractor)
|
||||
out = extractor._parse_cell_paragraph(para, {})
|
||||
assert out == "[john@test.com](mailto:john@test.com)"
|
||||
finally:
|
||||
if os.path.exists(tmp_path):
|
||||
os.remove(tmp_path)
|
||||
|
||||
@@ -1,466 +0,0 @@
|
||||
from datetime import datetime
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services.api_token_service as api_token_service_module
|
||||
from services.api_token_service import ApiTokenCache, CachedApiToken
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_db_session():
|
||||
"""Fixture providing common DB session mocking for query_token_from_db tests."""
|
||||
fake_engine = MagicMock()
|
||||
|
||||
session = MagicMock()
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
session_context.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
patch.object(api_token_service_module, "db", new=SimpleNamespace(engine=fake_engine)),
|
||||
patch.object(api_token_service_module, "Session", return_value=session_context) as mock_session_class,
|
||||
patch.object(api_token_service_module.ApiTokenCache, "set") as mock_cache_set,
|
||||
patch.object(api_token_service_module, "record_token_usage") as mock_record_usage,
|
||||
):
|
||||
yield {
|
||||
"session": session,
|
||||
"mock_session_class": mock_session_class,
|
||||
"mock_cache_set": mock_cache_set,
|
||||
"mock_record_usage": mock_record_usage,
|
||||
"fake_engine": fake_engine,
|
||||
}
|
||||
|
||||
|
||||
class TestQueryTokenFromDb:
|
||||
def test_should_return_api_token_and_cache_when_token_exists(self, mock_db_session):
|
||||
"""Test DB lookup success path caches token and records usage."""
|
||||
# Arrange
|
||||
auth_token = "token-123"
|
||||
scope = "app"
|
||||
api_token = MagicMock()
|
||||
|
||||
mock_db_session["session"].scalar.return_value = api_token
|
||||
|
||||
# Act
|
||||
result = api_token_service_module.query_token_from_db(auth_token, scope)
|
||||
|
||||
# Assert
|
||||
assert result == api_token
|
||||
mock_db_session["mock_session_class"].assert_called_once_with(
|
||||
mock_db_session["fake_engine"], expire_on_commit=False
|
||||
)
|
||||
mock_db_session["mock_cache_set"].assert_called_once_with(auth_token, scope, api_token)
|
||||
mock_db_session["mock_record_usage"].assert_called_once_with(auth_token, scope)
|
||||
|
||||
def test_should_cache_null_and_raise_unauthorized_when_token_not_found(self, mock_db_session):
|
||||
"""Test DB lookup miss path caches null marker and raises Unauthorized."""
|
||||
# Arrange
|
||||
auth_token = "missing-token"
|
||||
scope = "app"
|
||||
|
||||
mock_db_session["session"].scalar.return_value = None
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(Unauthorized, match="Access token is invalid"):
|
||||
api_token_service_module.query_token_from_db(auth_token, scope)
|
||||
|
||||
mock_db_session["mock_cache_set"].assert_called_once_with(auth_token, scope, None)
|
||||
mock_db_session["mock_record_usage"].assert_not_called()
|
||||
|
||||
|
||||
class TestRecordTokenUsage:
|
||||
def test_should_write_active_key_with_iso_timestamp_and_ttl(self):
|
||||
"""Test record_token_usage writes usage timestamp with one-hour TTL."""
|
||||
# Arrange
|
||||
auth_token = "token-123"
|
||||
scope = "dataset"
|
||||
fixed_time = datetime(2026, 2, 24, 12, 0, 0)
|
||||
expected_key = ApiTokenCache.make_active_key(auth_token, scope)
|
||||
|
||||
with (
|
||||
patch.object(api_token_service_module, "naive_utc_now", return_value=fixed_time),
|
||||
patch.object(api_token_service_module, "redis_client") as mock_redis,
|
||||
):
|
||||
# Act
|
||||
api_token_service_module.record_token_usage(auth_token, scope)
|
||||
|
||||
# Assert
|
||||
mock_redis.set.assert_called_once_with(expected_key, fixed_time.isoformat(), ex=3600)
|
||||
|
||||
def test_should_not_raise_when_redis_write_fails(self):
|
||||
"""Test record_token_usage swallows Redis errors."""
|
||||
# Arrange
|
||||
with patch.object(api_token_service_module, "redis_client") as mock_redis:
|
||||
mock_redis.set.side_effect = Exception("redis unavailable")
|
||||
|
||||
# Act / Assert
|
||||
api_token_service_module.record_token_usage("token-123", "app")
|
||||
|
||||
|
||||
class TestFetchTokenWithSingleFlight:
|
||||
def test_should_return_cached_token_when_lock_acquired_and_cache_filled(self):
|
||||
"""Test single-flight returns cache when another request already populated it."""
|
||||
# Arrange
|
||||
auth_token = "token-123"
|
||||
scope = "app"
|
||||
cached_token = CachedApiToken(
|
||||
id="id-1",
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
type="app",
|
||||
token=auth_token,
|
||||
last_used_at=None,
|
||||
created_at=None,
|
||||
)
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
|
||||
with (
|
||||
patch.object(api_token_service_module, "redis_client") as mock_redis,
|
||||
patch.object(api_token_service_module.ApiTokenCache, "get", return_value=cached_token) as mock_cache_get,
|
||||
patch.object(api_token_service_module, "query_token_from_db") as mock_query_db,
|
||||
):
|
||||
mock_redis.lock.return_value = lock
|
||||
|
||||
# Act
|
||||
result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope)
|
||||
|
||||
# Assert
|
||||
assert result == cached_token
|
||||
mock_redis.lock.assert_called_once_with(
|
||||
f"api_token_query_lock:{scope}:{auth_token}",
|
||||
timeout=10,
|
||||
blocking_timeout=5,
|
||||
)
|
||||
lock.acquire.assert_called_once_with(blocking=True)
|
||||
lock.release.assert_called_once()
|
||||
mock_cache_get.assert_called_once_with(auth_token, scope)
|
||||
mock_query_db.assert_not_called()
|
||||
|
||||
def test_should_query_db_when_lock_acquired_and_cache_missed(self):
|
||||
"""Test single-flight queries DB when cache remains empty after lock acquisition."""
|
||||
# Arrange
|
||||
auth_token = "token-123"
|
||||
scope = "app"
|
||||
db_token = MagicMock()
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
|
||||
with (
|
||||
patch.object(api_token_service_module, "redis_client") as mock_redis,
|
||||
patch.object(api_token_service_module.ApiTokenCache, "get", return_value=None),
|
||||
patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db,
|
||||
):
|
||||
mock_redis.lock.return_value = lock
|
||||
|
||||
# Act
|
||||
result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope)
|
||||
|
||||
# Assert
|
||||
assert result == db_token
|
||||
mock_query_db.assert_called_once_with(auth_token, scope)
|
||||
lock.release.assert_called_once()
|
||||
|
||||
def test_should_query_db_directly_when_lock_not_acquired(self):
|
||||
"""Test lock timeout branch falls back to direct DB query."""
|
||||
# Arrange
|
||||
auth_token = "token-123"
|
||||
scope = "app"
|
||||
db_token = MagicMock()
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = False
|
||||
|
||||
with (
|
||||
patch.object(api_token_service_module, "redis_client") as mock_redis,
|
||||
patch.object(api_token_service_module.ApiTokenCache, "get") as mock_cache_get,
|
||||
patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db,
|
||||
):
|
||||
mock_redis.lock.return_value = lock
|
||||
|
||||
# Act
|
||||
result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope)
|
||||
|
||||
# Assert
|
||||
assert result == db_token
|
||||
mock_cache_get.assert_not_called()
|
||||
mock_query_db.assert_called_once_with(auth_token, scope)
|
||||
lock.release.assert_not_called()
|
||||
|
||||
def test_should_reraise_unauthorized_from_db_query(self):
|
||||
"""Test Unauthorized from DB query is propagated unchanged."""
|
||||
# Arrange
|
||||
auth_token = "token-123"
|
||||
scope = "app"
|
||||
lock = MagicMock()
|
||||
lock.acquire.return_value = True
|
||||
|
||||
with (
|
||||
patch.object(api_token_service_module, "redis_client") as mock_redis,
|
||||
patch.object(api_token_service_module.ApiTokenCache, "get", return_value=None),
|
||||
patch.object(
|
||||
api_token_service_module,
|
||||
"query_token_from_db",
|
||||
side_effect=Unauthorized("Access token is invalid"),
|
||||
),
|
||||
):
|
||||
mock_redis.lock.return_value = lock
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(Unauthorized, match="Access token is invalid"):
|
||||
api_token_service_module.fetch_token_with_single_flight(auth_token, scope)
|
||||
|
||||
lock.release.assert_called_once()
|
||||
|
||||
def test_should_fallback_to_db_query_when_lock_raises_exception(self):
|
||||
"""Test Redis lock errors fall back to direct DB query."""
|
||||
# Arrange
|
||||
auth_token = "token-123"
|
||||
scope = "app"
|
||||
db_token = MagicMock()
|
||||
|
||||
lock = MagicMock()
|
||||
lock.acquire.side_effect = RuntimeError("redis lock error")
|
||||
|
||||
with (
|
||||
patch.object(api_token_service_module, "redis_client") as mock_redis,
|
||||
patch.object(api_token_service_module, "query_token_from_db", return_value=db_token) as mock_query_db,
|
||||
):
|
||||
mock_redis.lock.return_value = lock
|
||||
|
||||
# Act
|
||||
result = api_token_service_module.fetch_token_with_single_flight(auth_token, scope)
|
||||
|
||||
# Assert
|
||||
assert result == db_token
|
||||
mock_query_db.assert_called_once_with(auth_token, scope)
|
||||
|
||||
|
||||
class TestApiTokenCacheTenantBranches:
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_delete_with_scope_should_remove_from_tenant_index_when_tenant_found(self, mock_redis):
|
||||
"""Test scoped delete removes cache key and tenant index membership."""
|
||||
# Arrange
|
||||
token = "token-123"
|
||||
scope = "app"
|
||||
cache_key = ApiTokenCache._make_cache_key(token, scope)
|
||||
cached_token = CachedApiToken(
|
||||
id="id-1",
|
||||
app_id="app-1",
|
||||
tenant_id="tenant-1",
|
||||
type="app",
|
||||
token=token,
|
||||
last_used_at=None,
|
||||
created_at=None,
|
||||
)
|
||||
mock_redis.get.return_value = cached_token.model_dump_json().encode("utf-8")
|
||||
|
||||
with patch.object(ApiTokenCache, "_remove_from_tenant_index") as mock_remove_index:
|
||||
# Act
|
||||
result = ApiTokenCache.delete(token, scope)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_redis.delete.assert_called_once_with(cache_key)
|
||||
mock_remove_index.assert_called_once_with("tenant-1", cache_key)
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_invalidate_by_tenant_should_delete_all_indexed_cache_keys(self, mock_redis):
|
||||
"""Test tenant invalidation deletes indexed cache entries and index key."""
|
||||
# Arrange
|
||||
tenant_id = "tenant-1"
|
||||
index_key = ApiTokenCache._make_tenant_index_key(tenant_id)
|
||||
mock_redis.smembers.return_value = {
|
||||
b"api_token:app:token-1",
|
||||
b"api_token:any:token-2",
|
||||
}
|
||||
|
||||
# Act
|
||||
result = ApiTokenCache.invalidate_by_tenant(tenant_id)
|
||||
|
||||
# Assert
|
||||
assert result is True
|
||||
mock_redis.smembers.assert_called_once_with(index_key)
|
||||
mock_redis.delete.assert_any_call("api_token:app:token-1")
|
||||
mock_redis.delete.assert_any_call("api_token:any:token-2")
|
||||
mock_redis.delete.assert_any_call(index_key)
|
||||
|
||||
|
||||
class TestApiTokenCacheCoreBranches:
|
||||
def test_cached_api_token_repr_should_include_id_and_type(self):
|
||||
"""Test CachedApiToken __repr__ includes key identity fields."""
|
||||
token = CachedApiToken(
|
||||
id="id-123",
|
||||
app_id="app-123",
|
||||
tenant_id="tenant-123",
|
||||
type="app",
|
||||
token="token-123",
|
||||
last_used_at=None,
|
||||
created_at=None,
|
||||
)
|
||||
|
||||
assert repr(token) == "<CachedApiToken id=id-123 type=app>"
|
||||
|
||||
def test_serialize_token_should_handle_cached_api_token_instances(self):
|
||||
"""Test serialization path when input is already a CachedApiToken."""
|
||||
token = CachedApiToken(
|
||||
id="id-123",
|
||||
app_id="app-123",
|
||||
tenant_id="tenant-123",
|
||||
type="app",
|
||||
token="token-123",
|
||||
last_used_at=None,
|
||||
created_at=None,
|
||||
)
|
||||
|
||||
serialized = ApiTokenCache._serialize_token(token)
|
||||
|
||||
assert isinstance(serialized, bytes)
|
||||
assert b'"id":"id-123"' in serialized
|
||||
assert b'"token":"token-123"' in serialized
|
||||
|
||||
def test_deserialize_token_should_return_none_for_null_markers(self):
|
||||
"""Test null cache marker deserializes to None."""
|
||||
assert ApiTokenCache._deserialize_token("null") is None
|
||||
assert ApiTokenCache._deserialize_token(b"null") is None
|
||||
|
||||
def test_deserialize_token_should_return_none_for_invalid_payload(self):
|
||||
"""Test invalid serialized payload returns None."""
|
||||
assert ApiTokenCache._deserialize_token("not-json") is None
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_get_should_return_none_on_cache_miss(self, mock_redis):
|
||||
"""Test cache miss branch in ApiTokenCache.get."""
|
||||
mock_redis.get.return_value = None
|
||||
|
||||
result = ApiTokenCache.get("token-123", "app")
|
||||
|
||||
assert result is None
|
||||
mock_redis.get.assert_called_once_with("api_token:app:token-123")
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_get_should_deserialize_cached_payload_on_cache_hit(self, mock_redis):
|
||||
"""Test cache hit branch in ApiTokenCache.get."""
|
||||
token = CachedApiToken(
|
||||
id="id-123",
|
||||
app_id="app-123",
|
||||
tenant_id="tenant-123",
|
||||
type="app",
|
||||
token="token-123",
|
||||
last_used_at=None,
|
||||
created_at=None,
|
||||
)
|
||||
mock_redis.get.return_value = token.model_dump_json().encode("utf-8")
|
||||
|
||||
result = ApiTokenCache.get("token-123", "app")
|
||||
|
||||
assert isinstance(result, CachedApiToken)
|
||||
assert result.id == "id-123"
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_add_to_tenant_index_should_skip_when_tenant_id_missing(self, mock_redis):
|
||||
"""Test tenant index update exits early for missing tenant id."""
|
||||
ApiTokenCache._add_to_tenant_index(None, "api_token:app:token-123")
|
||||
|
||||
mock_redis.sadd.assert_not_called()
|
||||
mock_redis.expire.assert_not_called()
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_add_to_tenant_index_should_swallow_index_update_errors(self, mock_redis):
|
||||
"""Test tenant index update handles Redis write errors gracefully."""
|
||||
mock_redis.sadd.side_effect = Exception("redis down")
|
||||
|
||||
ApiTokenCache._add_to_tenant_index("tenant-123", "api_token:app:token-123")
|
||||
|
||||
mock_redis.sadd.assert_called_once()
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_remove_from_tenant_index_should_skip_when_tenant_id_missing(self, mock_redis):
|
||||
"""Test tenant index removal exits early for missing tenant id."""
|
||||
ApiTokenCache._remove_from_tenant_index(None, "api_token:app:token-123")
|
||||
|
||||
mock_redis.srem.assert_not_called()
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_remove_from_tenant_index_should_swallow_redis_errors(self, mock_redis):
|
||||
"""Test tenant index removal handles Redis errors gracefully."""
|
||||
mock_redis.srem.side_effect = Exception("redis down")
|
||||
|
||||
ApiTokenCache._remove_from_tenant_index("tenant-123", "api_token:app:token-123")
|
||||
|
||||
mock_redis.srem.assert_called_once()
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_set_should_return_false_when_cache_write_raises_exception(self, mock_redis):
|
||||
"""Test set returns False when Redis setex fails."""
|
||||
mock_redis.setex.side_effect = Exception("redis write failed")
|
||||
api_token = MagicMock()
|
||||
api_token.id = "id-123"
|
||||
api_token.app_id = "app-123"
|
||||
api_token.tenant_id = "tenant-123"
|
||||
api_token.type = "app"
|
||||
api_token.token = "token-123"
|
||||
api_token.last_used_at = None
|
||||
api_token.created_at = None
|
||||
|
||||
result = ApiTokenCache.set("token-123", "app", api_token)
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_delete_without_scope_should_return_false_when_scan_fails(self, mock_redis):
|
||||
"""Test delete(scope=None) returns False when scan_iter raises."""
|
||||
mock_redis.scan_iter.side_effect = Exception("scan failed")
|
||||
|
||||
result = ApiTokenCache.delete("token-123", None)
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_delete_with_scope_should_continue_when_tenant_lookup_raises(self, mock_redis):
|
||||
"""Test scoped delete still succeeds when tenant lookup from cache fails."""
|
||||
token = "token-123"
|
||||
scope = "app"
|
||||
cache_key = ApiTokenCache._make_cache_key(token, scope)
|
||||
mock_redis.get.side_effect = Exception("get failed")
|
||||
|
||||
result = ApiTokenCache.delete(token, scope)
|
||||
|
||||
assert result is True
|
||||
mock_redis.delete.assert_called_once_with(cache_key)
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_delete_with_scope_should_return_false_when_delete_raises(self, mock_redis):
|
||||
"""Test scoped delete returns False when delete operation fails."""
|
||||
token = "token-123"
|
||||
scope = "app"
|
||||
mock_redis.get.return_value = None
|
||||
mock_redis.delete.side_effect = Exception("delete failed")
|
||||
|
||||
result = ApiTokenCache.delete(token, scope)
|
||||
|
||||
assert result is False
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_invalidate_by_tenant_should_return_true_when_index_not_found(self, mock_redis):
|
||||
"""Test tenant invalidation returns True when tenant index is empty."""
|
||||
mock_redis.smembers.return_value = set()
|
||||
|
||||
result = ApiTokenCache.invalidate_by_tenant("tenant-123")
|
||||
|
||||
assert result is True
|
||||
mock_redis.delete.assert_not_called()
|
||||
|
||||
@patch("services.api_token_service.redis_client")
|
||||
def test_invalidate_by_tenant_should_return_false_when_redis_raises(self, mock_redis):
|
||||
"""Test tenant invalidation returns False when Redis operation fails."""
|
||||
mock_redis.smembers.side_effect = Exception("redis failed")
|
||||
|
||||
result = ApiTokenCache.invalidate_by_tenant("tenant-123")
|
||||
|
||||
assert result is False
|
||||
@@ -1,88 +0,0 @@
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from models.model import AppMode
|
||||
from services.app_model_config_service import AppModelConfigService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_config_managers():
|
||||
"""Fixture that patches all app config manager validate methods.
|
||||
|
||||
Returns a dictionary containing the mocked config_validate methods for each manager.
|
||||
"""
|
||||
with (
|
||||
patch("services.app_model_config_service.ChatAppConfigManager.config_validate") as mock_chat_validate,
|
||||
patch("services.app_model_config_service.AgentChatAppConfigManager.config_validate") as mock_agent_validate,
|
||||
patch(
|
||||
"services.app_model_config_service.CompletionAppConfigManager.config_validate"
|
||||
) as mock_completion_validate,
|
||||
):
|
||||
mock_chat_validate.return_value = {"manager": "chat"}
|
||||
mock_agent_validate.return_value = {"manager": "agent"}
|
||||
mock_completion_validate.return_value = {"manager": "completion"}
|
||||
|
||||
yield {
|
||||
"chat": mock_chat_validate,
|
||||
"agent": mock_agent_validate,
|
||||
"completion": mock_completion_validate,
|
||||
}
|
||||
|
||||
|
||||
class TestAppModelConfigService:
|
||||
@pytest.mark.parametrize(
|
||||
("app_mode", "selected_manager"),
|
||||
[
|
||||
(AppMode.CHAT, "chat"),
|
||||
(AppMode.AGENT_CHAT, "agent"),
|
||||
(AppMode.COMPLETION, "completion"),
|
||||
],
|
||||
)
|
||||
def test_should_route_validation_to_correct_manager_based_on_app_mode(
|
||||
self, app_mode, selected_manager, mock_config_managers
|
||||
):
|
||||
"""Test configuration validation is delegated to the expected manager for each supported app mode."""
|
||||
tenant_id = "tenant-123"
|
||||
config = {"temperature": 0.5}
|
||||
|
||||
mock_chat_validate = mock_config_managers["chat"]
|
||||
mock_agent_validate = mock_config_managers["agent"]
|
||||
mock_completion_validate = mock_config_managers["completion"]
|
||||
|
||||
result = AppModelConfigService.validate_configuration(tenant_id=tenant_id, config=config, app_mode=app_mode)
|
||||
|
||||
assert result == {"manager": selected_manager}
|
||||
|
||||
if selected_manager == "chat":
|
||||
mock_chat_validate.assert_called_once_with(tenant_id, config)
|
||||
mock_agent_validate.assert_not_called()
|
||||
mock_completion_validate.assert_not_called()
|
||||
elif selected_manager == "agent":
|
||||
mock_agent_validate.assert_called_once_with(tenant_id, config)
|
||||
mock_chat_validate.assert_not_called()
|
||||
mock_completion_validate.assert_not_called()
|
||||
else:
|
||||
mock_completion_validate.assert_called_once_with(tenant_id, config)
|
||||
mock_chat_validate.assert_not_called()
|
||||
mock_agent_validate.assert_not_called()
|
||||
|
||||
def test_should_raise_value_error_when_app_mode_is_not_supported(self, mock_config_managers):
|
||||
"""Test unsupported app modes raise ValueError with the invalid mode in the message."""
|
||||
tenant_id = "tenant-123"
|
||||
config = {"temperature": 0.5}
|
||||
|
||||
mock_chat_validate = mock_config_managers["chat"]
|
||||
mock_agent_validate = mock_config_managers["agent"]
|
||||
mock_completion_validate = mock_config_managers["completion"]
|
||||
|
||||
with pytest.raises(ValueError, match=f"Invalid app mode: {AppMode.WORKFLOW}"):
|
||||
AppModelConfigService.validate_configuration(
|
||||
tenant_id=tenant_id,
|
||||
config=config,
|
||||
app_mode=AppMode.WORKFLOW,
|
||||
)
|
||||
|
||||
mock_chat_validate.assert_not_called()
|
||||
mock_agent_validate.assert_not_called()
|
||||
mock_completion_validate.assert_not_called()
|
||||
@@ -1,507 +0,0 @@
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import services.async_workflow_service as async_workflow_service_module
|
||||
from models.enums import AppTriggerType, CreatorUserRole, WorkflowRunTriggeredFrom, WorkflowTriggerStatus
|
||||
from services.async_workflow_service import AsyncWorkflowService
|
||||
from services.errors.app import QuotaExceededError, WorkflowNotFoundError, WorkflowQuotaLimitError
|
||||
from services.workflow.entities import AsyncTriggerResponse, TriggerData
|
||||
from services.workflow.queue_dispatcher import QueuePriority
|
||||
|
||||
|
||||
class AsyncWorkflowServiceTestDataFactory:
|
||||
"""Factory helpers for async workflow service unit tests."""
|
||||
|
||||
@staticmethod
|
||||
def create_trigger_data(
|
||||
app_id: str = "app-123",
|
||||
tenant_id: str = "tenant-123",
|
||||
workflow_id: str | None = "workflow-123",
|
||||
root_node_id: str = "root-node-123",
|
||||
) -> TriggerData:
|
||||
"""Create valid trigger data for async workflow execution tests."""
|
||||
return TriggerData(
|
||||
app_id=app_id,
|
||||
tenant_id=tenant_id,
|
||||
workflow_id=workflow_id,
|
||||
root_node_id=root_node_id,
|
||||
inputs={"name": "dify"},
|
||||
files=[],
|
||||
trigger_type=AppTriggerType.UNKNOWN,
|
||||
trigger_from=WorkflowRunTriggeredFrom.APP_RUN,
|
||||
trigger_metadata=None,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_trigger_log_with_data(trigger_data: TriggerData, retry_count: int = 0) -> MagicMock:
|
||||
"""Create a mock trigger log with serialized trigger data."""
|
||||
trigger_log = MagicMock()
|
||||
trigger_log.id = "trigger-log-123"
|
||||
trigger_log.trigger_data = trigger_data.model_dump_json()
|
||||
trigger_log.retry_count = retry_count
|
||||
trigger_log.error = "previous-error"
|
||||
trigger_log.status = WorkflowTriggerStatus.FAILED
|
||||
trigger_log.to_dict.return_value = {"id": trigger_log.id}
|
||||
return trigger_log
|
||||
|
||||
|
||||
class TestAsyncWorkflowService:
|
||||
@pytest.fixture
|
||||
def async_workflow_trigger_mocks(self):
|
||||
"""Shared fixture for async workflow trigger tests.
|
||||
|
||||
Yields mocks for:
|
||||
- repo: SQLAlchemyWorkflowTriggerLogRepository
|
||||
- dispatcher_manager_class: QueueDispatcherManager class
|
||||
- dispatcher: dispatcher instance
|
||||
- quota_workflow: QuotaType.WORKFLOW
|
||||
- get_workflow: AsyncWorkflowService._get_workflow method
|
||||
- professional_task: execute_workflow_professional
|
||||
- team_task: execute_workflow_team
|
||||
- sandbox_task: execute_workflow_sandbox
|
||||
"""
|
||||
mock_repo = MagicMock()
|
||||
|
||||
def _create_side_effect(new_log):
|
||||
new_log.id = "trigger-log-123"
|
||||
return new_log
|
||||
|
||||
mock_repo.create.side_effect = _create_side_effect
|
||||
|
||||
mock_dispatcher = MagicMock()
|
||||
quota_workflow = MagicMock()
|
||||
mock_get_workflow = MagicMock()
|
||||
|
||||
mock_professional_task = MagicMock()
|
||||
mock_team_task = MagicMock()
|
||||
mock_sandbox_task = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(
|
||||
async_workflow_service_module,
|
||||
"SQLAlchemyWorkflowTriggerLogRepository",
|
||||
return_value=mock_repo,
|
||||
),
|
||||
patch.object(async_workflow_service_module, "QueueDispatcherManager") as mock_dispatcher_manager_class,
|
||||
patch.object(async_workflow_service_module, "WorkflowService"),
|
||||
patch.object(
|
||||
async_workflow_service_module.AsyncWorkflowService,
|
||||
"_get_workflow",
|
||||
) as mock_get_workflow,
|
||||
patch.object(
|
||||
async_workflow_service_module,
|
||||
"QuotaType",
|
||||
new=SimpleNamespace(WORKFLOW=quota_workflow),
|
||||
),
|
||||
patch.object(async_workflow_service_module, "execute_workflow_professional") as mock_professional_task,
|
||||
patch.object(async_workflow_service_module, "execute_workflow_team") as mock_team_task,
|
||||
patch.object(async_workflow_service_module, "execute_workflow_sandbox") as mock_sandbox_task,
|
||||
):
|
||||
# Configure dispatcher_manager to return our mock_dispatcher
|
||||
mock_dispatcher_manager_class.return_value.get_dispatcher.return_value = mock_dispatcher
|
||||
|
||||
yield {
|
||||
"repo": mock_repo,
|
||||
"dispatcher_manager_class": mock_dispatcher_manager_class,
|
||||
"dispatcher": mock_dispatcher,
|
||||
"quota_workflow": quota_workflow,
|
||||
"get_workflow": mock_get_workflow,
|
||||
"professional_task": mock_professional_task,
|
||||
"team_task": mock_team_task,
|
||||
"sandbox_task": mock_sandbox_task,
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("queue_name", "selected_task_attr"),
|
||||
[
|
||||
(QueuePriority.PROFESSIONAL, "execute_workflow_professional"),
|
||||
(QueuePriority.TEAM, "execute_workflow_team"),
|
||||
(QueuePriority.SANDBOX, "execute_workflow_sandbox"),
|
||||
],
|
||||
)
|
||||
def test_should_dispatch_to_matching_celery_task_when_triggering_workflow(
|
||||
self, queue_name, selected_task_attr, async_workflow_trigger_mocks
|
||||
):
|
||||
"""Test queue-based task routing and successful async trigger response."""
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
session.commit = MagicMock()
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app-123"
|
||||
session.scalar.return_value = app_model
|
||||
trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data()
|
||||
workflow = MagicMock()
|
||||
workflow.id = "workflow-123"
|
||||
|
||||
mocks = async_workflow_trigger_mocks
|
||||
mocks["dispatcher"].get_queue_name.return_value = queue_name
|
||||
mocks["get_workflow"].return_value = workflow
|
||||
|
||||
task_result = MagicMock()
|
||||
task_result.id = "task-123"
|
||||
mocks["professional_task"].delay.return_value = task_result
|
||||
mocks["team_task"].delay.return_value = task_result
|
||||
mocks["sandbox_task"].delay.return_value = task_result
|
||||
|
||||
class DummyAccount:
|
||||
def __init__(self, user_id: str):
|
||||
self.id = user_id
|
||||
|
||||
with patch.object(async_workflow_service_module, "Account", DummyAccount):
|
||||
user = DummyAccount("account-123")
|
||||
|
||||
# Act
|
||||
result = AsyncWorkflowService.trigger_workflow_async(session=session, user=user, trigger_data=trigger_data)
|
||||
|
||||
# Assert
|
||||
assert isinstance(result, AsyncTriggerResponse)
|
||||
assert result.workflow_trigger_log_id == "trigger-log-123"
|
||||
assert result.task_id == "task-123"
|
||||
assert result.status == "queued"
|
||||
assert result.queue == queue_name
|
||||
|
||||
mocks["quota_workflow"].consume.assert_called_once_with("tenant-123")
|
||||
assert session.commit.call_count == 2
|
||||
|
||||
created_log = mocks["repo"].create.call_args[0][0]
|
||||
assert created_log.status == WorkflowTriggerStatus.QUEUED
|
||||
assert created_log.queue_name == queue_name
|
||||
assert created_log.created_by_role == CreatorUserRole.ACCOUNT
|
||||
assert created_log.created_by == "account-123"
|
||||
assert created_log.trigger_data == trigger_data.model_dump_json()
|
||||
assert created_log.inputs == json.dumps(dict(trigger_data.inputs))
|
||||
assert created_log.celery_task_id == "task-123"
|
||||
|
||||
task_mocks = {
|
||||
"execute_workflow_professional": mocks["professional_task"],
|
||||
"execute_workflow_team": mocks["team_task"],
|
||||
"execute_workflow_sandbox": mocks["sandbox_task"],
|
||||
}
|
||||
for task_attr, task_mock in task_mocks.items():
|
||||
if task_attr == selected_task_attr:
|
||||
task_mock.delay.assert_called_once_with({"workflow_trigger_log_id": "trigger-log-123"})
|
||||
else:
|
||||
task_mock.delay.assert_not_called()
|
||||
|
||||
def test_should_set_end_user_role_when_triggered_by_end_user(self, async_workflow_trigger_mocks):
|
||||
"""Test that non-account users are tracked as END_USER in trigger logs."""
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
session.commit = MagicMock()
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app-123"
|
||||
session.scalar.return_value = app_model
|
||||
trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data()
|
||||
workflow = MagicMock()
|
||||
workflow.id = "workflow-123"
|
||||
|
||||
mocks = async_workflow_trigger_mocks
|
||||
mocks["dispatcher"].get_queue_name.return_value = QueuePriority.SANDBOX
|
||||
mocks["get_workflow"].return_value = workflow
|
||||
|
||||
task_result = MagicMock(id="task-123")
|
||||
mocks["sandbox_task"].delay.return_value = task_result
|
||||
|
||||
user = SimpleNamespace(id="end-user-123")
|
||||
|
||||
# Act
|
||||
AsyncWorkflowService.trigger_workflow_async(session=session, user=user, trigger_data=trigger_data)
|
||||
|
||||
# Assert
|
||||
created_log = mocks["repo"].create.call_args[0][0]
|
||||
assert created_log.created_by_role == CreatorUserRole.END_USER
|
||||
assert created_log.created_by == "end-user-123"
|
||||
|
||||
def test_should_raise_workflow_not_found_when_app_does_not_exist(self):
|
||||
"""Test trigger failure when app lookup returns no result."""
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = None
|
||||
trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data(app_id="missing-app")
|
||||
|
||||
with (
|
||||
patch.object(async_workflow_service_module, "SQLAlchemyWorkflowTriggerLogRepository"),
|
||||
patch.object(async_workflow_service_module, "QueueDispatcherManager"),
|
||||
patch.object(async_workflow_service_module, "WorkflowService"),
|
||||
):
|
||||
# Act / Assert
|
||||
with pytest.raises(WorkflowNotFoundError, match="App not found: missing-app"):
|
||||
AsyncWorkflowService.trigger_workflow_async(
|
||||
session=session,
|
||||
user=SimpleNamespace(id="user-123"),
|
||||
trigger_data=trigger_data,
|
||||
)
|
||||
|
||||
def test_should_mark_log_rate_limited_and_raise_when_quota_exceeded(self, async_workflow_trigger_mocks):
|
||||
"""Test quota-exceeded path updates trigger log and raises WorkflowQuotaLimitError."""
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
session.commit = MagicMock()
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app-123"
|
||||
session.scalar.return_value = app_model
|
||||
trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data()
|
||||
workflow = MagicMock()
|
||||
workflow.id = "workflow-123"
|
||||
|
||||
mocks = async_workflow_trigger_mocks
|
||||
mocks["dispatcher"].get_queue_name.return_value = QueuePriority.TEAM
|
||||
mocks["get_workflow"].return_value = workflow
|
||||
mocks["quota_workflow"].consume.side_effect = QuotaExceededError(
|
||||
feature="workflow",
|
||||
tenant_id="tenant-123",
|
||||
required=1,
|
||||
)
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(
|
||||
WorkflowQuotaLimitError,
|
||||
match="Workflow execution quota limit reached for tenant tenant-123",
|
||||
):
|
||||
AsyncWorkflowService.trigger_workflow_async(
|
||||
session=session,
|
||||
user=SimpleNamespace(id="user-123"),
|
||||
trigger_data=trigger_data,
|
||||
)
|
||||
|
||||
assert session.commit.call_count == 2
|
||||
updated_log = mocks["repo"].update.call_args[0][0]
|
||||
assert updated_log.status == WorkflowTriggerStatus.RATE_LIMITED
|
||||
assert "Quota limit reached" in updated_log.error
|
||||
mocks["professional_task"].delay.assert_not_called()
|
||||
mocks["team_task"].delay.assert_not_called()
|
||||
mocks["sandbox_task"].delay.assert_not_called()
|
||||
|
||||
def test_should_raise_when_reinvoke_target_log_does_not_exist(self):
|
||||
"""Test reinvoke_trigger error path when original trigger log is missing."""
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
repo = MagicMock()
|
||||
repo.get_by_id.return_value = None
|
||||
|
||||
with patch.object(async_workflow_service_module, "SQLAlchemyWorkflowTriggerLogRepository", return_value=repo):
|
||||
# Act / Assert
|
||||
with pytest.raises(ValueError, match="Trigger log not found: missing-log"):
|
||||
AsyncWorkflowService.reinvoke_trigger(
|
||||
session=session,
|
||||
user=SimpleNamespace(id="user-123"),
|
||||
workflow_trigger_log_id="missing-log",
|
||||
)
|
||||
|
||||
def test_should_update_original_log_and_requeue_when_reinvoking(self):
|
||||
"""Test reinvoke flow updates original log state and triggers a new async run."""
|
||||
# Arrange
|
||||
session = MagicMock()
|
||||
trigger_data = AsyncWorkflowServiceTestDataFactory.create_trigger_data()
|
||||
trigger_log = AsyncWorkflowServiceTestDataFactory.create_trigger_log_with_data(trigger_data, retry_count=1)
|
||||
repo = MagicMock()
|
||||
repo.get_by_id.return_value = trigger_log
|
||||
|
||||
expected_response = AsyncTriggerResponse(
|
||||
workflow_trigger_log_id="new-trigger-log-456",
|
||||
task_id="task-456",
|
||||
status="queued",
|
||||
queue=QueuePriority.TEAM,
|
||||
)
|
||||
|
||||
with (
|
||||
patch.object(async_workflow_service_module, "SQLAlchemyWorkflowTriggerLogRepository", return_value=repo),
|
||||
patch.object(
|
||||
async_workflow_service_module.AsyncWorkflowService,
|
||||
"trigger_workflow_async",
|
||||
return_value=expected_response,
|
||||
) as mock_trigger_workflow_async,
|
||||
):
|
||||
user = SimpleNamespace(id="user-123")
|
||||
|
||||
# Act
|
||||
response = AsyncWorkflowService.reinvoke_trigger(
|
||||
session=session,
|
||||
user=user,
|
||||
workflow_trigger_log_id="trigger-log-123",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert response == expected_response
|
||||
assert trigger_log.status == WorkflowTriggerStatus.RETRYING
|
||||
assert trigger_log.retry_count == 2
|
||||
assert trigger_log.error is None
|
||||
assert trigger_log.triggered_at is not None
|
||||
repo.update.assert_called_once_with(trigger_log)
|
||||
session.commit.assert_called_once()
|
||||
called_trigger_data = mock_trigger_workflow_async.call_args[0][2]
|
||||
assert isinstance(called_trigger_data, TriggerData)
|
||||
assert called_trigger_data.app_id == "app-123"
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("repo_result", "expected"),
|
||||
[
|
||||
(None, None),
|
||||
(MagicMock(), {"id": "trigger-log-123"}),
|
||||
],
|
||||
)
|
||||
def test_should_return_trigger_log_dict_or_none(self, repo_result, expected):
|
||||
"""Test get_trigger_log returns serialized log data or None."""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_repo = MagicMock()
|
||||
fake_engine = MagicMock()
|
||||
mock_repo.get_by_id.return_value = repo_result
|
||||
if repo_result:
|
||||
repo_result.to_dict.return_value = expected
|
||||
|
||||
mock_session_context = MagicMock()
|
||||
mock_session_context.__enter__.return_value = mock_session
|
||||
mock_session_context.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=fake_engine)),
|
||||
patch.object(
|
||||
async_workflow_service_module, "Session", return_value=mock_session_context
|
||||
) as mock_session_class,
|
||||
patch.object(
|
||||
async_workflow_service_module,
|
||||
"SQLAlchemyWorkflowTriggerLogRepository",
|
||||
return_value=mock_repo,
|
||||
),
|
||||
):
|
||||
# Act
|
||||
result = AsyncWorkflowService.get_trigger_log("trigger-log-123", tenant_id="tenant-123")
|
||||
|
||||
# Assert
|
||||
assert result == expected
|
||||
mock_session_class.assert_called_once_with(fake_engine)
|
||||
mock_repo.get_by_id.assert_called_once_with("trigger-log-123", "tenant-123")
|
||||
|
||||
def test_should_return_recent_logs_as_dict_list(self):
|
||||
"""Test get_recent_logs converts repository models into dictionaries."""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_repo = MagicMock()
|
||||
log1 = MagicMock()
|
||||
log1.to_dict.return_value = {"id": "log-1"}
|
||||
log2 = MagicMock()
|
||||
log2.to_dict.return_value = {"id": "log-2"}
|
||||
mock_repo.get_recent_logs.return_value = [log1, log2]
|
||||
|
||||
mock_session_context = MagicMock()
|
||||
mock_session_context.__enter__.return_value = mock_session
|
||||
mock_session_context.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())),
|
||||
patch.object(async_workflow_service_module, "Session", return_value=mock_session_context),
|
||||
patch.object(
|
||||
async_workflow_service_module,
|
||||
"SQLAlchemyWorkflowTriggerLogRepository",
|
||||
return_value=mock_repo,
|
||||
),
|
||||
):
|
||||
# Act
|
||||
result = AsyncWorkflowService.get_recent_logs(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-123",
|
||||
hours=12,
|
||||
limit=50,
|
||||
offset=10,
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == [{"id": "log-1"}, {"id": "log-2"}]
|
||||
mock_repo.get_recent_logs.assert_called_once_with(
|
||||
tenant_id="tenant-123",
|
||||
app_id="app-123",
|
||||
hours=12,
|
||||
limit=50,
|
||||
offset=10,
|
||||
)
|
||||
|
||||
def test_should_return_failed_logs_for_retry_as_dict_list(self):
|
||||
"""Test get_failed_logs_for_retry serializes repository logs into dicts."""
|
||||
# Arrange
|
||||
mock_session = MagicMock()
|
||||
mock_repo = MagicMock()
|
||||
log = MagicMock()
|
||||
log.to_dict.return_value = {"id": "failed-log-1"}
|
||||
mock_repo.get_failed_for_retry.return_value = [log]
|
||||
|
||||
mock_session_context = MagicMock()
|
||||
mock_session_context.__enter__.return_value = mock_session
|
||||
mock_session_context.__exit__.return_value = None
|
||||
|
||||
with (
|
||||
patch.object(async_workflow_service_module, "db", new=SimpleNamespace(engine=MagicMock())),
|
||||
patch.object(async_workflow_service_module, "Session", return_value=mock_session_context),
|
||||
patch.object(
|
||||
async_workflow_service_module,
|
||||
"SQLAlchemyWorkflowTriggerLogRepository",
|
||||
return_value=mock_repo,
|
||||
),
|
||||
):
|
||||
# Act
|
||||
result = AsyncWorkflowService.get_failed_logs_for_retry(tenant_id="tenant-123", max_retry_count=4, limit=20)
|
||||
|
||||
# Assert
|
||||
assert result == [{"id": "failed-log-1"}]
|
||||
mock_repo.get_failed_for_retry.assert_called_once_with(tenant_id="tenant-123", max_retry_count=4, limit=20)
|
||||
|
||||
|
||||
class TestAsyncWorkflowServiceGetWorkflow:
|
||||
def test_should_return_specific_workflow_when_workflow_id_exists(self):
|
||||
"""Test _get_workflow returns published workflow by id when provided."""
|
||||
# Arrange
|
||||
workflow_service = MagicMock()
|
||||
app_model = MagicMock()
|
||||
workflow = MagicMock()
|
||||
workflow_service.get_published_workflow_by_id.return_value = workflow
|
||||
|
||||
# Act
|
||||
result = AsyncWorkflowService._get_workflow(workflow_service, app_model, workflow_id="workflow-123")
|
||||
|
||||
# Assert
|
||||
assert result == workflow
|
||||
workflow_service.get_published_workflow_by_id.assert_called_once_with(app_model, "workflow-123")
|
||||
workflow_service.get_published_workflow.assert_not_called()
|
||||
|
||||
def test_should_raise_when_specific_workflow_id_not_found(self):
|
||||
"""Test _get_workflow raises WorkflowNotFoundError for unknown workflow id."""
|
||||
# Arrange
|
||||
workflow_service = MagicMock()
|
||||
app_model = MagicMock()
|
||||
workflow_service.get_published_workflow_by_id.return_value = None
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(WorkflowNotFoundError, match="Published workflow not found: workflow-404"):
|
||||
AsyncWorkflowService._get_workflow(workflow_service, app_model, workflow_id="workflow-404")
|
||||
|
||||
def test_should_return_default_published_workflow_when_workflow_id_not_provided(self):
|
||||
"""Test _get_workflow returns default published workflow when no id is provided."""
|
||||
# Arrange
|
||||
workflow_service = MagicMock()
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app-123"
|
||||
workflow = MagicMock()
|
||||
workflow_service.get_published_workflow.return_value = workflow
|
||||
|
||||
# Act
|
||||
result = AsyncWorkflowService._get_workflow(workflow_service, app_model)
|
||||
|
||||
# Assert
|
||||
assert result == workflow
|
||||
workflow_service.get_published_workflow.assert_called_once_with(app_model)
|
||||
workflow_service.get_published_workflow_by_id.assert_not_called()
|
||||
|
||||
def test_should_raise_when_default_published_workflow_not_found(self):
|
||||
"""Test _get_workflow raises WorkflowNotFoundError when app has no published workflow."""
|
||||
# Arrange
|
||||
workflow_service = MagicMock()
|
||||
app_model = MagicMock()
|
||||
app_model.id = "app-123"
|
||||
workflow_service.get_published_workflow.return_value = None
|
||||
|
||||
# Act / Assert
|
||||
with pytest.raises(WorkflowNotFoundError, match="No published workflow found for app: app-123"):
|
||||
AsyncWorkflowService._get_workflow(workflow_service, app_model)
|
||||
@@ -1,73 +0,0 @@
|
||||
import base64
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import create_engine
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
import services.attachment_service as attachment_service_module
|
||||
from models.model import UploadFile
|
||||
from services.attachment_service import AttachmentService
|
||||
|
||||
|
||||
class TestAttachmentService:
|
||||
def test_should_initialize_with_sessionmaker_when_sessionmaker_is_provided(self):
|
||||
"""Test that AttachmentService keeps the provided sessionmaker instance."""
|
||||
session_factory = sessionmaker()
|
||||
|
||||
service = AttachmentService(session_factory=session_factory)
|
||||
|
||||
assert service._session_maker is session_factory
|
||||
|
||||
def test_should_initialize_with_bound_sessionmaker_when_engine_is_provided(self):
|
||||
"""Test that AttachmentService builds a sessionmaker bound to the provided engine."""
|
||||
engine = create_engine("sqlite:///:memory:")
|
||||
|
||||
service = AttachmentService(session_factory=engine)
|
||||
session = service._session_maker()
|
||||
try:
|
||||
assert session.bind == engine
|
||||
finally:
|
||||
session.close()
|
||||
engine.dispose()
|
||||
|
||||
@pytest.mark.parametrize("invalid_session_factory", [None, "not-a-session-factory", 1])
|
||||
def test_should_raise_assertion_error_when_session_factory_type_is_invalid(self, invalid_session_factory):
|
||||
"""Test that invalid session_factory types are rejected."""
|
||||
with pytest.raises(AssertionError, match="must be a sessionmaker or an Engine."):
|
||||
AttachmentService(session_factory=invalid_session_factory)
|
||||
|
||||
def test_should_return_base64_encoded_blob_when_file_exists(self):
|
||||
"""Test that existing files are loaded from storage and returned as base64."""
|
||||
service = AttachmentService(session_factory=sessionmaker())
|
||||
upload_file = MagicMock(spec=UploadFile)
|
||||
upload_file.key = "upload-file-key"
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = upload_file
|
||||
service._session_maker = MagicMock(return_value=session)
|
||||
|
||||
with patch.object(attachment_service_module.storage, "load_once", return_value=b"binary-content") as mock_load:
|
||||
result = service.get_file_base64("file-123")
|
||||
|
||||
assert result == base64.b64encode(b"binary-content").decode()
|
||||
service._session_maker.assert_called_once_with(expire_on_commit=False)
|
||||
session.query.assert_called_once_with(UploadFile)
|
||||
mock_load.assert_called_once_with("upload-file-key")
|
||||
|
||||
def test_should_raise_not_found_when_file_does_not_exist(self):
|
||||
"""Test that missing files raise NotFound and never call storage."""
|
||||
service = AttachmentService(session_factory=sessionmaker())
|
||||
|
||||
session = MagicMock()
|
||||
session.query.return_value.where.return_value.first.return_value = None
|
||||
service._session_maker = MagicMock(return_value=session)
|
||||
|
||||
with patch.object(attachment_service_module.storage, "load_once") as mock_load:
|
||||
with pytest.raises(NotFound, match="File not found"):
|
||||
service.get_file_base64("missing-file")
|
||||
|
||||
service._session_maker.assert_called_once_with(expire_on_commit=False)
|
||||
session.query.assert_called_once_with(UploadFile)
|
||||
mock_load.assert_not_called()
|
||||
@@ -1,89 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
|
||||
from services.code_based_extension_service import CodeBasedExtensionService
|
||||
|
||||
|
||||
class TestCodeBasedExtensionService:
|
||||
def test_should_return_only_non_builtin_extensions_with_public_fields(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test service returns only non-builtin extensions with name/label/form_schema fields."""
|
||||
moderation_extension = SimpleNamespace(
|
||||
name="custom-moderation",
|
||||
label={"en-US": "Custom Moderation"},
|
||||
form_schema=[{"variable": "api_key"}],
|
||||
builtin=False,
|
||||
extension_class=object,
|
||||
position=20,
|
||||
)
|
||||
builtin_extension = SimpleNamespace(
|
||||
name="builtin-moderation",
|
||||
label={"en-US": "Builtin Moderation"},
|
||||
form_schema=[{"variable": "token"}],
|
||||
builtin=True,
|
||||
extension_class=object,
|
||||
position=1,
|
||||
)
|
||||
retrieval_extension = SimpleNamespace(
|
||||
name="custom-retrieval",
|
||||
label={"en-US": "Custom Retrieval"},
|
||||
form_schema=None,
|
||||
builtin=False,
|
||||
extension_class=object,
|
||||
position=30,
|
||||
)
|
||||
module_extensions_mock = MagicMock(return_value=[moderation_extension, builtin_extension, retrieval_extension])
|
||||
monkeypatch.setattr(
|
||||
"services.code_based_extension_service.code_based_extension.module_extensions",
|
||||
module_extensions_mock,
|
||||
)
|
||||
|
||||
result = CodeBasedExtensionService.get_code_based_extension("external_data_tool")
|
||||
|
||||
assert result == [
|
||||
{
|
||||
"name": "custom-moderation",
|
||||
"label": {"en-US": "Custom Moderation"},
|
||||
"form_schema": [{"variable": "api_key"}],
|
||||
},
|
||||
{
|
||||
"name": "custom-retrieval",
|
||||
"label": {"en-US": "Custom Retrieval"},
|
||||
"form_schema": None,
|
||||
},
|
||||
]
|
||||
assert set(result[0].keys()) == {"name", "label", "form_schema"}
|
||||
module_extensions_mock.assert_called_once_with("external_data_tool")
|
||||
|
||||
def test_should_return_empty_list_when_all_extensions_are_builtin(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test builtin extensions are filtered out completely."""
|
||||
builtin_extension = SimpleNamespace(
|
||||
name="builtin-moderation",
|
||||
label={"en-US": "Builtin Moderation"},
|
||||
form_schema=[{"variable": "token"}],
|
||||
builtin=True,
|
||||
)
|
||||
module_extensions_mock = MagicMock(return_value=[builtin_extension])
|
||||
monkeypatch.setattr(
|
||||
"services.code_based_extension_service.code_based_extension.module_extensions",
|
||||
module_extensions_mock,
|
||||
)
|
||||
|
||||
result = CodeBasedExtensionService.get_code_based_extension("moderation")
|
||||
|
||||
assert result == []
|
||||
module_extensions_mock.assert_called_once_with("moderation")
|
||||
|
||||
def test_should_propagate_error_when_module_extensions_lookup_fails(self, monkeypatch: pytest.MonkeyPatch):
|
||||
"""Test ValueError from extension lookup bubbles up unchanged."""
|
||||
module_extensions_mock = MagicMock(side_effect=ValueError("Extension Module invalid-module not found"))
|
||||
monkeypatch.setattr(
|
||||
"services.code_based_extension_service.code_based_extension.module_extensions",
|
||||
module_extensions_mock,
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Extension Module invalid-module not found"):
|
||||
CodeBasedExtensionService.get_code_based_extension("invalid-module")
|
||||
|
||||
module_extensions_mock.assert_called_once_with("invalid-module")
|
||||
@@ -1,75 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock
|
||||
|
||||
import pytest
|
||||
from core.variables.variables import StringVariable
|
||||
|
||||
from services.conversation_variable_updater import ConversationVariableNotFoundError, ConversationVariableUpdater
|
||||
|
||||
|
||||
class TestConversationVariableUpdater:
|
||||
def test_should_update_conversation_variable_data_and_commit(self):
|
||||
"""Test update persists serialized variable data when the row exists."""
|
||||
conversation_id = "conv-123"
|
||||
variable = StringVariable(
|
||||
id="var-123",
|
||||
name="topic",
|
||||
value="new value",
|
||||
)
|
||||
expected_json = variable.model_dump_json()
|
||||
|
||||
row = SimpleNamespace(data="old value")
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = row
|
||||
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
session_context.__exit__.return_value = None
|
||||
|
||||
session_maker = MagicMock(return_value=session_context)
|
||||
updater = ConversationVariableUpdater(session_maker)
|
||||
|
||||
updater.update(conversation_id=conversation_id, variable=variable)
|
||||
|
||||
session_maker.assert_called_once_with()
|
||||
session.scalar.assert_called_once()
|
||||
stmt = session.scalar.call_args.args[0]
|
||||
compiled_params = stmt.compile().params
|
||||
assert variable.id in compiled_params.values()
|
||||
assert conversation_id in compiled_params.values()
|
||||
assert row.data == expected_json
|
||||
session.commit.assert_called_once()
|
||||
|
||||
def test_should_raise_not_found_error_when_conversation_variable_missing(self):
|
||||
"""Test update raises ConversationVariableNotFoundError when no matching row exists."""
|
||||
conversation_id = "conv-404"
|
||||
variable = StringVariable(
|
||||
id="var-404",
|
||||
name="topic",
|
||||
value="value",
|
||||
)
|
||||
|
||||
session = MagicMock()
|
||||
session.scalar.return_value = None
|
||||
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
session_context.__exit__.return_value = None
|
||||
|
||||
session_maker = MagicMock(return_value=session_context)
|
||||
updater = ConversationVariableUpdater(session_maker)
|
||||
|
||||
with pytest.raises(ConversationVariableNotFoundError, match="conversation variable not found in the database"):
|
||||
updater.update(conversation_id=conversation_id, variable=variable)
|
||||
|
||||
session.commit.assert_not_called()
|
||||
|
||||
def test_should_do_nothing_when_flush_is_called(self):
|
||||
"""Test flush currently behaves as a no-op and returns None."""
|
||||
session_maker = MagicMock()
|
||||
updater = ConversationVariableUpdater(session_maker)
|
||||
|
||||
result = updater.flush()
|
||||
|
||||
assert result is None
|
||||
session_maker.assert_not_called()
|
||||
@@ -1,157 +0,0 @@
|
||||
from types import SimpleNamespace
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
import services.credit_pool_service as credit_pool_service_module
|
||||
from core.errors.error import QuotaExceededError
|
||||
from models import TenantCreditPool
|
||||
from services.credit_pool_service import CreditPoolService
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_credit_deduction_setup():
|
||||
"""Fixture providing common setup for credit deduction tests."""
|
||||
pool = SimpleNamespace(remaining_credits=50)
|
||||
fake_engine = MagicMock()
|
||||
session = MagicMock()
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
session_context.__exit__.return_value = None
|
||||
|
||||
mock_get_pool = patch.object(CreditPoolService, "get_pool", return_value=pool)
|
||||
mock_db = patch.object(credit_pool_service_module, "db", new=SimpleNamespace(engine=fake_engine))
|
||||
mock_session = patch.object(credit_pool_service_module, "Session", return_value=session_context)
|
||||
|
||||
return {
|
||||
"pool": pool,
|
||||
"fake_engine": fake_engine,
|
||||
"session": session,
|
||||
"session_context": session_context,
|
||||
"patches": (mock_get_pool, mock_db, mock_session),
|
||||
}
|
||||
|
||||
|
||||
class TestCreditPoolService:
|
||||
def test_should_create_default_pool_with_trial_type_and_configured_quota(self):
|
||||
"""Test create_default_pool persists a trial pool using configured hosted credits."""
|
||||
tenant_id = "tenant-123"
|
||||
hosted_pool_credits = 5000
|
||||
|
||||
with (
|
||||
patch.object(credit_pool_service_module.dify_config, "HOSTED_POOL_CREDITS", hosted_pool_credits),
|
||||
patch.object(credit_pool_service_module, "db") as mock_db,
|
||||
):
|
||||
pool = CreditPoolService.create_default_pool(tenant_id)
|
||||
|
||||
assert isinstance(pool, TenantCreditPool)
|
||||
assert pool.tenant_id == tenant_id
|
||||
assert pool.pool_type == "trial"
|
||||
assert pool.quota_limit == hosted_pool_credits
|
||||
assert pool.quota_used == 0
|
||||
mock_db.session.add.assert_called_once_with(pool)
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
def test_should_return_first_pool_from_query_when_get_pool_called(self):
|
||||
"""Test get_pool queries by tenant and pool_type and returns first result."""
|
||||
tenant_id = "tenant-123"
|
||||
pool_type = "enterprise"
|
||||
expected_pool = MagicMock(spec=TenantCreditPool)
|
||||
|
||||
with patch.object(credit_pool_service_module, "db") as mock_db:
|
||||
query = mock_db.session.query.return_value
|
||||
filtered_query = query.filter_by.return_value
|
||||
filtered_query.first.return_value = expected_pool
|
||||
|
||||
result = CreditPoolService.get_pool(tenant_id=tenant_id, pool_type=pool_type)
|
||||
|
||||
assert result == expected_pool
|
||||
mock_db.session.query.assert_called_once_with(TenantCreditPool)
|
||||
query.filter_by.assert_called_once_with(tenant_id=tenant_id, pool_type=pool_type)
|
||||
filtered_query.first.assert_called_once()
|
||||
|
||||
def test_should_return_false_when_pool_not_found_in_check_credits_available(self):
|
||||
"""Test check_credits_available returns False when tenant has no pool."""
|
||||
with patch.object(CreditPoolService, "get_pool", return_value=None) as mock_get_pool:
|
||||
result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=10)
|
||||
|
||||
assert result is False
|
||||
mock_get_pool.assert_called_once_with("tenant-123", "trial")
|
||||
|
||||
def test_should_return_true_when_remaining_credits_cover_required_amount(self):
|
||||
"""Test check_credits_available returns True when remaining credits are sufficient."""
|
||||
pool = SimpleNamespace(remaining_credits=100)
|
||||
|
||||
with patch.object(CreditPoolService, "get_pool", return_value=pool) as mock_get_pool:
|
||||
result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=60)
|
||||
|
||||
assert result is True
|
||||
mock_get_pool.assert_called_once_with("tenant-123", "trial")
|
||||
|
||||
def test_should_return_false_when_remaining_credits_are_insufficient(self):
|
||||
"""Test check_credits_available returns False when required credits exceed remaining credits."""
|
||||
pool = SimpleNamespace(remaining_credits=30)
|
||||
|
||||
with patch.object(CreditPoolService, "get_pool", return_value=pool):
|
||||
result = CreditPoolService.check_credits_available(tenant_id="tenant-123", credits_required=60)
|
||||
|
||||
assert result is False
|
||||
|
||||
def test_should_raise_quota_exceeded_when_pool_not_found_in_check_and_deduct(self):
|
||||
"""Test check_and_deduct_credits raises when tenant credit pool does not exist."""
|
||||
with patch.object(CreditPoolService, "get_pool", return_value=None):
|
||||
with pytest.raises(QuotaExceededError, match="Credit pool not found"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10)
|
||||
|
||||
def test_should_raise_quota_exceeded_when_pool_has_no_remaining_credits(self):
|
||||
"""Test check_and_deduct_credits raises when remaining credits are zero or negative."""
|
||||
pool = SimpleNamespace(remaining_credits=0)
|
||||
|
||||
with patch.object(CreditPoolService, "get_pool", return_value=pool):
|
||||
with pytest.raises(QuotaExceededError, match="No credits remaining"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10)
|
||||
|
||||
def test_should_deduct_minimum_of_required_and_remaining_credits(self, mock_credit_deduction_setup):
|
||||
"""Test check_and_deduct_credits updates quota_used by the actual deducted amount."""
|
||||
tenant_id = "tenant-123"
|
||||
pool_type = "trial"
|
||||
credits_required = 200
|
||||
remaining_credits = 120
|
||||
expected_deducted_credits = 120
|
||||
|
||||
mock_credit_deduction_setup["pool"].remaining_credits = remaining_credits
|
||||
patches = mock_credit_deduction_setup["patches"]
|
||||
session = mock_credit_deduction_setup["session"]
|
||||
|
||||
with patches[0], patches[1], patches[2]:
|
||||
result = CreditPoolService.check_and_deduct_credits(
|
||||
tenant_id=tenant_id,
|
||||
credits_required=credits_required,
|
||||
pool_type=pool_type,
|
||||
)
|
||||
|
||||
assert result == expected_deducted_credits
|
||||
session.execute.assert_called_once()
|
||||
session.commit.assert_called_once()
|
||||
|
||||
stmt = session.execute.call_args.args[0]
|
||||
compiled_params = stmt.compile().params
|
||||
assert tenant_id in compiled_params.values()
|
||||
assert pool_type in compiled_params.values()
|
||||
assert expected_deducted_credits in compiled_params.values()
|
||||
|
||||
def test_should_raise_quota_exceeded_when_deduction_update_fails(self, mock_credit_deduction_setup):
|
||||
"""Test check_and_deduct_credits translates DB update failures to QuotaExceededError."""
|
||||
mock_credit_deduction_setup["pool"].remaining_credits = 50
|
||||
mock_credit_deduction_setup["session"].execute.side_effect = Exception("db failure")
|
||||
session = mock_credit_deduction_setup["session"]
|
||||
|
||||
patches = mock_credit_deduction_setup["patches"]
|
||||
mock_logger = patch.object(credit_pool_service_module, "logger")
|
||||
|
||||
with patches[0], patches[1], patches[2], mock_logger as mock_logger_obj:
|
||||
with pytest.raises(QuotaExceededError, match="Failed to deduct credits"):
|
||||
CreditPoolService.check_and_deduct_credits(tenant_id="tenant-123", credits_required=10)
|
||||
|
||||
session.commit.assert_not_called()
|
||||
mock_logger_obj.exception.assert_called_once()
|
||||
@@ -7,7 +7,7 @@ import pytest
|
||||
from core.mcp.types import Tool as MCPTool
|
||||
from core.tools.entities.api_entities import ToolApiEntity, ToolProviderApiEntity
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
from core.tools.entities.tool_entities import ToolParameter, ToolProviderType
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from models.tools import MCPToolProvider
|
||||
from services.tools.tools_transform_service import ToolTransformService
|
||||
|
||||
@@ -175,137 +175,6 @@ class TestMCPToolTransform:
|
||||
# The actual parameter conversion is handled by convert_mcp_schema_to_parameter
|
||||
# which should be tested separately
|
||||
|
||||
def test_convert_mcp_schema_to_parameter_preserves_anyof_object_type(self):
|
||||
"""Nullable object schemas should keep the object parameter type."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retrieval_model": {
|
||||
"anyOf": [{"type": "object"}, {"type": "null"}],
|
||||
"description": "检索模型配置",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "retrieval_model"
|
||||
assert result[0].type == ToolParameter.ToolParameterType.OBJECT
|
||||
assert result[0].input_schema == schema["properties"]["retrieval_model"]
|
||||
|
||||
def test_convert_mcp_schema_to_parameter_preserves_oneof_object_type(self):
|
||||
"""Nullable oneOf object schemas should keep the object parameter type."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retrieval_model": {
|
||||
"oneOf": [{"type": "object"}, {"type": "null"}],
|
||||
"description": "检索模型配置",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "retrieval_model"
|
||||
assert result[0].type == ToolParameter.ToolParameterType.OBJECT
|
||||
assert result[0].input_schema == schema["properties"]["retrieval_model"]
|
||||
|
||||
def test_convert_mcp_schema_to_parameter_handles_null_type(self):
|
||||
"""Schemas with only a null type should fall back to string."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"null_prop_str": {"type": "null"},
|
||||
"null_prop_list": {"type": ["null"]},
|
||||
},
|
||||
}
|
||||
|
||||
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(result) == 2
|
||||
param_map = {parameter.name: parameter for parameter in result}
|
||||
assert "null_prop_str" in param_map
|
||||
assert param_map["null_prop_str"].type == ToolParameter.ToolParameterType.STRING
|
||||
assert "null_prop_list" in param_map
|
||||
assert param_map["null_prop_list"].type == ToolParameter.ToolParameterType.STRING
|
||||
|
||||
def test_convert_mcp_schema_to_parameter_preserves_allof_object_type_with_multiple_object_items(self):
|
||||
"""Property-level allOf with multiple object items should still resolve to object."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"allOf": [
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"enabled": {"type": "boolean"},
|
||||
},
|
||||
"required": ["enabled"],
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"priority": {"type": "integer", "minimum": 1, "maximum": 10},
|
||||
},
|
||||
"required": ["priority"],
|
||||
},
|
||||
],
|
||||
"description": "Config must match all schemas (allOf)",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "config"
|
||||
assert result[0].type == ToolParameter.ToolParameterType.OBJECT
|
||||
assert result[0].input_schema == schema["properties"]["config"]
|
||||
|
||||
def test_convert_mcp_schema_to_parameter_preserves_allof_object_type(self):
|
||||
"""Composed property schemas should keep the object parameter type."""
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"retrieval_model": {
|
||||
"allOf": [
|
||||
{"type": "object"},
|
||||
{"properties": {"top_k": {"type": "integer"}}},
|
||||
],
|
||||
"description": "检索模型配置",
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "retrieval_model"
|
||||
assert result[0].type == ToolParameter.ToolParameterType.OBJECT
|
||||
assert result[0].input_schema == schema["properties"]["retrieval_model"]
|
||||
|
||||
def test_convert_mcp_schema_to_parameter_limits_recursive_schema_depth(self):
|
||||
"""Self-referential composed schemas should stop resolving after the configured max depth."""
|
||||
recursive_property: dict[str, object] = {"description": "Recursive schema"}
|
||||
recursive_property["anyOf"] = [recursive_property]
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"recursive_config": recursive_property,
|
||||
},
|
||||
}
|
||||
|
||||
result = ToolTransformService.convert_mcp_schema_to_parameter(schema)
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].name == "recursive_config"
|
||||
assert result[0].type == ToolParameter.ToolParameterType.STRING
|
||||
assert result[0].input_schema is None
|
||||
|
||||
def test_mcp_provider_to_user_provider_for_list(self, mock_provider_full):
|
||||
"""Test mcp_provider_to_user_provider with for_list=True."""
|
||||
# Set tools data with null description
|
||||
|
||||
6
api/uv.lock
generated
6
api/uv.lock
generated
@@ -5113,11 +5113,11 @@ wheels = [
|
||||
|
||||
[[package]]
|
||||
name = "pypdf"
|
||||
version = "6.8.0"
|
||||
version = "6.7.5"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/b4/a3/e705b0805212b663a4c27b861c8a603dba0f8b4bb281f96f8e746576a50d/pypdf-6.8.0.tar.gz", hash = "sha256:cb7eaeaa4133ce76f762184069a854e03f4d9a08568f0e0623f7ea810407833b", size = 5307831, upload-time = "2026-03-09T13:37:40.591Z" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/f6/52/37cc0aa9e9d1bf7729a737a0d83f8b3f851c8eb137373d9f71eafb0a3405/pypdf-6.7.5.tar.gz", hash = "sha256:40bb2e2e872078655f12b9b89e2f900888bb505e88a82150b64f9f34fa25651d", size = 5304278, upload-time = "2026-03-02T09:05:21.464Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/ec/4ccf3bb86b1afe5d7176e1c8abcdbf22b53dd682ec2eda50e1caadcf6846/pypdf-6.8.0-py3-none-any.whl", hash = "sha256:2a025080a8dd73f48123c89c57174a5ff3806c71763ee4e49572dc90454943c7", size = 332177, upload-time = "2026-03-09T13:37:38.774Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/05/89/336673efd0a88956562658aba4f0bbef7cb92a6fbcbcaf94926dbc82b408/pypdf-6.7.5-py3-none-any.whl", hash = "sha256:07ba7f1d6e6d9aa2a17f5452e320a84718d4ce863367f7ede2fd72280349ab13", size = 331421, upload-time = "2026-03-02T09:05:19.722Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
|
||||
@@ -35,7 +35,7 @@ COPY --from=packages /app/web/ .
|
||||
COPY . .
|
||||
|
||||
ENV NODE_OPTIONS="--max-old-space-size=4096"
|
||||
RUN pnpm build
|
||||
RUN pnpm build:docker
|
||||
|
||||
|
||||
# production stage
|
||||
|
||||
@@ -295,7 +295,24 @@ describe('Pricing Modal Flow', () => {
|
||||
})
|
||||
})
|
||||
|
||||
// ─── 6. Pricing URL ─────────────────────────────────────────────────────
|
||||
// ─── 6. Close Handling ───────────────────────────────────────────────────
|
||||
describe('Close handling', () => {
|
||||
it('should call onCancel when pressing ESC key', () => {
|
||||
render(<Pricing onCancel={onCancel} />)
|
||||
|
||||
// ahooks useKeyPress listens on document for keydown events
|
||||
document.dispatchEvent(new KeyboardEvent('keydown', {
|
||||
key: 'Escape',
|
||||
code: 'Escape',
|
||||
keyCode: 27,
|
||||
bubbles: true,
|
||||
}))
|
||||
|
||||
expect(onCancel).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
// ─── 7. Pricing URL ─────────────────────────────────────────────────────
|
||||
describe('Pricing page URL', () => {
|
||||
it('should render pricing link with correct URL', () => {
|
||||
render(<Pricing onCancel={onCancel} />)
|
||||
|
||||
@@ -48,7 +48,7 @@ export default function ChartView({ appId, headerRight }: IChartViewProps) {
|
||||
return (
|
||||
<div>
|
||||
<div className="mb-4">
|
||||
<div className="mb-2 text-text-primary system-xl-semibold">{t('appMenus.overview', { ns: 'common' })}</div>
|
||||
<div className="system-xl-semibold mb-2 text-text-primary">{t('appMenus.overview', { ns: 'common' })}</div>
|
||||
<div className="flex items-center justify-between">
|
||||
{IS_CLOUD_EDITION
|
||||
? (
|
||||
|
||||
@@ -30,7 +30,7 @@ const DatePicker: FC<Props> = ({
|
||||
|
||||
const renderDate = useCallback(({ value, handleClickTrigger, isOpen }: TriggerProps) => {
|
||||
return (
|
||||
<div className={cn('flex h-7 cursor-pointer items-center rounded-lg px-1 text-components-input-text-filled system-sm-regular hover:bg-state-base-hover', isOpen && 'bg-state-base-hover')} onClick={handleClickTrigger}>
|
||||
<div className={cn('system-sm-regular flex h-7 cursor-pointer items-center rounded-lg px-1 text-components-input-text-filled hover:bg-state-base-hover', isOpen && 'bg-state-base-hover')} onClick={handleClickTrigger}>
|
||||
{value ? formatToLocalTime(value, locale, 'MMM D') : ''}
|
||||
</div>
|
||||
)
|
||||
@@ -64,7 +64,7 @@ const DatePicker: FC<Props> = ({
|
||||
noConfirm
|
||||
getIsDateDisabled={startDateDisabled}
|
||||
/>
|
||||
<span className="text-text-tertiary system-sm-regular">-</span>
|
||||
<span className="system-sm-regular text-text-tertiary">-</span>
|
||||
<Picker
|
||||
value={end}
|
||||
onChange={onEndChange}
|
||||
|
||||
@@ -45,7 +45,7 @@ const RangeSelector: FC<Props> = ({
|
||||
const renderTrigger = useCallback((item: Item | null, isOpen: boolean) => {
|
||||
return (
|
||||
<div className={cn('flex h-8 cursor-pointer items-center space-x-1.5 rounded-lg bg-components-input-bg-normal pl-3 pr-2', isOpen && 'bg-state-base-hover-alt')}>
|
||||
<div className="text-components-input-text-filled system-sm-regular">{isCustomRange ? t('filter.period.custom', { ns: 'appLog' }) : item?.name}</div>
|
||||
<div className="system-sm-regular text-components-input-text-filled">{isCustomRange ? t('filter.period.custom', { ns: 'appLog' }) : item?.name}</div>
|
||||
<RiArrowDownSLine className={cn('size-4 text-text-quaternary', isOpen && 'text-text-secondary')} />
|
||||
</div>
|
||||
)
|
||||
@@ -57,13 +57,13 @@ const RangeSelector: FC<Props> = ({
|
||||
{selected && (
|
||||
<span
|
||||
className={cn(
|
||||
'absolute left-2 top-[9px] flex items-center text-text-accent',
|
||||
'absolute left-2 top-[9px] flex items-center text-text-accent',
|
||||
)}
|
||||
>
|
||||
<RiCheckLine className="h-4 w-4" aria-hidden="true" />
|
||||
</span>
|
||||
)}
|
||||
<span className={cn('block truncate system-md-regular')}>{item.name}</span>
|
||||
<span className={cn('system-md-regular block truncate')}>{item.name}</span>
|
||||
</>
|
||||
)
|
||||
}, [])
|
||||
|
||||
@@ -327,11 +327,11 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
<div className="flex items-center justify-between">
|
||||
<div className="flex items-center">
|
||||
<TracingIcon size="md" className="mr-2" />
|
||||
<div className="text-text-primary title-2xl-semi-bold">{t(`${I18N_PREFIX}.tracing`, { ns: 'app' })}</div>
|
||||
<div className="title-2xl-semi-bold text-text-primary">{t(`${I18N_PREFIX}.tracing`, { ns: 'app' })}</div>
|
||||
</div>
|
||||
<div className="flex items-center">
|
||||
<Indicator color={enabled ? 'green' : 'gray'} />
|
||||
<div className={cn('ml-1 text-text-tertiary system-xs-semibold-uppercase', enabled && 'text-util-colors-green-green-600')}>
|
||||
<div className={cn('system-xs-semibold-uppercase ml-1 text-text-tertiary', enabled && 'text-util-colors-green-green-600')}>
|
||||
{t(`${I18N_PREFIX}.${enabled ? 'enabled' : 'disabled'}`, { ns: 'app' })}
|
||||
</div>
|
||||
{!readOnly && (
|
||||
@@ -350,7 +350,7 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div className="mt-2 text-text-tertiary system-xs-regular">
|
||||
<div className="system-xs-regular mt-2 text-text-tertiary">
|
||||
{t(`${I18N_PREFIX}.tracingDescription`, { ns: 'app' })}
|
||||
</div>
|
||||
<Divider className="my-3" />
|
||||
@@ -358,7 +358,7 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
{(providerAllConfigured || providerAllNotConfigured)
|
||||
? (
|
||||
<>
|
||||
<div className="text-text-tertiary system-xs-medium-uppercase">{t(`${I18N_PREFIX}.configProviderTitle.${providerAllConfigured ? 'configured' : 'notConfigured'}`, { ns: 'app' })}</div>
|
||||
<div className="system-xs-medium-uppercase text-text-tertiary">{t(`${I18N_PREFIX}.configProviderTitle.${providerAllConfigured ? 'configured' : 'notConfigured'}`, { ns: 'app' })}</div>
|
||||
<div className="mt-2 max-h-96 space-y-2 overflow-y-auto">
|
||||
{langfusePanel}
|
||||
{langSmithPanel}
|
||||
@@ -375,11 +375,11 @@ const ConfigPopup: FC<PopupProps> = ({
|
||||
)
|
||||
: (
|
||||
<>
|
||||
<div className="text-text-tertiary system-xs-medium-uppercase">{t(`${I18N_PREFIX}.configProviderTitle.configured`, { ns: 'app' })}</div>
|
||||
<div className="system-xs-medium-uppercase text-text-tertiary">{t(`${I18N_PREFIX}.configProviderTitle.configured`, { ns: 'app' })}</div>
|
||||
<div className="mt-2 max-h-40 space-y-2 overflow-y-auto">
|
||||
{configuredProviderPanel()}
|
||||
</div>
|
||||
<div className="mt-3 text-text-tertiary system-xs-medium-uppercase">{t(`${I18N_PREFIX}.configProviderTitle.moreProvider`, { ns: 'app' })}</div>
|
||||
<div className="system-xs-medium-uppercase mt-3 text-text-tertiary">{t(`${I18N_PREFIX}.configProviderTitle.moreProvider`, { ns: 'app' })}</div>
|
||||
<div className="mt-2 max-h-40 space-y-2 overflow-y-auto">
|
||||
{moreProviderPanel()}
|
||||
</div>
|
||||
|
||||
@@ -254,7 +254,7 @@ const Panel: FC = () => {
|
||||
)}
|
||||
>
|
||||
<TracingIcon size="md" />
|
||||
<div className="mx-2 text-text-secondary system-sm-semibold">{t(`${I18N_PREFIX}.title`, { ns: 'app' })}</div>
|
||||
<div className="system-sm-semibold mx-2 text-text-secondary">{t(`${I18N_PREFIX}.title`, { ns: 'app' })}</div>
|
||||
<div className="rounded-md p-1">
|
||||
<RiEqualizer2Line className="h-4 w-4 text-text-tertiary" />
|
||||
</div>
|
||||
@@ -294,7 +294,7 @@ const Panel: FC = () => {
|
||||
>
|
||||
<div className="ml-4 mr-1 flex items-center">
|
||||
<Indicator color={enabled ? 'green' : 'gray'} />
|
||||
<div className="ml-1.5 text-text-tertiary system-xs-semibold-uppercase">
|
||||
<div className="system-xs-semibold-uppercase ml-1.5 text-text-tertiary">
|
||||
{t(`${I18N_PREFIX}.${enabled ? 'enabled' : 'disabled'}`, { ns: 'app' })}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -302,7 +302,7 @@ const ProviderConfigModal: FC<Props> = ({
|
||||
<div className="mx-2 max-h-[calc(100vh-120px)] w-[640px] overflow-y-auto rounded-2xl bg-components-panel-bg shadow-xl">
|
||||
<div className="px-8 pt-8">
|
||||
<div className="mb-4 flex items-center justify-between">
|
||||
<div className="text-text-primary title-2xl-semi-bold">
|
||||
<div className="title-2xl-semi-bold text-text-primary">
|
||||
{t(`${I18N_PREFIX}.title`, { ns: 'app' })}
|
||||
{t(`tracing.${type}.title`, { ns: 'app' })}
|
||||
</div>
|
||||
|
||||
@@ -82,7 +82,7 @@ const ProviderPanel: FC<Props> = ({
|
||||
<div className="flex items-center justify-between space-x-1">
|
||||
<div className="flex items-center">
|
||||
<Icon className="h-6" />
|
||||
{isChosen && <div className="ml-1 flex h-4 items-center rounded-[4px] border border-text-accent-secondary px-1 text-text-accent-secondary system-2xs-medium-uppercase">{t(`${I18N_PREFIX}.inUse`, { ns: 'app' })}</div>}
|
||||
{isChosen && <div className="system-2xs-medium-uppercase ml-1 flex h-4 items-center rounded-[4px] border border-text-accent-secondary px-1 text-text-accent-secondary">{t(`${I18N_PREFIX}.inUse`, { ns: 'app' })}</div>}
|
||||
</div>
|
||||
{!readOnly && (
|
||||
<div className="flex items-center justify-between space-x-1">
|
||||
@@ -102,7 +102,7 @@ const ProviderPanel: FC<Props> = ({
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="mt-2 text-text-tertiary system-xs-regular">
|
||||
<div className="system-xs-regular mt-2 text-text-tertiary">
|
||||
{t(`${I18N_PREFIX}.${type}.description`, { ns: 'app' })}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -7,8 +7,8 @@ const Settings = () => {
|
||||
return (
|
||||
<div className="h-full overflow-y-auto">
|
||||
<div className="flex flex-col gap-y-0.5 px-6 pb-2 pt-3">
|
||||
<div className="text-text-primary system-xl-semibold">{t('title')}</div>
|
||||
<div className="text-text-tertiary system-sm-regular">{t('desc')}</div>
|
||||
<div className="system-xl-semibold text-text-primary">{t('title')}</div>
|
||||
<div className="system-sm-regular text-text-tertiary">{t('desc')}</div>
|
||||
</div>
|
||||
<Form />
|
||||
</div>
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import type { ReactNode } from 'react'
|
||||
import * as React from 'react'
|
||||
import { AppInitializer } from '@/app/components/app-initializer'
|
||||
import InSiteMessageNotification from '@/app/components/app/in-site-message/notification'
|
||||
import AmplitudeProvider from '@/app/components/base/amplitude'
|
||||
import GA, { GaType } from '@/app/components/base/ga'
|
||||
import Zendesk from '@/app/components/base/zendesk'
|
||||
@@ -33,7 +32,6 @@ const Layout = ({ children }: { children: ReactNode }) => {
|
||||
<RoleRouteGuard>
|
||||
{children}
|
||||
</RoleRouteGuard>
|
||||
<InSiteMessageNotification />
|
||||
<PartnerStack />
|
||||
<ReadmePanel />
|
||||
<GotoAnything />
|
||||
|
||||
@@ -106,17 +106,17 @@ const FormContent = () => {
|
||||
<RiCheckboxCircleFill className="h-8 w-8 text-text-success" />
|
||||
</div>
|
||||
<div className="grow">
|
||||
<div className="text-text-primary title-4xl-semi-bold">{t('humanInput.thanks', { ns: 'share' })}</div>
|
||||
<div className="text-text-primary title-4xl-semi-bold">{t('humanInput.recorded', { ns: 'share' })}</div>
|
||||
<div className="title-4xl-semi-bold text-text-primary">{t('humanInput.thanks', { ns: 'share' })}</div>
|
||||
<div className="title-4xl-semi-bold text-text-primary">{t('humanInput.recorded', { ns: 'share' })}</div>
|
||||
</div>
|
||||
<div className="shrink-0 text-text-tertiary system-2xs-regular-uppercase">{t('humanInput.submissionID', { id: token, ns: 'share' })}</div>
|
||||
<div className="system-2xs-regular-uppercase shrink-0 text-text-tertiary">{t('humanInput.submissionID', { id: token, ns: 'share' })}</div>
|
||||
</div>
|
||||
<div className="flex flex-row-reverse px-2 py-3">
|
||||
<div className={cn(
|
||||
'flex shrink-0 items-center gap-1.5 px-1',
|
||||
)}
|
||||
>
|
||||
<div className="text-text-tertiary system-2xs-medium-uppercase">{t('chat.poweredBy', { ns: 'share' })}</div>
|
||||
<div className="system-2xs-medium-uppercase text-text-tertiary">{t('chat.poweredBy', { ns: 'share' })}</div>
|
||||
<DifyLogo size="small" />
|
||||
</div>
|
||||
</div>
|
||||
@@ -134,17 +134,17 @@ const FormContent = () => {
|
||||
<RiInformation2Fill className="h-8 w-8 text-text-accent" />
|
||||
</div>
|
||||
<div className="grow">
|
||||
<div className="text-text-primary title-4xl-semi-bold">{t('humanInput.sorry', { ns: 'share' })}</div>
|
||||
<div className="text-text-primary title-4xl-semi-bold">{t('humanInput.expired', { ns: 'share' })}</div>
|
||||
<div className="title-4xl-semi-bold text-text-primary">{t('humanInput.sorry', { ns: 'share' })}</div>
|
||||
<div className="title-4xl-semi-bold text-text-primary">{t('humanInput.expired', { ns: 'share' })}</div>
|
||||
</div>
|
||||
<div className="shrink-0 text-text-tertiary system-2xs-regular-uppercase">{t('humanInput.submissionID', { id: token, ns: 'share' })}</div>
|
||||
<div className="system-2xs-regular-uppercase shrink-0 text-text-tertiary">{t('humanInput.submissionID', { id: token, ns: 'share' })}</div>
|
||||
</div>
|
||||
<div className="flex flex-row-reverse px-2 py-3">
|
||||
<div className={cn(
|
||||
'flex shrink-0 items-center gap-1.5 px-1',
|
||||
)}
|
||||
>
|
||||
<div className="text-text-tertiary system-2xs-medium-uppercase">{t('chat.poweredBy', { ns: 'share' })}</div>
|
||||
<div className="system-2xs-medium-uppercase text-text-tertiary">{t('chat.poweredBy', { ns: 'share' })}</div>
|
||||
<DifyLogo size="small" />
|
||||
</div>
|
||||
</div>
|
||||
@@ -162,17 +162,17 @@ const FormContent = () => {
|
||||
<RiInformation2Fill className="h-8 w-8 text-text-accent" />
|
||||
</div>
|
||||
<div className="grow">
|
||||
<div className="text-text-primary title-4xl-semi-bold">{t('humanInput.sorry', { ns: 'share' })}</div>
|
||||
<div className="text-text-primary title-4xl-semi-bold">{t('humanInput.completed', { ns: 'share' })}</div>
|
||||
<div className="title-4xl-semi-bold text-text-primary">{t('humanInput.sorry', { ns: 'share' })}</div>
|
||||
<div className="title-4xl-semi-bold text-text-primary">{t('humanInput.completed', { ns: 'share' })}</div>
|
||||
</div>
|
||||
<div className="shrink-0 text-text-tertiary system-2xs-regular-uppercase">{t('humanInput.submissionID', { id: token, ns: 'share' })}</div>
|
||||
<div className="system-2xs-regular-uppercase shrink-0 text-text-tertiary">{t('humanInput.submissionID', { id: token, ns: 'share' })}</div>
|
||||
</div>
|
||||
<div className="flex flex-row-reverse px-2 py-3">
|
||||
<div className={cn(
|
||||
'flex shrink-0 items-center gap-1.5 px-1',
|
||||
)}
|
||||
>
|
||||
<div className="text-text-tertiary system-2xs-medium-uppercase">{t('chat.poweredBy', { ns: 'share' })}</div>
|
||||
<div className="system-2xs-medium-uppercase text-text-tertiary">{t('chat.poweredBy', { ns: 'share' })}</div>
|
||||
<DifyLogo size="small" />
|
||||
</div>
|
||||
</div>
|
||||
@@ -190,7 +190,7 @@ const FormContent = () => {
|
||||
<RiErrorWarningFill className="h-8 w-8 text-text-destructive" />
|
||||
</div>
|
||||
<div className="grow">
|
||||
<div className="text-text-primary title-4xl-semi-bold">{t('humanInput.rateLimitExceeded', { ns: 'share' })}</div>
|
||||
<div className="title-4xl-semi-bold text-text-primary">{t('humanInput.rateLimitExceeded', { ns: 'share' })}</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-row-reverse px-2 py-3">
|
||||
@@ -198,7 +198,7 @@ const FormContent = () => {
|
||||
'flex shrink-0 items-center gap-1.5 px-1',
|
||||
)}
|
||||
>
|
||||
<div className="text-text-tertiary system-2xs-medium-uppercase">{t('chat.poweredBy', { ns: 'share' })}</div>
|
||||
<div className="system-2xs-medium-uppercase text-text-tertiary">{t('chat.poweredBy', { ns: 'share' })}</div>
|
||||
<DifyLogo size="small" />
|
||||
</div>
|
||||
</div>
|
||||
@@ -216,7 +216,7 @@ const FormContent = () => {
|
||||
<RiErrorWarningFill className="h-8 w-8 text-text-destructive" />
|
||||
</div>
|
||||
<div className="grow">
|
||||
<div className="text-text-primary title-4xl-semi-bold">{t('humanInput.formNotFound', { ns: 'share' })}</div>
|
||||
<div className="title-4xl-semi-bold text-text-primary">{t('humanInput.formNotFound', { ns: 'share' })}</div>
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-row-reverse px-2 py-3">
|
||||
@@ -224,7 +224,7 @@ const FormContent = () => {
|
||||
'flex shrink-0 items-center gap-1.5 px-1',
|
||||
)}
|
||||
>
|
||||
<div className="text-text-tertiary system-2xs-medium-uppercase">{t('chat.poweredBy', { ns: 'share' })}</div>
|
||||
<div className="system-2xs-medium-uppercase text-text-tertiary">{t('chat.poweredBy', { ns: 'share' })}</div>
|
||||
<DifyLogo size="small" />
|
||||
</div>
|
||||
</div>
|
||||
@@ -245,7 +245,7 @@ const FormContent = () => {
|
||||
background={site.icon_background}
|
||||
imageUrl={site.icon_url}
|
||||
/>
|
||||
<div className="grow text-text-primary system-xl-semibold">{site.title}</div>
|
||||
<div className="system-xl-semibold grow text-text-primary">{site.title}</div>
|
||||
</div>
|
||||
<div className="h-0 w-full grow overflow-y-auto">
|
||||
<div className="border-components-divider-subtle rounded-[20px] border bg-chat-bubble-bg p-4 shadow-lg backdrop-blur-sm">
|
||||
@@ -277,7 +277,7 @@ const FormContent = () => {
|
||||
'flex shrink-0 items-center gap-1.5 px-1',
|
||||
)}
|
||||
>
|
||||
<div className="text-text-tertiary system-2xs-medium-uppercase">{t('chat.poweredBy', { ns: 'share' })}</div>
|
||||
<div className="system-2xs-medium-uppercase text-text-tertiary">{t('chat.poweredBy', { ns: 'share' })}</div>
|
||||
<DifyLogo size="small" />
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -81,7 +81,7 @@ const AuthenticatedLayout = ({ children }: { children: React.ReactNode }) => {
|
||||
return (
|
||||
<div className="flex h-full flex-col items-center justify-center gap-y-2">
|
||||
<AppUnavailable className="h-auto w-auto" code={403} unknownReason="no permission." />
|
||||
<span className="cursor-pointer text-text-tertiary system-sm-regular" onClick={backToHome}>{t('userProfile.logout', { ns: 'common' })}</span>
|
||||
<span className="system-sm-regular cursor-pointer text-text-tertiary" onClick={backToHome}>{t('userProfile.logout', { ns: 'common' })}</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -95,7 +95,7 @@ const Splash: FC<PropsWithChildren> = ({ children }) => {
|
||||
return (
|
||||
<div className="flex h-full flex-col items-center justify-center gap-y-4">
|
||||
<AppUnavailable className="h-auto w-auto" code={code || t('common.appUnavailable', { ns: 'share' })} unknownReason={message} />
|
||||
<span className="cursor-pointer text-text-tertiary system-sm-regular" onClick={backToHome}>{code === '403' ? t('userProfile.logout', { ns: 'common' }) : t('login.backToHome', { ns: 'share' })}</span>
|
||||
<span className="system-sm-regular cursor-pointer text-text-tertiary" onClick={backToHome}>{code === '403' ? t('userProfile.logout', { ns: 'common' }) : t('login.backToHome', { ns: 'share' })}</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -69,8 +69,8 @@ export default function CheckCode() {
|
||||
<RiMailSendFill className="h-6 w-6 text-2xl" />
|
||||
</div>
|
||||
<div className="pb-4 pt-2">
|
||||
<h2 className="text-text-primary title-4xl-semi-bold">{t('checkCode.checkYourEmail', { ns: 'login' })}</h2>
|
||||
<p className="mt-2 text-text-secondary body-md-regular">
|
||||
<h2 className="title-4xl-semi-bold text-text-primary">{t('checkCode.checkYourEmail', { ns: 'login' })}</h2>
|
||||
<p className="body-md-regular mt-2 text-text-secondary">
|
||||
<span>
|
||||
{t('checkCode.tipsPrefix', { ns: 'login' })}
|
||||
<strong>{email}</strong>
|
||||
@@ -82,7 +82,7 @@ export default function CheckCode() {
|
||||
|
||||
<form action="">
|
||||
<input type="text" className="hidden" />
|
||||
<label htmlFor="code" className="mb-1 text-text-secondary system-md-semibold">{t('checkCode.verificationCode', { ns: 'login' })}</label>
|
||||
<label htmlFor="code" className="system-md-semibold mb-1 text-text-secondary">{t('checkCode.verificationCode', { ns: 'login' })}</label>
|
||||
<Input value={code} onChange={e => setVerifyCode(e.target.value)} maxLength={6} className="mt-1" placeholder={t('checkCode.verificationCodePlaceholder', { ns: 'login' }) || ''} />
|
||||
<Button loading={loading} disabled={loading} className="my-3 w-full" variant="primary" onClick={verify}>{t('checkCode.verify', { ns: 'login' })}</Button>
|
||||
<Countdown onResend={resendCode} />
|
||||
@@ -94,7 +94,7 @@ export default function CheckCode() {
|
||||
<div className="bg-background-default-dimm inline-block rounded-full p-1">
|
||||
<RiArrowLeftLine size={12} />
|
||||
</div>
|
||||
<span className="ml-2 system-xs-regular">{t('back', { ns: 'login' })}</span>
|
||||
<span className="system-xs-regular ml-2">{t('back', { ns: 'login' })}</span>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
||||
@@ -24,7 +24,7 @@ export default function SignInLayout({ children }: any) {
|
||||
</div>
|
||||
</div>
|
||||
{!systemFeatures.branding.enabled && (
|
||||
<div className="px-8 py-6 text-text-tertiary system-xs-regular">
|
||||
<div className="system-xs-regular px-8 py-6 text-text-tertiary">
|
||||
©
|
||||
{' '}
|
||||
{new Date().getFullYear()}
|
||||
|
||||
@@ -74,8 +74,8 @@ export default function CheckCode() {
|
||||
<RiLockPasswordLine className="h-6 w-6 text-2xl text-text-accent-light-mode-only" />
|
||||
</div>
|
||||
<div className="pb-4 pt-2">
|
||||
<h2 className="text-text-primary title-4xl-semi-bold">{t('resetPassword', { ns: 'login' })}</h2>
|
||||
<p className="mt-2 text-text-secondary body-md-regular">
|
||||
<h2 className="title-4xl-semi-bold text-text-primary">{t('resetPassword', { ns: 'login' })}</h2>
|
||||
<p className="body-md-regular mt-2 text-text-secondary">
|
||||
{t('resetPasswordDesc', { ns: 'login' })}
|
||||
</p>
|
||||
</div>
|
||||
@@ -83,7 +83,7 @@ export default function CheckCode() {
|
||||
<form onSubmit={noop}>
|
||||
<input type="text" className="hidden" />
|
||||
<div className="mb-2">
|
||||
<label htmlFor="email" className="my-2 text-text-secondary system-md-semibold">{t('email', { ns: 'login' })}</label>
|
||||
<label htmlFor="email" className="system-md-semibold my-2 text-text-secondary">{t('email', { ns: 'login' })}</label>
|
||||
<div className="mt-1">
|
||||
<Input id="email" type="email" disabled={loading} value={email} placeholder={t('emailPlaceholder', { ns: 'login' }) as string} onChange={e => setEmail(e.target.value)} />
|
||||
</div>
|
||||
@@ -99,7 +99,7 @@ export default function CheckCode() {
|
||||
<div className="inline-block rounded-full bg-background-default-dimmed p-1">
|
||||
<RiArrowLeftLine size={12} />
|
||||
</div>
|
||||
<span className="ml-2 system-xs-regular">{t('backToLogin', { ns: 'login' })}</span>
|
||||
<span className="system-xs-regular ml-2">{t('backToLogin', { ns: 'login' })}</span>
|
||||
</Link>
|
||||
</div>
|
||||
)
|
||||
|
||||
@@ -91,10 +91,10 @@ const ChangePasswordForm = () => {
|
||||
{!showSuccess && (
|
||||
<div className="flex flex-col md:w-[400px]">
|
||||
<div className="mx-auto w-full">
|
||||
<h2 className="text-text-primary title-4xl-semi-bold">
|
||||
<h2 className="title-4xl-semi-bold text-text-primary">
|
||||
{t('changePassword', { ns: 'login' })}
|
||||
</h2>
|
||||
<p className="mt-2 text-text-secondary body-md-regular">
|
||||
<p className="body-md-regular mt-2 text-text-secondary">
|
||||
{t('changePasswordTip', { ns: 'login' })}
|
||||
</p>
|
||||
</div>
|
||||
@@ -103,7 +103,7 @@ const ChangePasswordForm = () => {
|
||||
<div className="bg-white">
|
||||
{/* Password */}
|
||||
<div className="mb-5">
|
||||
<label htmlFor="password" className="my-2 text-text-secondary system-md-semibold">
|
||||
<label htmlFor="password" className="system-md-semibold my-2 text-text-secondary">
|
||||
{t('account.newPassword', { ns: 'common' })}
|
||||
</label>
|
||||
<div className="relative mt-1">
|
||||
@@ -125,11 +125,11 @@ const ChangePasswordForm = () => {
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-1 text-text-secondary body-xs-regular">{t('error.passwordInvalid', { ns: 'login' })}</div>
|
||||
<div className="body-xs-regular mt-1 text-text-secondary">{t('error.passwordInvalid', { ns: 'login' })}</div>
|
||||
</div>
|
||||
{/* Confirm Password */}
|
||||
<div className="mb-5">
|
||||
<label htmlFor="confirmPassword" className="my-2 text-text-secondary system-md-semibold">
|
||||
<label htmlFor="confirmPassword" className="system-md-semibold my-2 text-text-secondary">
|
||||
{t('account.confirmPassword', { ns: 'common' })}
|
||||
</label>
|
||||
<div className="relative mt-1">
|
||||
@@ -170,7 +170,7 @@ const ChangePasswordForm = () => {
|
||||
<div className="mb-3 flex h-14 w-14 items-center justify-center rounded-2xl border border-components-panel-border-subtle font-bold shadow-lg">
|
||||
<RiCheckboxCircleFill className="h-6 w-6 text-text-success" />
|
||||
</div>
|
||||
<h2 className="text-text-primary title-4xl-semi-bold">
|
||||
<h2 className="title-4xl-semi-bold text-text-primary">
|
||||
{t('passwordChangedTip', { ns: 'login' })}
|
||||
</h2>
|
||||
</div>
|
||||
|
||||
@@ -110,8 +110,8 @@ export default function CheckCode() {
|
||||
<RiMailSendFill className="h-6 w-6 text-2xl text-text-accent-light-mode-only" />
|
||||
</div>
|
||||
<div className="pb-4 pt-2">
|
||||
<h2 className="text-text-primary title-4xl-semi-bold">{t('checkCode.checkYourEmail', { ns: 'login' })}</h2>
|
||||
<p className="mt-2 text-text-secondary body-md-regular">
|
||||
<h2 className="title-4xl-semi-bold text-text-primary">{t('checkCode.checkYourEmail', { ns: 'login' })}</h2>
|
||||
<p className="body-md-regular mt-2 text-text-secondary">
|
||||
<span>
|
||||
{t('checkCode.tipsPrefix', { ns: 'login' })}
|
||||
<strong>{email}</strong>
|
||||
@@ -122,7 +122,7 @@ export default function CheckCode() {
|
||||
</div>
|
||||
|
||||
<form onSubmit={handleSubmit}>
|
||||
<label htmlFor="code" className="mb-1 text-text-secondary system-md-semibold">{t('checkCode.verificationCode', { ns: 'login' })}</label>
|
||||
<label htmlFor="code" className="system-md-semibold mb-1 text-text-secondary">{t('checkCode.verificationCode', { ns: 'login' })}</label>
|
||||
<Input
|
||||
ref={codeInputRef}
|
||||
id="code"
|
||||
@@ -142,7 +142,7 @@ export default function CheckCode() {
|
||||
<div className="bg-background-default-dimm inline-block rounded-full p-1">
|
||||
<RiArrowLeftLine size={12} />
|
||||
</div>
|
||||
<span className="ml-2 system-xs-regular">{t('back', { ns: 'login' })}</span>
|
||||
<span className="system-xs-regular ml-2">{t('back', { ns: 'login' })}</span>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
||||
@@ -55,7 +55,7 @@ export default function MailAndCodeAuth() {
|
||||
<form onSubmit={noop}>
|
||||
<input type="text" className="hidden" />
|
||||
<div className="mb-2">
|
||||
<label htmlFor="email" className="my-2 text-text-secondary system-md-semibold">{t('email', { ns: 'login' })}</label>
|
||||
<label htmlFor="email" className="system-md-semibold my-2 text-text-secondary">{t('email', { ns: 'login' })}</label>
|
||||
<div className="mt-1">
|
||||
<Input id="email" type="email" value={email} placeholder={t('emailPlaceholder', { ns: 'login' }) as string} onChange={e => setEmail(e.target.value)} />
|
||||
</div>
|
||||
|
||||
@@ -112,7 +112,7 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
|
||||
return (
|
||||
<form onSubmit={noop}>
|
||||
<div className="mb-3">
|
||||
<label htmlFor="email" className="my-2 text-text-secondary system-md-semibold">
|
||||
<label htmlFor="email" className="system-md-semibold my-2 text-text-secondary">
|
||||
{t('email', { ns: 'login' })}
|
||||
</label>
|
||||
<div className="mt-1">
|
||||
@@ -130,7 +130,7 @@ export default function MailAndPasswordAuth({ isEmailSetup }: MailAndPasswordAut
|
||||
|
||||
<div className="mb-3">
|
||||
<label htmlFor="password" className="my-2 flex items-center justify-between">
|
||||
<span className="text-text-secondary system-md-semibold">{t('password', { ns: 'login' })}</span>
|
||||
<span className="system-md-semibold text-text-secondary">{t('password', { ns: 'login' })}</span>
|
||||
<Link
|
||||
href={`/webapp-reset-password?${searchParams.toString()}`}
|
||||
className={`system-xs-regular ${isEmailSetup ? 'text-components-button-secondary-accent-text' : 'pointer-events-none text-components-button-secondary-accent-text-disabled'}`}
|
||||
|
||||
@@ -21,7 +21,7 @@ export default function SignInLayout({ children }: PropsWithChildren) {
|
||||
</div>
|
||||
</div>
|
||||
{systemFeatures.branding.enabled === false && (
|
||||
<div className="px-8 py-6 text-text-tertiary system-xs-regular">
|
||||
<div className="system-xs-regular px-8 py-6 text-text-tertiary">
|
||||
©
|
||||
{' '}
|
||||
{new Date().getFullYear()}
|
||||
|
||||
@@ -60,8 +60,8 @@ const NormalForm = () => {
|
||||
<RiContractLine className="h-5 w-5" />
|
||||
<RiErrorWarningFill className="absolute -right-1 -top-1 h-4 w-4 text-text-warning-secondary" />
|
||||
</div>
|
||||
<p className="text-text-primary system-sm-medium">{t('licenseLost', { ns: 'login' })}</p>
|
||||
<p className="mt-1 text-text-tertiary system-xs-regular">{t('licenseLostTip', { ns: 'login' })}</p>
|
||||
<p className="system-sm-medium text-text-primary">{t('licenseLost', { ns: 'login' })}</p>
|
||||
<p className="system-xs-regular mt-1 text-text-tertiary">{t('licenseLostTip', { ns: 'login' })}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -76,8 +76,8 @@ const NormalForm = () => {
|
||||
<RiContractLine className="h-5 w-5" />
|
||||
<RiErrorWarningFill className="absolute -right-1 -top-1 h-4 w-4 text-text-warning-secondary" />
|
||||
</div>
|
||||
<p className="text-text-primary system-sm-medium">{t('licenseExpired', { ns: 'login' })}</p>
|
||||
<p className="mt-1 text-text-tertiary system-xs-regular">{t('licenseExpiredTip', { ns: 'login' })}</p>
|
||||
<p className="system-sm-medium text-text-primary">{t('licenseExpired', { ns: 'login' })}</p>
|
||||
<p className="system-xs-regular mt-1 text-text-tertiary">{t('licenseExpiredTip', { ns: 'login' })}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -92,8 +92,8 @@ const NormalForm = () => {
|
||||
<RiContractLine className="h-5 w-5" />
|
||||
<RiErrorWarningFill className="absolute -right-1 -top-1 h-4 w-4 text-text-warning-secondary" />
|
||||
</div>
|
||||
<p className="text-text-primary system-sm-medium">{t('licenseInactive', { ns: 'login' })}</p>
|
||||
<p className="mt-1 text-text-tertiary system-xs-regular">{t('licenseInactiveTip', { ns: 'login' })}</p>
|
||||
<p className="system-sm-medium text-text-primary">{t('licenseInactive', { ns: 'login' })}</p>
|
||||
<p className="system-xs-regular mt-1 text-text-tertiary">{t('licenseInactiveTip', { ns: 'login' })}</p>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
@@ -104,8 +104,8 @@ const NormalForm = () => {
|
||||
<>
|
||||
<div className="mx-auto mt-8 w-full">
|
||||
<div className="mx-auto w-full">
|
||||
<h2 className="text-text-primary title-4xl-semi-bold">{systemFeatures.branding.enabled ? t('pageTitleForE', { ns: 'login' }) : t('pageTitle', { ns: 'login' })}</h2>
|
||||
<p className="mt-2 text-text-tertiary body-md-regular">{t('welcome', { ns: 'login' })}</p>
|
||||
<h2 className="title-4xl-semi-bold text-text-primary">{systemFeatures.branding.enabled ? t('pageTitleForE', { ns: 'login' }) : t('pageTitle', { ns: 'login' })}</h2>
|
||||
<p className="body-md-regular mt-2 text-text-tertiary">{t('welcome', { ns: 'login' })}</p>
|
||||
</div>
|
||||
<div className="relative">
|
||||
<div className="mt-6 flex flex-col gap-3">
|
||||
@@ -122,7 +122,7 @@ const NormalForm = () => {
|
||||
<div className="h-px w-full bg-gradient-to-r from-background-gradient-mask-transparent via-divider-regular to-background-gradient-mask-transparent"></div>
|
||||
</div>
|
||||
<div className="relative flex justify-center">
|
||||
<span className="px-2 text-text-tertiary system-xs-medium-uppercase">{t('or', { ns: 'login' })}</span>
|
||||
<span className="system-xs-medium-uppercase px-2 text-text-tertiary">{t('or', { ns: 'login' })}</span>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
@@ -134,7 +134,7 @@ const NormalForm = () => {
|
||||
<MailAndCodeAuth />
|
||||
{systemFeatures.enable_email_password_login && (
|
||||
<div className="cursor-pointer py-1 text-center" onClick={() => { updateAuthType('password') }}>
|
||||
<span className="text-components-button-secondary-accent-text system-xs-medium">{t('usePassword', { ns: 'login' })}</span>
|
||||
<span className="system-xs-medium text-components-button-secondary-accent-text">{t('usePassword', { ns: 'login' })}</span>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
@@ -144,7 +144,7 @@ const NormalForm = () => {
|
||||
<MailAndPasswordAuth isEmailSetup={systemFeatures.is_email_setup} />
|
||||
{systemFeatures.enable_email_code_login && (
|
||||
<div className="cursor-pointer py-1 text-center" onClick={() => { updateAuthType('code') }}>
|
||||
<span className="text-components-button-secondary-accent-text system-xs-medium">{t('useVerificationCode', { ns: 'login' })}</span>
|
||||
<span className="system-xs-medium text-components-button-secondary-accent-text">{t('useVerificationCode', { ns: 'login' })}</span>
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
@@ -158,8 +158,8 @@ const NormalForm = () => {
|
||||
<div className="shadows-shadow-lg mb-2 flex h-10 w-10 items-center justify-center rounded-xl bg-components-card-bg shadow">
|
||||
<RiDoorLockLine className="h-5 w-5" />
|
||||
</div>
|
||||
<p className="text-text-primary system-sm-medium">{t('noLoginMethod', { ns: 'login' })}</p>
|
||||
<p className="mt-1 text-text-tertiary system-xs-regular">{t('noLoginMethodTip', { ns: 'login' })}</p>
|
||||
<p className="system-sm-medium text-text-primary">{t('noLoginMethod', { ns: 'login' })}</p>
|
||||
<p className="system-xs-regular mt-1 text-text-tertiary">{t('noLoginMethodTip', { ns: 'login' })}</p>
|
||||
</div>
|
||||
<div className="relative my-2 py-2">
|
||||
<div className="absolute inset-0 flex items-center" aria-hidden="true">
|
||||
@@ -170,11 +170,11 @@ const NormalForm = () => {
|
||||
)}
|
||||
{!systemFeatures.branding.enabled && (
|
||||
<>
|
||||
<div className="mt-2 block w-full text-text-tertiary system-xs-regular">
|
||||
<div className="system-xs-regular mt-2 block w-full text-text-tertiary">
|
||||
{t('tosDesc', { ns: 'login' })}
|
||||
|
||||
<Link
|
||||
className="text-text-secondary system-xs-medium hover:underline"
|
||||
className="system-xs-medium text-text-secondary hover:underline"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
href="https://dify.ai/terms"
|
||||
@@ -183,7 +183,7 @@ const NormalForm = () => {
|
||||
</Link>
|
||||
&
|
||||
<Link
|
||||
className="text-text-secondary system-xs-medium hover:underline"
|
||||
className="system-xs-medium text-text-secondary hover:underline"
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
href="https://dify.ai/privacy"
|
||||
@@ -192,11 +192,11 @@ const NormalForm = () => {
|
||||
</Link>
|
||||
</div>
|
||||
{IS_CE_EDITION && (
|
||||
<div className="w-hull mt-2 block text-text-tertiary system-xs-regular">
|
||||
<div className="w-hull system-xs-regular mt-2 block text-text-tertiary">
|
||||
{t('goToInit', { ns: 'login' })}
|
||||
|
||||
<Link
|
||||
className="text-text-secondary system-xs-medium hover:underline"
|
||||
className="system-xs-medium text-text-secondary hover:underline"
|
||||
href="/install"
|
||||
>
|
||||
{t('setAdminAccount', { ns: 'login' })}
|
||||
|
||||
@@ -45,7 +45,7 @@ const WebSSOForm: FC = () => {
|
||||
if (!systemFeatures.webapp_auth.enabled) {
|
||||
return (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
<p className="text-text-tertiary system-xs-regular">{t('webapp.disabled', { ns: 'login' })}</p>
|
||||
<p className="system-xs-regular text-text-tertiary">{t('webapp.disabled', { ns: 'login' })}</p>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -63,7 +63,7 @@ const WebSSOForm: FC = () => {
|
||||
return (
|
||||
<div className="flex h-full flex-col items-center justify-center gap-y-4">
|
||||
<AppUnavailable className="h-auto w-auto" isUnknownReason={true} />
|
||||
<span className="cursor-pointer text-text-tertiary system-sm-regular" onClick={backToHome}>{t('login.backToHome', { ns: 'share' })}</span>
|
||||
<span className="system-sm-regular cursor-pointer text-text-tertiary" onClick={backToHome}>{t('login.backToHome', { ns: 'share' })}</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -11,7 +11,7 @@ import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
import ImageInput from '@/app/components/base/app-icon-picker/ImageInput'
|
||||
import getCroppedImg from '@/app/components/base/app-icon-picker/utils'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Divider from '@/app/components/base/divider'
|
||||
import { useLocalFileUploader } from '@/app/components/base/image-uploader/hooks'
|
||||
@@ -103,7 +103,7 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
|
||||
<>
|
||||
<div>
|
||||
<div className="group relative">
|
||||
<Avatar {...props} onLoadingStatusChange={status => setOnAvatarError(status === 'error')} />
|
||||
<Avatar {...props} onError={(x: boolean) => setOnAvatarError(x)} />
|
||||
<div
|
||||
className="absolute inset-0 flex cursor-pointer items-center justify-center rounded-full bg-black/50 opacity-0 transition-opacity group-hover:opacity-100"
|
||||
onClick={() => {
|
||||
@@ -160,7 +160,7 @@ const AvatarWithEdit = ({ onSave, ...props }: AvatarWithEditProps) => {
|
||||
isShow={isShowDeleteConfirm}
|
||||
onClose={() => setIsShowDeleteConfirm(false)}
|
||||
>
|
||||
<div className="mb-3 text-text-primary title-2xl-semi-bold">{t('avatar.deleteTitle', { ns: 'common' })}</div>
|
||||
<div className="title-2xl-semi-bold mb-3 text-text-primary">{t('avatar.deleteTitle', { ns: 'common' })}</div>
|
||||
<p className="mb-8 text-text-secondary">{t('avatar.deleteDescription', { ns: 'common' })}</p>
|
||||
|
||||
<div className="flex w-full items-center justify-center gap-2">
|
||||
|
||||
@@ -209,14 +209,14 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
</div>
|
||||
{step === STEP.start && (
|
||||
<>
|
||||
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.title', { ns: 'common' })}</div>
|
||||
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.title', { ns: 'common' })}</div>
|
||||
<div className="space-y-0.5 pb-2 pt-1">
|
||||
<div className="text-text-warning body-md-medium">{t('account.changeEmail.authTip', { ns: 'common' })}</div>
|
||||
<div className="text-text-secondary body-md-regular">
|
||||
<div className="body-md-medium text-text-warning">{t('account.changeEmail.authTip', { ns: 'common' })}</div>
|
||||
<div className="body-md-regular text-text-secondary">
|
||||
<Trans
|
||||
i18nKey="account.changeEmail.content1"
|
||||
ns="common"
|
||||
components={{ email: <span className="text-text-primary body-md-medium"></span> }}
|
||||
components={{ email: <span className="body-md-medium text-text-primary"></span> }}
|
||||
values={{ email }}
|
||||
/>
|
||||
</div>
|
||||
@@ -241,19 +241,19 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
)}
|
||||
{step === STEP.verifyOrigin && (
|
||||
<>
|
||||
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.verifyEmail', { ns: 'common' })}</div>
|
||||
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.verifyEmail', { ns: 'common' })}</div>
|
||||
<div className="space-y-0.5 pb-2 pt-1">
|
||||
<div className="text-text-secondary body-md-regular">
|
||||
<div className="body-md-regular text-text-secondary">
|
||||
<Trans
|
||||
i18nKey="account.changeEmail.content2"
|
||||
ns="common"
|
||||
components={{ email: <span className="text-text-primary body-md-medium"></span> }}
|
||||
components={{ email: <span className="body-md-medium text-text-primary"></span> }}
|
||||
values={{ email }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="pt-3">
|
||||
<div className="mb-1 flex h-6 items-center text-text-secondary system-sm-medium">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
|
||||
<div className="system-sm-medium mb-1 flex h-6 items-center text-text-secondary">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
|
||||
<Input
|
||||
className="!w-full"
|
||||
placeholder={t('account.changeEmail.codePlaceholder', { ns: 'common' })}
|
||||
@@ -278,25 +278,25 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
{t('operation.cancel', { ns: 'common' })}
|
||||
</Button>
|
||||
</div>
|
||||
<div className="mt-3 flex items-center gap-1 text-text-tertiary system-xs-regular">
|
||||
<div className="system-xs-regular mt-3 flex items-center gap-1 text-text-tertiary">
|
||||
<span>{t('account.changeEmail.resendTip', { ns: 'common' })}</span>
|
||||
{time > 0 && (
|
||||
<span>{t('account.changeEmail.resendCount', { ns: 'common', count: time })}</span>
|
||||
)}
|
||||
{!time && (
|
||||
<span onClick={sendCodeToOriginEmail} className="cursor-pointer text-text-accent-secondary system-xs-medium">{t('account.changeEmail.resend', { ns: 'common' })}</span>
|
||||
<span onClick={sendCodeToOriginEmail} className="system-xs-medium cursor-pointer text-text-accent-secondary">{t('account.changeEmail.resend', { ns: 'common' })}</span>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
{step === STEP.newEmail && (
|
||||
<>
|
||||
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.newEmail', { ns: 'common' })}</div>
|
||||
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.newEmail', { ns: 'common' })}</div>
|
||||
<div className="space-y-0.5 pb-2 pt-1">
|
||||
<div className="text-text-secondary body-md-regular">{t('account.changeEmail.content3', { ns: 'common' })}</div>
|
||||
<div className="body-md-regular text-text-secondary">{t('account.changeEmail.content3', { ns: 'common' })}</div>
|
||||
</div>
|
||||
<div className="pt-3">
|
||||
<div className="mb-1 flex h-6 items-center text-text-secondary system-sm-medium">{t('account.changeEmail.emailLabel', { ns: 'common' })}</div>
|
||||
<div className="system-sm-medium mb-1 flex h-6 items-center text-text-secondary">{t('account.changeEmail.emailLabel', { ns: 'common' })}</div>
|
||||
<Input
|
||||
className="!w-full"
|
||||
placeholder={t('account.changeEmail.emailPlaceholder', { ns: 'common' })}
|
||||
@@ -305,10 +305,10 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
destructive={newEmailExited || unAvailableEmail}
|
||||
/>
|
||||
{newEmailExited && (
|
||||
<div className="mt-1 py-0.5 text-text-destructive body-xs-regular">{t('account.changeEmail.existingEmail', { ns: 'common' })}</div>
|
||||
<div className="body-xs-regular mt-1 py-0.5 text-text-destructive">{t('account.changeEmail.existingEmail', { ns: 'common' })}</div>
|
||||
)}
|
||||
{unAvailableEmail && (
|
||||
<div className="mt-1 py-0.5 text-text-destructive body-xs-regular">{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}</div>
|
||||
<div className="body-xs-regular mt-1 py-0.5 text-text-destructive">{t('account.changeEmail.unAvailableEmail', { ns: 'common' })}</div>
|
||||
)}
|
||||
</div>
|
||||
<div className="mt-3 space-y-2">
|
||||
@@ -331,19 +331,19 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
)}
|
||||
{step === STEP.verifyNew && (
|
||||
<>
|
||||
<div className="pb-3 text-text-primary title-2xl-semi-bold">{t('account.changeEmail.verifyNew', { ns: 'common' })}</div>
|
||||
<div className="title-2xl-semi-bold pb-3 text-text-primary">{t('account.changeEmail.verifyNew', { ns: 'common' })}</div>
|
||||
<div className="space-y-0.5 pb-2 pt-1">
|
||||
<div className="text-text-secondary body-md-regular">
|
||||
<div className="body-md-regular text-text-secondary">
|
||||
<Trans
|
||||
i18nKey="account.changeEmail.content4"
|
||||
ns="common"
|
||||
components={{ email: <span className="text-text-primary body-md-medium"></span> }}
|
||||
components={{ email: <span className="body-md-medium text-text-primary"></span> }}
|
||||
values={{ email: mail }}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div className="pt-3">
|
||||
<div className="mb-1 flex h-6 items-center text-text-secondary system-sm-medium">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
|
||||
<div className="system-sm-medium mb-1 flex h-6 items-center text-text-secondary">{t('account.changeEmail.codeLabel', { ns: 'common' })}</div>
|
||||
<Input
|
||||
className="!w-full"
|
||||
placeholder={t('account.changeEmail.codePlaceholder', { ns: 'common' })}
|
||||
@@ -368,13 +368,13 @@ const EmailChangeModal = ({ onClose, email, show }: Props) => {
|
||||
{t('operation.cancel', { ns: 'common' })}
|
||||
</Button>
|
||||
</div>
|
||||
<div className="mt-3 flex items-center gap-1 text-text-tertiary system-xs-regular">
|
||||
<div className="system-xs-regular mt-3 flex items-center gap-1 text-text-tertiary">
|
||||
<span>{t('account.changeEmail.resendTip', { ns: 'common' })}</span>
|
||||
{time > 0 && (
|
||||
<span>{t('account.changeEmail.resendCount', { ns: 'common', count: time })}</span>
|
||||
)}
|
||||
{!time && (
|
||||
<span onClick={sendCodeToNewEmail} className="cursor-pointer text-text-accent-secondary system-xs-medium">{t('account.changeEmail.resend', { ns: 'common' })}</span>
|
||||
<span onClick={sendCodeToNewEmail} className="system-xs-medium cursor-pointer text-text-accent-secondary">{t('account.changeEmail.resend', { ns: 'common' })}</span>
|
||||
)}
|
||||
</div>
|
||||
</>
|
||||
|
||||
@@ -4,7 +4,6 @@ import type { App } from '@/types/app'
|
||||
import {
|
||||
RiGraduationCapFill,
|
||||
} from '@remixicon/react'
|
||||
import { useQueryClient } from '@tanstack/react-query'
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { useContext } from 'use-context-selector'
|
||||
@@ -16,11 +15,11 @@ import PremiumBadge from '@/app/components/base/premium-badge'
|
||||
import { ToastContext } from '@/app/components/base/toast/context'
|
||||
import Collapse from '@/app/components/header/account-setting/collapse'
|
||||
import { IS_CE_EDITION, validPassword } from '@/config'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { updateUserProfile } from '@/service/common'
|
||||
import { useAppList } from '@/service/use-apps'
|
||||
import { commonQueryKeys, useUserProfile } from '@/service/use-common'
|
||||
import DeleteAccount from '../delete-account'
|
||||
|
||||
import AvatarWithEdit from './AvatarWithEdit'
|
||||
@@ -38,10 +37,7 @@ export default function AccountPage() {
|
||||
const { systemFeatures } = useGlobalPublicStore()
|
||||
const { data: appList } = useAppList({ page: 1, limit: 100, name: '' })
|
||||
const apps = appList?.data || []
|
||||
const queryClient = useQueryClient()
|
||||
const { data: userProfileResp } = useUserProfile()
|
||||
const userProfile = userProfileResp?.profile
|
||||
const mutateUserProfile = () => queryClient.invalidateQueries({ queryKey: commonQueryKeys.userProfile })
|
||||
const { mutateUserProfile, userProfile } = useAppContext()
|
||||
const { isEducationAccount } = useProviderContext()
|
||||
const { notify } = useContext(ToastContext)
|
||||
const [editNameModalVisible, setEditNameModalVisible] = useState(false)
|
||||
@@ -57,9 +53,6 @@ export default function AccountPage() {
|
||||
const [showConfirmPassword, setShowConfirmPassword] = useState(false)
|
||||
const [showUpdateEmail, setShowUpdateEmail] = useState(false)
|
||||
|
||||
if (!userProfile)
|
||||
return null
|
||||
|
||||
const handleEditName = () => {
|
||||
setEditNameModalVisible(true)
|
||||
setEditName(userProfile.name)
|
||||
@@ -145,7 +138,7 @@ export default function AccountPage() {
|
||||
imageUrl={icon_url}
|
||||
/>
|
||||
</div>
|
||||
<div className="mt-[3px] text-text-secondary system-sm-medium">{item.name}</div>
|
||||
<div className="system-sm-medium mt-[3px] text-text-secondary">{item.name}</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -153,12 +146,12 @@ export default function AccountPage() {
|
||||
return (
|
||||
<>
|
||||
<div className="pb-3 pt-2">
|
||||
<h4 className="text-text-primary title-2xl-semi-bold">{t('account.myAccount', { ns: 'common' })}</h4>
|
||||
<h4 className="title-2xl-semi-bold text-text-primary">{t('account.myAccount', { ns: 'common' })}</h4>
|
||||
</div>
|
||||
<div className="mb-8 flex items-center rounded-xl bg-gradient-to-r from-background-gradient-bg-fill-chat-bg-2 to-background-gradient-bg-fill-chat-bg-1 p-6">
|
||||
<AvatarWithEdit avatar={userProfile.avatar_url} name={userProfile.name} onSave={mutateUserProfile} size="3xl" />
|
||||
<AvatarWithEdit avatar={userProfile.avatar_url} name={userProfile.name} onSave={mutateUserProfile} size={64} />
|
||||
<div className="ml-4">
|
||||
<p className="text-text-primary system-xl-semibold">
|
||||
<p className="system-xl-semibold text-text-primary">
|
||||
{userProfile.name}
|
||||
{isEducationAccount && (
|
||||
<PremiumBadge size="s" color="blue" className="ml-1 !px-2">
|
||||
@@ -167,16 +160,16 @@ export default function AccountPage() {
|
||||
</PremiumBadge>
|
||||
)}
|
||||
</p>
|
||||
<p className="text-text-tertiary system-xs-regular">{userProfile.email}</p>
|
||||
<p className="system-xs-regular text-text-tertiary">{userProfile.email}</p>
|
||||
</div>
|
||||
</div>
|
||||
<div className="mb-8">
|
||||
<div className={titleClassName}>{t('account.name', { ns: 'common' })}</div>
|
||||
<div className="mt-2 flex w-full items-center justify-between gap-2">
|
||||
<div className="flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled system-sm-regular">
|
||||
<div className="system-sm-regular flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled ">
|
||||
<span className="pl-1">{userProfile.name}</span>
|
||||
</div>
|
||||
<div className="cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text system-sm-medium" onClick={handleEditName}>
|
||||
<div className="system-sm-medium cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text" onClick={handleEditName}>
|
||||
{t('operation.edit', { ns: 'common' })}
|
||||
</div>
|
||||
</div>
|
||||
@@ -184,11 +177,11 @@ export default function AccountPage() {
|
||||
<div className="mb-8">
|
||||
<div className={titleClassName}>{t('account.email', { ns: 'common' })}</div>
|
||||
<div className="mt-2 flex w-full items-center justify-between gap-2">
|
||||
<div className="flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled system-sm-regular">
|
||||
<div className="system-sm-regular flex-1 rounded-lg bg-components-input-bg-normal p-2 text-components-input-text-filled ">
|
||||
<span className="pl-1">{userProfile.email}</span>
|
||||
</div>
|
||||
{systemFeatures.enable_change_email && (
|
||||
<div className="cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text system-sm-medium" onClick={() => setShowUpdateEmail(true)}>
|
||||
<div className="system-sm-medium cursor-pointer rounded-lg bg-components-button-tertiary-bg px-3 py-2 text-components-button-tertiary-text" onClick={() => setShowUpdateEmail(true)}>
|
||||
{t('operation.change', { ns: 'common' })}
|
||||
</div>
|
||||
)}
|
||||
@@ -198,8 +191,8 @@ export default function AccountPage() {
|
||||
systemFeatures.enable_email_password_login && (
|
||||
<div className="mb-8 flex justify-between gap-2">
|
||||
<div>
|
||||
<div className="mb-1 text-text-secondary system-sm-semibold">{t('account.password', { ns: 'common' })}</div>
|
||||
<div className="mb-2 text-text-tertiary body-xs-regular">{t('account.passwordTip', { ns: 'common' })}</div>
|
||||
<div className="system-sm-semibold mb-1 text-text-secondary">{t('account.password', { ns: 'common' })}</div>
|
||||
<div className="body-xs-regular mb-2 text-text-tertiary">{t('account.passwordTip', { ns: 'common' })}</div>
|
||||
</div>
|
||||
<Button onClick={() => setEditPasswordModalVisible(true)}>{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}</Button>
|
||||
</div>
|
||||
@@ -226,7 +219,7 @@ export default function AccountPage() {
|
||||
onClose={() => setEditNameModalVisible(false)}
|
||||
className="!w-[420px] !p-6"
|
||||
>
|
||||
<div className="mb-6 text-text-primary title-2xl-semi-bold">{t('account.editName', { ns: 'common' })}</div>
|
||||
<div className="title-2xl-semi-bold mb-6 text-text-primary">{t('account.editName', { ns: 'common' })}</div>
|
||||
<div className={titleClassName}>{t('account.name', { ns: 'common' })}</div>
|
||||
<Input
|
||||
className="mt-2"
|
||||
@@ -256,7 +249,7 @@ export default function AccountPage() {
|
||||
}}
|
||||
className="!w-[420px] !p-6"
|
||||
>
|
||||
<div className="mb-6 text-text-primary title-2xl-semi-bold">{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}</div>
|
||||
<div className="title-2xl-semi-bold mb-6 text-text-primary">{userProfile.is_password_set ? t('account.resetPassword', { ns: 'common' }) : t('account.setPassword', { ns: 'common' })}</div>
|
||||
{userProfile.is_password_set && (
|
||||
<>
|
||||
<div className={titleClassName}>{t('account.currentPassword', { ns: 'common' })}</div>
|
||||
@@ -279,7 +272,7 @@ export default function AccountPage() {
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
<div className="mt-8 text-text-secondary system-sm-semibold">
|
||||
<div className="system-sm-semibold mt-8 text-text-secondary">
|
||||
{userProfile.is_password_set ? t('account.newPassword', { ns: 'common' }) : t('account.password', { ns: 'common' })}
|
||||
</div>
|
||||
<div className="relative mt-2">
|
||||
@@ -298,7 +291,7 @@ export default function AccountPage() {
|
||||
</Button>
|
||||
</div>
|
||||
</div>
|
||||
<div className="mt-8 text-text-secondary system-sm-semibold">{t('account.confirmPassword', { ns: 'common' })}</div>
|
||||
<div className="system-sm-semibold mt-8 text-text-secondary">{t('account.confirmPassword', { ns: 'common' })}</div>
|
||||
<div className="relative mt-2">
|
||||
<Input
|
||||
type={showConfirmPassword ? 'text' : 'password'}
|
||||
|
||||
@@ -7,11 +7,12 @@ import { useRouter } from 'next/navigation'
|
||||
import { Fragment } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { resetUser } from '@/app/components/base/amplitude/utils'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
import { LogOut01 } from '@/app/components/base/icons/src/vender/line/general'
|
||||
import PremiumBadge from '@/app/components/base/premium-badge'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useProviderContext } from '@/context/provider-context'
|
||||
import { useLogout, useUserProfile } from '@/service/use-common'
|
||||
import { useLogout } from '@/service/use-common'
|
||||
|
||||
export type IAppSelector = {
|
||||
isMobile: boolean
|
||||
@@ -20,15 +21,10 @@ export type IAppSelector = {
|
||||
export default function AppSelector() {
|
||||
const router = useRouter()
|
||||
const { t } = useTranslation()
|
||||
const { data: userProfileResp } = useUserProfile()
|
||||
const userProfile = userProfileResp?.profile
|
||||
const { userProfile } = useAppContext()
|
||||
const { isEducationAccount } = useProviderContext()
|
||||
|
||||
const { mutateAsync: logout } = useLogout()
|
||||
|
||||
if (!userProfile)
|
||||
return null
|
||||
|
||||
const handleLogout = async () => {
|
||||
await logout()
|
||||
|
||||
@@ -54,7 +50,7 @@ export default function AppSelector() {
|
||||
${open && 'bg-components-panel-bg-blur'}
|
||||
`}
|
||||
>
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} />
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size={32} />
|
||||
</MenuButton>
|
||||
</div>
|
||||
<Transition
|
||||
@@ -77,7 +73,7 @@ export default function AppSelector() {
|
||||
<div className="p-1">
|
||||
<div className="flex flex-nowrap items-center px-3 py-2">
|
||||
<div className="grow">
|
||||
<div className="break-all text-text-primary system-md-medium">
|
||||
<div className="system-md-medium break-all text-text-primary">
|
||||
{userProfile.name}
|
||||
{isEducationAccount && (
|
||||
<PremiumBadge size="s" color="blue" className="ml-1 !px-2">
|
||||
@@ -86,9 +82,9 @@ export default function AppSelector() {
|
||||
</PremiumBadge>
|
||||
)}
|
||||
</div>
|
||||
<div className="break-all text-text-tertiary system-xs-regular">{userProfile.email}</div>
|
||||
<div className="system-xs-regular break-all text-text-tertiary">{userProfile.email}</div>
|
||||
</div>
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} />
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size={32} />
|
||||
</div>
|
||||
</div>
|
||||
</MenuItem>
|
||||
|
||||
@@ -30,14 +30,14 @@ export default function CheckEmail(props: DeleteAccountProps) {
|
||||
|
||||
return (
|
||||
<>
|
||||
<div className="py-1 text-text-destructive body-md-medium">
|
||||
<div className="body-md-medium py-1 text-text-destructive">
|
||||
{t('account.deleteTip', { ns: 'common' })}
|
||||
</div>
|
||||
<div className="pb-2 pt-1 text-text-secondary body-md-regular">
|
||||
<div className="body-md-regular pb-2 pt-1 text-text-secondary">
|
||||
{t('account.deletePrivacyLinkTip', { ns: 'common' })}
|
||||
<Link href="https://dify.ai/privacy" className="text-text-accent">{t('account.deletePrivacyLink', { ns: 'common' })}</Link>
|
||||
</div>
|
||||
<label className="mb-1 mt-3 flex h-6 items-center text-text-secondary system-sm-semibold">{t('account.deleteLabel', { ns: 'common' })}</label>
|
||||
<label className="system-sm-semibold mb-1 mt-3 flex h-6 items-center text-text-secondary">{t('account.deleteLabel', { ns: 'common' })}</label>
|
||||
<Input
|
||||
placeholder={t('account.deletePlaceholder', { ns: 'common' }) as string}
|
||||
onChange={(e) => {
|
||||
|
||||
@@ -54,7 +54,7 @@ export default function FeedBack(props: DeleteAccountProps) {
|
||||
className="max-w-[480px]"
|
||||
footer={false}
|
||||
>
|
||||
<label className="mb-1 mt-3 flex items-center text-text-secondary system-sm-semibold">{t('account.feedbackLabel', { ns: 'common' })}</label>
|
||||
<label className="system-sm-semibold mb-1 mt-3 flex items-center text-text-secondary">{t('account.feedbackLabel', { ns: 'common' })}</label>
|
||||
<Textarea
|
||||
rows={6}
|
||||
value={userFeedback}
|
||||
|
||||
@@ -36,14 +36,14 @@ export default function VerifyEmail(props: DeleteAccountProps) {
|
||||
}, [emailToken, verificationCode, confirmDeleteAccount, props])
|
||||
return (
|
||||
<>
|
||||
<div className="pt-1 text-text-destructive body-md-medium">
|
||||
<div className="body-md-medium pt-1 text-text-destructive">
|
||||
{t('account.deleteTip', { ns: 'common' })}
|
||||
</div>
|
||||
<div className="pb-2 pt-1 text-text-secondary body-md-regular">
|
||||
<div className="body-md-regular pb-2 pt-1 text-text-secondary">
|
||||
{t('account.deletePrivacyLinkTip', { ns: 'common' })}
|
||||
<Link href="https://dify.ai/privacy" className="text-text-accent">{t('account.deletePrivacyLink', { ns: 'common' })}</Link>
|
||||
</div>
|
||||
<label className="mb-1 mt-3 flex h-6 items-center text-text-secondary system-sm-semibold">{t('account.verificationLabel', { ns: 'common' })}</label>
|
||||
<label className="system-sm-semibold mb-1 mt-3 flex h-6 items-center text-text-secondary">{t('account.verificationLabel', { ns: 'common' })}</label>
|
||||
<Input
|
||||
minLength={6}
|
||||
maxLength={6}
|
||||
|
||||
@@ -32,10 +32,10 @@ const Header = () => {
|
||||
: <DifyLogo />}
|
||||
</div>
|
||||
<div className="h-4 w-[1px] origin-center rotate-[11.31deg] bg-divider-regular" />
|
||||
<p className="relative mt-[-2px] text-text-primary title-3xl-semi-bold">{t('account.account', { ns: 'common' })}</p>
|
||||
<p className="title-3xl-semi-bold relative mt-[-2px] text-text-primary">{t('account.account', { ns: 'common' })}</p>
|
||||
</div>
|
||||
<div className="flex shrink-0 items-center gap-3">
|
||||
<Button className="gap-2 px-3 py-2 system-sm-medium" onClick={goToStudio}>
|
||||
<Button className="system-sm-medium gap-2 px-3 py-2" onClick={goToStudio}>
|
||||
<RiRobot2Line className="h-4 w-4" />
|
||||
<p>{t('account.studio', { ns: 'common' })}</p>
|
||||
<RiArrowRightUpLine className="h-4 w-4" />
|
||||
|
||||
@@ -11,13 +11,14 @@ import { useRouter, useSearchParams } from 'next/navigation'
|
||||
import * as React from 'react'
|
||||
import { useEffect, useRef } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { Avatar } from '@/app/components/base/avatar'
|
||||
import Avatar from '@/app/components/base/avatar'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { useLanguage } from '@/app/components/header/account-setting/model-provider-page/hooks'
|
||||
import { setPostLoginRedirect } from '@/app/signin/utils/post-login-redirect'
|
||||
import { useIsLogin, useUserProfile } from '@/service/use-common'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useIsLogin } from '@/service/use-common'
|
||||
import { useAuthorizeOAuthApp, useOAuthAppInfo } from '@/service/use-oauth'
|
||||
|
||||
function buildReturnUrl(pathname: string, search: string) {
|
||||
@@ -61,8 +62,7 @@ export default function OAuthAuthorize() {
|
||||
const searchParams = useSearchParams()
|
||||
const client_id = decodeURIComponent(searchParams.get('client_id') || '')
|
||||
const redirect_uri = decodeURIComponent(searchParams.get('redirect_uri') || '')
|
||||
const { data: userProfileResp } = useUserProfile()
|
||||
const userProfile = userProfileResp?.profile
|
||||
const { userProfile } = useAppContext()
|
||||
const { data: authAppInfo, isLoading: isOAuthLoading, isError } = useOAuthAppInfo(client_id, redirect_uri)
|
||||
const { mutateAsync: authorize, isPending: authorizing } = useAuthorizeOAuthApp()
|
||||
const hasNotifiedRef = useRef(false)
|
||||
@@ -138,7 +138,7 @@ export default function OAuthAuthorize() {
|
||||
{isLoggedIn && userProfile && (
|
||||
<div className="flex items-center justify-between rounded-xl bg-background-section-burn-inverted p-3">
|
||||
<div className="flex items-center gap-2.5">
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size="lg" />
|
||||
<Avatar avatar={userProfile.avatar_url} name={userProfile.name} size={36} />
|
||||
<div>
|
||||
<div className="system-md-semi-bold text-text-secondary">{userProfile.name}</div>
|
||||
<div className="text-text-tertiary system-xs-regular">{userProfile.email}</div>
|
||||
|
||||
@@ -31,7 +31,7 @@ const EditItem: FC<Props> = ({
|
||||
{avatar}
|
||||
</div>
|
||||
<div className="grow">
|
||||
<div className="mb-1 text-text-primary system-xs-semibold">{name}</div>
|
||||
<div className="system-xs-semibold mb-1 text-text-primary">{name}</div>
|
||||
<Textarea
|
||||
value={content}
|
||||
onChange={(e: React.ChangeEvent<HTMLTextAreaElement>) => onChange(e.target.value)}
|
||||
|
||||
@@ -99,7 +99,7 @@ const AddAnnotationModal: FC<Props> = ({
|
||||
<AnnotationFull />
|
||||
</div>
|
||||
)}
|
||||
<div className="flex h-16 items-center justify-between rounded-bl-xl rounded-br-xl border-t border-divider-subtle bg-background-section-burn px-4 text-text-tertiary system-sm-medium">
|
||||
<div className="system-sm-medium flex h-16 items-center justify-between rounded-bl-xl rounded-br-xl border-t border-divider-subtle bg-background-section-burn px-4 text-text-tertiary">
|
||||
<div
|
||||
className="flex items-center space-x-2"
|
||||
>
|
||||
|
||||
@@ -33,7 +33,7 @@ const CSVDownload: FC = () => {
|
||||
|
||||
return (
|
||||
<div className="mt-6">
|
||||
<div className="text-text-primary system-sm-medium">{t('generation.csvStructureTitle', { ns: 'share' })}</div>
|
||||
<div className="system-sm-medium text-text-primary">{t('generation.csvStructureTitle', { ns: 'share' })}</div>
|
||||
<div className="mt-2 max-h-[500px] overflow-auto">
|
||||
<table className="w-full table-fixed border-separate border-spacing-0 rounded-lg border border-divider-regular text-xs">
|
||||
<thead className="text-text-tertiary">
|
||||
@@ -77,7 +77,7 @@ const CSVDownload: FC = () => {
|
||||
bom={true}
|
||||
data={getTemplate()}
|
||||
>
|
||||
<div className="flex h-[18px] items-center space-x-1 text-text-accent system-xs-medium">
|
||||
<div className="system-xs-medium flex h-[18px] items-center space-x-1 text-text-accent">
|
||||
<DownloadIcon className="mr-1 h-3 w-3" />
|
||||
{t('batchModal.template', { ns: 'appAnnotation' })}
|
||||
</div>
|
||||
|
||||
@@ -94,7 +94,7 @@ const CSVUploader: FC<Props> = ({
|
||||
/>
|
||||
<div ref={dropRef}>
|
||||
{!file && (
|
||||
<div className={cn('flex h-20 items-center rounded-xl border border-dashed border-components-dropzone-border bg-components-dropzone-bg system-sm-regular', dragging && 'border border-components-dropzone-border-accent bg-components-dropzone-bg-accent')}>
|
||||
<div className={cn('system-sm-regular flex h-20 items-center rounded-xl border border-dashed border-components-dropzone-border bg-components-dropzone-bg', dragging && 'border border-components-dropzone-border-accent bg-components-dropzone-bg-accent')}>
|
||||
<div className="flex w-full items-center justify-center space-x-2">
|
||||
<CSVIcon className="shrink-0" />
|
||||
<div className="text-text-tertiary">
|
||||
|
||||
@@ -52,7 +52,7 @@ const BatchModal: FC<IBatchModalProps> = ({
|
||||
const res = await checkAnnotationBatchImportProgress({ jobID, appId })
|
||||
setImportStatus(res.job_status)
|
||||
if (res.job_status === ProcessStatus.WAITING || res.job_status === ProcessStatus.PROCESSING)
|
||||
setTimeout(checkProcess, 2500, res.job_id)
|
||||
setTimeout(() => checkProcess(res.job_id), 2500)
|
||||
if (res.job_status === ProcessStatus.ERROR)
|
||||
notify({ type: 'error', message: `${t('batchModal.runError', { ns: 'appAnnotation' })}` })
|
||||
if (res.job_status === ProcessStatus.COMPLETED) {
|
||||
@@ -90,7 +90,7 @@ const BatchModal: FC<IBatchModalProps> = ({
|
||||
|
||||
return (
|
||||
<Modal isShow={isShow} onClose={noop} className="!max-w-[520px] !rounded-xl px-8 py-6">
|
||||
<div className="relative pb-1 text-text-primary system-xl-medium">{t('batchModal.title', { ns: 'appAnnotation' })}</div>
|
||||
<div className="system-xl-medium relative pb-1 text-text-primary">{t('batchModal.title', { ns: 'appAnnotation' })}</div>
|
||||
<div className="absolute right-4 top-4 cursor-pointer p-2" onClick={onCancel}>
|
||||
<RiCloseLine className="h-4 w-4 text-text-tertiary" />
|
||||
</div>
|
||||
@@ -107,7 +107,7 @@ const BatchModal: FC<IBatchModalProps> = ({
|
||||
)}
|
||||
|
||||
<div className="mt-[28px] flex justify-end pt-6">
|
||||
<Button className="mr-2 text-text-tertiary system-sm-medium" onClick={onCancel}>
|
||||
<Button className="system-sm-medium mr-2 text-text-tertiary" onClick={onCancel}>
|
||||
{t('batchModal.cancel', { ns: 'appAnnotation' })}
|
||||
</Button>
|
||||
<Button
|
||||
|
||||
@@ -21,7 +21,7 @@ type Props = {
|
||||
}
|
||||
|
||||
export const EditTitle: FC<{ className?: string, title: string }> = ({ className, title }) => (
|
||||
<div className={cn(className, 'flex h-[18px] items-center text-text-tertiary system-xs-medium')}>
|
||||
<div className={cn(className, 'system-xs-medium flex h-[18px] items-center text-text-tertiary')}>
|
||||
<RiEditFill className="mr-1 h-3.5 w-3.5" />
|
||||
<div>{title}</div>
|
||||
<div
|
||||
@@ -75,21 +75,21 @@ const EditItem: FC<Props> = ({
|
||||
{avatar}
|
||||
</div>
|
||||
<div className="grow">
|
||||
<div className="mb-1 text-text-primary system-xs-semibold">{name}</div>
|
||||
<div className="text-text-primary system-sm-regular">{content}</div>
|
||||
<div className="system-xs-semibold mb-1 text-text-primary">{name}</div>
|
||||
<div className="system-sm-regular text-text-primary">{content}</div>
|
||||
{!isEdit
|
||||
? (
|
||||
<div>
|
||||
{showNewContent && (
|
||||
<div className="mt-3">
|
||||
<EditTitle title={editTitle} />
|
||||
<div className="mt-1 text-text-primary system-sm-regular">{newContent}</div>
|
||||
<div className="system-sm-regular mt-1 text-text-primary">{newContent}</div>
|
||||
</div>
|
||||
)}
|
||||
<div className="mt-2 flex items-center">
|
||||
{!readonly && (
|
||||
<div
|
||||
className="flex cursor-pointer items-center space-x-1 text-text-accent system-xs-medium"
|
||||
className="system-xs-medium flex cursor-pointer items-center space-x-1 text-text-accent"
|
||||
onClick={() => {
|
||||
setIsEdit(true)
|
||||
}}
|
||||
@@ -100,7 +100,7 @@ const EditItem: FC<Props> = ({
|
||||
)}
|
||||
|
||||
{showNewContent && (
|
||||
<div className="ml-2 flex items-center text-text-tertiary system-xs-medium">
|
||||
<div className="system-xs-medium ml-2 flex items-center text-text-tertiary">
|
||||
<div className="mr-2">·</div>
|
||||
<div
|
||||
className="flex cursor-pointer items-center space-x-1"
|
||||
|
||||
@@ -136,7 +136,7 @@ const EditAnnotationModal: FC<Props> = ({
|
||||
{
|
||||
annotationId
|
||||
? (
|
||||
<div className="flex h-16 items-center justify-between rounded-bl-xl rounded-br-xl border-t border-divider-subtle bg-background-section-burn px-4 text-text-tertiary system-sm-medium">
|
||||
<div className="system-sm-medium flex h-16 items-center justify-between rounded-bl-xl rounded-br-xl border-t border-divider-subtle bg-background-section-burn px-4 text-text-tertiary">
|
||||
<div
|
||||
className="flex cursor-pointer items-center space-x-2 pl-3"
|
||||
onClick={() => setShowModal(true)}
|
||||
|
||||
@@ -17,11 +17,11 @@ const EmptyElement: FC = () => {
|
||||
return (
|
||||
<div className="flex h-full items-center justify-center">
|
||||
<div className="box-border h-fit w-[560px] rounded-2xl bg-background-section-burn px-5 py-4">
|
||||
<span className="text-text-secondary system-md-semibold">
|
||||
<span className="system-md-semibold text-text-secondary">
|
||||
{t('noData.title', { ns: 'appAnnotation' })}
|
||||
<ThreeDotsIcon className="relative -left-1.5 -top-3 inline" />
|
||||
</span>
|
||||
<div className="mt-2 text-text-tertiary system-sm-regular">
|
||||
<div className="system-sm-regular mt-2 text-text-tertiary">
|
||||
{t('noData.description', { ns: 'appAnnotation' })}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
@@ -103,12 +103,12 @@ const HeaderOptions: FC<Props> = ({
|
||||
}}
|
||||
>
|
||||
<FilePlus02 className="h-4 w-4 text-text-tertiary" />
|
||||
<span className="grow text-left text-text-secondary system-sm-regular">{t('table.header.bulkImport', { ns: 'appAnnotation' })}</span>
|
||||
<span className="system-sm-regular grow text-left text-text-secondary">{t('table.header.bulkImport', { ns: 'appAnnotation' })}</span>
|
||||
</button>
|
||||
<Menu as="div" className="relative h-full w-full">
|
||||
<MenuButton className="mx-1 flex h-9 w-[calc(100%_-_8px)] cursor-pointer items-center space-x-2 rounded-lg px-3 py-2 hover:bg-components-panel-on-panel-item-bg-hover disabled:opacity-50">
|
||||
<FileDownload02 className="h-4 w-4 text-text-tertiary" />
|
||||
<span className="grow text-left text-text-secondary system-sm-regular">{t('table.header.bulkExport', { ns: 'appAnnotation' })}</span>
|
||||
<span className="system-sm-regular grow text-left text-text-secondary">{t('table.header.bulkExport', { ns: 'appAnnotation' })}</span>
|
||||
<ChevronRight className="h-[14px] w-[14px] shrink-0 text-text-tertiary" />
|
||||
</MenuButton>
|
||||
<Transition
|
||||
@@ -135,11 +135,11 @@ const HeaderOptions: FC<Props> = ({
|
||||
]}
|
||||
>
|
||||
<button type="button" disabled={annotationUnavailable} className="mx-1 flex h-9 w-[calc(100%_-_8px)] cursor-pointer items-center space-x-2 rounded-lg px-3 py-2 hover:bg-components-panel-on-panel-item-bg-hover disabled:opacity-50">
|
||||
<span className="grow text-left text-text-secondary system-sm-regular">CSV</span>
|
||||
<span className="system-sm-regular grow text-left text-text-secondary">CSV</span>
|
||||
</button>
|
||||
</CSVDownloader>
|
||||
<button type="button" disabled={annotationUnavailable} className={cn('mx-1 flex h-9 w-[calc(100%_-_8px)] cursor-pointer items-center space-x-2 rounded-lg px-3 py-2 hover:bg-components-panel-on-panel-item-bg-hover disabled:opacity-50', '!border-0')} onClick={JSONLOutput}>
|
||||
<span className="grow text-left text-text-secondary system-sm-regular">JSONL</span>
|
||||
<span className="system-sm-regular grow text-left text-text-secondary">JSONL</span>
|
||||
</button>
|
||||
</MenuItems>
|
||||
</Transition>
|
||||
@@ -150,7 +150,7 @@ const HeaderOptions: FC<Props> = ({
|
||||
className="mx-1 flex h-9 w-[calc(100%_-_8px)] cursor-pointer items-center space-x-2 rounded-lg px-3 py-2 text-red-600 hover:bg-red-50 disabled:opacity-50"
|
||||
>
|
||||
<RiDeleteBinLine className="h-4 w-4" />
|
||||
<span className="grow text-left system-sm-regular">
|
||||
<span className="system-sm-regular grow text-left">
|
||||
{t('table.header.clearAll', { ns: 'appAnnotation' })}
|
||||
</span>
|
||||
</button>
|
||||
|
||||
@@ -58,7 +58,7 @@ const List: FC<Props> = ({
|
||||
<>
|
||||
<div className="relative mt-2 grow overflow-x-auto">
|
||||
<table className={cn('w-full min-w-[440px] border-collapse border-0')}>
|
||||
<thead className="text-text-tertiary system-xs-medium-uppercase">
|
||||
<thead className="system-xs-medium-uppercase text-text-tertiary">
|
||||
<tr>
|
||||
<td className="w-12 whitespace-nowrap rounded-l-lg bg-background-section-burn px-2">
|
||||
<Checkbox
|
||||
@@ -75,7 +75,7 @@ const List: FC<Props> = ({
|
||||
<td className="w-[96px] whitespace-nowrap rounded-r-lg bg-background-section-burn py-1.5 pl-3">{t('table.header.actions', { ns: 'appAnnotation' })}</td>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody className="text-text-secondary system-sm-regular">
|
||||
<tbody className="system-sm-regular text-text-secondary">
|
||||
{list.map(item => (
|
||||
<tr
|
||||
key={item.id}
|
||||
|
||||
@@ -11,7 +11,7 @@ const HitHistoryNoData: FC = () => {
|
||||
<div className="inline-block rounded-lg border border-divider-subtle p-3">
|
||||
<ClockFastForward className="h-5 w-5 text-text-tertiary" />
|
||||
</div>
|
||||
<div className="text-text-tertiary system-sm-regular">{t('viewModal.noHitHistory', { ns: 'appAnnotation' })}</div>
|
||||
<div className="system-sm-regular text-text-tertiary">{t('viewModal.noHitHistory', { ns: 'appAnnotation' })}</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -137,7 +137,7 @@ const ViewAnnotationModal: FC<Props> = ({
|
||||
: (
|
||||
<div>
|
||||
<table className={cn('w-full min-w-[440px] border-collapse border-0')}>
|
||||
<thead className="text-text-tertiary system-xs-medium-uppercase">
|
||||
<thead className="system-xs-medium-uppercase text-text-tertiary">
|
||||
<tr>
|
||||
<td className="w-5 whitespace-nowrap rounded-l-lg bg-background-section-burn pl-2 pr-1">{t('hitHistoryTable.query', { ns: 'appAnnotation' })}</td>
|
||||
<td className="whitespace-nowrap bg-background-section-burn py-1.5 pl-3">{t('hitHistoryTable.match', { ns: 'appAnnotation' })}</td>
|
||||
@@ -147,7 +147,7 @@ const ViewAnnotationModal: FC<Props> = ({
|
||||
<td className="w-[160px] whitespace-nowrap rounded-r-lg bg-background-section-burn py-1.5 pl-3">{t('hitHistoryTable.time', { ns: 'appAnnotation' })}</td>
|
||||
</tr>
|
||||
</thead>
|
||||
<tbody className="text-text-secondary system-sm-regular">
|
||||
<tbody className="system-sm-regular text-text-secondary">
|
||||
{hitHistoryList.map(item => (
|
||||
<tr
|
||||
key={item.id}
|
||||
@@ -226,7 +226,7 @@ const ViewAnnotationModal: FC<Props> = ({
|
||||
)}
|
||||
foot={id
|
||||
? (
|
||||
<div className="flex h-16 items-center justify-between rounded-bl-xl rounded-br-xl border-t border-divider-subtle bg-background-section-burn px-4 text-text-tertiary system-sm-medium">
|
||||
<div className="system-sm-medium flex h-16 items-center justify-between rounded-bl-xl rounded-br-xl border-t border-divider-subtle bg-background-section-burn px-4 text-text-tertiary">
|
||||
<div
|
||||
className="flex cursor-pointer items-center space-x-2 pl-3"
|
||||
onClick={() => setShowModal(true)}
|
||||
|
||||
@@ -10,7 +10,7 @@ import { SubjectType } from '@/models/access-control'
|
||||
import { useSearchForWhiteListCandidates } from '@/service/access-control'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import useAccessControlStore from '../../../../context/access-control-store'
|
||||
import { Avatar } from '../../base/avatar'
|
||||
import Avatar from '../../base/avatar'
|
||||
import Button from '../../base/button'
|
||||
import Checkbox from '../../base/checkbox'
|
||||
import Input from '../../base/input'
|
||||
@@ -24,7 +24,7 @@ export default function AddMemberOrGroupDialog() {
|
||||
const selectedGroupsForBreadcrumb = useAccessControlStore(s => s.selectedGroupsForBreadcrumb)
|
||||
const debouncedKeyword = useDebounce(keyword, { wait: 500 })
|
||||
|
||||
const lastAvailableGroup = selectedGroupsForBreadcrumb.at(-1)
|
||||
const lastAvailableGroup = selectedGroupsForBreadcrumb[selectedGroupsForBreadcrumb.length - 1]
|
||||
const { isLoading, isFetchingNextPage, fetchNextPage, data } = useSearchForWhiteListCandidates({ keyword: debouncedKeyword, groupId: lastAvailableGroup?.id, resultsPerPage: 10 }, open)
|
||||
const handleKeywordChange = (e: React.ChangeEvent<HTMLInputElement>) => {
|
||||
setKeyword(e.target.value)
|
||||
@@ -76,7 +76,7 @@ export default function AddMemberOrGroupDialog() {
|
||||
)
|
||||
: (
|
||||
<div className="flex h-7 items-center justify-center px-2 py-0.5">
|
||||
<span className="text-text-tertiary system-xs-regular">{t('accessControlDialog.operateGroupAndMember.noResult', { ns: 'app' })}</span>
|
||||
<span className="system-xs-regular text-text-tertiary">{t('accessControlDialog.operateGroupAndMember.noResult', { ns: 'app' })}</span>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
@@ -115,10 +115,10 @@ function SelectedGroupsBreadCrumb() {
|
||||
}, [setSelectedGroupsForBreadcrumb])
|
||||
return (
|
||||
<div className="flex h-7 items-center gap-x-0.5 px-2 py-0.5">
|
||||
<span className={cn('text-text-tertiary system-xs-regular', selectedGroupsForBreadcrumb.length > 0 && 'cursor-pointer text-text-accent')} onClick={handleReset}>{t('accessControlDialog.operateGroupAndMember.allMembers', { ns: 'app' })}</span>
|
||||
<span className={cn('system-xs-regular text-text-tertiary', selectedGroupsForBreadcrumb.length > 0 && 'cursor-pointer text-text-accent')} onClick={handleReset}>{t('accessControlDialog.operateGroupAndMember.allMembers', { ns: 'app' })}</span>
|
||||
{selectedGroupsForBreadcrumb.map((group, index) => {
|
||||
return (
|
||||
<div key={index} className="flex items-center gap-x-0.5 text-text-tertiary system-xs-regular">
|
||||
<div key={index} className="system-xs-regular flex items-center gap-x-0.5 text-text-tertiary">
|
||||
<span>/</span>
|
||||
<span className={index === selectedGroupsForBreadcrumb.length - 1 ? '' : 'cursor-pointer text-text-accent'} onClick={() => handleBreadCrumbClick(index)}>{group.name}</span>
|
||||
</div>
|
||||
@@ -161,8 +161,8 @@ function GroupItem({ group }: GroupItemProps) {
|
||||
<RiOrganizationChart className="h-[14px] w-[14px] text-components-avatar-shape-fill-stop-0" />
|
||||
</div>
|
||||
</div>
|
||||
<p className="mr-1 text-text-secondary system-sm-medium">{group.name}</p>
|
||||
<p className="text-text-tertiary system-xs-regular">{group.groupSize}</p>
|
||||
<p className="system-sm-medium mr-1 text-text-secondary">{group.name}</p>
|
||||
<p className="system-xs-regular text-text-tertiary">{group.groupSize}</p>
|
||||
</div>
|
||||
<Button
|
||||
size="small"
|
||||
@@ -203,19 +203,19 @@ function MemberItem({ member }: MemberItemProps) {
|
||||
<div className="flex grow items-center">
|
||||
<div className="mr-2 h-5 w-5 overflow-hidden rounded-full bg-components-icon-bg-blue-solid">
|
||||
<div className="bg-access-app-icon-mask-bg flex h-full w-full items-center justify-center">
|
||||
<Avatar size="xxs" avatar={null} name={member.name} />
|
||||
<Avatar className="h-[14px] w-[14px]" textClassName="text-[12px]" avatar={null} name={member.name} />
|
||||
</div>
|
||||
</div>
|
||||
<p className="mr-1 text-text-secondary system-sm-medium">{member.name}</p>
|
||||
<p className="system-sm-medium mr-1 text-text-secondary">{member.name}</p>
|
||||
{currentUser.email === member.email && (
|
||||
<p className="text-text-tertiary system-xs-regular">
|
||||
<p className="system-xs-regular text-text-tertiary">
|
||||
(
|
||||
{t('you', { ns: 'common' })}
|
||||
)
|
||||
</p>
|
||||
)}
|
||||
</div>
|
||||
<p className="text-text-quaternary system-xs-regular">{member.email}</p>
|
||||
<p className="system-xs-regular text-text-quaternary">{member.email}</p>
|
||||
</BaseItem>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -68,18 +68,18 @@ export default function AccessControl(props: AccessControlProps) {
|
||||
<AccessControlDialog show onClose={onClose}>
|
||||
<div className="flex flex-col gap-y-3">
|
||||
<div className="pb-3 pl-6 pr-14 pt-6">
|
||||
<DialogTitle className="text-text-primary title-2xl-semi-bold">{t('accessControlDialog.title', { ns: 'app' })}</DialogTitle>
|
||||
<DialogDescription className="mt-1 text-text-tertiary system-xs-regular">{t('accessControlDialog.description', { ns: 'app' })}</DialogDescription>
|
||||
<DialogTitle className="title-2xl-semi-bold text-text-primary">{t('accessControlDialog.title', { ns: 'app' })}</DialogTitle>
|
||||
<DialogDescription className="system-xs-regular mt-1 text-text-tertiary">{t('accessControlDialog.description', { ns: 'app' })}</DialogDescription>
|
||||
</div>
|
||||
<div className="flex flex-col gap-y-1 px-6 pb-3">
|
||||
<div className="leading-6">
|
||||
<p className="text-text-tertiary system-sm-medium">{t('accessControlDialog.accessLabel', { ns: 'app' })}</p>
|
||||
<p className="system-sm-medium text-text-tertiary">{t('accessControlDialog.accessLabel', { ns: 'app' })}</p>
|
||||
</div>
|
||||
<AccessControlItem type={AccessMode.ORGANIZATION}>
|
||||
<div className="flex items-center p-3">
|
||||
<div className="flex grow items-center gap-x-2">
|
||||
<RiBuildingLine className="h-4 w-4 text-text-primary" />
|
||||
<p className="text-text-primary system-sm-medium">{t('accessControlDialog.accessItems.organization', { ns: 'app' })}</p>
|
||||
<p className="system-sm-medium text-text-primary">{t('accessControlDialog.accessItems.organization', { ns: 'app' })}</p>
|
||||
</div>
|
||||
</div>
|
||||
</AccessControlItem>
|
||||
@@ -90,7 +90,7 @@ export default function AccessControl(props: AccessControlProps) {
|
||||
<div className="flex items-center p-3">
|
||||
<div className="flex grow items-center gap-x-2">
|
||||
<RiVerifiedBadgeLine className="h-4 w-4 text-text-primary" />
|
||||
<p className="text-text-primary system-sm-medium">{t('accessControlDialog.accessItems.external', { ns: 'app' })}</p>
|
||||
<p className="system-sm-medium text-text-primary">{t('accessControlDialog.accessItems.external', { ns: 'app' })}</p>
|
||||
</div>
|
||||
{!hideTip && <WebAppSSONotEnabledTip />}
|
||||
</div>
|
||||
@@ -98,7 +98,7 @@ export default function AccessControl(props: AccessControlProps) {
|
||||
<AccessControlItem type={AccessMode.PUBLIC}>
|
||||
<div className="flex items-center gap-x-2 p-3">
|
||||
<RiGlobalLine className="h-4 w-4 text-text-primary" />
|
||||
<p className="text-text-primary system-sm-medium">{t('accessControlDialog.accessItems.anyone', { ns: 'app' })}</p>
|
||||
<p className="system-sm-medium text-text-primary">{t('accessControlDialog.accessItems.anyone', { ns: 'app' })}</p>
|
||||
</div>
|
||||
</AccessControlItem>
|
||||
</div>
|
||||
|
||||
@@ -6,7 +6,7 @@ import { useTranslation } from 'react-i18next'
|
||||
import { AccessMode } from '@/models/access-control'
|
||||
import { useAppWhiteListSubjects } from '@/service/access-control'
|
||||
import useAccessControlStore from '../../../../context/access-control-store'
|
||||
import { Avatar } from '../../base/avatar'
|
||||
import Avatar from '../../base/avatar'
|
||||
import Loading from '../../base/loading'
|
||||
import Tooltip from '../../base/tooltip'
|
||||
import AddMemberOrGroupDialog from './add-member-or-group-pop'
|
||||
@@ -29,7 +29,7 @@ export default function SpecificGroupsOrMembers() {
|
||||
<div className="flex items-center p-3">
|
||||
<div className="flex grow items-center gap-x-2">
|
||||
<RiLockLine className="h-4 w-4 text-text-primary" />
|
||||
<p className="text-text-primary system-sm-medium">{t('accessControlDialog.accessItems.specific', { ns: 'app' })}</p>
|
||||
<p className="system-sm-medium text-text-primary">{t('accessControlDialog.accessItems.specific', { ns: 'app' })}</p>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
@@ -40,7 +40,7 @@ export default function SpecificGroupsOrMembers() {
|
||||
<div className="flex items-center gap-x-1 p-3">
|
||||
<div className="flex grow items-center gap-x-1">
|
||||
<RiLockLine className="h-4 w-4 text-text-primary" />
|
||||
<p className="text-text-primary system-sm-medium">{t('accessControlDialog.accessItems.specific', { ns: 'app' })}</p>
|
||||
<p className="system-sm-medium text-text-primary">{t('accessControlDialog.accessItems.specific', { ns: 'app' })}</p>
|
||||
</div>
|
||||
<div className="flex items-center gap-x-1">
|
||||
<AddMemberOrGroupDialog />
|
||||
@@ -60,14 +60,14 @@ function RenderGroupsAndMembers() {
|
||||
const specificGroups = useAccessControlStore(s => s.specificGroups)
|
||||
const specificMembers = useAccessControlStore(s => s.specificMembers)
|
||||
if (specificGroups.length <= 0 && specificMembers.length <= 0)
|
||||
return <div className="px-2 pb-1.5 pt-5"><p className="text-center text-text-tertiary system-xs-regular">{t('accessControlDialog.noGroupsOrMembers', { ns: 'app' })}</p></div>
|
||||
return <div className="px-2 pb-1.5 pt-5"><p className="system-xs-regular text-center text-text-tertiary">{t('accessControlDialog.noGroupsOrMembers', { ns: 'app' })}</p></div>
|
||||
return (
|
||||
<>
|
||||
<p className="sticky top-0 text-text-tertiary system-2xs-medium-uppercase">{t('accessControlDialog.groups', { ns: 'app', count: specificGroups.length ?? 0 })}</p>
|
||||
<p className="system-2xs-medium-uppercase sticky top-0 text-text-tertiary">{t('accessControlDialog.groups', { ns: 'app', count: specificGroups.length ?? 0 })}</p>
|
||||
<div className="flex flex-row flex-wrap gap-1">
|
||||
{specificGroups.map((group, index) => <GroupItem key={index} group={group} />)}
|
||||
</div>
|
||||
<p className="sticky top-0 text-text-tertiary system-2xs-medium-uppercase">{t('accessControlDialog.members', { ns: 'app', count: specificMembers.length ?? 0 })}</p>
|
||||
<p className="system-2xs-medium-uppercase sticky top-0 text-text-tertiary">{t('accessControlDialog.members', { ns: 'app', count: specificMembers.length ?? 0 })}</p>
|
||||
<div className="flex flex-row flex-wrap gap-1">
|
||||
{specificMembers.map((member, index) => <MemberItem key={index} member={member} />)}
|
||||
</div>
|
||||
@@ -89,8 +89,8 @@ function GroupItem({ group }: GroupItemProps) {
|
||||
icon={<RiOrganizationChart className="h-[14px] w-[14px] text-components-avatar-shape-fill-stop-0" />}
|
||||
onRemove={handleRemoveGroup}
|
||||
>
|
||||
<p className="text-text-primary system-xs-regular">{group.name}</p>
|
||||
<p className="text-text-tertiary system-xs-regular">{group.groupSize}</p>
|
||||
<p className="system-xs-regular text-text-primary">{group.name}</p>
|
||||
<p className="system-xs-regular text-text-tertiary">{group.groupSize}</p>
|
||||
</BaseItem>
|
||||
)
|
||||
}
|
||||
@@ -106,10 +106,10 @@ function MemberItem({ member }: MemberItemProps) {
|
||||
}, [member, setSpecificMembers, specificMembers])
|
||||
return (
|
||||
<BaseItem
|
||||
icon={<Avatar size="xxs" avatar={null} name={member.name} />}
|
||||
icon={<Avatar className="h-[14px] w-[14px]" textClassName="text-[12px]" avatar={null} name={member.name} />}
|
||||
onRemove={handleRemoveMember}
|
||||
>
|
||||
<p className="text-text-primary system-xs-regular">{member.name}</p>
|
||||
<p className="system-xs-regular text-text-primary">{member.name}</p>
|
||||
</BaseItem>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -24,7 +24,7 @@ const SuggestedAction = ({ icon, link, disabled, children, className, onClick, .
|
||||
{...props}
|
||||
>
|
||||
<div className="relative h-4 w-4">{icon}</div>
|
||||
<div className="shrink grow basis-0 system-sm-medium">{children}</div>
|
||||
<div className="system-sm-medium shrink grow basis-0">{children}</div>
|
||||
<RiArrowRightUpLine className="h-3.5 w-3.5" />
|
||||
</a>
|
||||
)
|
||||
|
||||
@@ -74,7 +74,7 @@ const VersionInfoModal: FC<VersionInfoModalProps> = ({
|
||||
return (
|
||||
<Modal className="p-0" isShow={isOpen} onClose={onClose}>
|
||||
<div className="relative w-full p-6 pb-4 pr-14">
|
||||
<div className="text-text-primary title-2xl-semi-bold first-letter:capitalize">
|
||||
<div className="title-2xl-semi-bold text-text-primary first-letter:capitalize">
|
||||
{versionInfo?.marked_name ? t('versionHistory.editVersionInfo', { ns: 'workflow' }) : t('versionHistory.nameThisVersion', { ns: 'workflow' })}
|
||||
</div>
|
||||
<div className="absolute right-5 top-5 flex h-8 w-8 cursor-pointer items-center justify-center p-1.5" onClick={onClose}>
|
||||
@@ -83,7 +83,7 @@ const VersionInfoModal: FC<VersionInfoModalProps> = ({
|
||||
</div>
|
||||
<div className="flex flex-col gap-y-4 px-6 py-3">
|
||||
<div className="flex flex-col gap-y-1">
|
||||
<div className="flex h-6 items-center text-text-secondary system-sm-semibold">
|
||||
<div className="system-sm-semibold flex h-6 items-center text-text-secondary">
|
||||
{t('versionHistory.editField.title', { ns: 'workflow' })}
|
||||
</div>
|
||||
<Input
|
||||
@@ -94,7 +94,7 @@ const VersionInfoModal: FC<VersionInfoModalProps> = ({
|
||||
/>
|
||||
</div>
|
||||
<div className="flex flex-col gap-y-1">
|
||||
<div className="flex h-6 items-center text-text-secondary system-sm-semibold">
|
||||
<div className="system-sm-semibold flex h-6 items-center text-text-secondary">
|
||||
{t('versionHistory.editField.releaseNotes', { ns: 'workflow' })}
|
||||
</div>
|
||||
<Textarea
|
||||
|
||||
@@ -29,7 +29,7 @@ const FeaturePanel: FC<IFeaturePanelProps> = ({
|
||||
<div className="flex h-8 items-center justify-between">
|
||||
<div className="flex shrink-0 items-center space-x-1">
|
||||
{!!headerIcon && <div className="flex h-6 w-6 items-center justify-center">{headerIcon}</div>}
|
||||
<div className="text-text-secondary system-sm-semibold">{title}</div>
|
||||
<div className="system-sm-semibold text-text-secondary">{title}</div>
|
||||
</div>
|
||||
<div className="flex items-center gap-2">
|
||||
{!!headerRight && <div>{headerRight}</div>}
|
||||
|
||||
@@ -2,19 +2,25 @@ import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import * as React from 'react'
|
||||
import HasNotSetAPI from './has-not-set-api'
|
||||
|
||||
describe('HasNotSetAPI', () => {
|
||||
it('should render the empty state copy', () => {
|
||||
render(<HasNotSetAPI onSetting={vi.fn()} />)
|
||||
describe('HasNotSetAPI WarningMask', () => {
|
||||
it('should show default title when trial not finished', () => {
|
||||
render(<HasNotSetAPI isTrailFinished={false} onSetting={vi.fn()} />)
|
||||
|
||||
expect(screen.getByText('appDebug.noModelProviderConfigured')).toBeInTheDocument()
|
||||
expect(screen.getByText('appDebug.noModelProviderConfiguredTip')).toBeInTheDocument()
|
||||
expect(screen.getByText('appDebug.notSetAPIKey.title')).toBeInTheDocument()
|
||||
expect(screen.getByText('appDebug.notSetAPIKey.description')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onSetting when manage models button is clicked', () => {
|
||||
const onSetting = vi.fn()
|
||||
render(<HasNotSetAPI onSetting={onSetting} />)
|
||||
it('should show trail finished title when flag is true', () => {
|
||||
render(<HasNotSetAPI isTrailFinished onSetting={vi.fn()} />)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'appDebug.manageModels' }))
|
||||
expect(screen.getByText('appDebug.notSetAPIKey.trailFinished')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onSetting when primary button clicked', () => {
|
||||
const onSetting = vi.fn()
|
||||
render(<HasNotSetAPI isTrailFinished={false} onSetting={onSetting} />)
|
||||
|
||||
fireEvent.click(screen.getByRole('button', { name: 'appDebug.notSetAPIKey.settingBtn' }))
|
||||
expect(onSetting).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
})
|
||||
|
||||
@@ -2,38 +2,38 @@
|
||||
import type { FC } from 'react'
|
||||
import * as React from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import WarningMask from '.'
|
||||
|
||||
export type IHasNotSetAPIProps = {
|
||||
isTrailFinished: boolean
|
||||
onSetting: () => void
|
||||
}
|
||||
|
||||
const icon = (
|
||||
<svg width="16" height="16" viewBox="0 0 16 16" fill="none" xmlns="http://www.w3.org/2000/svg">
|
||||
<path d="M14 6.00001L14 2.00001M14 2.00001H9.99999M14 2.00001L8 8M6.66667 2H5.2C4.0799 2 3.51984 2 3.09202 2.21799C2.71569 2.40973 2.40973 2.71569 2.21799 3.09202C2 3.51984 2 4.07989 2 5.2V10.8C2 11.9201 2 12.4802 2.21799 12.908C2.40973 13.2843 2.71569 13.5903 3.09202 13.782C3.51984 14 4.07989 14 5.2 14H10.8C11.9201 14 12.4802 14 12.908 13.782C13.2843 13.5903 13.5903 13.2843 13.782 12.908C14 12.4802 14 11.9201 14 10.8V9.33333" stroke="white" strokeWidth="1.5" strokeLinecap="round" strokeLinejoin="round" />
|
||||
</svg>
|
||||
|
||||
)
|
||||
|
||||
const HasNotSetAPI: FC<IHasNotSetAPIProps> = ({
|
||||
isTrailFinished,
|
||||
onSetting,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
return (
|
||||
<div className="flex grow flex-col items-center justify-center pb-[120px]">
|
||||
<div className="flex w-full max-w-[400px] flex-col gap-2 px-4 py-4">
|
||||
<div className="flex h-10 w-10 items-center justify-center rounded-[10px]">
|
||||
<div className="flex h-full w-full items-center justify-center overflow-hidden rounded-[10px] border-[0.5px] border-components-card-border bg-components-card-bg p-1 shadow-lg backdrop-blur-[5px]">
|
||||
<span className="i-ri-brain-2-line h-5 w-5 text-text-tertiary" />
|
||||
</div>
|
||||
</div>
|
||||
<div className="flex flex-col gap-1">
|
||||
<div className="text-text-secondary system-md-semibold">{t('noModelProviderConfigured', { ns: 'appDebug' })}</div>
|
||||
<div className="text-text-tertiary system-xs-regular">{t('noModelProviderConfiguredTip', { ns: 'appDebug' })}</div>
|
||||
</div>
|
||||
<button
|
||||
type="button"
|
||||
className="flex w-fit items-center gap-1 rounded-lg border-[0.5px] border-components-button-secondary-border bg-components-button-secondary-bg px-3 py-2 shadow-xs backdrop-blur-[5px]"
|
||||
onClick={onSetting}
|
||||
>
|
||||
<span className="text-components-button-secondary-accent-text system-sm-medium">{t('manageModels', { ns: 'appDebug' })}</span>
|
||||
<span className="i-ri-arrow-right-line h-4 w-4 text-components-button-secondary-accent-text" />
|
||||
</button>
|
||||
</div>
|
||||
</div>
|
||||
<WarningMask
|
||||
title={isTrailFinished ? t('notSetAPIKey.trailFinished', { ns: 'appDebug' }) : t('notSetAPIKey.title', { ns: 'appDebug' })}
|
||||
description={t('notSetAPIKey.description', { ns: 'appDebug' })}
|
||||
footer={(
|
||||
<Button variant="primary" className="flex space-x-2" onClick={onSetting}>
|
||||
<span>{t('notSetAPIKey.settingBtn', { ns: 'appDebug' })}</span>
|
||||
{icon}
|
||||
</Button>
|
||||
)}
|
||||
/>
|
||||
)
|
||||
}
|
||||
export default React.memo(HasNotSetAPI)
|
||||
|
||||
@@ -35,7 +35,7 @@ const ConfirmAddVar: FC<IConfirmAddVarProps> = ({
|
||||
// }, mainContentRef)
|
||||
return (
|
||||
<div
|
||||
className="absolute inset-0 flex items-center justify-center rounded-xl"
|
||||
className="absolute inset-0 flex items-center justify-center rounded-xl"
|
||||
style={{
|
||||
backgroundColor: 'rgba(35, 56, 118, 0.2)',
|
||||
}}
|
||||
|
||||
@@ -28,7 +28,7 @@ const MessageTypeSelector: FC<Props> = ({
|
||||
className={cn(showOption && 'bg-indigo-100', 'flex h-7 cursor-pointer items-center space-x-0.5 rounded-lg pl-1.5 pr-1 text-indigo-800')}
|
||||
>
|
||||
<div className="text-sm font-semibold uppercase">{value}</div>
|
||||
<ChevronSelectorVertical className="h-3 w-3" />
|
||||
<ChevronSelectorVertical className="h-3 w-3 " />
|
||||
</div>
|
||||
{showOption && (
|
||||
<div className="absolute top-[30px] z-10 rounded-lg border border-components-panel-border bg-components-panel-bg p-1 shadow-lg">
|
||||
|
||||
@@ -178,7 +178,7 @@ const Prompt: FC<ISimplePromptInput> = ({
|
||||
{!noTitle && (
|
||||
<div className="flex h-11 items-center justify-between pl-3 pr-2.5">
|
||||
<div className="flex items-center space-x-1">
|
||||
<div className="h2 text-text-secondary system-sm-semibold-uppercase">{mode !== AppModeEnum.COMPLETION ? t('chatSubTitle', { ns: 'appDebug' }) : t('completionSubTitle', { ns: 'appDebug' })}</div>
|
||||
<div className="h2 system-sm-semibold-uppercase text-text-secondary">{mode !== AppModeEnum.COMPLETION ? t('chatSubTitle', { ns: 'appDebug' }) : t('completionSubTitle', { ns: 'appDebug' })}</div>
|
||||
{!readonly && (
|
||||
<Tooltip
|
||||
popupContent={(
|
||||
|
||||
@@ -482,12 +482,12 @@ const ConfigModal: FC<IConfigModalProps> = ({
|
||||
|
||||
<div className="!mt-5 flex h-6 items-center space-x-2">
|
||||
<Checkbox checked={tempPayload.required} disabled={tempPayload.hide} onCheck={() => handlePayloadChange('required')(!tempPayload.required)} />
|
||||
<span className="text-text-secondary system-sm-semibold">{t('variableConfig.required', { ns: 'appDebug' })}</span>
|
||||
<span className="system-sm-semibold text-text-secondary">{t('variableConfig.required', { ns: 'appDebug' })}</span>
|
||||
</div>
|
||||
|
||||
<div className="!mt-5 flex h-6 items-center space-x-2">
|
||||
<Checkbox checked={tempPayload.hide} disabled={tempPayload.required} onCheck={() => handlePayloadChange('hide')(!tempPayload.hide)} />
|
||||
<span className="text-text-secondary system-sm-semibold">{t('variableConfig.hide', { ns: 'appDebug' })}</span>
|
||||
<span className="system-sm-semibold text-text-secondary">{t('variableConfig.hide', { ns: 'appDebug' })}</span>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user