Compare commits

...

10 Commits

Author SHA1 Message Date
FFXN
1ce0610c4c feat: Inject "Start" node for snippet before running the whole snippet workflow. 2026-02-14 13:28:30 +08:00
FFXN
b2b0be6b8a Merge remote-tracking branch 'origin/feat/evaluation' into feat/evaluation
# Conflicts:
#	api/controllers/console/snippets/payloads.py
#	api/controllers/console/snippets/snippet_workflow.py
#	api/services/snippet_service.py
2026-02-14 09:55:19 +08:00
FFXN
fb4584b776 feat: Features about running and debugging snippets. 2026-02-14 09:50:34 +08:00
FFXN
632d93f475 feat: Implement the APIs of downloading evaluation dataset template file and downloading evaluation dataset file/evaluation result file. 2026-02-14 09:50:34 +08:00
jyong
36dc948520 evaluation 2026-02-14 09:50:34 +08:00
jyong
bad6fb3470 evaluations 2026-02-14 09:50:34 +08:00
Coding On Star
210710e76d refactor(web): extract custom hooks from complex components and add comprehensive tests (#32301)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-02-13 17:21:34 +08:00
FFXN
a49504bd5b feat: Implement the APIs of downloading evaluation dataset template file and downloading evaluation dataset file/evaluation result file. 2026-02-12 13:32:43 +08:00
jyong
3dfc797645 evaluation 2026-02-11 16:56:30 +08:00
jyong
bea428e308 evaluations 2026-01-30 17:35:36 +08:00
35 changed files with 5121 additions and 983 deletions

View File

@@ -116,6 +116,12 @@ from .explore import (
trial,
)
# Import evaluation controllers
from .evaluation import evaluation
# Import snippet controllers
from .snippets import snippet_workflow
# Import tag controllers
from .tag import tags
@@ -129,6 +135,7 @@ from .workspace import (
model_providers,
models,
plugin,
snippets,
tool_providers,
trigger_providers,
workspace,
@@ -166,6 +173,7 @@ __all__ = [
"datasource_content_preview",
"email_register",
"endpoint",
"evaluation",
"extension",
"external",
"feature",
@@ -199,6 +207,8 @@ __all__ = [
"saved_message",
"setup",
"site",
"snippet_workflow",
"snippets",
"spec",
"statistic",
"tags",

View File

@@ -0,0 +1 @@
# Evaluation controller module

View File

@@ -0,0 +1,320 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar, Union
from urllib.parse import quote
from flask import Response, request
from flask_restx import Resource, fields
from pydantic import BaseModel
from sqlalchemy import select
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from core.file import helpers as file_helpers
from extensions.ext_database import db
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models import App
from models.model import UploadFile
from models.snippet import CustomizedSnippet
from services.evaluation_service import EvaluationService
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
# Valid evaluation target types
EVALUATE_TARGET_TYPES = {"app", "snippets"}
class VersionQuery(BaseModel):
"""Query parameters for version endpoint."""
version: str
register_schema_models(
console_ns,
VersionQuery,
)
# Response field definitions
file_info_fields = {
"id": fields.String,
"name": fields.String,
}
evaluation_log_fields = {
"created_at": TimestampField,
"created_by": fields.String,
"test_file": fields.Nested(
console_ns.model(
"EvaluationTestFile",
file_info_fields,
)
),
"result_file": fields.Nested(
console_ns.model(
"EvaluationResultFile",
file_info_fields,
),
allow_null=True,
),
"version": fields.String,
}
evaluation_log_list_model = console_ns.model(
"EvaluationLogList",
{
"data": fields.List(fields.Nested(console_ns.model("EvaluationLog", evaluation_log_fields))),
},
)
customized_matrix_fields = {
"evaluation_workflow_id": fields.String,
"input_fields": fields.Raw,
"output_fields": fields.Raw,
}
condition_fields = {
"name": fields.List(fields.String),
"comparison_operator": fields.String,
"value": fields.String,
}
judgement_conditions_fields = {
"logical_operator": fields.String,
"conditions": fields.List(fields.Nested(console_ns.model("EvaluationCondition", condition_fields))),
}
evaluation_detail_fields = {
"evaluation_model": fields.String,
"evaluation_model_provider": fields.String,
"customized_matrix": fields.Nested(
console_ns.model("EvaluationCustomizedMatrix", customized_matrix_fields),
allow_null=True,
),
"judgement_conditions": fields.Nested(
console_ns.model("EvaluationJudgementConditions", judgement_conditions_fields),
allow_null=True,
),
}
evaluation_detail_model = console_ns.model("EvaluationDetail", evaluation_detail_fields)
def get_evaluation_target(view_func: Callable[P, R]):
"""
Decorator to resolve polymorphic evaluation target (app or snippet).
Validates the target_type parameter and fetches the corresponding
model (App or CustomizedSnippet) with tenant isolation.
"""
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
target_type = kwargs.get("evaluate_target_type")
target_id = kwargs.get("evaluate_target_id")
if target_type not in EVALUATE_TARGET_TYPES:
raise NotFound(f"Invalid evaluation target type: {target_type}")
_, current_tenant_id = current_account_with_tenant()
target_id = str(target_id)
# Remove path parameters
del kwargs["evaluate_target_type"]
del kwargs["evaluate_target_id"]
target: Union[App, CustomizedSnippet] | None = None
if target_type == "app":
target = (
db.session.query(App).where(App.id == target_id, App.tenant_id == current_tenant_id).first()
)
elif target_type == "snippets":
target = (
db.session.query(CustomizedSnippet)
.where(CustomizedSnippet.id == target_id, CustomizedSnippet.tenant_id == current_tenant_id)
.first()
)
if not target:
raise NotFound(f"{str(target_type)} not found")
kwargs["target"] = target
kwargs["target_type"] = target_type
return view_func(*args, **kwargs)
return decorated_view
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/dataset-template/download")
class EvaluationDatasetTemplateDownloadApi(Resource):
@console_ns.doc("download_evaluation_dataset_template")
@console_ns.response(200, "Template file streamed as XLSX attachment")
@console_ns.response(400, "Invalid target type or excluded app mode")
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
@edit_permission_required
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Download evaluation dataset template.
Generates an XLSX template based on the target's input parameters
and streams it directly as a file attachment.
"""
try:
xlsx_content, filename = EvaluationService.generate_dataset_template(
target=target,
target_type=target_type,
)
except ValueError as e:
return {"message": str(e)}, 400
encoded_filename = quote(filename)
response = Response(
xlsx_content,
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
)
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
response.headers["Content-Length"] = str(len(xlsx_content))
return response
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation")
class EvaluationDetailApi(Resource):
@console_ns.doc("get_evaluation_detail")
@console_ns.response(200, "Evaluation details retrieved successfully", evaluation_detail_model)
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get evaluation details for the target.
Returns evaluation configuration including model settings,
customized matrix, and judgement conditions.
"""
# TODO: Implement actual evaluation detail retrieval
# This is a placeholder implementation
return {
"evaluation_model": None,
"evaluation_model_provider": None,
"customized_matrix": None,
"judgement_conditions": None,
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/logs")
class EvaluationLogsApi(Resource):
@console_ns.doc("get_evaluation_logs")
@console_ns.response(200, "Evaluation logs retrieved successfully", evaluation_log_list_model)
@console_ns.response(404, "Target not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get offline evaluation logs for the target.
Returns a list of evaluation runs with test files,
result files, and version information.
"""
# TODO: Implement actual evaluation logs retrieval
# This is a placeholder implementation
return {
"data": [],
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/files/<uuid:file_id>")
class EvaluationFileDownloadApi(Resource):
@console_ns.doc("download_evaluation_file")
@console_ns.response(200, "File download URL generated successfully")
@console_ns.response(404, "Target or file not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str, file_id: str):
"""
Download evaluation test file or result file.
Looks up the specified file, verifies it belongs to the same tenant,
and returns file info and download URL.
"""
file_id = str(file_id)
_, current_tenant_id = current_account_with_tenant()
with Session(db.engine, expire_on_commit=False) as session:
stmt = select(UploadFile).where(
UploadFile.id == file_id,
UploadFile.tenant_id == current_tenant_id,
)
upload_file = session.execute(stmt).scalar_one_or_none()
if not upload_file:
raise NotFound("File not found")
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
return {
"id": upload_file.id,
"name": upload_file.name,
"size": upload_file.size,
"extension": upload_file.extension,
"mime_type": upload_file.mime_type,
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
"download_url": download_url,
}
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/version")
class EvaluationVersionApi(Resource):
@console_ns.doc("get_evaluation_version_detail")
@console_ns.expect(console_ns.models.get(VersionQuery.__name__))
@console_ns.response(200, "Version details retrieved successfully")
@console_ns.response(404, "Target or version not found")
@setup_required
@login_required
@account_initialization_required
@get_evaluation_target
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
"""
Get evaluation target version details.
Returns the workflow graph for the specified version.
"""
version = request.args.get("version")
if not version:
return {"message": "version parameter is required"}, 400
# TODO: Implement actual version detail retrieval
# For now, return the current graph if available
graph = {}
if target_type == "snippets" and isinstance(target, CustomizedSnippet):
graph = target.graph_dict
return {
"graph": graph,
}

View File

@@ -0,0 +1,102 @@
from typing import Any, Literal
from pydantic import BaseModel, Field
class SnippetListQuery(BaseModel):
"""Query parameters for listing snippets."""
page: int = Field(default=1, ge=1, le=99999)
limit: int = Field(default=20, ge=1, le=100)
keyword: str | None = None
class IconInfo(BaseModel):
"""Icon information model."""
icon: str | None = None
icon_type: Literal["emoji", "image"] | None = None
icon_background: str | None = None
icon_url: str | None = None
class InputFieldDefinition(BaseModel):
"""Input field definition for snippet parameters."""
default: str | None = None
hint: bool | None = None
label: str | None = None
max_length: int | None = None
options: list[str] | None = None
placeholder: str | None = None
required: bool | None = None
type: str | None = None # e.g., "text-input"
class CreateSnippetPayload(BaseModel):
"""Payload for creating a new snippet."""
name: str = Field(..., min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=2000)
type: Literal["node", "group"] = "node"
icon_info: IconInfo | None = None
graph: dict[str, Any] | None = None
input_fields: list[InputFieldDefinition] | None = Field(default_factory=list)
class UpdateSnippetPayload(BaseModel):
"""Payload for updating a snippet."""
name: str | None = Field(default=None, min_length=1, max_length=255)
description: str | None = Field(default=None, max_length=2000)
icon_info: IconInfo | None = None
class SnippetDraftSyncPayload(BaseModel):
"""Payload for syncing snippet draft workflow."""
graph: dict[str, Any]
hash: str | None = None
environment_variables: list[dict[str, Any]] | None = None
conversation_variables: list[dict[str, Any]] | None = None
input_variables: list[dict[str, Any]] | None = None
class WorkflowRunQuery(BaseModel):
"""Query parameters for workflow runs."""
last_id: str | None = None
limit: int = Field(default=20, ge=1, le=100)
class SnippetDraftRunPayload(BaseModel):
"""Payload for running snippet draft workflow."""
inputs: dict[str, Any]
files: list[dict[str, Any]] | None = None
class SnippetDraftNodeRunPayload(BaseModel):
"""Payload for running a single node in snippet draft workflow."""
inputs: dict[str, Any]
query: str = ""
files: list[dict[str, Any]] | None = None
class SnippetIterationNodeRunPayload(BaseModel):
"""Payload for running an iteration node in snippet draft workflow."""
inputs: dict[str, Any] | None = None
class SnippetLoopNodeRunPayload(BaseModel):
"""Payload for running a loop node in snippet draft workflow."""
inputs: dict[str, Any] | None = None
class PublishWorkflowPayload(BaseModel):
"""Payload for publishing snippet workflow."""
knowledge_base_setting: dict[str, Any] | None = None

View File

@@ -0,0 +1,540 @@
import logging
from collections.abc import Callable
from functools import wraps
from typing import ParamSpec, TypeVar
from flask import request
from flask_restx import Resource, marshal_with
from sqlalchemy.orm import Session
from werkzeug.exceptions import InternalServerError, NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
from controllers.console.app.workflow import workflow_model
from controllers.console.app.workflow_run import (
workflow_run_detail_model,
workflow_run_node_execution_list_model,
workflow_run_node_execution_model,
workflow_run_pagination_model,
)
from controllers.console.snippets.payloads import (
PublishWorkflowPayload,
SnippetDraftNodeRunPayload,
SnippetDraftRunPayload,
SnippetDraftSyncPayload,
SnippetIterationNodeRunPayload,
SnippetLoopNodeRunPayload,
WorkflowRunQuery,
)
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.entities.app_invoke_entities import InvokeFrom
from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from factories import variable_factory
from libs import helper
from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.snippet import CustomizedSnippet
from services.errors.app import WorkflowHashNotEqualError
from services.snippet_generate_service import SnippetGenerateService
from services.snippet_service import SnippetService
logger = logging.getLogger(__name__)
P = ParamSpec("P")
R = TypeVar("R")
# Register Pydantic models with Swagger
register_schema_models(
console_ns,
SnippetDraftSyncPayload,
SnippetDraftNodeRunPayload,
SnippetDraftRunPayload,
SnippetIterationNodeRunPayload,
SnippetLoopNodeRunPayload,
WorkflowRunQuery,
PublishWorkflowPayload,
)
class SnippetNotFoundError(Exception):
"""Snippet not found error."""
pass
def get_snippet(view_func: Callable[P, R]):
"""Decorator to fetch and validate snippet access."""
@wraps(view_func)
def decorated_view(*args: P.args, **kwargs: P.kwargs):
if not kwargs.get("snippet_id"):
raise ValueError("missing snippet_id in path parameters")
_, current_tenant_id = current_account_with_tenant()
snippet_id = str(kwargs.get("snippet_id"))
del kwargs["snippet_id"]
snippet = SnippetService.get_snippet_by_id(
snippet_id=snippet_id,
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
kwargs["snippet"] = snippet
return view_func(*args, **kwargs)
return decorated_view
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft")
class SnippetDraftWorkflowApi(Resource):
@console_ns.doc("get_snippet_draft_workflow")
@console_ns.response(200, "Draft workflow retrieved successfully", workflow_model)
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
@marshal_with(workflow_model)
def get(self, snippet: CustomizedSnippet):
"""Get draft workflow for snippet."""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
raise DraftWorkflowNotExist()
return workflow
@console_ns.doc("sync_snippet_draft_workflow")
@console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__))
@console_ns.response(200, "Draft workflow synced successfully")
@console_ns.response(400, "Hash mismatch")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""Sync draft workflow for snippet."""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
try:
environment_variables_list = payload.environment_variables or []
environment_variables = [
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
]
conversation_variables_list = payload.conversation_variables or []
conversation_variables = [
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
]
snippet_service = SnippetService()
workflow = snippet_service.sync_draft_workflow(
snippet=snippet,
graph=payload.graph,
unique_hash=payload.hash,
account=current_user,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
input_variables=payload.input_variables,
)
except WorkflowHashNotEqualError:
raise DraftWorkflowNotSync()
return {
"result": "success",
"hash": workflow.unique_hash,
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/config")
class SnippetDraftConfigApi(Resource):
@console_ns.doc("get_snippet_draft_config")
@console_ns.response(200, "Draft config retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def get(self, snippet: CustomizedSnippet):
"""Get snippet draft workflow configuration limits."""
return {
"parallel_depth_limit": 3,
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/publish")
class SnippetPublishedWorkflowApi(Resource):
@console_ns.doc("get_snippet_published_workflow")
@console_ns.response(200, "Published workflow retrieved successfully", workflow_model)
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
@marshal_with(workflow_model)
def get(self, snippet: CustomizedSnippet):
"""Get published workflow for snippet."""
if not snippet.is_published:
return None
snippet_service = SnippetService()
workflow = snippet_service.get_published_workflow(snippet=snippet)
return workflow
@console_ns.doc("publish_snippet_workflow")
@console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__))
@console_ns.response(200, "Workflow published successfully")
@console_ns.response(400, "No draft workflow found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""Publish snippet workflow."""
current_user, _ = current_account_with_tenant()
snippet_service = SnippetService()
with Session(db.engine) as session:
snippet = session.merge(snippet)
try:
workflow = snippet_service.publish_workflow(
session=session,
snippet=snippet,
account=current_user,
)
workflow_created_at = TimestampField().format(workflow.created_at)
session.commit()
except ValueError as e:
return {"message": str(e)}, 400
return {
"result": "success",
"created_at": workflow_created_at,
}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/default-workflow-block-configs")
class SnippetDefaultBlockConfigsApi(Resource):
@console_ns.doc("get_snippet_default_block_configs")
@console_ns.response(200, "Default block configs retrieved successfully")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def get(self, snippet: CustomizedSnippet):
"""Get default block configurations for snippet workflow."""
snippet_service = SnippetService()
return snippet_service.get_default_block_configs()
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs")
class SnippetWorkflowRunsApi(Resource):
@console_ns.doc("list_snippet_workflow_runs")
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_pagination_model)
def get(self, snippet: CustomizedSnippet):
"""List workflow runs for snippet."""
query = WorkflowRunQuery.model_validate(
{
"last_id": request.args.get("last_id"),
"limit": request.args.get("limit", type=int, default=20),
}
)
args = {
"last_id": query.last_id,
"limit": query.limit,
}
snippet_service = SnippetService()
result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args)
return result
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>")
class SnippetWorkflowRunDetailApi(Resource):
@console_ns.doc("get_snippet_workflow_run_detail")
@console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
@console_ns.response(404, "Workflow run not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_detail_model)
def get(self, snippet: CustomizedSnippet, run_id):
"""Get workflow run detail for snippet."""
run_id = str(run_id)
snippet_service = SnippetService()
workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
if not workflow_run:
raise NotFound("Workflow run not found")
return workflow_run
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>/node-executions")
class SnippetWorkflowRunNodeExecutionsApi(Resource):
@console_ns.doc("list_snippet_workflow_run_node_executions")
@console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_node_execution_list_model)
def get(self, snippet: CustomizedSnippet, run_id):
"""List node executions for a workflow run."""
run_id = str(run_id)
snippet_service = SnippetService()
node_executions = snippet_service.get_snippet_workflow_run_node_executions(
snippet=snippet,
run_id=run_id,
)
return {"data": node_executions}
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/run")
class SnippetDraftNodeRunApi(Resource):
@console_ns.doc("run_snippet_draft_node")
@console_ns.doc(description="Run a single node in snippet draft workflow (single-step debugging)")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models.get(SnippetDraftNodeRunPayload.__name__))
@console_ns.response(200, "Node run completed successfully", workflow_run_node_execution_model)
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_node_execution_model)
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
"""
Run a single node in snippet draft workflow.
Executes a specific node with provided inputs for single-step debugging.
Returns the node execution result including status, outputs, and timing.
"""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {})
user_inputs = payload.inputs
# Get draft workflow for file parsing
snippet_service = SnippetService()
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not draft_workflow:
raise NotFound("Draft workflow not found")
files = SnippetGenerateService.parse_files(draft_workflow, payload.files)
workflow_node_execution = SnippetGenerateService.run_draft_node(
snippet=snippet,
node_id=node_id,
user_inputs=user_inputs,
account=current_user,
query=payload.query,
files=files,
)
return workflow_node_execution
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/last-run")
class SnippetDraftNodeLastRunApi(Resource):
@console_ns.doc("get_snippet_draft_node_last_run")
@console_ns.doc(description="Get last run result for a node in snippet draft workflow")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_model)
@console_ns.response(404, "Snippet, draft workflow, or node last run not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@marshal_with(workflow_run_node_execution_model)
def get(self, snippet: CustomizedSnippet, node_id: str):
"""
Get the last run result for a specific node in snippet draft workflow.
Returns the most recent execution record for the given node,
including status, inputs, outputs, and timing information.
"""
snippet_service = SnippetService()
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not draft_workflow:
raise NotFound("Draft workflow not found")
node_exec = snippet_service.get_snippet_node_last_run(
snippet=snippet,
workflow=draft_workflow,
node_id=node_id,
)
if node_exec is None:
raise NotFound("Node last run not found")
return node_exec
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
class SnippetDraftRunIterationNodeApi(Resource):
@console_ns.doc("run_snippet_draft_iteration_node")
@console_ns.doc(description="Run draft workflow iteration node for snippet")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models.get(SnippetIterationNodeRunPayload.__name__))
@console_ns.response(200, "Iteration node run started successfully (SSE stream)")
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
"""
Run a draft workflow iteration node for snippet.
Iteration nodes execute their internal sub-graph multiple times over an input list.
Returns an SSE event stream with iteration progress and results.
"""
current_user, _ = current_account_with_tenant()
args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
try:
response = SnippetGenerateService.generate_single_iteration(
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
)
return helper.compact_generate_response(response)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/loop/nodes/<string:node_id>/run")
class SnippetDraftRunLoopNodeApi(Resource):
@console_ns.doc("run_snippet_draft_loop_node")
@console_ns.doc(description="Run draft workflow loop node for snippet")
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models.get(SnippetLoopNodeRunPayload.__name__))
@console_ns.response(200, "Loop node run started successfully (SSE stream)")
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, node_id: str):
"""
Run a draft workflow loop node for snippet.
Loop nodes execute their internal sub-graph repeatedly until a condition is met.
Returns an SSE event stream with loop progress and results.
"""
current_user, _ = current_account_with_tenant()
args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {})
try:
response = SnippetGenerateService.generate_single_loop(
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
)
return helper.compact_generate_response(response)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/run")
class SnippetDraftWorkflowRunApi(Resource):
@console_ns.doc("run_snippet_draft_workflow")
@console_ns.expect(console_ns.models.get(SnippetDraftRunPayload.__name__))
@console_ns.response(200, "Draft workflow run started successfully (SSE stream)")
@console_ns.response(404, "Snippet or draft workflow not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet):
"""
Run draft workflow for snippet.
Executes the snippet's draft workflow with the provided inputs
and returns an SSE event stream with execution progress and results.
"""
current_user, _ = current_account_with_tenant()
payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {})
args = payload.model_dump(exclude_none=True)
try:
response = SnippetGenerateService.generate(
snippet=snippet,
user=current_user,
args=args,
invoke_from=InvokeFrom.DEBUGGER,
streaming=True,
)
return helper.compact_generate_response(response)
except ValueError as e:
raise e
except Exception:
logger.exception("internal server error.")
raise InternalServerError()
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/tasks/<string:task_id>/stop")
class SnippetWorkflowTaskStopApi(Resource):
@console_ns.doc("stop_snippet_workflow_task")
@console_ns.response(200, "Task stopped successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@get_snippet
@edit_permission_required
def post(self, snippet: CustomizedSnippet, task_id: str):
"""
Stop a running snippet workflow task.
Uses both the legacy stop flag mechanism and the graph engine
command channel for backward compatibility.
"""
# Stop using both mechanisms for backward compatibility
# Legacy stop flag mechanism (without user check)
AppQueueManager.set_stop_flag_no_user_check(task_id)
# New graph engine command channel mechanism
GraphEngineManager.send_stop_command(task_id)
return {"result": "success"}

View File

@@ -0,0 +1,202 @@
import logging
from flask import request
from flask_restx import Resource, marshal, marshal_with
from sqlalchemy.orm import Session
from werkzeug.exceptions import NotFound
from controllers.common.schema import register_schema_models
from controllers.console import console_ns
from controllers.console.snippets.payloads import (
CreateSnippetPayload,
SnippetListQuery,
UpdateSnippetPayload,
)
from controllers.console.wraps import (
account_initialization_required,
edit_permission_required,
setup_required,
)
from extensions.ext_database import db
from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields
from libs.login import current_account_with_tenant, login_required
from models.snippet import SnippetType
from services.snippet_service import SnippetService
logger = logging.getLogger(__name__)
# Register Pydantic models with Swagger
register_schema_models(
console_ns,
SnippetListQuery,
CreateSnippetPayload,
UpdateSnippetPayload,
)
# Create namespace models for marshaling
snippet_model = console_ns.model("Snippet", snippet_fields)
snippet_list_model = console_ns.model("SnippetList", snippet_list_fields)
snippet_pagination_model = console_ns.model("SnippetPagination", snippet_pagination_fields)
@console_ns.route("/workspaces/current/customized-snippets")
class CustomizedSnippetsApi(Resource):
@console_ns.doc("list_customized_snippets")
@console_ns.expect(console_ns.models.get(SnippetListQuery.__name__))
@console_ns.response(200, "Snippets retrieved successfully", snippet_pagination_model)
@setup_required
@login_required
@account_initialization_required
def get(self):
"""List customized snippets with pagination and search."""
_, current_tenant_id = current_account_with_tenant()
query_params = request.args.to_dict()
query = SnippetListQuery.model_validate(query_params)
snippets, total, has_more = SnippetService.get_snippets(
tenant_id=current_tenant_id,
page=query.page,
limit=query.limit,
keyword=query.keyword,
)
return {
"data": marshal(snippets, snippet_list_fields),
"page": query.page,
"limit": query.limit,
"total": total,
"has_more": has_more,
}, 200
@console_ns.doc("create_customized_snippet")
@console_ns.expect(console_ns.models.get(CreateSnippetPayload.__name__))
@console_ns.response(201, "Snippet created successfully", snippet_model)
@console_ns.response(400, "Invalid request or name already exists")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def post(self):
"""Create a new customized snippet."""
current_user, current_tenant_id = current_account_with_tenant()
payload = CreateSnippetPayload.model_validate(console_ns.payload or {})
try:
snippet_type = SnippetType(payload.type)
except ValueError:
snippet_type = SnippetType.NODE
try:
snippet = SnippetService.create_snippet(
tenant_id=current_tenant_id,
name=payload.name,
description=payload.description,
snippet_type=snippet_type,
icon_info=payload.icon_info.model_dump() if payload.icon_info else None,
graph=payload.graph,
input_fields=[f.model_dump() for f in payload.input_fields] if payload.input_fields else None,
account=current_user,
)
except ValueError as e:
return {"message": str(e)}, 400
return marshal(snippet, snippet_fields), 201
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>")
class CustomizedSnippetDetailApi(Resource):
@console_ns.doc("get_customized_snippet")
@console_ns.response(200, "Snippet retrieved successfully", snippet_model)
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
def get(self, snippet_id: str):
"""Get customized snippet details."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
return marshal(snippet, snippet_fields), 200
@console_ns.doc("update_customized_snippet")
@console_ns.expect(console_ns.models.get(UpdateSnippetPayload.__name__))
@console_ns.response(200, "Snippet updated successfully", snippet_model)
@console_ns.response(400, "Invalid request or name already exists")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def patch(self, snippet_id: str):
"""Update customized snippet."""
current_user, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
payload = UpdateSnippetPayload.model_validate(console_ns.payload or {})
update_data = payload.model_dump(exclude_unset=True)
if "icon_info" in update_data and update_data["icon_info"] is not None:
update_data["icon_info"] = payload.icon_info.model_dump() if payload.icon_info else None
if not update_data:
return {"message": "No valid fields to update"}, 400
try:
with Session(db.engine, expire_on_commit=False) as session:
snippet = session.merge(snippet)
snippet = SnippetService.update_snippet(
session=session,
snippet=snippet,
account_id=current_user.id,
data=update_data,
)
session.commit()
except ValueError as e:
return {"message": str(e)}, 400
return marshal(snippet, snippet_fields), 200
@console_ns.doc("delete_customized_snippet")
@console_ns.response(204, "Snippet deleted successfully")
@console_ns.response(404, "Snippet not found")
@setup_required
@login_required
@account_initialization_required
@edit_permission_required
def delete(self, snippet_id: str):
"""Delete customized snippet."""
_, current_tenant_id = current_account_with_tenant()
snippet = SnippetService.get_snippet_by_id(
snippet_id=str(snippet_id),
tenant_id=current_tenant_id,
)
if not snippet:
raise NotFound("Snippet not found")
with Session(db.engine) as session:
snippet = session.merge(snippet)
SnippetService.delete_snippet(
session=session,
snippet=snippet,
)
session.commit()
return "", 204

View File

@@ -0,0 +1,45 @@
from flask_restx import fields
from fields.member_fields import simple_account_fields
from libs.helper import TimestampField
# Snippet list item fields (lightweight for list display)
snippet_list_fields = {
"id": fields.String,
"name": fields.String,
"description": fields.String,
"type": fields.String,
"version": fields.Integer,
"use_count": fields.Integer,
"is_published": fields.Boolean,
"icon_info": fields.Raw,
"created_at": TimestampField,
"updated_at": TimestampField,
}
# Full snippet fields (includes creator info and graph data)
snippet_fields = {
"id": fields.String,
"name": fields.String,
"description": fields.String,
"type": fields.String,
"version": fields.Integer,
"use_count": fields.Integer,
"is_published": fields.Boolean,
"icon_info": fields.Raw,
"graph": fields.Raw(attribute="graph_dict"),
"input_fields": fields.Raw(attribute="input_fields_list"),
"created_by": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
"created_at": TimestampField,
"updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True),
"updated_at": TimestampField,
}
# Pagination response fields
snippet_pagination_fields = {
"data": fields.List(fields.Nested(snippet_list_fields)),
"page": fields.Integer,
"limit": fields.Integer,
"total": fields.Integer,
"has_more": fields.Boolean,
}

View File

@@ -0,0 +1,83 @@
"""add_customized_snippets_table
Revision ID: 1c05e80d2380
Revises: 788d3099ae3a
Create Date: 2026-01-29 12:00:00.000000
"""
import sqlalchemy as sa
from alembic import op
from sqlalchemy.dialects import postgresql
import models as models
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = "1c05e80d2380"
down_revision = "788d3099ae3a"
branch_labels = None
depends_on = None
def upgrade():
conn = op.get_bind()
if _is_pg(conn):
op.create_table(
"customized_snippets",
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("description", sa.Text(), nullable=True),
sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False),
sa.Column("workflow_id", models.types.StringUUID(), nullable=True),
sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False),
sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False),
sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
sa.Column("icon_info", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
sa.Column("graph", sa.Text(), nullable=True),
sa.Column("input_fields", sa.Text(), nullable=True),
sa.Column("created_by", models.types.StringUUID(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
)
else:
op.create_table(
"customized_snippets",
sa.Column("id", models.types.StringUUID(), nullable=False),
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
sa.Column("name", sa.String(length=255), nullable=False),
sa.Column("description", models.types.LongText(), nullable=True),
sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False),
sa.Column("workflow_id", models.types.StringUUID(), nullable=True),
sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False),
sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False),
sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
sa.Column("icon_info", models.types.AdjustedJSON(astext_type=sa.Text()), nullable=True),
sa.Column("graph", models.types.LongText(), nullable=True),
sa.Column("input_fields", models.types.LongText(), nullable=True),
sa.Column("created_by", models.types.StringUUID(), nullable=True),
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
)
with op.batch_alter_table("customized_snippets", schema=None) as batch_op:
batch_op.create_index("customized_snippet_tenant_idx", ["tenant_id"], unique=False)
def downgrade():
with op.batch_alter_table("customized_snippets", schema=None) as batch_op:
batch_op.drop_index("customized_snippet_tenant_idx")
op.drop_table("customized_snippets")

View File

@@ -81,6 +81,7 @@ from .provider import (
TenantDefaultModel,
TenantPreferredModelProvider,
)
from .snippet import CustomizedSnippet, SnippetType
from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
from .task import CeleryTask, CeleryTaskSet
from .tools import (
@@ -140,6 +141,7 @@ __all__ = [
"Conversation",
"ConversationVariable",
"CreatorUserRole",
"CustomizedSnippet",
"DataSourceApiKeyAuthBinding",
"DataSourceOauthBinding",
"Dataset",
@@ -184,6 +186,7 @@ __all__ = [
"RecommendedApp",
"SavedMessage",
"Site",
"SnippetType",
"Tag",
"TagBinding",
"Tenant",

101
api/models/snippet.py Normal file
View File

@@ -0,0 +1,101 @@
import json
from datetime import datetime
from enum import StrEnum
from typing import Any
import sqlalchemy as sa
from sqlalchemy import DateTime, String, func
from sqlalchemy.orm import Mapped, mapped_column
from libs.uuid_utils import uuidv7
from .account import Account
from .base import Base
from .engine import db
from .types import AdjustedJSON, LongText, StringUUID
class SnippetType(StrEnum):
"""Snippet Type Enum"""
NODE = "node"
GROUP = "group"
class CustomizedSnippet(Base):
"""
Customized Snippet Model
Stores reusable workflow components (nodes or node groups) that can be
shared across applications within a workspace.
"""
__tablename__ = "customized_snippets"
__table_args__ = (
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
sa.Index("customized_snippet_tenant_idx", "tenant_id"),
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
name: Mapped[str] = mapped_column(String(255), nullable=False)
description: Mapped[str | None] = mapped_column(LongText, nullable=True)
type: Mapped[str] = mapped_column(String(50), nullable=False, server_default=sa.text("'node'"))
# Workflow reference for published version
workflow_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
# State flags
is_published: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
version: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("1"))
use_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
# Visual customization
icon_info: Mapped[dict | None] = mapped_column(AdjustedJSON, nullable=True)
# Snippet configuration (stored as JSON text)
input_fields: Mapped[str | None] = mapped_column(LongText, nullable=True)
# Audit fields
created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
updated_at: Mapped[datetime] = mapped_column(
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
)
@property
def graph_dict(self) -> dict[str, Any]:
"""Get graph from associated workflow."""
if self.workflow_id:
from .workflow import Workflow
workflow = db.session.get(Workflow, self.workflow_id)
if workflow:
return json.loads(workflow.graph) if workflow.graph else {}
return {}
@property
def input_fields_list(self) -> list[dict[str, Any]]:
"""Parse input_fields JSON to list."""
return json.loads(self.input_fields) if self.input_fields else []
@property
def created_by_account(self) -> Account | None:
"""Get the account that created this snippet."""
if self.created_by:
return db.session.get(Account, self.created_by)
return None
@property
def updated_by_account(self) -> Account | None:
"""Get the account that last updated this snippet."""
if self.updated_by:
return db.session.get(Account, self.updated_by)
return None
@property
def version_str(self) -> str:
"""Get version as string for API response."""
return str(self.version)

View File

@@ -67,6 +67,7 @@ class WorkflowType(StrEnum):
WORKFLOW = "workflow"
CHAT = "chat"
RAG_PIPELINE = "rag-pipeline"
SNIPPET = "snippet"
@classmethod
def value_of(cls, value: str) -> "WorkflowType":

View File

@@ -0,0 +1,178 @@
import io
import logging
from typing import Union
from openpyxl import Workbook
from openpyxl.styles import Alignment, Border, Font, PatternFill, Side
from openpyxl.utils import get_column_letter
from models.model import App, AppMode
from models.snippet import CustomizedSnippet
from services.snippet_service import SnippetService
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
class EvaluationService:
"""
Service for evaluation-related operations.
Provides functionality to generate evaluation dataset templates
based on App or Snippet input parameters.
"""
# Excluded app modes that don't support evaluation templates
EXCLUDED_APP_MODES = {AppMode.RAG_PIPELINE}
@classmethod
def generate_dataset_template(
cls,
target: Union[App, CustomizedSnippet],
target_type: str,
) -> tuple[bytes, str]:
"""
Generate evaluation dataset template as XLSX bytes.
Creates an XLSX file with headers based on the evaluation target's input parameters.
The first column is index, followed by input parameter columns.
:param target: App or CustomizedSnippet instance
:param target_type: Target type string ("app" or "snippet")
:return: Tuple of (xlsx_content_bytes, filename)
:raises ValueError: If target type is not supported or app mode is excluded
"""
# Validate target type
if target_type == "app":
if not isinstance(target, App):
raise ValueError("Invalid target: expected App instance")
if AppMode.value_of(target.mode) in cls.EXCLUDED_APP_MODES:
raise ValueError(f"App mode '{target.mode}' does not support evaluation templates")
input_fields = cls._get_app_input_fields(target)
elif target_type == "snippet":
if not isinstance(target, CustomizedSnippet):
raise ValueError("Invalid target: expected CustomizedSnippet instance")
input_fields = cls._get_snippet_input_fields(target)
else:
raise ValueError(f"Unsupported target type: {target_type}")
# Generate XLSX template
xlsx_content = cls._generate_xlsx_template(input_fields, target.name)
# Build filename
truncated_name = target.name[:10] + "..." if len(target.name) > 10 else target.name
filename = f"{truncated_name}-evaluation-dataset.xlsx"
return xlsx_content, filename
@classmethod
def _get_app_input_fields(cls, app: App) -> list[dict]:
"""
Get input fields from App's workflow.
:param app: App instance
:return: List of input field definitions
"""
workflow_service = WorkflowService()
workflow = workflow_service.get_published_workflow(app_model=app)
if not workflow:
workflow = workflow_service.get_draft_workflow(app_model=app)
if not workflow:
return []
# Get user input form from workflow
user_input_form = workflow.user_input_form()
return user_input_form
@classmethod
def _get_snippet_input_fields(cls, snippet: CustomizedSnippet) -> list[dict]:
"""
Get input fields from Snippet.
Tries to get from snippet's own input_fields first,
then falls back to workflow's user_input_form.
:param snippet: CustomizedSnippet instance
:return: List of input field definitions
"""
# Try snippet's own input_fields first
input_fields = snippet.input_fields_list
if input_fields:
return input_fields
# Fallback to workflow's user_input_form
snippet_service = SnippetService()
workflow = snippet_service.get_published_workflow(snippet=snippet)
if not workflow:
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if workflow:
return workflow.user_input_form()
return []
@classmethod
def _generate_xlsx_template(cls, input_fields: list[dict], target_name: str) -> bytes:
"""
Generate XLSX template file content.
Creates a workbook with:
- First row as header row with "index" and input field names
- Styled header with background color and borders
- Empty data rows ready for user input
:param input_fields: List of input field definitions
:param target_name: Name of the target (for sheet name)
:return: XLSX file content as bytes
"""
wb = Workbook()
ws = wb.active
sheet_name = "Evaluation Dataset"
ws.title = sheet_name
header_font = Font(bold=True, color="FFFFFF")
header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
header_alignment = Alignment(horizontal="center", vertical="center")
thin_border = Border(
left=Side(style="thin"),
right=Side(style="thin"),
top=Side(style="thin"),
bottom=Side(style="thin"),
)
# Build header row
headers = ["index"]
for field in input_fields:
field_label = field.get("label") or field.get("variable")
headers.append(field_label)
# Write header row
for col_idx, header in enumerate(headers, start=1):
cell = ws.cell(row=1, column=col_idx, value=header)
cell.font = header_font
cell.fill = header_fill
cell.alignment = header_alignment
cell.border = thin_border
# Set column widths
ws.column_dimensions["A"].width = 10 # index column
for col_idx in range(2, len(headers) + 1):
ws.column_dimensions[get_column_letter(col_idx)].width = 20
# Add one empty row with row number for user reference
for col_idx in range(1, len(headers) + 1):
cell = ws.cell(row=2, column=col_idx, value="")
cell.border = thin_border
if col_idx == 1:
cell.value = 1
cell.alignment = Alignment(horizontal="center")
# Save to bytes
output = io.BytesIO()
wb.save(output)
output.seek(0)
return output.getvalue()

View File

@@ -0,0 +1,374 @@
"""
Service for generating snippet workflow executions.
Uses an adapter pattern to bridge CustomizedSnippet with the App-based
WorkflowAppGenerator. The adapter (_SnippetAsApp) provides the minimal App-like
interface needed by the generator, avoiding modifications to core workflow
infrastructure.
Key invariants:
- Snippets always run as WORKFLOW mode (not CHAT or ADVANCED_CHAT).
- The adapter maps snippet.id to app_id in workflow execution records.
- Snippet debugging has no rate limiting (max_active_requests = 0).
Supported execution modes:
- Full workflow run (generate): Runs the entire draft workflow as SSE stream.
- Single node run (run_draft_node): Synchronous single-step debugging for regular nodes.
- Single iteration run (generate_single_iteration): SSE stream for iteration container nodes.
- Single loop run (generate_single_loop): SSE stream for loop container nodes.
"""
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import Any, Union
from sqlalchemy.orm import make_transient
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.file.models import File
from factories import file_factory
from models import Account
from models.model import AppMode, EndUser
from models.snippet import CustomizedSnippet
from models.workflow import Workflow, WorkflowNodeExecutionModel
from services.snippet_service import SnippetService
from services.workflow_service import WorkflowService
logger = logging.getLogger(__name__)
class _SnippetAsApp:
"""
Minimal adapter that wraps a CustomizedSnippet to satisfy the App-like
interface required by WorkflowAppGenerator, WorkflowAppConfigManager,
and WorkflowService.run_draft_workflow_node.
Used properties:
- id: maps to snippet.id (stored as app_id in workflows table)
- tenant_id: maps to snippet.tenant_id
- mode: hardcoded to AppMode.WORKFLOW since snippets always run as workflows
- max_active_requests: defaults to 0 (no limit) for snippet debugging
- app_model_config_id: None (snippets don't have app model configs)
"""
id: str
tenant_id: str
mode: str
max_active_requests: int
app_model_config_id: str | None
def __init__(self, snippet: CustomizedSnippet) -> None:
self.id = snippet.id
self.tenant_id = snippet.tenant_id
self.mode = AppMode.WORKFLOW.value
self.max_active_requests = 0
self.app_model_config_id = None
class SnippetGenerateService:
"""
Service for running snippet workflow executions.
Adapts CustomizedSnippet to work with the existing App-based
WorkflowAppGenerator infrastructure, avoiding duplication of the
complex workflow execution pipeline.
"""
# Specific ID for the injected virtual Start node so it can be recognised
_VIRTUAL_START_NODE_ID = "__snippet_virtual_start__"
@classmethod
def generate(
cls,
snippet: CustomizedSnippet,
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
"""
Run a snippet's draft workflow.
Retrieves the draft workflow, adapts the snippet to an App-like proxy,
then delegates execution to WorkflowAppGenerator.
If the workflow graph has no Start node, a virtual Start node is injected
in-memory so that:
1. Graph validation passes (root node must have execution_type=ROOT).
2. User inputs are processed into the variable pool by the StartNode logic.
:param snippet: CustomizedSnippet instance
:param user: Account or EndUser initiating the run
:param args: Workflow inputs (must include "inputs" key)
:param invoke_from: Source of invocation (typically DEBUGGER)
:param streaming: Whether to stream the response
:return: Blocking response mapping or SSE streaming generator
:raises ValueError: If the snippet has no draft workflow
"""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
raise ValueError("Workflow not initialized")
# Inject a virtual Start node when the graph doesn't have one.
workflow = cls._ensure_start_node(workflow, snippet)
# Adapt snippet to App-like interface for WorkflowAppGenerator
app_proxy = _SnippetAsApp(snippet)
return WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().generate(
app_model=app_proxy, # type: ignore[arg-type]
workflow=workflow,
user=user,
args=args,
invoke_from=invoke_from,
streaming=streaming,
call_depth=0,
)
)
@classmethod
def _ensure_start_node(cls, workflow: Workflow, snippet: CustomizedSnippet) -> Workflow:
"""
Return *workflow* with a Start node.
If the graph already contains a Start node, the original workflow is
returned unchanged. Otherwise a virtual Start node is injected and the
workflow object is detached from the SQLAlchemy session so the in-memory
change is never flushed to the database.
"""
graph_dict = workflow.graph_dict
nodes: list[dict[str, Any]] = graph_dict.get("nodes", [])
has_start = any(node.get("data", {}).get("type") == "start" for node in nodes)
if has_start:
return workflow
modified_graph = cls._inject_virtual_start_node(
graph_dict=graph_dict,
input_fields=snippet.input_fields_list,
)
# Detach from session to prevent accidental DB persistence of the
# modified graph. All attributes remain accessible for read.
make_transient(workflow)
workflow.graph = json.dumps(modified_graph)
return workflow
@classmethod
def _inject_virtual_start_node(
cls,
graph_dict: Mapping[str, Any],
input_fields: list[dict[str, Any]],
) -> dict[str, Any]:
"""
Build a new graph dict with a virtual Start node prepended.
The virtual Start node is wired to every existing node that has no
incoming edges (i.e. the current root candidates). This guarantees:
:param graph_dict: Original graph configuration.
:param input_fields: Snippet input field definitions from
``CustomizedSnippet.input_fields_list``.
:return: New graph dict containing the virtual Start node and edges.
"""
nodes: list[dict[str, Any]] = list(graph_dict.get("nodes", []))
edges: list[dict[str, Any]] = list(graph_dict.get("edges", []))
# Identify nodes with no incoming edges.
nodes_with_incoming: set[str] = set()
for edge in edges:
target = edge.get("target")
if isinstance(target, str):
nodes_with_incoming.add(target)
root_candidate_ids = [n["id"] for n in nodes if n["id"] not in nodes_with_incoming]
# Build Start node ``variables`` from snippet input fields.
start_variables: list[dict[str, Any]] = []
for field in input_fields:
var: dict[str, Any] = {
"variable": field.get("variable", ""),
"label": field.get("label", field.get("variable", "")),
"type": field.get("type", "text-input"),
"required": field.get("required", False),
"options": field.get("options", []),
}
if field.get("max_length") is not None:
var["max_length"] = field["max_length"]
start_variables.append(var)
virtual_start_node: dict[str, Any] = {
"id": cls._VIRTUAL_START_NODE_ID,
"data": {
"type": "start",
"title": "Start",
"variables": start_variables,
},
}
# Create edges from virtual Start to each root candidate.
new_edges: list[dict[str, Any]] = [
{
"source": cls._VIRTUAL_START_NODE_ID,
"sourceHandle": "source",
"target": root_id,
"targetHandle": "target",
}
for root_id in root_candidate_ids
]
return {
**graph_dict,
"nodes": [virtual_start_node, *nodes],
"edges": [*edges, *new_edges],
}
@classmethod
def run_draft_node(
cls,
snippet: CustomizedSnippet,
node_id: str,
user_inputs: Mapping[str, Any],
account: Account,
query: str = "",
files: Sequence[File] | None = None,
) -> WorkflowNodeExecutionModel:
"""
Run a single node in a snippet's draft workflow (single-step debugging).
Retrieves the draft workflow, adapts the snippet to an App-like proxy,
parses file inputs, then delegates to WorkflowService.run_draft_workflow_node.
:param snippet: CustomizedSnippet instance
:param node_id: ID of the node to run
:param user_inputs: User input values for the node
:param account: Account initiating the run
:param query: Optional query string
:param files: Optional parsed file objects
:return: WorkflowNodeExecutionModel with execution results
:raises ValueError: If the snippet has no draft workflow
"""
snippet_service = SnippetService()
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not draft_workflow:
raise ValueError("Workflow not initialized")
app_proxy = _SnippetAsApp(snippet)
workflow_service = WorkflowService()
return workflow_service.run_draft_workflow_node(
app_model=app_proxy, # type: ignore[arg-type]
draft_workflow=draft_workflow,
node_id=node_id,
user_inputs=user_inputs,
account=account,
query=query,
files=files,
)
@classmethod
def generate_single_iteration(
cls,
snippet: CustomizedSnippet,
user: Union[Account, EndUser],
node_id: str,
args: Mapping[str, Any],
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
"""
Run a single iteration node in a snippet's draft workflow.
Iteration nodes are container nodes that execute their sub-graph multiple
times, producing many events. Therefore, this uses the full WorkflowAppGenerator
pipeline with SSE streaming (unlike regular single-step node run).
:param snippet: CustomizedSnippet instance
:param user: Account or EndUser initiating the run
:param node_id: ID of the iteration node to run
:param args: Dict containing 'inputs' key with iteration input data
:param streaming: Whether to stream the response (should be True)
:return: SSE streaming generator
:raises ValueError: If the snippet has no draft workflow
"""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
raise ValueError("Workflow not initialized")
app_proxy = _SnippetAsApp(snippet)
return WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_iteration_generate(
app_model=app_proxy, # type: ignore[arg-type]
workflow=workflow,
node_id=node_id,
user=user,
args=args,
streaming=streaming,
)
)
@classmethod
def generate_single_loop(
cls,
snippet: CustomizedSnippet,
user: Union[Account, EndUser],
node_id: str,
args: Any,
streaming: bool = True,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
"""
Run a single loop node in a snippet's draft workflow.
Loop nodes are container nodes that execute their sub-graph repeatedly,
producing many events. Therefore, this uses the full WorkflowAppGenerator
pipeline with SSE streaming (unlike regular single-step node run).
:param snippet: CustomizedSnippet instance
:param user: Account or EndUser initiating the run
:param node_id: ID of the loop node to run
:param args: Pydantic model with 'inputs' attribute containing loop input data
:param streaming: Whether to stream the response (should be True)
:return: SSE streaming generator
:raises ValueError: If the snippet has no draft workflow
"""
snippet_service = SnippetService()
workflow = snippet_service.get_draft_workflow(snippet=snippet)
if not workflow:
raise ValueError("Workflow not initialized")
app_proxy = _SnippetAsApp(snippet)
return WorkflowAppGenerator.convert_to_event_stream(
WorkflowAppGenerator().single_loop_generate(
app_model=app_proxy, # type: ignore[arg-type]
workflow=workflow,
node_id=node_id,
user=user,
args=args, # type: ignore[arg-type]
streaming=streaming,
)
)
@staticmethod
def parse_files(workflow: Workflow, files: list[dict] | None = None) -> Sequence[File]:
"""
Parse file mappings into File objects based on workflow configuration.
:param workflow: Workflow instance for file upload config
:param files: Raw file mapping dicts
:return: Parsed File objects
"""
files = files or []
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
if file_extra_config is None:
return []
return file_factory.build_from_mappings(
mappings=files,
tenant_id=workflow.tenant_id,
config=file_extra_config,
)

View File

@@ -0,0 +1,566 @@
import json
import logging
from collections.abc import Mapping, Sequence
from datetime import UTC, datetime
from typing import Any
from sqlalchemy import func, select
from sqlalchemy.orm import Session, sessionmaker
from core.variables.variables import VariableBase
from core.workflow.enums import NodeType
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
from models.enums import WorkflowRunTriggeredFrom
from models.snippet import CustomizedSnippet, SnippetType
from models.workflow import (
Workflow,
WorkflowNodeExecutionModel,
WorkflowRun,
WorkflowType,
)
from repositories.factory import DifyAPIRepositoryFactory
from services.errors.app import WorkflowHashNotEqualError
logger = logging.getLogger(__name__)
class SnippetService:
"""Service for managing customized snippets."""
def __init__(self, session_maker: sessionmaker | None = None):
"""Initialize SnippetService with repository dependencies."""
if session_maker is None:
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
session_maker
)
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
# --- CRUD Operations ---
@staticmethod
def get_snippets(
*,
tenant_id: str,
page: int = 1,
limit: int = 20,
keyword: str | None = None,
) -> tuple[Sequence[CustomizedSnippet], int, bool]:
"""
Get paginated list of snippets with optional search.
:param tenant_id: Tenant ID
:param page: Page number (1-indexed)
:param limit: Number of items per page
:param keyword: Optional search keyword for name/description
:return: Tuple of (snippets list, total count, has_more flag)
"""
stmt = (
select(CustomizedSnippet)
.where(CustomizedSnippet.tenant_id == tenant_id)
.order_by(CustomizedSnippet.created_at.desc())
)
if keyword:
stmt = stmt.where(
CustomizedSnippet.name.ilike(f"%{keyword}%") | CustomizedSnippet.description.ilike(f"%{keyword}%")
)
# Get total count
count_stmt = select(func.count()).select_from(stmt.subquery())
total = db.session.scalar(count_stmt) or 0
# Apply pagination
stmt = stmt.limit(limit + 1).offset((page - 1) * limit)
snippets = list(db.session.scalars(stmt).all())
has_more = len(snippets) > limit
if has_more:
snippets = snippets[:-1]
return snippets, total, has_more
@staticmethod
def get_snippet_by_id(
*,
snippet_id: str,
tenant_id: str,
) -> CustomizedSnippet | None:
"""
Get snippet by ID with tenant isolation.
:param snippet_id: Snippet ID
:param tenant_id: Tenant ID
:return: CustomizedSnippet or None
"""
return (
db.session.query(CustomizedSnippet)
.where(
CustomizedSnippet.id == snippet_id,
CustomizedSnippet.tenant_id == tenant_id,
)
.first()
)
@staticmethod
def create_snippet(
*,
tenant_id: str,
name: str,
description: str | None,
snippet_type: SnippetType,
icon_info: dict | None,
graph: dict | None,
input_fields: list[dict] | None,
account: Account,
) -> CustomizedSnippet:
"""
Create a new snippet.
:param tenant_id: Tenant ID
:param name: Snippet name (must be unique per tenant)
:param description: Snippet description
:param snippet_type: Type of snippet (node or group)
:param icon_info: Icon information
:param graph: Workflow graph structure
:param input_fields: Input field definitions
:param account: Creator account
:return: Created CustomizedSnippet
:raises ValueError: If name already exists
"""
# Check if name already exists for this tenant
existing = (
db.session.query(CustomizedSnippet)
.where(
CustomizedSnippet.tenant_id == tenant_id,
CustomizedSnippet.name == name,
)
.first()
)
if existing:
raise ValueError(f"Snippet with name '{name}' already exists")
snippet = CustomizedSnippet(
tenant_id=tenant_id,
name=name,
description=description or "",
type=snippet_type.value,
icon_info=icon_info,
graph=json.dumps(graph) if graph else None,
input_fields=json.dumps(input_fields) if input_fields else None,
created_by=account.id,
)
db.session.add(snippet)
db.session.commit()
return snippet
@staticmethod
def update_snippet(
*,
session: Session,
snippet: CustomizedSnippet,
account_id: str,
data: dict,
) -> CustomizedSnippet:
"""
Update snippet attributes.
:param session: Database session
:param snippet: Snippet to update
:param account_id: ID of account making the update
:param data: Dictionary of fields to update
:return: Updated CustomizedSnippet
"""
if "name" in data:
# Check if new name already exists for this tenant
existing = (
session.query(CustomizedSnippet)
.where(
CustomizedSnippet.tenant_id == snippet.tenant_id,
CustomizedSnippet.name == data["name"],
CustomizedSnippet.id != snippet.id,
)
.first()
)
if existing:
raise ValueError(f"Snippet with name '{data['name']}' already exists")
snippet.name = data["name"]
if "description" in data:
snippet.description = data["description"]
if "icon_info" in data:
snippet.icon_info = data["icon_info"]
snippet.updated_by = account_id
snippet.updated_at = datetime.now(UTC).replace(tzinfo=None)
session.add(snippet)
return snippet
@staticmethod
def delete_snippet(
*,
session: Session,
snippet: CustomizedSnippet,
) -> bool:
"""
Delete a snippet.
:param session: Database session
:param snippet: Snippet to delete
:return: True if deleted successfully
"""
session.delete(snippet)
return True
# --- Workflow Operations ---
def get_draft_workflow(self, snippet: CustomizedSnippet) -> Workflow | None:
"""
Get draft workflow for snippet.
:param snippet: CustomizedSnippet instance
:return: Draft Workflow or None
"""
workflow = (
db.session.query(Workflow)
.where(
Workflow.tenant_id == snippet.tenant_id,
Workflow.app_id == snippet.id,
Workflow.type == WorkflowType.SNIPPET.value,
Workflow.version == "draft",
)
.first()
)
return workflow
def get_published_workflow(self, snippet: CustomizedSnippet) -> Workflow | None:
"""
Get published workflow for snippet.
:param snippet: CustomizedSnippet instance
:return: Published Workflow or None
"""
if not snippet.workflow_id:
return None
workflow = (
db.session.query(Workflow)
.where(
Workflow.tenant_id == snippet.tenant_id,
Workflow.app_id == snippet.id,
Workflow.type == WorkflowType.SNIPPET.value,
Workflow.id == snippet.workflow_id,
)
.first()
)
return workflow
def sync_draft_workflow(
self,
*,
snippet: CustomizedSnippet,
graph: dict,
unique_hash: str | None,
account: Account,
environment_variables: Sequence[VariableBase],
conversation_variables: Sequence[VariableBase],
input_variables: list[dict] | None = None,
) -> Workflow:
"""
Sync draft workflow for snippet.
:param snippet: CustomizedSnippet instance
:param graph: Workflow graph configuration
:param unique_hash: Hash for conflict detection
:param account: Account making the change
:param environment_variables: Environment variables
:param conversation_variables: Conversation variables
:param input_variables: Input variables for snippet
:return: Synced Workflow
:raises WorkflowHashNotEqualError: If hash mismatch
"""
workflow = self.get_draft_workflow(snippet=snippet)
if workflow and workflow.unique_hash != unique_hash:
raise WorkflowHashNotEqualError()
# Create draft workflow if not found
if not workflow:
workflow = Workflow(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
features="{}",
type=WorkflowType.SNIPPET.value,
version="draft",
graph=json.dumps(graph),
created_by=account.id,
environment_variables=environment_variables,
conversation_variables=conversation_variables,
)
db.session.add(workflow)
db.session.flush()
else:
# Update existing draft workflow
workflow.graph = json.dumps(graph)
workflow.updated_by = account.id
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
workflow.environment_variables = environment_variables
workflow.conversation_variables = conversation_variables
# Update snippet's input_fields if provided
if input_variables is not None:
snippet.input_fields = json.dumps(input_variables)
snippet.updated_by = account.id
snippet.updated_at = datetime.now(UTC).replace(tzinfo=None)
db.session.commit()
return workflow
def publish_workflow(
self,
*,
session: Session,
snippet: CustomizedSnippet,
account: Account,
) -> Workflow:
"""
Publish the draft workflow as a new version.
:param session: Database session
:param snippet: CustomizedSnippet instance
:param account: Account making the change
:return: Published Workflow
:raises ValueError: If no draft workflow exists
"""
draft_workflow_stmt = select(Workflow).where(
Workflow.tenant_id == snippet.tenant_id,
Workflow.app_id == snippet.id,
Workflow.type == WorkflowType.SNIPPET.value,
Workflow.version == "draft",
)
draft_workflow = session.scalar(draft_workflow_stmt)
if not draft_workflow:
raise ValueError("No valid workflow found.")
# Create new published workflow
workflow = Workflow.new(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
type=draft_workflow.type,
version=str(datetime.now(UTC).replace(tzinfo=None)),
graph=draft_workflow.graph,
features=draft_workflow.features,
created_by=account.id,
environment_variables=draft_workflow.environment_variables,
conversation_variables=draft_workflow.conversation_variables,
marked_name="",
marked_comment="",
)
session.add(workflow)
# Update snippet version
snippet.version += 1
snippet.is_published = True
snippet.workflow_id = workflow.id
snippet.updated_by = account.id
session.add(snippet)
return workflow
def get_all_published_workflows(
self,
*,
session: Session,
snippet: CustomizedSnippet,
page: int,
limit: int,
) -> tuple[Sequence[Workflow], bool]:
"""
Get all published workflow versions for snippet.
:param session: Database session
:param snippet: CustomizedSnippet instance
:param page: Page number
:param limit: Items per page
:return: Tuple of (workflows list, has_more flag)
"""
if not snippet.workflow_id:
return [], False
stmt = (
select(Workflow)
.where(
Workflow.app_id == snippet.id,
Workflow.type == WorkflowType.SNIPPET.value,
Workflow.version != "draft",
)
.order_by(Workflow.version.desc())
.limit(limit + 1)
.offset((page - 1) * limit)
)
workflows = list(session.scalars(stmt).all())
has_more = len(workflows) > limit
if has_more:
workflows = workflows[:-1]
return workflows, has_more
# --- Default Block Configs ---
def get_default_block_configs(self) -> list[dict]:
"""
Get default block configurations for all node types.
:return: List of default configurations
"""
default_block_configs: list[dict[str, Any]] = []
for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
node_class = node_class_mapping[LATEST_VERSION]
default_config = node_class.get_default_config()
if default_config:
default_block_configs.append(dict(default_config))
return default_block_configs
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None:
"""
Get default config for specific node type.
:param node_type: Node type string
:param filters: Optional filters
:return: Default configuration or None
"""
node_type_enum = NodeType(node_type)
if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
return None
node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
default_config = node_class.get_default_config(filters=filters)
if not default_config:
return None
return default_config
# --- Workflow Run Operations ---
def get_snippet_workflow_runs(
self,
*,
snippet: CustomizedSnippet,
args: dict,
) -> InfiniteScrollPagination:
"""
Get paginated workflow runs for snippet.
:param snippet: CustomizedSnippet instance
:param args: Request arguments (last_id, limit)
:return: InfiniteScrollPagination result
"""
limit = int(args.get("limit", 20))
last_id = args.get("last_id")
triggered_from_values = [
WorkflowRunTriggeredFrom.DEBUGGING,
]
return self._workflow_run_repo.get_paginated_workflow_runs(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
triggered_from=triggered_from_values,
limit=limit,
last_id=last_id,
)
def get_snippet_workflow_run(
self,
*,
snippet: CustomizedSnippet,
run_id: str,
) -> WorkflowRun | None:
"""
Get workflow run details.
:param snippet: CustomizedSnippet instance
:param run_id: Workflow run ID
:return: WorkflowRun or None
"""
return self._workflow_run_repo.get_workflow_run_by_id(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
run_id=run_id,
)
def get_snippet_workflow_run_node_executions(
self,
*,
snippet: CustomizedSnippet,
run_id: str,
) -> Sequence[WorkflowNodeExecutionModel]:
"""
Get workflow run node execution list.
:param snippet: CustomizedSnippet instance
:param run_id: Workflow run ID
:return: List of WorkflowNodeExecutionModel
"""
workflow_run = self.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
if not workflow_run:
return []
node_executions = self._node_execution_service_repo.get_executions_by_workflow_run(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
workflow_run_id=workflow_run.id,
)
return node_executions
# --- Node Execution Operations ---
def get_snippet_node_last_run(
self,
*,
snippet: CustomizedSnippet,
workflow: Workflow,
node_id: str,
) -> WorkflowNodeExecutionModel | None:
"""
Get the most recent execution for a specific node in a snippet workflow.
:param snippet: CustomizedSnippet instance
:param workflow: Workflow instance
:param node_id: Node identifier
:return: WorkflowNodeExecutionModel or None
"""
return self._node_execution_service_repo.get_node_last_execution(
tenant_id=snippet.tenant_id,
app_id=snippet.id,
workflow_id=workflow.id,
node_id=node_id,
)
# --- Use Count ---
@staticmethod
def increment_use_count(
*,
session: Session,
snippet: CustomizedSnippet,
) -> None:
"""
Increment the use_count when snippet is used.
:param session: Database session
:param snippet: CustomizedSnippet instance
"""
snippet.use_count += 1
session.add(snippet)

View File

@@ -277,7 +277,10 @@ describe('App Card Operations Flow', () => {
}
})
// -- Basic rendering --
afterEach(() => {
vi.restoreAllMocks()
})
describe('Card Rendering', () => {
it('should render app name and description', () => {
renderAppCard({ name: 'My AI Bot', description: 'An intelligent assistant' })

View File

@@ -187,7 +187,10 @@ describe('App List Browsing Flow', () => {
mockShowTagManagementModal = false
})
// -- Loading and Empty states --
afterEach(() => {
vi.restoreAllMocks()
})
describe('Loading and Empty States', () => {
it('should show skeleton cards during initial loading', () => {
mockIsLoading = true

View File

@@ -237,7 +237,6 @@ describe('Create App Flow', () => {
mockShowTagManagementModal = false
})
// -- NewAppCard rendering --
describe('NewAppCard Rendering', () => {
it('should render the "Create App" card with all options', () => {
renderList()

View File

@@ -12,7 +12,6 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
import DevelopMain from '@/app/components/develop'
import { AppModeEnum, Theme } from '@/types/app'
// ---------- fake timers ----------
beforeEach(() => {
vi.useFakeTimers({ shouldAdvanceTime: true })
})
@@ -28,8 +27,6 @@ async function flushUI() {
})
}
// ---------- store mock ----------
let storeAppDetail: unknown
vi.mock('@/app/components/app/store', () => ({
@@ -38,8 +35,6 @@ vi.mock('@/app/components/app/store', () => ({
},
}))
// ---------- Doc dependencies ----------
vi.mock('@/context/i18n', () => ({
useLocale: () => 'en-US',
}))
@@ -48,11 +43,12 @@ vi.mock('@/hooks/use-theme', () => ({
default: () => ({ theme: Theme.light }),
}))
vi.mock('@/i18n-config/language', () => ({
LanguagesSupported: ['en-US', 'zh-Hans', 'zh-Hant', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP'],
}))
// ---------- SecretKeyModal dependencies ----------
vi.mock('@/i18n-config/language', async (importOriginal) => {
const actual = await importOriginal<typeof import('@/i18n-config/language')>()
return {
...actual,
}
})
vi.mock('@/context/app-context', () => ({
useAppContext: () => ({

View File

@@ -11,7 +11,7 @@ import { RETRIEVE_METHOD } from '@/types/app'
import Item from './index'
vi.mock('../settings-modal', () => ({
default: ({ onSave, onCancel, currentDataset }: any) => (
default: ({ onSave, onCancel, currentDataset }: { currentDataset: DataSet, onCancel: () => void, onSave: (newDataset: DataSet) => void }) => (
<div>
<div>Mock settings modal</div>
<button onClick={() => onSave({ ...currentDataset, name: 'Updated dataset' })}>Save changes</button>
@@ -177,7 +177,7 @@ describe('dataset-config/card-item', () => {
expect(screen.getByRole('dialog')).toBeVisible()
})
await user.click(screen.getByText('Save changes'))
fireEvent.click(screen.getByText('Save changes'))
await waitFor(() => {
expect(onSave).toHaveBeenCalledWith(expect.objectContaining({ name: 'Updated dataset' }))

View File

@@ -53,6 +53,10 @@ vi.mock('@/hooks/use-theme', () => ({
vi.mock('@/i18n-config/language', () => ({
LanguagesSupported: ['en-US', 'zh-Hans', 'zh-Hant', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP'],
getDocLanguage: (locale: string) => {
const map: Record<string, string> = { 'zh-Hans': 'zh', 'ja-JP': 'ja' }
return map[locale] || 'en'
},
}))
describe('Doc', () => {
@@ -63,7 +67,7 @@ describe('Doc', () => {
prompt_variables: variables,
},
},
})
}) as unknown as Parameters<typeof Doc>[0]['appDetail']
beforeEach(() => {
vi.clearAllMocks()
@@ -123,13 +127,13 @@ describe('Doc', () => {
describe('null/undefined appDetail', () => {
it('should render nothing when appDetail has no mode', () => {
render(<Doc appDetail={{}} />)
render(<Doc appDetail={{} as unknown as Parameters<typeof Doc>[0]['appDetail']} />)
expect(screen.queryByTestId('template-completion-en')).not.toBeInTheDocument()
expect(screen.queryByTestId('template-chat-en')).not.toBeInTheDocument()
})
it('should render nothing when appDetail is null', () => {
render(<Doc appDetail={null} />)
render(<Doc appDetail={null as unknown as Parameters<typeof Doc>[0]['appDetail']} />)
expect(screen.queryByTestId('template-completion-en')).not.toBeInTheDocument()
})
})

View File

@@ -0,0 +1,199 @@
import type { TocItem } from '../hooks/use-doc-toc'
import { fireEvent, render, screen } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import TocPanel from '../toc-panel'
/**
* Unit tests for the TocPanel presentational component.
* Covers collapsed/expanded states, item rendering, active section, and callbacks.
*/
describe('TocPanel', () => {
const defaultProps = {
toc: [] as TocItem[],
activeSection: '',
isTocExpanded: false,
onToggle: vi.fn(),
onItemClick: vi.fn(),
}
const sampleToc: TocItem[] = [
{ href: '#introduction', text: 'Introduction' },
{ href: '#authentication', text: 'Authentication' },
{ href: '#endpoints', text: 'Endpoints' },
]
beforeEach(() => {
vi.clearAllMocks()
})
// Covers collapsed state rendering
describe('collapsed state', () => {
it('should render expand button when collapsed', () => {
render(<TocPanel {...defaultProps} />)
expect(screen.getByLabelText('Open table of contents')).toBeInTheDocument()
})
it('should not render nav or toc items when collapsed', () => {
render(<TocPanel {...defaultProps} toc={sampleToc} />)
expect(screen.queryByRole('navigation')).not.toBeInTheDocument()
expect(screen.queryByText('Introduction')).not.toBeInTheDocument()
})
it('should call onToggle(true) when expand button is clicked', () => {
const onToggle = vi.fn()
render(<TocPanel {...defaultProps} onToggle={onToggle} />)
fireEvent.click(screen.getByLabelText('Open table of contents'))
expect(onToggle).toHaveBeenCalledWith(true)
})
})
// Covers expanded state with empty toc
describe('expanded state - empty', () => {
it('should render nav with empty message when toc is empty', () => {
render(<TocPanel {...defaultProps} isTocExpanded />)
expect(screen.getByRole('navigation')).toBeInTheDocument()
expect(screen.getByText('appApi.develop.noContent')).toBeInTheDocument()
})
it('should render TOC header with title', () => {
render(<TocPanel {...defaultProps} isTocExpanded />)
expect(screen.getByText('appApi.develop.toc')).toBeInTheDocument()
})
it('should call onToggle(false) when close button is clicked', () => {
const onToggle = vi.fn()
render(<TocPanel {...defaultProps} isTocExpanded onToggle={onToggle} />)
fireEvent.click(screen.getByLabelText('Close'))
expect(onToggle).toHaveBeenCalledWith(false)
})
})
// Covers expanded state with toc items
describe('expanded state - with items', () => {
it('should render all toc items as links', () => {
render(<TocPanel {...defaultProps} isTocExpanded toc={sampleToc} />)
expect(screen.getByText('Introduction')).toBeInTheDocument()
expect(screen.getByText('Authentication')).toBeInTheDocument()
expect(screen.getByText('Endpoints')).toBeInTheDocument()
})
it('should render links with correct href attributes', () => {
render(<TocPanel {...defaultProps} isTocExpanded toc={sampleToc} />)
const links = screen.getAllByRole('link')
expect(links).toHaveLength(3)
expect(links[0]).toHaveAttribute('href', '#introduction')
expect(links[1]).toHaveAttribute('href', '#authentication')
expect(links[2]).toHaveAttribute('href', '#endpoints')
})
it('should not render empty message when toc has items', () => {
render(<TocPanel {...defaultProps} isTocExpanded toc={sampleToc} />)
expect(screen.queryByText('appApi.develop.noContent')).not.toBeInTheDocument()
})
})
// Covers active section highlighting
describe('active section', () => {
it('should apply active style to the matching toc item', () => {
render(
<TocPanel {...defaultProps} isTocExpanded toc={sampleToc} activeSection="authentication" />,
)
const activeLink = screen.getByText('Authentication').closest('a')
expect(activeLink?.className).toContain('font-medium')
expect(activeLink?.className).toContain('text-text-primary')
})
it('should apply inactive style to non-matching items', () => {
render(
<TocPanel {...defaultProps} isTocExpanded toc={sampleToc} activeSection="authentication" />,
)
const inactiveLink = screen.getByText('Introduction').closest('a')
expect(inactiveLink?.className).toContain('text-text-tertiary')
expect(inactiveLink?.className).not.toContain('font-medium')
})
it('should apply active indicator dot to active item', () => {
render(
<TocPanel {...defaultProps} isTocExpanded toc={sampleToc} activeSection="endpoints" />,
)
const activeLink = screen.getByText('Endpoints').closest('a')
const activeDot = activeLink?.querySelector('span:first-child')
expect(activeDot?.className).toContain('bg-text-accent')
})
})
// Covers click event delegation
describe('item click handling', () => {
it('should call onItemClick with the event and item when a link is clicked', () => {
const onItemClick = vi.fn()
render(
<TocPanel {...defaultProps} isTocExpanded toc={sampleToc} onItemClick={onItemClick} />,
)
fireEvent.click(screen.getByText('Authentication'))
expect(onItemClick).toHaveBeenCalledTimes(1)
expect(onItemClick).toHaveBeenCalledWith(
expect.any(Object),
{ href: '#authentication', text: 'Authentication' },
)
})
it('should call onItemClick for each clicked item independently', () => {
const onItemClick = vi.fn()
render(
<TocPanel {...defaultProps} isTocExpanded toc={sampleToc} onItemClick={onItemClick} />,
)
fireEvent.click(screen.getByText('Introduction'))
fireEvent.click(screen.getByText('Endpoints'))
expect(onItemClick).toHaveBeenCalledTimes(2)
})
})
// Covers edge cases
describe('edge cases', () => {
it('should handle single item toc', () => {
const singleItem = [{ href: '#only', text: 'Only Section' }]
render(<TocPanel {...defaultProps} isTocExpanded toc={singleItem} activeSection="only" />)
expect(screen.getByText('Only Section')).toBeInTheDocument()
expect(screen.getAllByRole('link')).toHaveLength(1)
})
it('should handle toc items with empty text', () => {
const emptyTextItem = [{ href: '#empty', text: '' }]
render(<TocPanel {...defaultProps} isTocExpanded toc={emptyTextItem} />)
expect(screen.getAllByRole('link')).toHaveLength(1)
})
it('should handle active section that does not match any item', () => {
render(
<TocPanel {...defaultProps} isTocExpanded toc={sampleToc} activeSection="nonexistent" />,
)
// All items should be in inactive style
const links = screen.getAllByRole('link')
links.forEach((link) => {
expect(link.className).toContain('text-text-tertiary')
expect(link.className).not.toContain('font-medium')
})
})
})
})

View File

@@ -0,0 +1,425 @@
import { act, renderHook } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { useDocToc } from '../hooks/use-doc-toc'
/**
* Unit tests for the useDocToc custom hook.
* Covers TOC extraction, viewport-based expansion, scroll tracking, and click handling.
*/
describe('useDocToc', () => {
const defaultOptions = { appDetail: { mode: 'chat' }, locale: 'en-US' }
beforeEach(() => {
vi.clearAllMocks()
vi.useRealTimers()
Object.defineProperty(window, 'matchMedia', {
writable: true,
value: vi.fn().mockReturnValue({ matches: false }),
})
})
// Covers initial state values based on viewport width
describe('initial state', () => {
it('should set isTocExpanded to false on narrow viewport', () => {
const { result } = renderHook(() => useDocToc(defaultOptions))
expect(result.current.isTocExpanded).toBe(false)
expect(result.current.toc).toEqual([])
expect(result.current.activeSection).toBe('')
})
it('should set isTocExpanded to true on wide viewport', () => {
Object.defineProperty(window, 'matchMedia', {
writable: true,
value: vi.fn().mockReturnValue({ matches: true }),
})
const { result } = renderHook(() => useDocToc(defaultOptions))
expect(result.current.isTocExpanded).toBe(true)
})
})
// Covers TOC extraction from DOM article headings
describe('TOC extraction', () => {
it('should extract toc items from article h2 anchors', async () => {
vi.useFakeTimers()
const article = document.createElement('article')
const h2 = document.createElement('h2')
const anchor = document.createElement('a')
anchor.href = '#section-1'
anchor.textContent = 'Section 1'
h2.appendChild(anchor)
article.appendChild(h2)
document.body.appendChild(article)
const { result } = renderHook(() => useDocToc(defaultOptions))
await act(async () => {
vi.runAllTimers()
})
expect(result.current.toc).toEqual([
{ href: '#section-1', text: 'Section 1' },
])
expect(result.current.activeSection).toBe('section-1')
document.body.removeChild(article)
vi.useRealTimers()
})
it('should return empty toc when no article exists', async () => {
vi.useFakeTimers()
const { result } = renderHook(() => useDocToc(defaultOptions))
await act(async () => {
vi.runAllTimers()
})
expect(result.current.toc).toEqual([])
expect(result.current.activeSection).toBe('')
vi.useRealTimers()
})
it('should skip h2 headings without anchors', async () => {
vi.useFakeTimers()
const article = document.createElement('article')
const h2NoAnchor = document.createElement('h2')
h2NoAnchor.textContent = 'No Anchor'
article.appendChild(h2NoAnchor)
const h2WithAnchor = document.createElement('h2')
const anchor = document.createElement('a')
anchor.href = '#valid'
anchor.textContent = 'Valid'
h2WithAnchor.appendChild(anchor)
article.appendChild(h2WithAnchor)
document.body.appendChild(article)
const { result } = renderHook(() => useDocToc(defaultOptions))
await act(async () => {
vi.runAllTimers()
})
expect(result.current.toc).toHaveLength(1)
expect(result.current.toc[0]).toEqual({ href: '#valid', text: 'Valid' })
document.body.removeChild(article)
vi.useRealTimers()
})
it('should re-extract toc when appDetail changes', async () => {
vi.useFakeTimers()
const article = document.createElement('article')
document.body.appendChild(article)
const { result, rerender } = renderHook(
props => useDocToc(props),
{ initialProps: defaultOptions },
)
await act(async () => {
vi.runAllTimers()
})
expect(result.current.toc).toEqual([])
// Add a heading, then change appDetail to trigger re-extraction
const h2 = document.createElement('h2')
const anchor = document.createElement('a')
anchor.href = '#new-section'
anchor.textContent = 'New Section'
h2.appendChild(anchor)
article.appendChild(h2)
rerender({ appDetail: { mode: 'workflow' }, locale: 'en-US' })
await act(async () => {
vi.runAllTimers()
})
expect(result.current.toc).toHaveLength(1)
document.body.removeChild(article)
vi.useRealTimers()
})
it('should re-extract toc when locale changes', async () => {
vi.useFakeTimers()
const article = document.createElement('article')
const h2 = document.createElement('h2')
const anchor = document.createElement('a')
anchor.href = '#sec'
anchor.textContent = 'Sec'
h2.appendChild(anchor)
article.appendChild(h2)
document.body.appendChild(article)
const { result, rerender } = renderHook(
props => useDocToc(props),
{ initialProps: defaultOptions },
)
await act(async () => {
vi.runAllTimers()
})
expect(result.current.toc).toHaveLength(1)
rerender({ appDetail: defaultOptions.appDetail, locale: 'zh-Hans' })
await act(async () => {
vi.runAllTimers()
})
// Should still have the toc item after re-extraction
expect(result.current.toc).toHaveLength(1)
document.body.removeChild(article)
vi.useRealTimers()
})
})
// Covers manual toggle via setIsTocExpanded
describe('setIsTocExpanded', () => {
it('should toggle isTocExpanded state', () => {
const { result } = renderHook(() => useDocToc(defaultOptions))
expect(result.current.isTocExpanded).toBe(false)
act(() => {
result.current.setIsTocExpanded(true)
})
expect(result.current.isTocExpanded).toBe(true)
act(() => {
result.current.setIsTocExpanded(false)
})
expect(result.current.isTocExpanded).toBe(false)
})
})
// Covers smooth-scroll click handler
describe('handleTocClick', () => {
it('should prevent default and scroll to target element', () => {
const scrollContainer = document.createElement('div')
scrollContainer.className = 'overflow-auto'
scrollContainer.scrollTo = vi.fn()
document.body.appendChild(scrollContainer)
const target = document.createElement('div')
target.id = 'target-section'
Object.defineProperty(target, 'offsetTop', { value: 500 })
scrollContainer.appendChild(target)
const { result } = renderHook(() => useDocToc(defaultOptions))
const mockEvent = { preventDefault: vi.fn() } as unknown as React.MouseEvent<HTMLAnchorElement>
act(() => {
result.current.handleTocClick(mockEvent, { href: '#target-section', text: 'Target' })
})
expect(mockEvent.preventDefault).toHaveBeenCalled()
expect(scrollContainer.scrollTo).toHaveBeenCalledWith({
top: 420, // 500 - 80 (HEADER_OFFSET)
behavior: 'smooth',
})
document.body.removeChild(scrollContainer)
})
it('should do nothing when target element does not exist', () => {
const { result } = renderHook(() => useDocToc(defaultOptions))
const mockEvent = { preventDefault: vi.fn() } as unknown as React.MouseEvent<HTMLAnchorElement>
act(() => {
result.current.handleTocClick(mockEvent, { href: '#nonexistent', text: 'Missing' })
})
expect(mockEvent.preventDefault).toHaveBeenCalled()
})
})
// Covers scroll-based active section tracking
describe('scroll tracking', () => {
// Helper: set up DOM with scroll container, article headings, and matching target elements
const setupScrollDOM = (sections: Array<{ id: string, text: string, top: number }>) => {
const scrollContainer = document.createElement('div')
scrollContainer.className = 'overflow-auto'
document.body.appendChild(scrollContainer)
const article = document.createElement('article')
sections.forEach(({ id, text, top }) => {
// Heading with anchor for TOC extraction
const h2 = document.createElement('h2')
const anchor = document.createElement('a')
anchor.href = `#${id}`
anchor.textContent = text
h2.appendChild(anchor)
article.appendChild(h2)
// Target element for scroll tracking
const target = document.createElement('div')
target.id = id
target.getBoundingClientRect = vi.fn().mockReturnValue({ top })
scrollContainer.appendChild(target)
})
document.body.appendChild(article)
return {
scrollContainer,
article,
cleanup: () => {
document.body.removeChild(scrollContainer)
document.body.removeChild(article)
},
}
}
it('should register scroll listener when toc has items', async () => {
vi.useFakeTimers()
const { scrollContainer, cleanup } = setupScrollDOM([
{ id: 'sec-a', text: 'Section A', top: 0 },
])
const addSpy = vi.spyOn(scrollContainer, 'addEventListener')
const removeSpy = vi.spyOn(scrollContainer, 'removeEventListener')
const { unmount } = renderHook(() => useDocToc(defaultOptions))
await act(async () => {
vi.runAllTimers()
})
expect(addSpy).toHaveBeenCalledWith('scroll', expect.any(Function))
unmount()
expect(removeSpy).toHaveBeenCalledWith('scroll', expect.any(Function))
cleanup()
vi.useRealTimers()
})
it('should update activeSection when scrolling past a section', async () => {
vi.useFakeTimers()
// innerHeight/2 = 384 in jsdom (default 768), so top <= 384 means "scrolled past"
const { scrollContainer, cleanup } = setupScrollDOM([
{ id: 'intro', text: 'Intro', top: 100 },
{ id: 'details', text: 'Details', top: 600 },
])
const { result } = renderHook(() => useDocToc(defaultOptions))
// Extract TOC items
await act(async () => {
vi.runAllTimers()
})
expect(result.current.toc).toHaveLength(2)
expect(result.current.activeSection).toBe('intro')
// Fire scroll — 'intro' (top=100) is above midpoint, 'details' (top=600) is below
await act(async () => {
scrollContainer.dispatchEvent(new Event('scroll'))
})
expect(result.current.activeSection).toBe('intro')
cleanup()
vi.useRealTimers()
})
it('should track the last section above the viewport midpoint', async () => {
vi.useFakeTimers()
const { scrollContainer, cleanup } = setupScrollDOM([
{ id: 'sec-1', text: 'Section 1', top: 50 },
{ id: 'sec-2', text: 'Section 2', top: 200 },
{ id: 'sec-3', text: 'Section 3', top: 800 },
])
const { result } = renderHook(() => useDocToc(defaultOptions))
await act(async () => {
vi.runAllTimers()
})
// Fire scroll — sec-1 (top=50) and sec-2 (top=200) are above midpoint (384),
// sec-3 (top=800) is below. The last one above midpoint wins.
await act(async () => {
scrollContainer.dispatchEvent(new Event('scroll'))
})
expect(result.current.activeSection).toBe('sec-2')
cleanup()
vi.useRealTimers()
})
it('should not update activeSection when no section is above midpoint', async () => {
vi.useFakeTimers()
const { scrollContainer, cleanup } = setupScrollDOM([
{ id: 'far-away', text: 'Far Away', top: 1000 },
])
const { result } = renderHook(() => useDocToc(defaultOptions))
await act(async () => {
vi.runAllTimers()
})
// Initial activeSection is set by extraction
const initialSection = result.current.activeSection
await act(async () => {
scrollContainer.dispatchEvent(new Event('scroll'))
})
// Should not change since the element is below midpoint
expect(result.current.activeSection).toBe(initialSection)
cleanup()
vi.useRealTimers()
})
it('should handle elements not found in DOM during scroll', async () => {
vi.useFakeTimers()
const scrollContainer = document.createElement('div')
scrollContainer.className = 'overflow-auto'
document.body.appendChild(scrollContainer)
// Article with heading but NO matching target element by id
const article = document.createElement('article')
const h2 = document.createElement('h2')
const anchor = document.createElement('a')
anchor.href = '#missing-target'
anchor.textContent = 'Missing'
h2.appendChild(anchor)
article.appendChild(h2)
document.body.appendChild(article)
const { result } = renderHook(() => useDocToc(defaultOptions))
await act(async () => {
vi.runAllTimers()
})
const initialSection = result.current.activeSection
// Scroll fires but getElementById returns null — no crash, no change
await act(async () => {
scrollContainer.dispatchEvent(new Event('scroll'))
})
expect(result.current.activeSection).toBe(initialSection)
document.body.removeChild(scrollContainer)
document.body.removeChild(article)
vi.useRealTimers()
})
})
})

View File

@@ -1,12 +1,13 @@
'use client'
import { RiCloseLine, RiListUnordered } from '@remixicon/react'
import { useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import type { ComponentType } from 'react'
import type { App, AppSSO } from '@/types/app'
import { useMemo } from 'react'
import { useLocale } from '@/context/i18n'
import useTheme from '@/hooks/use-theme'
import { LanguagesSupported } from '@/i18n-config/language'
import { getDocLanguage } from '@/i18n-config/language'
import { AppModeEnum, Theme } from '@/types/app'
import { cn } from '@/utils/classnames'
import { useDocToc } from './hooks/use-doc-toc'
import TemplateEn from './template/template.en.mdx'
import TemplateJa from './template/template.ja.mdx'
import TemplateZh from './template/template.zh.mdx'
@@ -19,225 +20,75 @@ import TemplateChatZh from './template/template_chat.zh.mdx'
import TemplateWorkflowEn from './template/template_workflow.en.mdx'
import TemplateWorkflowJa from './template/template_workflow.ja.mdx'
import TemplateWorkflowZh from './template/template_workflow.zh.mdx'
import TocPanel from './toc-panel'
type AppDetail = App & Partial<AppSSO>
type PromptVariable = { key: string, name: string }
type IDocProps = {
appDetail: any
appDetail: AppDetail
}
// Shared props shape for all MDX template components
type TemplateProps = {
appDetail: AppDetail
variables: PromptVariable[]
inputs: Record<string, string>
}
// Lookup table: [appMode][docLanguage] → template component
// MDX components accept arbitrary props at runtime but expose a narrow static type,
// so we assert the map type to allow passing TemplateProps when rendering.
const TEMPLATE_MAP = {
[AppModeEnum.CHAT]: { zh: TemplateChatZh, ja: TemplateChatJa, en: TemplateChatEn },
[AppModeEnum.AGENT_CHAT]: { zh: TemplateChatZh, ja: TemplateChatJa, en: TemplateChatEn },
[AppModeEnum.ADVANCED_CHAT]: { zh: TemplateAdvancedChatZh, ja: TemplateAdvancedChatJa, en: TemplateAdvancedChatEn },
[AppModeEnum.WORKFLOW]: { zh: TemplateWorkflowZh, ja: TemplateWorkflowJa, en: TemplateWorkflowEn },
[AppModeEnum.COMPLETION]: { zh: TemplateZh, ja: TemplateJa, en: TemplateEn },
} as Record<string, Record<string, ComponentType<TemplateProps>>>
const resolveTemplate = (mode: string | undefined, locale: string): ComponentType<TemplateProps> | null => {
if (!mode)
return null
const langTemplates = TEMPLATE_MAP[mode]
if (!langTemplates)
return null
const docLang = getDocLanguage(locale)
return langTemplates[docLang] ?? langTemplates.en ?? null
}
const Doc = ({ appDetail }: IDocProps) => {
const locale = useLocale()
const { t } = useTranslation()
const [toc, setToc] = useState<Array<{ href: string, text: string }>>([])
const [isTocExpanded, setIsTocExpanded] = useState(false)
const [activeSection, setActiveSection] = useState<string>('')
const { theme } = useTheme()
const { toc, isTocExpanded, setIsTocExpanded, activeSection, handleTocClick } = useDocToc({ appDetail, locale })
const variables = appDetail?.model_config?.configs?.prompt_variables || []
const inputs = variables.reduce((res: any, variable: any) => {
// model_config.configs.prompt_variables exists in the raw API response but is not modeled in ModelConfig type
const variables: PromptVariable[] = (
appDetail?.model_config as unknown as Record<string, Record<string, PromptVariable[]>> | undefined
)?.configs?.prompt_variables ?? []
const inputs = variables.reduce<Record<string, string>>((res, variable) => {
res[variable.key] = variable.name || ''
return res
}, {})
useEffect(() => {
const mediaQuery = window.matchMedia('(min-width: 1280px)')
setIsTocExpanded(mediaQuery.matches)
}, [])
useEffect(() => {
const extractTOC = () => {
const article = document.querySelector('article')
if (article) {
const headings = article.querySelectorAll('h2')
const tocItems = Array.from(headings).map((heading) => {
const anchor = heading.querySelector('a')
if (anchor) {
return {
href: anchor.getAttribute('href') || '',
text: anchor.textContent || '',
}
}
return null
}).filter((item): item is { href: string, text: string } => item !== null)
setToc(tocItems)
if (tocItems.length > 0)
setActiveSection(tocItems[0].href.replace('#', ''))
}
}
setTimeout(extractTOC, 0)
}, [appDetail, locale])
useEffect(() => {
const handleScroll = () => {
const scrollContainer = document.querySelector('.overflow-auto')
if (!scrollContainer || toc.length === 0)
return
let currentSection = ''
toc.forEach((item) => {
const targetId = item.href.replace('#', '')
const element = document.getElementById(targetId)
if (element) {
const rect = element.getBoundingClientRect()
if (rect.top <= window.innerHeight / 2)
currentSection = targetId
}
})
if (currentSection && currentSection !== activeSection)
setActiveSection(currentSection)
}
const scrollContainer = document.querySelector('.overflow-auto')
if (scrollContainer) {
scrollContainer.addEventListener('scroll', handleScroll)
handleScroll()
return () => scrollContainer.removeEventListener('scroll', handleScroll)
}
}, [toc, activeSection])
const handleTocClick = (e: React.MouseEvent<HTMLAnchorElement>, item: { href: string, text: string }) => {
e.preventDefault()
const targetId = item.href.replace('#', '')
const element = document.getElementById(targetId)
if (element) {
const scrollContainer = document.querySelector('.overflow-auto')
if (scrollContainer) {
const headerOffset = 80
const elementTop = element.offsetTop - headerOffset
scrollContainer.scrollTo({
top: elementTop,
behavior: 'smooth',
})
}
}
}
const Template = useMemo(() => {
if (appDetail?.mode === AppModeEnum.CHAT || appDetail?.mode === AppModeEnum.AGENT_CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateChatZh appDetail={appDetail} variables={variables} inputs={inputs} />
case LanguagesSupported[7]:
return <TemplateChatJa appDetail={appDetail} variables={variables} inputs={inputs} />
default:
return <TemplateChatEn appDetail={appDetail} variables={variables} inputs={inputs} />
}
}
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateAdvancedChatZh appDetail={appDetail} variables={variables} inputs={inputs} />
case LanguagesSupported[7]:
return <TemplateAdvancedChatJa appDetail={appDetail} variables={variables} inputs={inputs} />
default:
return <TemplateAdvancedChatEn appDetail={appDetail} variables={variables} inputs={inputs} />
}
}
if (appDetail?.mode === AppModeEnum.WORKFLOW) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateWorkflowZh appDetail={appDetail} variables={variables} inputs={inputs} />
case LanguagesSupported[7]:
return <TemplateWorkflowJa appDetail={appDetail} variables={variables} inputs={inputs} />
default:
return <TemplateWorkflowEn appDetail={appDetail} variables={variables} inputs={inputs} />
}
}
if (appDetail?.mode === AppModeEnum.COMPLETION) {
switch (locale) {
case LanguagesSupported[1]:
return <TemplateZh appDetail={appDetail} variables={variables} inputs={inputs} />
case LanguagesSupported[7]:
return <TemplateJa appDetail={appDetail} variables={variables} inputs={inputs} />
default:
return <TemplateEn appDetail={appDetail} variables={variables} inputs={inputs} />
}
}
return null
}, [appDetail, locale, variables, inputs])
const TemplateComponent = useMemo(
() => resolveTemplate(appDetail?.mode, locale),
[appDetail?.mode, locale],
)
return (
<div className="flex">
<div className={`fixed right-20 top-32 z-10 transition-all duration-150 ease-out ${isTocExpanded ? 'w-[280px]' : 'w-11'}`}>
{isTocExpanded
? (
<nav className="toc flex max-h-[calc(100vh-150px)] w-full flex-col overflow-hidden rounded-xl border-[0.5px] border-components-panel-border bg-background-default-hover shadow-xl">
<div className="relative z-10 flex items-center justify-between border-b border-components-panel-border-subtle bg-background-default-hover px-4 py-2.5">
<span className="text-xs font-medium uppercase tracking-wide text-text-tertiary">
{t('develop.toc', { ns: 'appApi' })}
</span>
<button
type="button"
onClick={() => setIsTocExpanded(false)}
className="group flex h-6 w-6 items-center justify-center rounded-md transition-colors hover:bg-state-base-hover"
aria-label="Close"
>
<RiCloseLine className="h-3 w-3 text-text-quaternary transition-colors group-hover:text-text-secondary" />
</button>
</div>
<div className="from-components-panel-border-subtle/20 pointer-events-none absolute left-0 right-0 top-[41px] z-10 h-2 bg-gradient-to-b to-transparent"></div>
<div className="pointer-events-none absolute left-0 right-0 top-[43px] z-10 h-3 bg-gradient-to-b from-background-default-hover to-transparent"></div>
<div className="relative flex-1 overflow-y-auto px-3 py-3 pt-1">
{toc.length === 0
? (
<div className="px-2 py-8 text-center text-xs text-text-quaternary">
{t('develop.noContent', { ns: 'appApi' })}
</div>
)
: (
<ul className="space-y-0.5">
{toc.map((item, index) => {
const isActive = activeSection === item.href.replace('#', '')
return (
<li key={index}>
<a
href={item.href}
onClick={e => handleTocClick(e, item)}
className={cn(
'group relative flex items-center rounded-md px-3 py-2 text-[13px] transition-all duration-200',
isActive
? 'bg-state-base-hover font-medium text-text-primary'
: 'text-text-tertiary hover:bg-state-base-hover hover:text-text-secondary',
)}
>
<span
className={cn(
'mr-2 h-1.5 w-1.5 rounded-full transition-all duration-200',
isActive
? 'scale-100 bg-text-accent'
: 'scale-75 bg-components-panel-border',
)}
/>
<span className="flex-1 truncate">
{item.text}
</span>
</a>
</li>
)
})}
</ul>
)}
</div>
<div className="pointer-events-none absolute bottom-0 left-0 right-0 z-10 h-4 rounded-b-xl bg-gradient-to-t from-background-default-hover to-transparent"></div>
</nav>
)
: (
<button
type="button"
onClick={() => setIsTocExpanded(true)}
className="group flex h-11 w-11 items-center justify-center rounded-full border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg transition-all duration-150 hover:bg-background-default-hover hover:shadow-xl"
aria-label="Open table of contents"
>
<RiListUnordered className="h-5 w-5 text-text-tertiary transition-colors group-hover:text-text-secondary" />
</button>
)}
<TocPanel
toc={toc}
activeSection={activeSection}
isTocExpanded={isTocExpanded}
onToggle={setIsTocExpanded}
onItemClick={handleTocClick}
/>
</div>
<article className={cn('prose-xl prose', theme === Theme.dark && 'prose-invert')}>
{Template}
{TemplateComponent && <TemplateComponent appDetail={appDetail} variables={variables} inputs={inputs} />}
</article>
</div>
)

View File

@@ -0,0 +1,115 @@
import { useCallback, useEffect, useState } from 'react'
export type TocItem = {
href: string
text: string
}
type UseDocTocOptions = {
appDetail: Record<string, unknown> | null
locale: string
}
const HEADER_OFFSET = 80
const SCROLL_CONTAINER_SELECTOR = '.overflow-auto'
const getTargetId = (href: string) => href.replace('#', '')
/**
* Extract heading anchors from the rendered <article> as TOC items.
*/
const extractTocFromArticle = (): TocItem[] => {
const article = document.querySelector('article')
if (!article)
return []
return Array.from(article.querySelectorAll('h2'))
.map((heading) => {
const anchor = heading.querySelector('a')
if (!anchor)
return null
return {
href: anchor.getAttribute('href') || '',
text: anchor.textContent || '',
}
})
.filter((item): item is TocItem => item !== null)
}
/**
* Custom hook that manages table-of-contents state:
* - Extracts TOC items from rendered headings
* - Tracks the active section on scroll
* - Auto-expands the panel on wide viewports
*/
export const useDocToc = ({ appDetail, locale }: UseDocTocOptions) => {
const [toc, setToc] = useState<TocItem[]>([])
const [isTocExpanded, setIsTocExpanded] = useState(() => {
if (typeof window === 'undefined')
return false
return window.matchMedia('(min-width: 1280px)').matches
})
const [activeSection, setActiveSection] = useState<string>('')
// Re-extract TOC items whenever the doc content changes
useEffect(() => {
const timer = setTimeout(() => {
const tocItems = extractTocFromArticle()
setToc(tocItems)
if (tocItems.length > 0)
setActiveSection(getTargetId(tocItems[0].href))
}, 0)
return () => clearTimeout(timer)
}, [appDetail, locale])
// Track active section based on scroll position
useEffect(() => {
const scrollContainer = document.querySelector(SCROLL_CONTAINER_SELECTOR)
if (!scrollContainer || toc.length === 0)
return
const handleScroll = () => {
let currentSection = ''
for (const item of toc) {
const targetId = getTargetId(item.href)
const element = document.getElementById(targetId)
if (element) {
const rect = element.getBoundingClientRect()
if (rect.top <= window.innerHeight / 2)
currentSection = targetId
}
}
if (currentSection && currentSection !== activeSection)
setActiveSection(currentSection)
}
scrollContainer.addEventListener('scroll', handleScroll)
return () => scrollContainer.removeEventListener('scroll', handleScroll)
}, [toc, activeSection])
// Smooth-scroll to a TOC target on click
const handleTocClick = useCallback((e: React.MouseEvent<HTMLAnchorElement>, item: TocItem) => {
e.preventDefault()
const targetId = getTargetId(item.href)
const element = document.getElementById(targetId)
if (!element)
return
const scrollContainer = document.querySelector(SCROLL_CONTAINER_SELECTOR)
if (scrollContainer) {
scrollContainer.scrollTo({
top: element.offsetTop - HEADER_OFFSET,
behavior: 'smooth',
})
}
}, [])
return {
toc,
isTocExpanded,
setIsTocExpanded,
activeSection,
handleTocClick,
}
}

View File

@@ -0,0 +1,96 @@
'use client'
import type { TocItem } from './hooks/use-doc-toc'
import { useTranslation } from 'react-i18next'
import { cn } from '@/utils/classnames'
type TocPanelProps = {
toc: TocItem[]
activeSection: string
isTocExpanded: boolean
onToggle: (expanded: boolean) => void
onItemClick: (e: React.MouseEvent<HTMLAnchorElement>, item: TocItem) => void
}
const TocPanel = ({ toc, activeSection, isTocExpanded, onToggle, onItemClick }: TocPanelProps) => {
const { t } = useTranslation()
if (!isTocExpanded) {
return (
<button
type="button"
onClick={() => onToggle(true)}
className="group flex h-11 w-11 items-center justify-center rounded-full border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg transition-all duration-150 hover:bg-background-default-hover hover:shadow-xl"
aria-label="Open table of contents"
>
<span className="i-ri-list-unordered h-5 w-5 text-text-tertiary transition-colors group-hover:text-text-secondary" />
</button>
)
}
return (
<nav className="toc flex max-h-[calc(100vh-150px)] w-full flex-col overflow-hidden rounded-xl border-[0.5px] border-components-panel-border bg-background-default-hover shadow-xl">
<div className="relative z-10 flex items-center justify-between border-b border-components-panel-border-subtle bg-background-default-hover px-4 py-2.5">
<span className="text-xs font-medium uppercase tracking-wide text-text-tertiary">
{t('develop.toc', { ns: 'appApi' })}
</span>
<button
type="button"
onClick={() => onToggle(false)}
className="group flex h-6 w-6 items-center justify-center rounded-md transition-colors hover:bg-state-base-hover"
aria-label="Close"
>
<span className="i-ri-close-line h-3 w-3 text-text-quaternary transition-colors group-hover:text-text-secondary" />
</button>
</div>
<div className="from-components-panel-border-subtle/20 pointer-events-none absolute left-0 right-0 top-[41px] z-10 h-2 bg-gradient-to-b to-transparent"></div>
<div className="pointer-events-none absolute left-0 right-0 top-[43px] z-10 h-3 bg-gradient-to-b from-background-default-hover to-transparent"></div>
<div className="relative flex-1 overflow-y-auto px-3 py-3 pt-1">
{toc.length === 0
? (
<div className="px-2 py-8 text-center text-xs text-text-quaternary">
{t('develop.noContent', { ns: 'appApi' })}
</div>
)
: (
<ul className="space-y-0.5">
{toc.map((item) => {
const isActive = activeSection === item.href.replace('#', '')
return (
<li key={item.href}>
<a
href={item.href}
onClick={e => onItemClick(e, item)}
className={cn(
'group relative flex items-center rounded-md px-3 py-2 text-[13px] transition-all duration-200',
isActive
? 'bg-state-base-hover font-medium text-text-primary'
: 'text-text-tertiary hover:bg-state-base-hover hover:text-text-secondary',
)}
>
<span
className={cn(
'mr-2 h-1.5 w-1.5 rounded-full transition-all duration-200',
isActive
? 'scale-100 bg-text-accent'
: 'scale-75 bg-components-panel-border',
)}
/>
<span className="flex-1 truncate">
{item.text}
</span>
</a>
</li>
)
})}
</ul>
)}
</div>
<div className="pointer-events-none absolute bottom-0 left-0 right-0 z-10 h-4 rounded-b-xl bg-gradient-to-t from-background-default-hover to-transparent"></div>
</nav>
)
}
export default TocPanel

View File

@@ -61,8 +61,8 @@ vi.mock('@/context/global-public-context', () => ({
useGlobalPublicStore: () => ({}),
}))
// Mock pluginInstallLimit
vi.mock('../../../hooks/use-install-plugin-limit', () => ({
// Mock pluginInstallLimit (imported by the useInstallMultiState hook via @/ path)
vi.mock('@/app/components/plugins/install-plugin/hooks/use-install-plugin-limit', () => ({
pluginInstallLimit: () => ({ canInstall: true }),
}))

View File

@@ -0,0 +1,568 @@
import type { Dependency, GitHubItemAndMarketPlaceDependency, PackageDependency, Plugin, VersionInfo } from '@/app/components/plugins/types'
import { act, renderHook, waitFor } from '@testing-library/react'
import { beforeEach, describe, expect, it, vi } from 'vitest'
import { PluginCategoryEnum } from '@/app/components/plugins/types'
import { getPluginKey, useInstallMultiState } from '../use-install-multi-state'
let mockMarketplaceData: ReturnType<typeof createMarketplaceApiData> | null = null
let mockMarketplaceError: Error | null = null
let mockInstalledInfo: Record<string, VersionInfo> = {}
let mockCanInstall = true
vi.mock('@/service/use-plugins', () => ({
useFetchPluginsInMarketPlaceByInfo: () => ({
isLoading: false,
data: mockMarketplaceData,
error: mockMarketplaceError,
}),
}))
vi.mock('@/app/components/plugins/install-plugin/hooks/use-check-installed', () => ({
default: () => ({
installedInfo: mockInstalledInfo,
}),
}))
vi.mock('@/context/global-public-context', () => ({
useGlobalPublicStore: () => ({}),
}))
vi.mock('@/app/components/plugins/install-plugin/hooks/use-install-plugin-limit', () => ({
pluginInstallLimit: () => ({ canInstall: mockCanInstall }),
}))
const createMockPlugin = (overrides: Partial<Plugin> = {}): Plugin => ({
type: 'plugin',
org: 'test-org',
name: 'Test Plugin',
plugin_id: 'test-plugin-id',
version: '1.0.0',
latest_version: '1.0.0',
latest_package_identifier: 'test-pkg-id',
icon: 'icon.png',
verified: true,
label: { 'en-US': 'Test Plugin' },
brief: { 'en-US': 'Brief' },
description: { 'en-US': 'Description' },
introduction: 'Intro',
repository: 'https://github.com/test/plugin',
category: PluginCategoryEnum.tool,
install_count: 100,
endpoint: { settings: [] },
tags: [],
badges: [],
verification: { authorized_category: 'community' },
from: 'marketplace',
...overrides,
})
const createPackageDependency = (index: number) => ({
type: 'package',
value: {
unique_identifier: `package-plugin-${index}-uid`,
manifest: {
plugin_unique_identifier: `package-plugin-${index}-uid`,
version: '1.0.0',
author: 'test-author',
icon: 'icon.png',
name: `Package Plugin ${index}`,
category: PluginCategoryEnum.tool,
label: { 'en-US': `Package Plugin ${index}` },
description: { 'en-US': 'Test package plugin' },
created_at: '2024-01-01',
resource: {},
plugins: [],
verified: true,
endpoint: { settings: [], endpoints: [] },
model: null,
tags: [],
agent_strategy: null,
meta: { version: '1.0.0' },
trigger: {},
},
},
} as unknown as PackageDependency)
const createMarketplaceDependency = (index: number): GitHubItemAndMarketPlaceDependency => ({
type: 'marketplace',
value: {
marketplace_plugin_unique_identifier: `test-org/plugin-${index}:1.0.0`,
plugin_unique_identifier: `plugin-${index}`,
version: '1.0.0',
},
})
const createGitHubDependency = (index: number): GitHubItemAndMarketPlaceDependency => ({
type: 'github',
value: {
repo: `test-org/plugin-${index}`,
version: 'v1.0.0',
package: `plugin-${index}.zip`,
},
})
const createMarketplaceApiData = (indexes: number[]) => ({
data: {
list: indexes.map(i => ({
plugin: {
plugin_id: `test-org/plugin-${i}`,
org: 'test-org',
name: `Test Plugin ${i}`,
version: '1.0.0',
latest_version: '1.0.0',
},
version: {
unique_identifier: `plugin-${i}-uid`,
},
})),
},
})
const createDefaultParams = (overrides = {}) => ({
allPlugins: [createPackageDependency(0)] as Dependency[],
selectedPlugins: [] as Plugin[],
onSelect: vi.fn(),
onLoadedAllPlugin: vi.fn(),
...overrides,
})
// ==================== getPluginKey Tests ====================
describe('getPluginKey', () => {
it('should return org/name when org is available', () => {
const plugin = createMockPlugin({ org: 'my-org', name: 'my-plugin' })
expect(getPluginKey(plugin)).toBe('my-org/my-plugin')
})
it('should fall back to author when org is not available', () => {
const plugin = createMockPlugin({ org: undefined, author: 'my-author', name: 'my-plugin' })
expect(getPluginKey(plugin)).toBe('my-author/my-plugin')
})
it('should prefer org over author when both exist', () => {
const plugin = createMockPlugin({ org: 'my-org', author: 'my-author', name: 'my-plugin' })
expect(getPluginKey(plugin)).toBe('my-org/my-plugin')
})
it('should handle undefined plugin', () => {
expect(getPluginKey(undefined)).toBe('undefined/undefined')
})
})
// ==================== useInstallMultiState Tests ====================
describe('useInstallMultiState', () => {
beforeEach(() => {
vi.clearAllMocks()
mockMarketplaceData = null
mockMarketplaceError = null
mockInstalledInfo = {}
mockCanInstall = true
})
// ==================== Initial State ====================
describe('Initial State', () => {
it('should initialize plugins from package dependencies', () => {
const params = createDefaultParams()
const { result } = renderHook(() => useInstallMultiState(params))
expect(result.current.plugins).toHaveLength(1)
expect(result.current.plugins[0]).toBeDefined()
expect(result.current.plugins[0]?.plugin_id).toBe('package-plugin-0-uid')
})
it('should have slots for all dependencies even when no packages exist', () => {
const params = createDefaultParams({
allPlugins: [createGitHubDependency(0)] as Dependency[],
})
const { result } = renderHook(() => useInstallMultiState(params))
// Array has slots for all dependencies, but unresolved ones are undefined
expect(result.current.plugins).toHaveLength(1)
expect(result.current.plugins[0]).toBeUndefined()
})
it('should return undefined for non-package items in mixed dependencies', () => {
const params = createDefaultParams({
allPlugins: [
createPackageDependency(0),
createGitHubDependency(1),
] as Dependency[],
})
const { result } = renderHook(() => useInstallMultiState(params))
expect(result.current.plugins).toHaveLength(2)
expect(result.current.plugins[0]).toBeDefined()
expect(result.current.plugins[1]).toBeUndefined()
})
it('should start with empty errorIndexes', () => {
const params = createDefaultParams()
const { result } = renderHook(() => useInstallMultiState(params))
expect(result.current.errorIndexes).toEqual([])
})
})
// ==================== Marketplace Data Sync ====================
describe('Marketplace Data Sync', () => {
it('should update plugins when marketplace data loads by ID', async () => {
mockMarketplaceData = createMarketplaceApiData([0])
const params = createDefaultParams({
allPlugins: [createMarketplaceDependency(0)] as Dependency[],
})
const { result } = renderHook(() => useInstallMultiState(params))
await waitFor(() => {
expect(result.current.plugins[0]).toBeDefined()
expect(result.current.plugins[0]?.version).toBe('1.0.0')
})
})
it('should update plugins when marketplace data loads by meta', async () => {
mockMarketplaceData = createMarketplaceApiData([0])
const params = createDefaultParams({
allPlugins: [createMarketplaceDependency(0)] as Dependency[],
})
const { result } = renderHook(() => useInstallMultiState(params))
// The "by meta" effect sets plugin_id from version.unique_identifier
await waitFor(() => {
expect(result.current.plugins[0]).toBeDefined()
})
})
it('should add to errorIndexes when marketplace item not found in response', async () => {
mockMarketplaceData = { data: { list: [] } }
const params = createDefaultParams({
allPlugins: [createMarketplaceDependency(0)] as Dependency[],
})
const { result } = renderHook(() => useInstallMultiState(params))
await waitFor(() => {
expect(result.current.errorIndexes).toContain(0)
})
})
it('should handle multiple marketplace plugins', async () => {
mockMarketplaceData = createMarketplaceApiData([0, 1])
const params = createDefaultParams({
allPlugins: [
createMarketplaceDependency(0),
createMarketplaceDependency(1),
] as Dependency[],
})
const { result } = renderHook(() => useInstallMultiState(params))
await waitFor(() => {
expect(result.current.plugins[0]).toBeDefined()
expect(result.current.plugins[1]).toBeDefined()
})
})
})
// ==================== Error Handling ====================
describe('Error Handling', () => {
it('should mark all marketplace indexes as errors on fetch failure', async () => {
mockMarketplaceError = new Error('Fetch failed')
const params = createDefaultParams({
allPlugins: [
createMarketplaceDependency(0),
createMarketplaceDependency(1),
] as Dependency[],
})
const { result } = renderHook(() => useInstallMultiState(params))
await waitFor(() => {
expect(result.current.errorIndexes).toContain(0)
expect(result.current.errorIndexes).toContain(1)
})
})
it('should not affect non-marketplace indexes on marketplace fetch error', async () => {
mockMarketplaceError = new Error('Fetch failed')
const params = createDefaultParams({
allPlugins: [
createPackageDependency(0),
createMarketplaceDependency(1),
] as Dependency[],
})
const { result } = renderHook(() => useInstallMultiState(params))
await waitFor(() => {
expect(result.current.errorIndexes).toContain(1)
expect(result.current.errorIndexes).not.toContain(0)
})
})
})
// ==================== Loaded All Data Notification ====================
describe('Loaded All Data Notification', () => {
it('should call onLoadedAllPlugin when all data loaded', async () => {
const params = createDefaultParams()
renderHook(() => useInstallMultiState(params))
await waitFor(() => {
expect(params.onLoadedAllPlugin).toHaveBeenCalledWith(mockInstalledInfo)
})
})
it('should not call onLoadedAllPlugin when not all plugins resolved', () => {
// GitHub plugin not fetched yet → isLoadedAllData = false
const params = createDefaultParams({
allPlugins: [
createPackageDependency(0),
createGitHubDependency(1),
] as Dependency[],
})
renderHook(() => useInstallMultiState(params))
expect(params.onLoadedAllPlugin).not.toHaveBeenCalled()
})
it('should call onLoadedAllPlugin after all errors are counted', async () => {
mockMarketplaceError = new Error('Fetch failed')
const params = createDefaultParams({
allPlugins: [createMarketplaceDependency(0)] as Dependency[],
})
renderHook(() => useInstallMultiState(params))
// Error fills errorIndexes → isLoadedAllData becomes true
await waitFor(() => {
expect(params.onLoadedAllPlugin).toHaveBeenCalled()
})
})
})
// ==================== handleGitHubPluginFetched ====================
describe('handleGitHubPluginFetched', () => {
it('should update plugin at the specified index', async () => {
const params = createDefaultParams({
allPlugins: [createGitHubDependency(0)] as Dependency[],
})
const { result } = renderHook(() => useInstallMultiState(params))
const mockPlugin = createMockPlugin({ plugin_id: 'github-plugin-0' })
await act(async () => {
result.current.handleGitHubPluginFetched(0)(mockPlugin)
})
expect(result.current.plugins[0]).toEqual(mockPlugin)
})
it('should not affect other plugin slots', async () => {
const params = createDefaultParams({
allPlugins: [
createPackageDependency(0),
createGitHubDependency(1),
] as Dependency[],
})
const { result } = renderHook(() => useInstallMultiState(params))
const originalPlugin0 = result.current.plugins[0]
const mockPlugin = createMockPlugin({ plugin_id: 'github-plugin-1' })
await act(async () => {
result.current.handleGitHubPluginFetched(1)(mockPlugin)
})
expect(result.current.plugins[0]).toEqual(originalPlugin0)
expect(result.current.plugins[1]).toEqual(mockPlugin)
})
})
// ==================== handleGitHubPluginFetchError ====================
describe('handleGitHubPluginFetchError', () => {
it('should add index to errorIndexes', async () => {
const params = createDefaultParams({
allPlugins: [createGitHubDependency(0)] as Dependency[],
})
const { result } = renderHook(() => useInstallMultiState(params))
await act(async () => {
result.current.handleGitHubPluginFetchError(0)()
})
expect(result.current.errorIndexes).toContain(0)
})
it('should accumulate multiple error indexes without stale closure', async () => {
const params = createDefaultParams({
allPlugins: [
createGitHubDependency(0),
createGitHubDependency(1),
] as Dependency[],
})
const { result } = renderHook(() => useInstallMultiState(params))
await act(async () => {
result.current.handleGitHubPluginFetchError(0)()
})
await act(async () => {
result.current.handleGitHubPluginFetchError(1)()
})
expect(result.current.errorIndexes).toContain(0)
expect(result.current.errorIndexes).toContain(1)
})
})
// ==================== getVersionInfo ====================
describe('getVersionInfo', () => {
it('should return hasInstalled false when plugin not installed', () => {
const params = createDefaultParams()
const { result } = renderHook(() => useInstallMultiState(params))
const info = result.current.getVersionInfo('unknown/plugin')
expect(info.hasInstalled).toBe(false)
expect(info.installedVersion).toBeUndefined()
expect(info.toInstallVersion).toBe('')
})
it('should return hasInstalled true with version when installed', () => {
mockInstalledInfo = {
'test-author/Package Plugin 0': {
installedId: 'installed-1',
installedVersion: '0.9.0',
uniqueIdentifier: 'uid-1',
},
}
const params = createDefaultParams()
const { result } = renderHook(() => useInstallMultiState(params))
const info = result.current.getVersionInfo('test-author/Package Plugin 0')
expect(info.hasInstalled).toBe(true)
expect(info.installedVersion).toBe('0.9.0')
})
})
// ==================== handleSelect ====================
describe('handleSelect', () => {
it('should call onSelect with plugin, index, and installable count', async () => {
const params = createDefaultParams()
const { result } = renderHook(() => useInstallMultiState(params))
await act(async () => {
result.current.handleSelect(0)()
})
expect(params.onSelect).toHaveBeenCalledWith(
result.current.plugins[0],
0,
expect.any(Number),
)
})
it('should filter installable plugins using pluginInstallLimit', async () => {
const params = createDefaultParams({
allPlugins: [
createPackageDependency(0),
createPackageDependency(1),
] as Dependency[],
})
const { result } = renderHook(() => useInstallMultiState(params))
await act(async () => {
result.current.handleSelect(0)()
})
// mockCanInstall is true, so all 2 plugins are installable
expect(params.onSelect).toHaveBeenCalledWith(
expect.anything(),
0,
2,
)
})
})
// ==================== isPluginSelected ====================
describe('isPluginSelected', () => {
it('should return true when plugin is in selectedPlugins', () => {
const selectedPlugin = createMockPlugin({ plugin_id: 'package-plugin-0-uid' })
const params = createDefaultParams({
selectedPlugins: [selectedPlugin],
})
const { result } = renderHook(() => useInstallMultiState(params))
expect(result.current.isPluginSelected(0)).toBe(true)
})
it('should return false when plugin is not in selectedPlugins', () => {
const params = createDefaultParams({ selectedPlugins: [] })
const { result } = renderHook(() => useInstallMultiState(params))
expect(result.current.isPluginSelected(0)).toBe(false)
})
it('should return false when plugin at index is undefined', () => {
const params = createDefaultParams({
allPlugins: [createGitHubDependency(0)] as Dependency[],
selectedPlugins: [createMockPlugin()],
})
const { result } = renderHook(() => useInstallMultiState(params))
// plugins[0] is undefined (GitHub not yet fetched)
expect(result.current.isPluginSelected(0)).toBe(false)
})
})
// ==================== getInstallablePlugins ====================
describe('getInstallablePlugins', () => {
it('should return all plugins when canInstall is true', () => {
mockCanInstall = true
const params = createDefaultParams({
allPlugins: [
createPackageDependency(0),
createPackageDependency(1),
] as Dependency[],
})
const { result } = renderHook(() => useInstallMultiState(params))
const { installablePlugins, selectedIndexes } = result.current.getInstallablePlugins()
expect(installablePlugins).toHaveLength(2)
expect(selectedIndexes).toEqual([0, 1])
})
it('should return empty arrays when canInstall is false', () => {
mockCanInstall = false
const params = createDefaultParams({
allPlugins: [createPackageDependency(0)] as Dependency[],
})
const { result } = renderHook(() => useInstallMultiState(params))
const { installablePlugins, selectedIndexes } = result.current.getInstallablePlugins()
expect(installablePlugins).toHaveLength(0)
expect(selectedIndexes).toEqual([])
})
it('should skip unloaded (undefined) plugins', () => {
mockCanInstall = true
const params = createDefaultParams({
allPlugins: [
createPackageDependency(0),
createGitHubDependency(1),
] as Dependency[],
})
const { result } = renderHook(() => useInstallMultiState(params))
const { installablePlugins, selectedIndexes } = result.current.getInstallablePlugins()
// Only package plugin is loaded; GitHub not yet fetched
expect(installablePlugins).toHaveLength(1)
expect(selectedIndexes).toEqual([0])
})
})
})

View File

@@ -0,0 +1,230 @@
'use client'
import type { Dependency, GitHubItemAndMarketPlaceDependency, PackageDependency, Plugin, VersionInfo } from '@/app/components/plugins/types'
import { useCallback, useEffect, useMemo, useState } from 'react'
import useCheckInstalled from '@/app/components/plugins/install-plugin/hooks/use-check-installed'
import { pluginInstallLimit } from '@/app/components/plugins/install-plugin/hooks/use-install-plugin-limit'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useFetchPluginsInMarketPlaceByInfo } from '@/service/use-plugins'
type UseInstallMultiStateParams = {
allPlugins: Dependency[]
selectedPlugins: Plugin[]
onSelect: (plugin: Plugin, selectedIndex: number, allCanInstallPluginsLength: number) => void
onLoadedAllPlugin: (installedInfo: Record<string, VersionInfo>) => void
}
export function getPluginKey(plugin: Plugin | undefined): string {
return `${plugin?.org || plugin?.author}/${plugin?.name}`
}
function parseMarketplaceIdentifier(identifier: string) {
const [orgPart, nameAndVersionPart] = identifier.split('@')[0].split('/')
const [name, version] = nameAndVersionPart.split(':')
return { organization: orgPart, plugin: name, version }
}
function initPluginsFromDependencies(allPlugins: Dependency[]): (Plugin | undefined)[] {
if (!allPlugins.some(d => d.type === 'package'))
return []
return allPlugins.map((d) => {
if (d.type !== 'package')
return undefined
const { manifest, unique_identifier } = (d as PackageDependency).value
return {
...manifest,
plugin_id: unique_identifier,
} as unknown as Plugin
})
}
export function useInstallMultiState({
allPlugins,
selectedPlugins,
onSelect,
onLoadedAllPlugin,
}: UseInstallMultiStateParams) {
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
// Marketplace plugins filtering and index mapping
const marketplacePlugins = useMemo(
() => allPlugins.filter((d): d is GitHubItemAndMarketPlaceDependency => d.type === 'marketplace'),
[allPlugins],
)
const marketPlaceInDSLIndex = useMemo(() => {
return allPlugins.reduce<number[]>((acc, d, index) => {
if (d.type === 'marketplace')
acc.push(index)
return acc
}, [])
}, [allPlugins])
// Marketplace data fetching: by unique identifier and by meta info
const {
isLoading: isFetchingById,
data: infoGetById,
error: infoByIdError,
} = useFetchPluginsInMarketPlaceByInfo(
marketplacePlugins.map(d => parseMarketplaceIdentifier(d.value.marketplace_plugin_unique_identifier!)),
)
const {
isLoading: isFetchingByMeta,
data: infoByMeta,
error: infoByMetaError,
} = useFetchPluginsInMarketPlaceByInfo(
marketplacePlugins.map(d => d.value!),
)
// Derive marketplace plugin data and errors from API responses
const { marketplacePluginMap, marketplaceErrorIndexes } = useMemo(() => {
const pluginMap = new Map<number, Plugin>()
const errorSet = new Set<number>()
// Process "by ID" response
if (!isFetchingById && infoGetById?.data.list) {
const sortedList = marketplacePlugins.map((d) => {
const id = d.value.marketplace_plugin_unique_identifier?.split(':')[0]
const retPluginInfo = infoGetById.data.list.find(item => item.plugin.plugin_id === id)?.plugin
return { ...retPluginInfo, from: d.type } as Plugin
})
marketPlaceInDSLIndex.forEach((index, i) => {
if (sortedList[i]) {
pluginMap.set(index, {
...sortedList[i],
version: sortedList[i]!.version || sortedList[i]!.latest_version,
})
}
else { errorSet.add(index) }
})
}
// Process "by meta" response (may overwrite "by ID" results)
if (!isFetchingByMeta && infoByMeta?.data.list) {
const payloads = infoByMeta.data.list
marketPlaceInDSLIndex.forEach((index, i) => {
if (payloads[i]) {
const item = payloads[i]
pluginMap.set(index, {
...item.plugin,
plugin_id: item.version.unique_identifier,
} as Plugin)
}
else { errorSet.add(index) }
})
}
// Mark all marketplace indexes as errors on fetch failure
if (infoByMetaError || infoByIdError)
marketPlaceInDSLIndex.forEach(index => errorSet.add(index))
return { marketplacePluginMap: pluginMap, marketplaceErrorIndexes: errorSet }
}, [isFetchingById, isFetchingByMeta, infoGetById, infoByMeta, infoByMetaError, infoByIdError, marketPlaceInDSLIndex, marketplacePlugins])
// GitHub-fetched plugins and errors (imperative state from child callbacks)
const [githubPluginMap, setGithubPluginMap] = useState<Map<number, Plugin>>(() => new Map())
const [githubErrorIndexes, setGithubErrorIndexes] = useState<number[]>([])
// Merge all plugin sources into a single array
const plugins = useMemo(() => {
const initial = initPluginsFromDependencies(allPlugins)
const result: (Plugin | undefined)[] = allPlugins.map((_, i) => initial[i])
marketplacePluginMap.forEach((plugin, index) => {
result[index] = plugin
})
githubPluginMap.forEach((plugin, index) => {
result[index] = plugin
})
return result
}, [allPlugins, marketplacePluginMap, githubPluginMap])
// Merge all error sources
const errorIndexes = useMemo(() => {
return [...marketplaceErrorIndexes, ...githubErrorIndexes]
}, [marketplaceErrorIndexes, githubErrorIndexes])
// Check installed status after all data is loaded
const isLoadedAllData = (plugins.filter(Boolean).length + errorIndexes.length) === allPlugins.length
const { installedInfo } = useCheckInstalled({
pluginIds: plugins.filter(Boolean).map(d => getPluginKey(d)) || [],
enabled: isLoadedAllData,
})
// Notify parent when all plugin data and install info is ready
useEffect(() => {
if (isLoadedAllData && installedInfo)
onLoadedAllPlugin(installedInfo!)
// eslint-disable-next-line react-hooks/exhaustive-deps
}, [isLoadedAllData, installedInfo])
// Callback: handle GitHub plugin fetch success
const handleGitHubPluginFetched = useCallback((index: number) => {
return (p: Plugin) => {
setGithubPluginMap(prev => new Map(prev).set(index, p))
}
}, [])
// Callback: handle GitHub plugin fetch error
const handleGitHubPluginFetchError = useCallback((index: number) => {
return () => {
setGithubErrorIndexes(prev => [...prev, index])
}
}, [])
// Callback: get version info for a plugin by its key
const getVersionInfo = useCallback((pluginId: string) => {
const pluginDetail = installedInfo?.[pluginId]
return {
hasInstalled: !!pluginDetail,
installedVersion: pluginDetail?.installedVersion,
toInstallVersion: '',
}
}, [installedInfo])
// Callback: handle plugin selection
const handleSelect = useCallback((index: number) => {
return () => {
const canSelectPlugins = plugins.filter((p) => {
const { canInstall } = pluginInstallLimit(p!, systemFeatures)
return canInstall
})
onSelect(plugins[index]!, index, canSelectPlugins.length)
}
}, [onSelect, plugins, systemFeatures])
// Callback: check if a plugin at given index is selected
const isPluginSelected = useCallback((index: number) => {
return !!selectedPlugins.find(p => p.plugin_id === plugins[index]?.plugin_id)
}, [selectedPlugins, plugins])
// Callback: get all installable plugins with their indexes
const getInstallablePlugins = useCallback(() => {
const selectedIndexes: number[] = []
const installablePlugins: Plugin[] = []
allPlugins.forEach((_d, index) => {
const p = plugins[index]
if (!p)
return
const { canInstall } = pluginInstallLimit(p, systemFeatures)
if (canInstall) {
selectedIndexes.push(index)
installablePlugins.push(p)
}
})
return { selectedIndexes, installablePlugins }
}, [allPlugins, plugins, systemFeatures])
return {
plugins,
errorIndexes,
handleGitHubPluginFetched,
handleGitHubPluginFetchError,
getVersionInfo,
handleSelect,
isPluginSelected,
getInstallablePlugins,
}
}

View File

@@ -1,16 +1,12 @@
'use client'
import type { Dependency, GitHubItemAndMarketPlaceDependency, PackageDependency, Plugin, VersionInfo } from '../../../types'
import { produce } from 'immer'
import * as React from 'react'
import { useCallback, useEffect, useImperativeHandle, useMemo, useState } from 'react'
import useCheckInstalled from '@/app/components/plugins/install-plugin/hooks/use-check-installed'
import { useGlobalPublicStore } from '@/context/global-public-context'
import { useFetchPluginsInMarketPlaceByInfo } from '@/service/use-plugins'
import { useImperativeHandle } from 'react'
import LoadingError from '../../base/loading-error'
import { pluginInstallLimit } from '../../hooks/use-install-plugin-limit'
import GithubItem from '../item/github-item'
import MarketplaceItem from '../item/marketplace-item'
import PackageItem from '../item/package-item'
import { getPluginKey, useInstallMultiState } from './hooks/use-install-multi-state'
type Props = {
allPlugins: Dependency[]
@@ -38,206 +34,50 @@ const InstallByDSLList = ({
isFromMarketPlace,
ref,
}: Props) => {
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
// DSL has id, to get plugin info to show more info
const { isLoading: isFetchingMarketplaceDataById, data: infoGetById, error: infoByIdError } = useFetchPluginsInMarketPlaceByInfo(allPlugins.filter(d => d.type === 'marketplace').map((d) => {
const dependecy = (d as GitHubItemAndMarketPlaceDependency).value
// split org, name, version by / and :
// and remove @ and its suffix
const [orgPart, nameAndVersionPart] = dependecy.marketplace_plugin_unique_identifier!.split('@')[0].split('/')
const [name, version] = nameAndVersionPart.split(':')
return {
organization: orgPart,
plugin: name,
version,
}
}))
// has meta(org,name,version), to get id
const { isLoading: isFetchingDataByMeta, data: infoByMeta, error: infoByMetaError } = useFetchPluginsInMarketPlaceByInfo(allPlugins.filter(d => d.type === 'marketplace').map(d => (d as GitHubItemAndMarketPlaceDependency).value!))
const [plugins, doSetPlugins] = useState<(Plugin | undefined)[]>((() => {
const hasLocalPackage = allPlugins.some(d => d.type === 'package')
if (!hasLocalPackage)
return []
const _plugins = allPlugins.map((d) => {
if (d.type === 'package') {
return {
...(d as any).value.manifest,
plugin_id: (d as any).value.unique_identifier,
}
}
return undefined
})
return _plugins
})())
const pluginsRef = React.useRef<(Plugin | undefined)[]>(plugins)
const setPlugins = useCallback((p: (Plugin | undefined)[]) => {
doSetPlugins(p)
pluginsRef.current = p
}, [])
const [errorIndexes, setErrorIndexes] = useState<number[]>([])
const handleGitHubPluginFetched = useCallback((index: number) => {
return (p: Plugin) => {
const nextPlugins = produce(pluginsRef.current, (draft) => {
draft[index] = p
})
setPlugins(nextPlugins)
}
}, [setPlugins])
const handleGitHubPluginFetchError = useCallback((index: number) => {
return () => {
setErrorIndexes([...errorIndexes, index])
}
}, [errorIndexes])
const marketPlaceInDSLIndex = useMemo(() => {
const res: number[] = []
allPlugins.forEach((d, index) => {
if (d.type === 'marketplace')
res.push(index)
})
return res
}, [allPlugins])
useEffect(() => {
if (!isFetchingMarketplaceDataById && infoGetById?.data.list) {
const sortedList = allPlugins.filter(d => d.type === 'marketplace').map((d) => {
const p = d as GitHubItemAndMarketPlaceDependency
const id = p.value.marketplace_plugin_unique_identifier?.split(':')[0]
const retPluginInfo = infoGetById.data.list.find(item => item.plugin.plugin_id === id)?.plugin
return { ...retPluginInfo, from: d.type } as Plugin
})
const payloads = sortedList
const failedIndex: number[] = []
const nextPlugins = produce(pluginsRef.current, (draft) => {
marketPlaceInDSLIndex.forEach((index, i) => {
if (payloads[i]) {
draft[index] = {
...payloads[i],
version: payloads[i]!.version || payloads[i]!.latest_version,
}
}
else { failedIndex.push(index) }
})
})
setPlugins(nextPlugins)
if (failedIndex.length > 0)
setErrorIndexes([...errorIndexes, ...failedIndex])
}
}, [isFetchingMarketplaceDataById])
useEffect(() => {
if (!isFetchingDataByMeta && infoByMeta?.data.list) {
const payloads = infoByMeta?.data.list
const failedIndex: number[] = []
const nextPlugins = produce(pluginsRef.current, (draft) => {
marketPlaceInDSLIndex.forEach((index, i) => {
if (payloads[i]) {
const item = payloads[i]
draft[index] = {
...item.plugin,
plugin_id: item.version.unique_identifier,
}
}
else {
failedIndex.push(index)
}
})
})
setPlugins(nextPlugins)
if (failedIndex.length > 0)
setErrorIndexes([...errorIndexes, ...failedIndex])
}
}, [isFetchingDataByMeta])
useEffect(() => {
// get info all failed
if (infoByMetaError || infoByIdError)
setErrorIndexes([...errorIndexes, ...marketPlaceInDSLIndex])
}, [infoByMetaError, infoByIdError])
const isLoadedAllData = (plugins.filter(p => !!p).length + errorIndexes.length) === allPlugins.length
const { installedInfo } = useCheckInstalled({
pluginIds: plugins?.filter(p => !!p).map((d) => {
return `${d?.org || d?.author}/${d?.name}`
}) || [],
enabled: isLoadedAllData,
const {
plugins,
errorIndexes,
handleGitHubPluginFetched,
handleGitHubPluginFetchError,
getVersionInfo,
handleSelect,
isPluginSelected,
getInstallablePlugins,
} = useInstallMultiState({
allPlugins,
selectedPlugins,
onSelect,
onLoadedAllPlugin,
})
const getVersionInfo = useCallback((pluginId: string) => {
const pluginDetail = installedInfo?.[pluginId]
const hasInstalled = !!pluginDetail
return {
hasInstalled,
installedVersion: pluginDetail?.installedVersion,
toInstallVersion: '',
}
}, [installedInfo])
useEffect(() => {
if (isLoadedAllData && installedInfo)
onLoadedAllPlugin(installedInfo!)
}, [isLoadedAllData, installedInfo])
const handleSelect = useCallback((index: number) => {
return () => {
const canSelectPlugins = plugins.filter((p) => {
const { canInstall } = pluginInstallLimit(p!, systemFeatures)
return canInstall
})
onSelect(plugins[index]!, index, canSelectPlugins.length)
}
}, [onSelect, plugins, systemFeatures])
useImperativeHandle(ref, () => ({
selectAllPlugins: () => {
const selectedIndexes: number[] = []
const selectedPlugins: Plugin[] = []
allPlugins.forEach((d, index) => {
const p = plugins[index]
if (!p)
return
const { canInstall } = pluginInstallLimit(p, systemFeatures)
if (canInstall) {
selectedIndexes.push(index)
selectedPlugins.push(p)
}
})
onSelectAll(selectedPlugins, selectedIndexes)
},
deSelectAllPlugins: () => {
onDeSelectAll()
const { installablePlugins, selectedIndexes } = getInstallablePlugins()
onSelectAll(installablePlugins, selectedIndexes)
},
deSelectAllPlugins: onDeSelectAll,
}))
return (
<>
{allPlugins.map((d, index) => {
if (errorIndexes.includes(index)) {
return (
<LoadingError key={index} />
)
}
if (errorIndexes.includes(index))
return <LoadingError key={index} />
const plugin = plugins[index]
const checked = isPluginSelected(index)
const versionInfo = getVersionInfo(getPluginKey(plugin))
if (d.type === 'github') {
return (
<GithubItem
key={index}
checked={!!selectedPlugins.find(p => p.plugin_id === plugins[index]?.plugin_id)}
checked={checked}
onCheckedChange={handleSelect(index)}
dependency={d as GitHubItemAndMarketPlaceDependency}
onFetchedPayload={handleGitHubPluginFetched(index)}
onFetchError={handleGitHubPluginFetchError(index)}
versionInfo={getVersionInfo(`${plugin?.org || plugin?.author}/${plugin?.name}`)}
versionInfo={versionInfo}
/>
)
}
@@ -246,24 +86,23 @@ const InstallByDSLList = ({
return (
<MarketplaceItem
key={index}
checked={!!selectedPlugins.find(p => p.plugin_id === plugins[index]?.plugin_id)}
checked={checked}
onCheckedChange={handleSelect(index)}
payload={{ ...plugin, from: d.type } as Plugin}
version={(d as GitHubItemAndMarketPlaceDependency).value.version! || plugin?.version || ''}
versionInfo={getVersionInfo(`${plugin?.org || plugin?.author}/${plugin?.name}`)}
versionInfo={versionInfo}
/>
)
}
// Local package
return (
<PackageItem
key={index}
checked={!!selectedPlugins.find(p => p.plugin_id === plugins[index]?.plugin_id)}
checked={checked}
onCheckedChange={handleSelect(index)}
payload={d as PackageDependency}
isFromMarketPlace={isFromMarketPlace}
versionInfo={getVersionInfo(`${plugin?.org || plugin?.author}/${plugin?.name}`)}
versionInfo={versionInfo}
/>
)
})}

View File

@@ -30,19 +30,21 @@ vi.mock('@/context/app-context', () => ({
}))
// Mock API services - only mock external services
const mockFetchWorkflowToolDetailByAppID = vi.fn()
const mockCreateWorkflowToolProvider = vi.fn()
const mockSaveWorkflowToolProvider = vi.fn()
vi.mock('@/service/tools', () => ({
fetchWorkflowToolDetailByAppID: (...args: unknown[]) => mockFetchWorkflowToolDetailByAppID(...args),
createWorkflowToolProvider: (...args: unknown[]) => mockCreateWorkflowToolProvider(...args),
saveWorkflowToolProvider: (...args: unknown[]) => mockSaveWorkflowToolProvider(...args),
}))
// Mock invalidate workflow tools hook
// Mock service hooks
const mockInvalidateAllWorkflowTools = vi.fn()
const mockInvalidateWorkflowToolDetailByAppID = vi.fn()
const mockUseWorkflowToolDetailByAppID = vi.fn()
vi.mock('@/service/use-tools', () => ({
useInvalidateAllWorkflowTools: () => mockInvalidateAllWorkflowTools,
useInvalidateWorkflowToolDetailByAppID: () => mockInvalidateWorkflowToolDetailByAppID,
useWorkflowToolDetailByAppID: (...args: unknown[]) => mockUseWorkflowToolDetailByAppID(...args),
}))
// Mock Toast - need to verify notification calls
@@ -242,7 +244,10 @@ describe('WorkflowToolConfigureButton', () => {
vi.clearAllMocks()
mockPortalOpenState = false
mockIsCurrentWorkspaceManager.mockReturnValue(true)
mockFetchWorkflowToolDetailByAppID.mockResolvedValue(createMockWorkflowToolDetail())
mockUseWorkflowToolDetailByAppID.mockImplementation((_appId: string, enabled: boolean) => ({
data: enabled ? createMockWorkflowToolDetail() : undefined,
isLoading: false,
}))
})
// Rendering Tests (REQUIRED)
@@ -307,19 +312,17 @@ describe('WorkflowToolConfigureButton', () => {
expect(screen.getByText('Please save the workflow first')).toBeInTheDocument()
})
it('should render loading state when published and fetching details', async () => {
it('should render loading state when published and fetching details', () => {
// Arrange
mockFetchWorkflowToolDetailByAppID.mockImplementation(() => new Promise(() => { })) // Never resolves
mockUseWorkflowToolDetailByAppID.mockReturnValue({ data: undefined, isLoading: true })
const props = createDefaultConfigureButtonProps({ published: true })
// Act
render(<WorkflowToolConfigureButton {...props} />)
// Assert
await waitFor(() => {
const loadingElement = document.querySelector('.pt-2')
expect(loadingElement).toBeInTheDocument()
})
const loadingElement = document.querySelector('.pt-2')
expect(loadingElement).toBeInTheDocument()
})
it('should render configure and manage buttons when published', async () => {
@@ -381,76 +384,10 @@ describe('WorkflowToolConfigureButton', () => {
// Act & Assert
expect(() => render(<WorkflowToolConfigureButton {...props} />)).not.toThrow()
})
it('should call handlePublish when updating workflow tool', async () => {
// Arrange
const user = userEvent.setup()
const handlePublish = vi.fn().mockResolvedValue(undefined)
mockSaveWorkflowToolProvider.mockResolvedValue({})
const props = createDefaultConfigureButtonProps({ published: true, handlePublish })
// Act
render(<WorkflowToolConfigureButton {...props} />)
await waitFor(() => {
expect(screen.getByText('workflow.common.configure')).toBeInTheDocument()
})
await user.click(screen.getByText('workflow.common.configure'))
// Fill required fields and save
await waitFor(() => {
expect(screen.getByTestId('drawer')).toBeInTheDocument()
})
const saveButton = screen.getByText('common.operation.save')
await user.click(saveButton)
// Confirm in modal
await waitFor(() => {
expect(screen.getByText('tools.createTool.confirmTitle')).toBeInTheDocument()
})
await user.click(screen.getByText('common.operation.confirm'))
// Assert
await waitFor(() => {
expect(handlePublish).toHaveBeenCalled()
})
})
})
// State Management Tests
describe('State Management', () => {
it('should fetch detail when published and mount', async () => {
// Arrange
const props = createDefaultConfigureButtonProps({ published: true })
// Act
render(<WorkflowToolConfigureButton {...props} />)
// Assert
await waitFor(() => {
expect(mockFetchWorkflowToolDetailByAppID).toHaveBeenCalledWith('workflow-app-123')
})
})
it('should refetch detail when detailNeedUpdate changes to true', async () => {
// Arrange
const props = createDefaultConfigureButtonProps({ published: true, detailNeedUpdate: false })
// Act
const { rerender } = render(<WorkflowToolConfigureButton {...props} />)
await waitFor(() => {
expect(mockFetchWorkflowToolDetailByAppID).toHaveBeenCalledTimes(1)
})
// Rerender with detailNeedUpdate true
rerender(<WorkflowToolConfigureButton {...props} detailNeedUpdate={true} />)
// Assert
await waitFor(() => {
expect(mockFetchWorkflowToolDetailByAppID).toHaveBeenCalledTimes(2)
})
})
// Modal behavior tests
describe('Modal Behavior', () => {
it('should toggle modal visibility', async () => {
// Arrange
const user = userEvent.setup()
@@ -513,85 +450,6 @@ describe('WorkflowToolConfigureButton', () => {
})
})
// Memoization Tests
describe('Memoization - outdated detection', () => {
it('should detect outdated when parameter count differs', async () => {
// Arrange
const detail = createMockWorkflowToolDetail()
mockFetchWorkflowToolDetailByAppID.mockResolvedValue(detail)
const props = createDefaultConfigureButtonProps({
published: true,
inputs: [
createMockInputVar({ variable: 'test_var' }),
createMockInputVar({ variable: 'extra_var' }),
],
})
// Act
render(<WorkflowToolConfigureButton {...props} />)
// Assert - should show outdated warning
await waitFor(() => {
expect(screen.getByText('workflow.common.workflowAsToolTip')).toBeInTheDocument()
})
})
it('should detect outdated when parameter not found', async () => {
// Arrange
const detail = createMockWorkflowToolDetail()
mockFetchWorkflowToolDetailByAppID.mockResolvedValue(detail)
const props = createDefaultConfigureButtonProps({
published: true,
inputs: [createMockInputVar({ variable: 'different_var' })],
})
// Act
render(<WorkflowToolConfigureButton {...props} />)
// Assert
await waitFor(() => {
expect(screen.getByText('workflow.common.workflowAsToolTip')).toBeInTheDocument()
})
})
it('should detect outdated when required property differs', async () => {
// Arrange
const detail = createMockWorkflowToolDetail()
mockFetchWorkflowToolDetailByAppID.mockResolvedValue(detail)
const props = createDefaultConfigureButtonProps({
published: true,
inputs: [createMockInputVar({ variable: 'test_var', required: false })], // Detail has required: true
})
// Act
render(<WorkflowToolConfigureButton {...props} />)
// Assert
await waitFor(() => {
expect(screen.getByText('workflow.common.workflowAsToolTip')).toBeInTheDocument()
})
})
it('should not show outdated when parameters match', async () => {
// Arrange
const detail = createMockWorkflowToolDetail()
mockFetchWorkflowToolDetailByAppID.mockResolvedValue(detail)
const props = createDefaultConfigureButtonProps({
published: true,
inputs: [createMockInputVar({ variable: 'test_var', required: true, type: InputVarType.textInput })],
})
// Act
render(<WorkflowToolConfigureButton {...props} />)
// Assert
await waitFor(() => {
expect(screen.getByText('workflow.common.configure')).toBeInTheDocument()
})
expect(screen.queryByText('workflow.common.workflowAsToolTip')).not.toBeInTheDocument()
})
})
// User Interactions Tests
describe('User Interactions', () => {
it('should navigate to tools page when manage button clicked', async () => {
@@ -611,174 +469,10 @@ describe('WorkflowToolConfigureButton', () => {
// Assert
expect(mockPush).toHaveBeenCalledWith('/tools?category=workflow')
})
it('should create workflow tool provider on first publish', async () => {
// Arrange
const user = userEvent.setup()
mockCreateWorkflowToolProvider.mockResolvedValue({})
const props = createDefaultConfigureButtonProps()
// Act
render(<WorkflowToolConfigureButton {...props} />)
// Open modal
const triggerArea = screen.getByText('workflow.common.workflowAsTool').closest('.flex')
await user.click(triggerArea!)
await waitFor(() => {
expect(screen.getByTestId('drawer')).toBeInTheDocument()
})
// Fill in required name field
const nameInput = screen.getByPlaceholderText('tools.createTool.nameForToolCallPlaceHolder')
await user.type(nameInput, 'my_tool')
// Click save
await user.click(screen.getByText('common.operation.save'))
// Assert
await waitFor(() => {
expect(mockCreateWorkflowToolProvider).toHaveBeenCalled()
})
})
it('should show success toast after creating workflow tool', async () => {
// Arrange
const user = userEvent.setup()
mockCreateWorkflowToolProvider.mockResolvedValue({})
const props = createDefaultConfigureButtonProps()
// Act
render(<WorkflowToolConfigureButton {...props} />)
const triggerArea = screen.getByText('workflow.common.workflowAsTool').closest('.flex')
await user.click(triggerArea!)
await waitFor(() => {
expect(screen.getByTestId('drawer')).toBeInTheDocument()
})
const nameInput = screen.getByPlaceholderText('tools.createTool.nameForToolCallPlaceHolder')
await user.type(nameInput, 'my_tool')
await user.click(screen.getByText('common.operation.save'))
// Assert
await waitFor(() => {
expect(mockToastNotify).toHaveBeenCalledWith({
type: 'success',
message: 'common.api.actionSuccess',
})
})
})
it('should show error toast when create fails', async () => {
// Arrange
const user = userEvent.setup()
mockCreateWorkflowToolProvider.mockRejectedValue(new Error('Create failed'))
const props = createDefaultConfigureButtonProps()
// Act
render(<WorkflowToolConfigureButton {...props} />)
const triggerArea = screen.getByText('workflow.common.workflowAsTool').closest('.flex')
await user.click(triggerArea!)
await waitFor(() => {
expect(screen.getByTestId('drawer')).toBeInTheDocument()
})
const nameInput = screen.getByPlaceholderText('tools.createTool.nameForToolCallPlaceHolder')
await user.type(nameInput, 'my_tool')
await user.click(screen.getByText('common.operation.save'))
// Assert
await waitFor(() => {
expect(mockToastNotify).toHaveBeenCalledWith({
type: 'error',
message: 'Create failed',
})
})
})
it('should call onRefreshData after successful create', async () => {
// Arrange
const user = userEvent.setup()
const onRefreshData = vi.fn()
mockCreateWorkflowToolProvider.mockResolvedValue({})
const props = createDefaultConfigureButtonProps({ onRefreshData })
// Act
render(<WorkflowToolConfigureButton {...props} />)
const triggerArea = screen.getByText('workflow.common.workflowAsTool').closest('.flex')
await user.click(triggerArea!)
await waitFor(() => {
expect(screen.getByTestId('drawer')).toBeInTheDocument()
})
const nameInput = screen.getByPlaceholderText('tools.createTool.nameForToolCallPlaceHolder')
await user.type(nameInput, 'my_tool')
await user.click(screen.getByText('common.operation.save'))
// Assert
await waitFor(() => {
expect(onRefreshData).toHaveBeenCalled()
})
})
it('should invalidate all workflow tools after successful create', async () => {
// Arrange
const user = userEvent.setup()
mockCreateWorkflowToolProvider.mockResolvedValue({})
const props = createDefaultConfigureButtonProps()
// Act
render(<WorkflowToolConfigureButton {...props} />)
const triggerArea = screen.getByText('workflow.common.workflowAsTool').closest('.flex')
await user.click(triggerArea!)
await waitFor(() => {
expect(screen.getByTestId('drawer')).toBeInTheDocument()
})
const nameInput = screen.getByPlaceholderText('tools.createTool.nameForToolCallPlaceHolder')
await user.type(nameInput, 'my_tool')
await user.click(screen.getByText('common.operation.save'))
// Assert
await waitFor(() => {
expect(mockInvalidateAllWorkflowTools).toHaveBeenCalled()
})
})
})
// Edge Cases (REQUIRED)
describe('Edge Cases', () => {
it('should handle API returning undefined', async () => {
// Arrange - API returns undefined (simulating empty response or handled error)
mockFetchWorkflowToolDetailByAppID.mockResolvedValue(undefined)
const props = createDefaultConfigureButtonProps({ published: true })
// Act
render(<WorkflowToolConfigureButton {...props} />)
// Assert - should not crash and wait for API call
await waitFor(() => {
expect(mockFetchWorkflowToolDetailByAppID).toHaveBeenCalled()
})
// Component should still render without crashing
await waitFor(() => {
expect(screen.getByText('workflow.common.workflowAsTool')).toBeInTheDocument()
})
})
it('should handle rapid publish/unpublish state changes', async () => {
// Arrange
const props = createDefaultConfigureButtonProps({ published: false })
@@ -798,35 +492,7 @@ describe('WorkflowToolConfigureButton', () => {
})
// Assert - should not crash
expect(mockFetchWorkflowToolDetailByAppID).toHaveBeenCalled()
})
it('should handle detail with empty parameters', async () => {
// Arrange
const detail = createMockWorkflowToolDetail()
detail.tool.parameters = []
mockFetchWorkflowToolDetailByAppID.mockResolvedValue(detail)
const props = createDefaultConfigureButtonProps({ published: true, inputs: [] })
// Act
render(<WorkflowToolConfigureButton {...props} />)
// Assert
await waitFor(() => {
expect(screen.getByText('workflow.common.configure')).toBeInTheDocument()
})
})
it('should handle detail with undefined output_schema', async () => {
// Arrange
const detail = createMockWorkflowToolDetail()
// @ts-expect-error - testing undefined case
detail.tool.output_schema = undefined
mockFetchWorkflowToolDetailByAppID.mockResolvedValue(detail)
const props = createDefaultConfigureButtonProps({ published: true })
// Act & Assert
expect(() => render(<WorkflowToolConfigureButton {...props} />)).not.toThrow()
expect(screen.getByText('workflow.common.workflowAsTool')).toBeInTheDocument()
})
it('should handle paragraph type input conversion', async () => {
@@ -1853,7 +1519,10 @@ describe('Integration Tests', () => {
vi.clearAllMocks()
mockPortalOpenState = false
mockIsCurrentWorkspaceManager.mockReturnValue(true)
mockFetchWorkflowToolDetailByAppID.mockResolvedValue(createMockWorkflowToolDetail())
mockUseWorkflowToolDetailByAppID.mockImplementation((_appId: string, enabled: boolean) => ({
data: enabled ? createMockWorkflowToolDetail() : undefined,
isLoading: false,
}))
})
// Complete workflow: open modal -> fill form -> save

View File

@@ -1,22 +1,16 @@
'use client'
import type { Emoji, WorkflowToolProviderOutputParameter, WorkflowToolProviderParameter, WorkflowToolProviderRequest, WorkflowToolProviderResponse } from '@/app/components/tools/types'
import type { Emoji } from '@/app/components/tools/types'
import type { InputVar, Variable } from '@/app/components/workflow/types'
import type { PublishWorkflowParams } from '@/types/workflow'
import { RiArrowRightUpLine, RiHammerLine } from '@remixicon/react'
import { useRouter } from 'next/navigation'
import * as React from 'react'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Button from '@/app/components/base/button'
import Loading from '@/app/components/base/loading'
import Toast from '@/app/components/base/toast'
import Indicator from '@/app/components/header/indicator'
import WorkflowToolModal from '@/app/components/tools/workflow-tool'
import { useAppContext } from '@/context/app-context'
import { createWorkflowToolProvider, fetchWorkflowToolDetailByAppID, saveWorkflowToolProvider } from '@/service/tools'
import { useInvalidateAllWorkflowTools } from '@/service/use-tools'
import { cn } from '@/utils/classnames'
import Divider from '../../base/divider'
import { useConfigureButton } from './hooks/use-configure-button'
type Props = {
disabled: boolean
@@ -48,153 +42,29 @@ const WorkflowToolConfigureButton = ({
disabledReason,
}: Props) => {
const { t } = useTranslation()
const router = useRouter()
const [showModal, setShowModal] = useState(false)
const [isLoading, setIsLoading] = useState(false)
const [detail, setDetail] = useState<WorkflowToolProviderResponse>()
const { isCurrentWorkspaceManager } = useAppContext()
const invalidateAllWorkflowTools = useInvalidateAllWorkflowTools()
const outdated = useMemo(() => {
if (!detail)
return false
if (detail.tool.parameters.length !== inputs?.length) {
return true
}
else {
for (const item of inputs || []) {
const param = detail.tool.parameters.find(toolParam => toolParam.name === item.variable)
if (!param) {
return true
}
else if (param.required !== item.required) {
return true
}
else {
if (item.type === 'paragraph' && param.type !== 'string')
return true
if (item.type === 'text-input' && param.type !== 'string')
return true
}
}
}
return false
}, [detail, inputs])
const payload = useMemo(() => {
let parameters: WorkflowToolProviderParameter[] = []
let outputParameters: WorkflowToolProviderOutputParameter[] = []
if (!published) {
parameters = (inputs || []).map((item) => {
return {
name: item.variable,
description: '',
form: 'llm',
required: item.required,
type: item.type,
}
})
outputParameters = (outputs || []).map((item) => {
return {
name: item.variable,
description: '',
type: item.value_type,
}
})
}
else if (detail && detail.tool) {
parameters = (inputs || []).map((item) => {
return {
name: item.variable,
required: item.required,
type: item.type === 'paragraph' ? 'string' : item.type,
description: detail.tool.parameters.find(param => param.name === item.variable)?.llm_description || '',
form: detail.tool.parameters.find(param => param.name === item.variable)?.form || 'llm',
}
})
outputParameters = (outputs || []).map((item) => {
const found = detail.tool.output_schema?.properties?.[item.variable]
return {
name: item.variable,
description: found ? found.description : '',
type: item.value_type,
}
})
}
return {
icon: detail?.icon || icon,
label: detail?.label || name,
name: detail?.name || '',
description: detail?.description || description,
parameters,
outputParameters,
labels: detail?.tool?.labels || [],
privacy_policy: detail?.privacy_policy || '',
...(published
? {
workflow_tool_id: detail?.workflow_tool_id,
}
: {
workflow_app_id: workflowAppId,
}),
}
}, [detail, published, workflowAppId, icon, name, description, inputs])
const getDetail = useCallback(async (workflowAppId: string) => {
setIsLoading(true)
const res = await fetchWorkflowToolDetailByAppID(workflowAppId)
setDetail(res)
setIsLoading(false)
}, [])
useEffect(() => {
if (published)
getDetail(workflowAppId)
}, [getDetail, published, workflowAppId])
useEffect(() => {
if (detailNeedUpdate)
getDetail(workflowAppId)
}, [detailNeedUpdate, getDetail, workflowAppId])
const createHandle = async (data: WorkflowToolProviderRequest & { workflow_app_id: string }) => {
try {
await createWorkflowToolProvider(data)
invalidateAllWorkflowTools()
onRefreshData?.()
getDetail(workflowAppId)
Toast.notify({
type: 'success',
message: t('api.actionSuccess', { ns: 'common' }),
})
setShowModal(false)
}
catch (e) {
Toast.notify({ type: 'error', message: (e as Error).message })
}
}
const updateWorkflowToolProvider = async (data: WorkflowToolProviderRequest & Partial<{
workflow_app_id: string
workflow_tool_id: string
}>) => {
try {
await handlePublish()
await saveWorkflowToolProvider(data)
onRefreshData?.()
invalidateAllWorkflowTools()
getDetail(workflowAppId)
Toast.notify({
type: 'success',
message: t('api.actionSuccess', { ns: 'common' }),
})
setShowModal(false)
}
catch (e) {
Toast.notify({ type: 'error', message: (e as Error).message })
}
}
const {
showModal,
isLoading,
outdated,
payload,
isCurrentWorkspaceManager,
openModal,
closeModal,
handleCreate,
handleUpdate,
navigateToTools,
} = useConfigureButton({
published,
detailNeedUpdate,
workflowAppId,
icon,
name,
description,
inputs,
outputs,
handlePublish,
onRefreshData,
})
return (
<>
@@ -210,17 +80,17 @@ const WorkflowToolConfigureButton = ({
? (
<div
className="flex items-center justify-start gap-2 p-2 pl-2.5"
onClick={() => !disabled && !published && setShowModal(true)}
onClick={() => !disabled && !published && openModal()}
>
<RiHammerLine className={cn('relative h-4 w-4 text-text-secondary', !disabled && !published && 'group-hover:text-text-accent')} />
<div
title={t('common.workflowAsTool', { ns: 'workflow' }) || ''}
className={cn('system-sm-medium shrink grow basis-0 truncate text-text-secondary', !disabled && !published && 'group-hover:text-text-accent')}
className={cn('shrink grow basis-0 truncate text-text-secondary system-sm-medium', !disabled && !published && 'group-hover:text-text-accent')}
>
{t('common.workflowAsTool', { ns: 'workflow' })}
</div>
{!published && (
<span className="system-2xs-medium-uppercase shrink-0 rounded-[5px] border border-divider-deep bg-components-badge-bg-dimm px-1 py-0.5 text-text-tertiary">
<span className="shrink-0 rounded-[5px] border border-divider-deep bg-components-badge-bg-dimm px-1 py-0.5 text-text-tertiary system-2xs-medium-uppercase">
{t('common.configureRequired', { ns: 'workflow' })}
</span>
)}
@@ -233,7 +103,7 @@ const WorkflowToolConfigureButton = ({
<RiHammerLine className="h-4 w-4 text-text-tertiary" />
<div
title={t('common.workflowAsTool', { ns: 'workflow' }) || ''}
className="system-sm-medium shrink grow basis-0 truncate text-text-tertiary"
className="shrink grow basis-0 truncate text-text-tertiary system-sm-medium"
>
{t('common.workflowAsTool', { ns: 'workflow' })}
</div>
@@ -250,7 +120,7 @@ const WorkflowToolConfigureButton = ({
<Button
size="small"
className="w-[140px]"
onClick={() => setShowModal(true)}
onClick={openModal}
disabled={!isCurrentWorkspaceManager || disabled}
>
{t('common.configure', { ns: 'workflow' })}
@@ -259,7 +129,7 @@ const WorkflowToolConfigureButton = ({
<Button
size="small"
className="w-[140px]"
onClick={() => router.push('/tools?category=workflow')}
onClick={navigateToTools}
disabled={disabled}
>
{t('common.manageInTools', { ns: 'workflow' })}
@@ -280,9 +150,9 @@ const WorkflowToolConfigureButton = ({
<WorkflowToolModal
isAdd={!published}
payload={payload}
onHide={() => setShowModal(false)}
onCreate={createHandle}
onSave={updateWorkflowToolProvider}
onHide={closeModal}
onCreate={handleCreate}
onSave={handleUpdate}
/>
)}
</>

View File

@@ -0,0 +1,541 @@
import type { WorkflowToolProviderRequest, WorkflowToolProviderResponse } from '@/app/components/tools/types'
import type { InputVar, Variable } from '@/app/components/workflow/types'
import { act, renderHook } from '@testing-library/react'
import { InputVarType } from '@/app/components/workflow/types'
import { isParametersOutdated, useConfigureButton } from '../use-configure-button'
const mockPush = vi.fn()
vi.mock('next/navigation', () => ({
useRouter: () => ({ push: mockPush }),
}))
const mockIsCurrentWorkspaceManager = vi.fn(() => true)
vi.mock('@/context/app-context', () => ({
useAppContext: () => ({
isCurrentWorkspaceManager: mockIsCurrentWorkspaceManager(),
}),
}))
const mockCreateWorkflowToolProvider = vi.fn()
const mockSaveWorkflowToolProvider = vi.fn()
vi.mock('@/service/tools', () => ({
createWorkflowToolProvider: (...args: unknown[]) => mockCreateWorkflowToolProvider(...args),
saveWorkflowToolProvider: (...args: unknown[]) => mockSaveWorkflowToolProvider(...args),
}))
const mockInvalidateAllWorkflowTools = vi.fn()
const mockInvalidateWorkflowToolDetailByAppID = vi.fn()
const mockUseWorkflowToolDetailByAppID = vi.fn()
vi.mock('@/service/use-tools', () => ({
useInvalidateAllWorkflowTools: () => mockInvalidateAllWorkflowTools,
useInvalidateWorkflowToolDetailByAppID: () => mockInvalidateWorkflowToolDetailByAppID,
useWorkflowToolDetailByAppID: (...args: unknown[]) => mockUseWorkflowToolDetailByAppID(...args),
}))
const mockToastNotify = vi.fn()
vi.mock('@/app/components/base/toast', () => ({
default: {
notify: (options: { type: string, message: string }) => mockToastNotify(options),
},
}))
const createMockEmoji = () => ({ content: '🔧', background: '#ffffff' })
const createMockInputVar = (overrides: Partial<InputVar> = {}): InputVar => ({
variable: 'test_var',
label: 'Test Variable',
type: InputVarType.textInput,
required: true,
max_length: 100,
options: [],
...overrides,
} as InputVar)
const createMockVariable = (overrides: Partial<Variable> = {}): Variable => ({
variable: 'output_var',
value_type: 'string',
...overrides,
} as Variable)
const createMockDetail = (overrides: Partial<WorkflowToolProviderResponse> = {}): WorkflowToolProviderResponse => ({
workflow_app_id: 'app-123',
workflow_tool_id: 'tool-456',
label: 'Test Tool',
name: 'test_tool',
icon: createMockEmoji(),
description: 'A test workflow tool',
synced: true,
tool: {
author: 'test-author',
name: 'test_tool',
label: { en_US: 'Test Tool', zh_Hans: '测试工具' },
description: { en_US: 'Test description', zh_Hans: '测试描述' },
labels: ['label1'],
parameters: [
{
name: 'test_var',
label: { en_US: 'Test Variable', zh_Hans: '测试变量' },
human_description: { en_US: 'A test variable', zh_Hans: '测试变量' },
type: 'string',
form: 'llm',
llm_description: 'Test variable description',
required: true,
default: '',
},
],
output_schema: {
type: 'object',
properties: {
output_var: { type: 'string', description: 'Output description' },
},
},
},
privacy_policy: 'https://example.com/privacy',
...overrides,
})
const createDefaultOptions = (overrides = {}) => ({
published: false,
detailNeedUpdate: false,
workflowAppId: 'app-123',
icon: createMockEmoji(),
name: 'Test Workflow',
description: 'Test workflow description',
inputs: [createMockInputVar()],
outputs: [createMockVariable()],
handlePublish: vi.fn().mockResolvedValue(undefined),
onRefreshData: vi.fn(),
...overrides,
})
const createMockRequest = (extra: Record<string, string> = {}): WorkflowToolProviderRequest & Record<string, unknown> => ({
name: 'test_tool',
description: 'desc',
icon: createMockEmoji(),
label: 'Test Tool',
parameters: [{ name: 'test_var', description: '', form: 'llm' }],
labels: [],
privacy_policy: '',
...extra,
})
describe('isParametersOutdated', () => {
it('should return false when detail is undefined', () => {
expect(isParametersOutdated(undefined, [createMockInputVar()])).toBe(false)
})
it('should return true when parameter count differs', () => {
const detail = createMockDetail()
const inputs = [
createMockInputVar({ variable: 'test_var' }),
createMockInputVar({ variable: 'extra_var' }),
]
expect(isParametersOutdated(detail, inputs)).toBe(true)
})
it('should return true when parameter is not found in detail', () => {
const detail = createMockDetail()
const inputs = [createMockInputVar({ variable: 'unknown_var' })]
expect(isParametersOutdated(detail, inputs)).toBe(true)
})
it('should return true when required property differs', () => {
const detail = createMockDetail()
const inputs = [createMockInputVar({ variable: 'test_var', required: false })]
expect(isParametersOutdated(detail, inputs)).toBe(true)
})
it('should return true when paragraph type does not match string', () => {
const detail = createMockDetail()
detail.tool.parameters[0].type = 'number'
const inputs = [createMockInputVar({ variable: 'test_var', type: InputVarType.paragraph })]
expect(isParametersOutdated(detail, inputs)).toBe(true)
})
it('should return true when text-input type does not match string', () => {
const detail = createMockDetail()
detail.tool.parameters[0].type = 'number'
const inputs = [createMockInputVar({ variable: 'test_var', type: InputVarType.textInput })]
expect(isParametersOutdated(detail, inputs)).toBe(true)
})
it('should return false when paragraph type matches string', () => {
const detail = createMockDetail()
const inputs = [createMockInputVar({ variable: 'test_var', type: InputVarType.paragraph })]
expect(isParametersOutdated(detail, inputs)).toBe(false)
})
it('should return false when text-input type matches string', () => {
const detail = createMockDetail()
const inputs = [createMockInputVar({ variable: 'test_var', type: InputVarType.textInput })]
expect(isParametersOutdated(detail, inputs)).toBe(false)
})
it('should return false when all parameters match', () => {
const detail = createMockDetail()
const inputs = [createMockInputVar({ variable: 'test_var', required: true })]
expect(isParametersOutdated(detail, inputs)).toBe(false)
})
it('should handle undefined inputs with empty detail parameters', () => {
const detail = createMockDetail()
detail.tool.parameters = []
expect(isParametersOutdated(detail, undefined)).toBe(false)
})
it('should return true when inputs undefined but detail has parameters', () => {
const detail = createMockDetail()
expect(isParametersOutdated(detail, undefined)).toBe(true)
})
it('should handle empty inputs and empty detail parameters', () => {
const detail = createMockDetail()
detail.tool.parameters = []
expect(isParametersOutdated(detail, [])).toBe(false)
})
})
describe('useConfigureButton', () => {
beforeEach(() => {
vi.clearAllMocks()
mockIsCurrentWorkspaceManager.mockReturnValue(true)
mockUseWorkflowToolDetailByAppID.mockImplementation((_appId: string, enabled: boolean) => ({
data: enabled ? createMockDetail() : undefined,
isLoading: false,
}))
})
afterEach(() => {
vi.restoreAllMocks()
})
describe('Initialization', () => {
it('should return showModal as false by default', () => {
const { result } = renderHook(() => useConfigureButton(createDefaultOptions()))
expect(result.current.showModal).toBe(false)
})
it('should forward isCurrentWorkspaceManager from context', () => {
mockIsCurrentWorkspaceManager.mockReturnValue(false)
const { result } = renderHook(() => useConfigureButton(createDefaultOptions()))
expect(result.current.isCurrentWorkspaceManager).toBe(false)
})
it('should forward isLoading from query hook', () => {
mockUseWorkflowToolDetailByAppID.mockReturnValue({ data: undefined, isLoading: true })
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({ published: true })))
expect(result.current.isLoading).toBe(true)
})
it('should call query hook with enabled=true when published', () => {
renderHook(() => useConfigureButton(createDefaultOptions({ published: true })))
expect(mockUseWorkflowToolDetailByAppID).toHaveBeenCalledWith('app-123', true)
})
it('should call query hook with enabled=false when not published', () => {
renderHook(() => useConfigureButton(createDefaultOptions({ published: false })))
expect(mockUseWorkflowToolDetailByAppID).toHaveBeenCalledWith('app-123', false)
})
})
// Computed values
describe('Computed - outdated', () => {
it('should be false when not published (no detail)', () => {
const { result } = renderHook(() => useConfigureButton(createDefaultOptions()))
expect(result.current.outdated).toBe(false)
})
it('should be true when parameters differ', () => {
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({
published: true,
inputs: [
createMockInputVar({ variable: 'test_var' }),
createMockInputVar({ variable: 'extra_var' }),
],
})))
expect(result.current.outdated).toBe(true)
})
it('should be false when parameters match', () => {
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({
published: true,
inputs: [createMockInputVar({ variable: 'test_var', required: true })],
})))
expect(result.current.outdated).toBe(false)
})
})
describe('Computed - payload', () => {
it('should use prop values when not published', () => {
const { result } = renderHook(() => useConfigureButton(createDefaultOptions()))
expect(result.current.payload).toMatchObject({
icon: createMockEmoji(),
label: 'Test Workflow',
name: '',
description: 'Test workflow description',
workflow_app_id: 'app-123',
})
expect(result.current.payload.parameters).toHaveLength(1)
expect(result.current.payload.parameters[0]).toMatchObject({
name: 'test_var',
form: 'llm',
description: '',
})
})
it('should use detail values when published with detail', () => {
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({ published: true })))
expect(result.current.payload).toMatchObject({
icon: createMockEmoji(),
label: 'Test Tool',
name: 'test_tool',
description: 'A test workflow tool',
workflow_tool_id: 'tool-456',
privacy_policy: 'https://example.com/privacy',
labels: ['label1'],
})
expect(result.current.payload.parameters[0]).toMatchObject({
name: 'test_var',
description: 'Test variable description',
form: 'llm',
})
})
it('should return empty parameters when published without detail', () => {
mockUseWorkflowToolDetailByAppID.mockReturnValue({ data: undefined, isLoading: false })
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({ published: true })))
expect(result.current.payload.parameters).toHaveLength(0)
expect(result.current.payload.outputParameters).toHaveLength(0)
})
it('should build output parameters from detail output_schema', () => {
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({ published: true })))
expect(result.current.payload.outputParameters).toHaveLength(1)
expect(result.current.payload.outputParameters[0]).toMatchObject({
name: 'output_var',
description: 'Output description',
})
})
it('should handle undefined output_schema in detail', () => {
const detail = createMockDetail()
// @ts-expect-error - testing undefined case
detail.tool.output_schema = undefined
mockUseWorkflowToolDetailByAppID.mockReturnValue({ data: detail, isLoading: false })
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({ published: true })))
expect(result.current.payload.outputParameters[0]).toMatchObject({
name: 'output_var',
description: '',
})
})
it('should convert paragraph type to string in existing parameters', () => {
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({
published: true,
inputs: [createMockInputVar({ variable: 'test_var', type: InputVarType.paragraph })],
})))
expect(result.current.payload.parameters[0].type).toBe('string')
})
})
// Modal controls
describe('Modal Controls', () => {
it('should open modal via openModal', () => {
const { result } = renderHook(() => useConfigureButton(createDefaultOptions()))
act(() => {
result.current.openModal()
})
expect(result.current.showModal).toBe(true)
})
it('should close modal via closeModal', () => {
const { result } = renderHook(() => useConfigureButton(createDefaultOptions()))
act(() => {
result.current.openModal()
})
act(() => {
result.current.closeModal()
})
expect(result.current.showModal).toBe(false)
})
it('should navigate to tools page', () => {
const { result } = renderHook(() => useConfigureButton(createDefaultOptions()))
act(() => {
result.current.navigateToTools()
})
expect(mockPush).toHaveBeenCalledWith('/tools?category=workflow')
})
})
// Mutation handlers
describe('handleCreate', () => {
it('should create provider, invalidate caches, refresh, and close modal', async () => {
mockCreateWorkflowToolProvider.mockResolvedValue({})
const onRefreshData = vi.fn()
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({ onRefreshData })))
act(() => {
result.current.openModal()
})
await act(async () => {
await result.current.handleCreate(createMockRequest({ workflow_app_id: 'app-123' }) as WorkflowToolProviderRequest & { workflow_app_id: string })
})
expect(mockCreateWorkflowToolProvider).toHaveBeenCalled()
expect(mockInvalidateAllWorkflowTools).toHaveBeenCalled()
expect(onRefreshData).toHaveBeenCalled()
expect(mockInvalidateWorkflowToolDetailByAppID).toHaveBeenCalledWith('app-123')
expect(mockToastNotify).toHaveBeenCalledWith({ type: 'success', message: expect.any(String) })
expect(result.current.showModal).toBe(false)
})
it('should show error toast on failure', async () => {
mockCreateWorkflowToolProvider.mockRejectedValue(new Error('Create failed'))
const { result } = renderHook(() => useConfigureButton(createDefaultOptions()))
await act(async () => {
await result.current.handleCreate(createMockRequest({ workflow_app_id: 'app-123' }) as WorkflowToolProviderRequest & { workflow_app_id: string })
})
expect(mockToastNotify).toHaveBeenCalledWith({ type: 'error', message: 'Create failed' })
})
})
describe('handleUpdate', () => {
it('should publish, save, invalidate caches, and close modal', async () => {
mockSaveWorkflowToolProvider.mockResolvedValue({})
const handlePublish = vi.fn().mockResolvedValue(undefined)
const onRefreshData = vi.fn()
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({
published: true,
handlePublish,
onRefreshData,
})))
act(() => {
result.current.openModal()
})
await act(async () => {
await result.current.handleUpdate(createMockRequest({ workflow_tool_id: 'tool-456' }) as WorkflowToolProviderRequest & Partial<{ workflow_app_id: string, workflow_tool_id: string }>)
})
expect(handlePublish).toHaveBeenCalled()
expect(mockSaveWorkflowToolProvider).toHaveBeenCalled()
expect(onRefreshData).toHaveBeenCalled()
expect(mockInvalidateAllWorkflowTools).toHaveBeenCalled()
expect(mockInvalidateWorkflowToolDetailByAppID).toHaveBeenCalledWith('app-123')
expect(mockToastNotify).toHaveBeenCalledWith({ type: 'success', message: expect.any(String) })
expect(result.current.showModal).toBe(false)
})
it('should show error toast when publish fails', async () => {
const handlePublish = vi.fn().mockRejectedValue(new Error('Publish failed'))
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({
published: true,
handlePublish,
})))
await act(async () => {
await result.current.handleUpdate(createMockRequest() as WorkflowToolProviderRequest & Partial<{ workflow_app_id: string, workflow_tool_id: string }>)
})
expect(mockToastNotify).toHaveBeenCalledWith({ type: 'error', message: 'Publish failed' })
})
it('should show error toast when save fails', async () => {
mockSaveWorkflowToolProvider.mockRejectedValue(new Error('Save failed'))
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({ published: true })))
await act(async () => {
await result.current.handleUpdate(createMockRequest() as WorkflowToolProviderRequest & Partial<{ workflow_app_id: string, workflow_tool_id: string }>)
})
expect(mockToastNotify).toHaveBeenCalledWith({ type: 'error', message: 'Save failed' })
})
})
// Effects
describe('Effects', () => {
it('should invalidate detail when detailNeedUpdate becomes true', () => {
const options = createDefaultOptions({ published: true, detailNeedUpdate: false })
const { rerender } = renderHook(
(props: ReturnType<typeof createDefaultOptions>) => useConfigureButton(props),
{ initialProps: options },
)
rerender({ ...options, detailNeedUpdate: true })
expect(mockInvalidateWorkflowToolDetailByAppID).toHaveBeenCalledWith('app-123')
})
it('should not invalidate when detailNeedUpdate stays false', () => {
const options = createDefaultOptions({ published: true, detailNeedUpdate: false })
const { rerender } = renderHook(
(props: ReturnType<typeof createDefaultOptions>) => useConfigureButton(props),
{ initialProps: options },
)
rerender({ ...options })
expect(mockInvalidateWorkflowToolDetailByAppID).not.toHaveBeenCalled()
})
})
// Edge cases
describe('Edge Cases', () => {
it('should handle undefined detail from query gracefully', () => {
mockUseWorkflowToolDetailByAppID.mockReturnValue({ data: undefined, isLoading: false })
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({ published: true })))
expect(result.current.outdated).toBe(false)
expect(result.current.payload.parameters).toHaveLength(0)
})
it('should handle detail with empty parameters', () => {
const detail = createMockDetail()
detail.tool.parameters = []
mockUseWorkflowToolDetailByAppID.mockReturnValue({ data: detail, isLoading: false })
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({
published: true,
inputs: [],
})))
expect(result.current.outdated).toBe(false)
})
it('should handle undefined inputs and outputs', () => {
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({
inputs: undefined,
outputs: undefined,
})))
expect(result.current.payload.parameters).toHaveLength(0)
expect(result.current.payload.outputParameters).toHaveLength(0)
})
it('should handle missing onRefreshData callback in create', async () => {
mockCreateWorkflowToolProvider.mockResolvedValue({})
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({
onRefreshData: undefined,
})))
// Should not throw
await act(async () => {
await result.current.handleCreate(createMockRequest({ workflow_app_id: 'app-123' }) as WorkflowToolProviderRequest & { workflow_app_id: string })
})
expect(mockCreateWorkflowToolProvider).toHaveBeenCalled()
})
})
})

View File

@@ -0,0 +1,235 @@
import type { Emoji, WorkflowToolProviderOutputParameter, WorkflowToolProviderParameter, WorkflowToolProviderRequest, WorkflowToolProviderResponse } from '@/app/components/tools/types'
import type { InputVar, Variable } from '@/app/components/workflow/types'
import type { PublishWorkflowParams } from '@/types/workflow'
import { useRouter } from 'next/navigation'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Toast from '@/app/components/base/toast'
import { useAppContext } from '@/context/app-context'
import { createWorkflowToolProvider, saveWorkflowToolProvider } from '@/service/tools'
import { useInvalidateAllWorkflowTools, useInvalidateWorkflowToolDetailByAppID, useWorkflowToolDetailByAppID } from '@/service/use-tools'
// region Pure helpers
/**
* Check if workflow tool parameters are outdated compared to current inputs.
* Uses flat early-return style to reduce cyclomatic complexity.
*/
export function isParametersOutdated(
detail: WorkflowToolProviderResponse | undefined,
inputs: InputVar[] | undefined,
): boolean {
if (!detail)
return false
if (detail.tool.parameters.length !== (inputs?.length ?? 0))
return true
for (const item of inputs || []) {
const param = detail.tool.parameters.find(p => p.name === item.variable)
if (!param)
return true
if (param.required !== item.required)
return true
const needsStringType = item.type === 'paragraph' || item.type === 'text-input'
if (needsStringType && param.type !== 'string')
return true
}
return false
}
function buildNewParameters(inputs?: InputVar[]): WorkflowToolProviderParameter[] {
return (inputs || []).map(item => ({
name: item.variable,
description: '',
form: 'llm',
required: item.required,
type: item.type,
}))
}
function buildExistingParameters(
inputs: InputVar[] | undefined,
detail: WorkflowToolProviderResponse,
): WorkflowToolProviderParameter[] {
return (inputs || []).map((item) => {
const matched = detail.tool.parameters.find(p => p.name === item.variable)
return {
name: item.variable,
required: item.required,
type: item.type === 'paragraph' ? 'string' : item.type,
description: matched?.llm_description || '',
form: matched?.form || 'llm',
}
})
}
function buildNewOutputParameters(outputs?: Variable[]): WorkflowToolProviderOutputParameter[] {
return (outputs || []).map(item => ({
name: item.variable,
description: '',
type: item.value_type,
}))
}
function buildExistingOutputParameters(
outputs: Variable[] | undefined,
detail: WorkflowToolProviderResponse,
): WorkflowToolProviderOutputParameter[] {
return (outputs || []).map((item) => {
const found = detail.tool.output_schema?.properties?.[item.variable]
return {
name: item.variable,
description: found ? found.description : '',
type: item.value_type,
}
})
}
// endregion
type UseConfigureButtonOptions = {
published: boolean
detailNeedUpdate: boolean
workflowAppId: string
icon: Emoji
name: string
description: string
inputs?: InputVar[]
outputs?: Variable[]
handlePublish: (params?: PublishWorkflowParams) => Promise<void>
onRefreshData?: () => void
}
export function useConfigureButton(options: UseConfigureButtonOptions) {
const {
published,
detailNeedUpdate,
workflowAppId,
icon,
name,
description,
inputs,
outputs,
handlePublish,
onRefreshData,
} = options
const { t } = useTranslation()
const router = useRouter()
const { isCurrentWorkspaceManager } = useAppContext()
const [showModal, setShowModal] = useState(false)
// Data fetching via React Query
const { data: detail, isLoading } = useWorkflowToolDetailByAppID(workflowAppId, published)
// Invalidation functions (store in ref for stable effect dependency)
const invalidateDetail = useInvalidateWorkflowToolDetailByAppID()
const invalidateAllWorkflowTools = useInvalidateAllWorkflowTools()
const invalidateDetailRef = useRef(invalidateDetail)
invalidateDetailRef.current = invalidateDetail
// Refetch when detailNeedUpdate becomes true
useEffect(() => {
if (detailNeedUpdate)
invalidateDetailRef.current(workflowAppId)
}, [detailNeedUpdate, workflowAppId])
// Computed values
const outdated = useMemo(
() => isParametersOutdated(detail, inputs),
[detail, inputs],
)
const payload = useMemo(() => {
const hasPublishedDetail = published && detail?.tool
const parameters = !published
? buildNewParameters(inputs)
: hasPublishedDetail
? buildExistingParameters(inputs, detail)
: []
const outputParameters = !published
? buildNewOutputParameters(outputs)
: hasPublishedDetail
? buildExistingOutputParameters(outputs, detail)
: []
return {
icon: detail?.icon || icon,
label: detail?.label || name,
name: detail?.name || '',
description: detail?.description || description,
parameters,
outputParameters,
labels: detail?.tool?.labels || [],
privacy_policy: detail?.privacy_policy || '',
...(published
? { workflow_tool_id: detail?.workflow_tool_id }
: { workflow_app_id: workflowAppId }),
}
}, [detail, published, workflowAppId, icon, name, description, inputs, outputs])
// Modal controls (stable callbacks)
const openModal = useCallback(() => setShowModal(true), [])
const closeModal = useCallback(() => setShowModal(false), [])
const navigateToTools = useCallback(
() => router.push('/tools?category=workflow'),
[router],
)
// Mutation handlers (not memoized — only used in conditionally-rendered modal)
const handleCreate = async (data: WorkflowToolProviderRequest & { workflow_app_id: string }) => {
try {
await createWorkflowToolProvider(data)
invalidateAllWorkflowTools()
onRefreshData?.()
invalidateDetail(workflowAppId)
Toast.notify({
type: 'success',
message: t('api.actionSuccess', { ns: 'common' }),
})
setShowModal(false)
}
catch (e) {
Toast.notify({ type: 'error', message: (e as Error).message })
}
}
const handleUpdate = async (data: WorkflowToolProviderRequest & Partial<{
workflow_app_id: string
workflow_tool_id: string
}>) => {
try {
await handlePublish()
await saveWorkflowToolProvider(data)
onRefreshData?.()
invalidateAllWorkflowTools()
invalidateDetail(workflowAppId)
Toast.notify({
type: 'success',
message: t('api.actionSuccess', { ns: 'common' }),
})
setShowModal(false)
}
catch (e) {
Toast.notify({ type: 'error', message: (e as Error).message })
}
}
return {
showModal,
isLoading,
outdated,
payload,
isCurrentWorkspaceManager,
openModal,
closeModal,
handleCreate,
handleUpdate,
navigateToTools,
}
}

View File

@@ -784,11 +784,6 @@
"count": 1
}
},
"app/components/app/configuration/dataset-config/card-item/index.spec.tsx": {
"ts/no-explicit-any": {
"count": 1
}
},
"app/components/app/configuration/dataset-config/card-item/index.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 1
@@ -1231,11 +1226,6 @@
"count": 1
}
},
"app/components/apps/app-card.spec.tsx": {
"ts/no-explicit-any": {
"count": 22
}
},
"app/components/apps/app-card.tsx": {
"react-hooks-extra/no-direct-set-state-in-use-effect": {
"count": 1
@@ -1260,21 +1250,11 @@
"count": 1
}
},
"app/components/apps/list.spec.tsx": {
"ts/no-explicit-any": {
"count": 5
}
},
"app/components/apps/list.tsx": {
"unused-imports/no-unused-vars": {
"count": 1
}
},
"app/components/apps/new-app-card.spec.tsx": {
"ts/no-explicit-any": {
"count": 4
}
},
"app/components/apps/new-app-card.tsx": {
"ts/no-explicit-any": {
"count": 1
@@ -3042,11 +3022,6 @@
"count": 1
}
},
"app/components/custom/custom-web-app-brand/index.spec.tsx": {
"ts/no-explicit-any": {
"count": 7
}
},
"app/components/custom/custom-web-app-brand/index.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 12
@@ -4073,14 +4048,6 @@
"count": 9
}
},
"app/components/develop/doc.tsx": {
"react-hooks-extra/no-direct-set-state-in-use-effect": {
"count": 2
},
"ts/no-explicit-any": {
"count": 3
}
},
"app/components/develop/md.tsx": {
"ts/no-empty-object-type": {
"count": 1
@@ -4735,14 +4702,6 @@
"count": 1
}
},
"app/components/plugins/install-plugin/install-bundle/steps/install-multi.tsx": {
"react-hooks-extra/no-direct-set-state-in-use-effect": {
"count": 5
},
"ts/no-explicit-any": {
"count": 2
}
},
"app/components/plugins/install-plugin/install-bundle/steps/install.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 2
@@ -5766,11 +5725,6 @@
"count": 4
}
},
"app/components/tools/workflow-tool/configure-button.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 3
}
},
"app/components/tools/workflow-tool/index.tsx": {
"tailwindcss/enforce-consistent-class-order": {
"count": 7
@@ -5807,11 +5761,6 @@
"count": 2
}
},
"app/components/workflow-app/components/workflow-onboarding-modal/start-node-selection-panel.spec.tsx": {
"ts/no-explicit-any": {
"count": 1
}
},
"app/components/workflow-app/hooks/use-DSL.ts": {
"ts/no-explicit-any": {
"count": 1

View File

@@ -3,6 +3,7 @@ import type {
Collection,
MCPServerDetail,
Tool,
WorkflowToolProviderResponse,
} from '@/app/components/tools/types'
import type { RAGRecommendedPlugins, ToolWithProvider } from '@/app/components/workflow/types'
import type { AppIconType } from '@/types/app'
@@ -402,3 +403,22 @@ export const useUpdateTriggerStatus = () => {
},
})
}
const workflowToolDetailByAppIDKey = (appId: string) => [NAME_SPACE, 'workflowToolDetailByAppID', appId]
export const useWorkflowToolDetailByAppID = (appId: string, enabled = true) => {
return useQuery<WorkflowToolProviderResponse>({
queryKey: workflowToolDetailByAppIDKey(appId),
queryFn: () => get<WorkflowToolProviderResponse>(`/workspaces/current/tool-provider/workflow/get?workflow_app_id=${appId}`),
enabled: enabled && !!appId,
})
}
export const useInvalidateWorkflowToolDetailByAppID = () => {
const queryClient = useQueryClient()
return (appId: string) => {
queryClient.invalidateQueries({
queryKey: workflowToolDetailByAppIDKey(appId),
})
}
}