mirror of
https://github.com/langgenius/dify.git
synced 2026-02-14 12:43:59 +00:00
Compare commits
10 Commits
1.14.0-rc1
...
feat/evalu
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ce0610c4c | ||
|
|
b2b0be6b8a | ||
|
|
fb4584b776 | ||
|
|
632d93f475 | ||
|
|
36dc948520 | ||
|
|
bad6fb3470 | ||
|
|
210710e76d | ||
|
|
a49504bd5b | ||
|
|
3dfc797645 | ||
|
|
bea428e308 |
@@ -116,6 +116,12 @@ from .explore import (
|
||||
trial,
|
||||
)
|
||||
|
||||
# Import evaluation controllers
|
||||
from .evaluation import evaluation
|
||||
|
||||
# Import snippet controllers
|
||||
from .snippets import snippet_workflow
|
||||
|
||||
# Import tag controllers
|
||||
from .tag import tags
|
||||
|
||||
@@ -129,6 +135,7 @@ from .workspace import (
|
||||
model_providers,
|
||||
models,
|
||||
plugin,
|
||||
snippets,
|
||||
tool_providers,
|
||||
trigger_providers,
|
||||
workspace,
|
||||
@@ -166,6 +173,7 @@ __all__ = [
|
||||
"datasource_content_preview",
|
||||
"email_register",
|
||||
"endpoint",
|
||||
"evaluation",
|
||||
"extension",
|
||||
"external",
|
||||
"feature",
|
||||
@@ -199,6 +207,8 @@ __all__ = [
|
||||
"saved_message",
|
||||
"setup",
|
||||
"site",
|
||||
"snippet_workflow",
|
||||
"snippets",
|
||||
"spec",
|
||||
"statistic",
|
||||
"tags",
|
||||
|
||||
1
api/controllers/console/evaluation/__init__.py
Normal file
1
api/controllers/console/evaluation/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
# Evaluation controller module
|
||||
320
api/controllers/console/evaluation/evaluation.py
Normal file
320
api/controllers/console/evaluation/evaluation.py
Normal file
@@ -0,0 +1,320 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar, Union
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.file import helpers as file_helpers
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.model import UploadFile
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.evaluation_service import EvaluationService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
# Valid evaluation target types
|
||||
EVALUATE_TARGET_TYPES = {"app", "snippets"}
|
||||
|
||||
|
||||
class VersionQuery(BaseModel):
|
||||
"""Query parameters for version endpoint."""
|
||||
|
||||
version: str
|
||||
|
||||
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
VersionQuery,
|
||||
)
|
||||
|
||||
|
||||
# Response field definitions
|
||||
file_info_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
}
|
||||
|
||||
evaluation_log_fields = {
|
||||
"created_at": TimestampField,
|
||||
"created_by": fields.String,
|
||||
"test_file": fields.Nested(
|
||||
console_ns.model(
|
||||
"EvaluationTestFile",
|
||||
file_info_fields,
|
||||
)
|
||||
),
|
||||
"result_file": fields.Nested(
|
||||
console_ns.model(
|
||||
"EvaluationResultFile",
|
||||
file_info_fields,
|
||||
),
|
||||
allow_null=True,
|
||||
),
|
||||
"version": fields.String,
|
||||
}
|
||||
|
||||
evaluation_log_list_model = console_ns.model(
|
||||
"EvaluationLogList",
|
||||
{
|
||||
"data": fields.List(fields.Nested(console_ns.model("EvaluationLog", evaluation_log_fields))),
|
||||
},
|
||||
)
|
||||
|
||||
customized_matrix_fields = {
|
||||
"evaluation_workflow_id": fields.String,
|
||||
"input_fields": fields.Raw,
|
||||
"output_fields": fields.Raw,
|
||||
}
|
||||
|
||||
condition_fields = {
|
||||
"name": fields.List(fields.String),
|
||||
"comparison_operator": fields.String,
|
||||
"value": fields.String,
|
||||
}
|
||||
|
||||
judgement_conditions_fields = {
|
||||
"logical_operator": fields.String,
|
||||
"conditions": fields.List(fields.Nested(console_ns.model("EvaluationCondition", condition_fields))),
|
||||
}
|
||||
|
||||
evaluation_detail_fields = {
|
||||
"evaluation_model": fields.String,
|
||||
"evaluation_model_provider": fields.String,
|
||||
"customized_matrix": fields.Nested(
|
||||
console_ns.model("EvaluationCustomizedMatrix", customized_matrix_fields),
|
||||
allow_null=True,
|
||||
),
|
||||
"judgement_conditions": fields.Nested(
|
||||
console_ns.model("EvaluationJudgementConditions", judgement_conditions_fields),
|
||||
allow_null=True,
|
||||
),
|
||||
}
|
||||
|
||||
evaluation_detail_model = console_ns.model("EvaluationDetail", evaluation_detail_fields)
|
||||
|
||||
|
||||
def get_evaluation_target(view_func: Callable[P, R]):
|
||||
"""
|
||||
Decorator to resolve polymorphic evaluation target (app or snippet).
|
||||
|
||||
Validates the target_type parameter and fetches the corresponding
|
||||
model (App or CustomizedSnippet) with tenant isolation.
|
||||
"""
|
||||
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
target_type = kwargs.get("evaluate_target_type")
|
||||
target_id = kwargs.get("evaluate_target_id")
|
||||
|
||||
if target_type not in EVALUATE_TARGET_TYPES:
|
||||
raise NotFound(f"Invalid evaluation target type: {target_type}")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
target_id = str(target_id)
|
||||
|
||||
# Remove path parameters
|
||||
del kwargs["evaluate_target_type"]
|
||||
del kwargs["evaluate_target_id"]
|
||||
|
||||
target: Union[App, CustomizedSnippet] | None = None
|
||||
|
||||
if target_type == "app":
|
||||
target = (
|
||||
db.session.query(App).where(App.id == target_id, App.tenant_id == current_tenant_id).first()
|
||||
)
|
||||
elif target_type == "snippets":
|
||||
target = (
|
||||
db.session.query(CustomizedSnippet)
|
||||
.where(CustomizedSnippet.id == target_id, CustomizedSnippet.tenant_id == current_tenant_id)
|
||||
.first()
|
||||
)
|
||||
|
||||
if not target:
|
||||
raise NotFound(f"{str(target_type)} not found")
|
||||
|
||||
kwargs["target"] = target
|
||||
kwargs["target_type"] = target_type
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/dataset-template/download")
|
||||
class EvaluationDatasetTemplateDownloadApi(Resource):
|
||||
@console_ns.doc("download_evaluation_dataset_template")
|
||||
@console_ns.response(200, "Template file streamed as XLSX attachment")
|
||||
@console_ns.response(400, "Invalid target type or excluded app mode")
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
@edit_permission_required
|
||||
def post(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Download evaluation dataset template.
|
||||
|
||||
Generates an XLSX template based on the target's input parameters
|
||||
and streams it directly as a file attachment.
|
||||
"""
|
||||
try:
|
||||
xlsx_content, filename = EvaluationService.generate_dataset_template(
|
||||
target=target,
|
||||
target_type=target_type,
|
||||
)
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
encoded_filename = quote(filename)
|
||||
response = Response(
|
||||
xlsx_content,
|
||||
mimetype="application/vnd.openxmlformats-officedocument.spreadsheetml.sheet",
|
||||
)
|
||||
response.headers["Content-Disposition"] = f"attachment; filename*=UTF-8''{encoded_filename}"
|
||||
response.headers["Content-Length"] = str(len(xlsx_content))
|
||||
return response
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation")
|
||||
class EvaluationDetailApi(Resource):
|
||||
@console_ns.doc("get_evaluation_detail")
|
||||
@console_ns.response(200, "Evaluation details retrieved successfully", evaluation_detail_model)
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get evaluation details for the target.
|
||||
|
||||
Returns evaluation configuration including model settings,
|
||||
customized matrix, and judgement conditions.
|
||||
"""
|
||||
# TODO: Implement actual evaluation detail retrieval
|
||||
# This is a placeholder implementation
|
||||
return {
|
||||
"evaluation_model": None,
|
||||
"evaluation_model_provider": None,
|
||||
"customized_matrix": None,
|
||||
"judgement_conditions": None,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/logs")
|
||||
class EvaluationLogsApi(Resource):
|
||||
@console_ns.doc("get_evaluation_logs")
|
||||
@console_ns.response(200, "Evaluation logs retrieved successfully", evaluation_log_list_model)
|
||||
@console_ns.response(404, "Target not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get offline evaluation logs for the target.
|
||||
|
||||
Returns a list of evaluation runs with test files,
|
||||
result files, and version information.
|
||||
"""
|
||||
# TODO: Implement actual evaluation logs retrieval
|
||||
# This is a placeholder implementation
|
||||
return {
|
||||
"data": [],
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/files/<uuid:file_id>")
|
||||
class EvaluationFileDownloadApi(Resource):
|
||||
@console_ns.doc("download_evaluation_file")
|
||||
@console_ns.response(200, "File download URL generated successfully")
|
||||
@console_ns.response(404, "Target or file not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str, file_id: str):
|
||||
"""
|
||||
Download evaluation test file or result file.
|
||||
|
||||
Looks up the specified file, verifies it belongs to the same tenant,
|
||||
and returns file info and download URL.
|
||||
"""
|
||||
file_id = str(file_id)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(UploadFile).where(
|
||||
UploadFile.id == file_id,
|
||||
UploadFile.tenant_id == current_tenant_id,
|
||||
)
|
||||
upload_file = session.execute(stmt).scalar_one_or_none()
|
||||
|
||||
if not upload_file:
|
||||
raise NotFound("File not found")
|
||||
|
||||
download_url = file_helpers.get_signed_file_url(upload_file_id=upload_file.id, as_attachment=True)
|
||||
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_at": int(upload_file.created_at.timestamp()) if upload_file.created_at else None,
|
||||
"download_url": download_url,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/<string:evaluate_target_type>/<uuid:evaluate_target_id>/evaluation/version")
|
||||
class EvaluationVersionApi(Resource):
|
||||
@console_ns.doc("get_evaluation_version_detail")
|
||||
@console_ns.expect(console_ns.models.get(VersionQuery.__name__))
|
||||
@console_ns.response(200, "Version details retrieved successfully")
|
||||
@console_ns.response(404, "Target or version not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_evaluation_target
|
||||
def get(self, target: Union[App, CustomizedSnippet], target_type: str):
|
||||
"""
|
||||
Get evaluation target version details.
|
||||
|
||||
Returns the workflow graph for the specified version.
|
||||
"""
|
||||
version = request.args.get("version")
|
||||
|
||||
if not version:
|
||||
return {"message": "version parameter is required"}, 400
|
||||
|
||||
# TODO: Implement actual version detail retrieval
|
||||
# For now, return the current graph if available
|
||||
graph = {}
|
||||
if target_type == "snippets" and isinstance(target, CustomizedSnippet):
|
||||
graph = target.graph_dict
|
||||
|
||||
return {
|
||||
"graph": graph,
|
||||
}
|
||||
102
api/controllers/console/snippets/payloads.py
Normal file
102
api/controllers/console/snippets/payloads.py
Normal file
@@ -0,0 +1,102 @@
|
||||
from typing import Any, Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class SnippetListQuery(BaseModel):
|
||||
"""Query parameters for listing snippets."""
|
||||
|
||||
page: int = Field(default=1, ge=1, le=99999)
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
keyword: str | None = None
|
||||
|
||||
|
||||
class IconInfo(BaseModel):
|
||||
"""Icon information model."""
|
||||
|
||||
icon: str | None = None
|
||||
icon_type: Literal["emoji", "image"] | None = None
|
||||
icon_background: str | None = None
|
||||
icon_url: str | None = None
|
||||
|
||||
|
||||
class InputFieldDefinition(BaseModel):
|
||||
"""Input field definition for snippet parameters."""
|
||||
|
||||
default: str | None = None
|
||||
hint: bool | None = None
|
||||
label: str | None = None
|
||||
max_length: int | None = None
|
||||
options: list[str] | None = None
|
||||
placeholder: str | None = None
|
||||
required: bool | None = None
|
||||
type: str | None = None # e.g., "text-input"
|
||||
|
||||
|
||||
class CreateSnippetPayload(BaseModel):
|
||||
"""Payload for creating a new snippet."""
|
||||
|
||||
name: str = Field(..., min_length=1, max_length=255)
|
||||
description: str | None = Field(default=None, max_length=2000)
|
||||
type: Literal["node", "group"] = "node"
|
||||
icon_info: IconInfo | None = None
|
||||
graph: dict[str, Any] | None = None
|
||||
input_fields: list[InputFieldDefinition] | None = Field(default_factory=list)
|
||||
|
||||
|
||||
class UpdateSnippetPayload(BaseModel):
|
||||
"""Payload for updating a snippet."""
|
||||
|
||||
name: str | None = Field(default=None, min_length=1, max_length=255)
|
||||
description: str | None = Field(default=None, max_length=2000)
|
||||
icon_info: IconInfo | None = None
|
||||
|
||||
|
||||
class SnippetDraftSyncPayload(BaseModel):
|
||||
"""Payload for syncing snippet draft workflow."""
|
||||
|
||||
graph: dict[str, Any]
|
||||
hash: str | None = None
|
||||
environment_variables: list[dict[str, Any]] | None = None
|
||||
conversation_variables: list[dict[str, Any]] | None = None
|
||||
input_variables: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class WorkflowRunQuery(BaseModel):
|
||||
"""Query parameters for workflow runs."""
|
||||
|
||||
last_id: str | None = None
|
||||
limit: int = Field(default=20, ge=1, le=100)
|
||||
|
||||
|
||||
class SnippetDraftRunPayload(BaseModel):
|
||||
"""Payload for running snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any]
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetDraftNodeRunPayload(BaseModel):
|
||||
"""Payload for running a single node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any]
|
||||
query: str = ""
|
||||
files: list[dict[str, Any]] | None = None
|
||||
|
||||
|
||||
class SnippetIterationNodeRunPayload(BaseModel):
|
||||
"""Payload for running an iteration node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class SnippetLoopNodeRunPayload(BaseModel):
|
||||
"""Payload for running a loop node in snippet draft workflow."""
|
||||
|
||||
inputs: dict[str, Any] | None = None
|
||||
|
||||
|
||||
class PublishWorkflowPayload(BaseModel):
|
||||
"""Payload for publishing snippet workflow."""
|
||||
|
||||
knowledge_base_setting: dict[str, Any] | None = None
|
||||
540
api/controllers/console/snippets/snippet_workflow.py
Normal file
540
api/controllers/console/snippets/snippet_workflow.py
Normal file
@@ -0,0 +1,540 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
from typing import ParamSpec, TypeVar
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal_with
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import InternalServerError, NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.error import DraftWorkflowNotExist, DraftWorkflowNotSync
|
||||
from controllers.console.app.workflow import workflow_model
|
||||
from controllers.console.app.workflow_run import (
|
||||
workflow_run_detail_model,
|
||||
workflow_run_node_execution_list_model,
|
||||
workflow_run_node_execution_model,
|
||||
workflow_run_pagination_model,
|
||||
)
|
||||
from controllers.console.snippets.payloads import (
|
||||
PublishWorkflowPayload,
|
||||
SnippetDraftNodeRunPayload,
|
||||
SnippetDraftRunPayload,
|
||||
SnippetDraftSyncPayload,
|
||||
SnippetIterationNodeRunPayload,
|
||||
SnippetLoopNodeRunPayload,
|
||||
WorkflowRunQuery,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from factories import variable_factory
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
from services.snippet_generate_service import SnippetGenerateService
|
||||
from services.snippet_service import SnippetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
# Register Pydantic models with Swagger
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SnippetDraftSyncPayload,
|
||||
SnippetDraftNodeRunPayload,
|
||||
SnippetDraftRunPayload,
|
||||
SnippetIterationNodeRunPayload,
|
||||
SnippetLoopNodeRunPayload,
|
||||
WorkflowRunQuery,
|
||||
PublishWorkflowPayload,
|
||||
)
|
||||
|
||||
|
||||
class SnippetNotFoundError(Exception):
|
||||
"""Snippet not found error."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
def get_snippet(view_func: Callable[P, R]):
|
||||
"""Decorator to fetch and validate snippet access."""
|
||||
|
||||
@wraps(view_func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs):
|
||||
if not kwargs.get("snippet_id"):
|
||||
raise ValueError("missing snippet_id in path parameters")
|
||||
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet_id = str(kwargs.get("snippet_id"))
|
||||
del kwargs["snippet_id"]
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=snippet_id,
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
kwargs["snippet"] = snippet
|
||||
|
||||
return view_func(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft")
|
||||
class SnippetDraftWorkflowApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_workflow")
|
||||
@console_ns.response(200, "Draft workflow retrieved successfully", workflow_model)
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
@marshal_with(workflow_model)
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get draft workflow for snippet."""
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
|
||||
if not workflow:
|
||||
raise DraftWorkflowNotExist()
|
||||
|
||||
return workflow
|
||||
|
||||
@console_ns.doc("sync_snippet_draft_workflow")
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftSyncPayload.__name__))
|
||||
@console_ns.response(200, "Draft workflow synced successfully")
|
||||
@console_ns.response(400, "Hash mismatch")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
"""Sync draft workflow for snippet."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = SnippetDraftSyncPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
environment_variables_list = payload.environment_variables or []
|
||||
environment_variables = [
|
||||
variable_factory.build_environment_variable_from_mapping(obj) for obj in environment_variables_list
|
||||
]
|
||||
conversation_variables_list = payload.conversation_variables or []
|
||||
conversation_variables = [
|
||||
variable_factory.build_conversation_variable_from_mapping(obj) for obj in conversation_variables_list
|
||||
]
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.sync_draft_workflow(
|
||||
snippet=snippet,
|
||||
graph=payload.graph,
|
||||
unique_hash=payload.hash,
|
||||
account=current_user,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
input_variables=payload.input_variables,
|
||||
)
|
||||
except WorkflowHashNotEqualError:
|
||||
raise DraftWorkflowNotSync()
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"hash": workflow.unique_hash,
|
||||
"updated_at": TimestampField().format(workflow.updated_at or workflow.created_at),
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/config")
|
||||
class SnippetDraftConfigApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_config")
|
||||
@console_ns.response(200, "Draft config retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get snippet draft workflow configuration limits."""
|
||||
return {
|
||||
"parallel_depth_limit": 3,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/publish")
|
||||
class SnippetPublishedWorkflowApi(Resource):
|
||||
@console_ns.doc("get_snippet_published_workflow")
|
||||
@console_ns.response(200, "Published workflow retrieved successfully", workflow_model)
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
@marshal_with(workflow_model)
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get published workflow for snippet."""
|
||||
if not snippet.is_published:
|
||||
return None
|
||||
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_published_workflow(snippet=snippet)
|
||||
|
||||
return workflow
|
||||
|
||||
@console_ns.doc("publish_snippet_workflow")
|
||||
@console_ns.expect(console_ns.models.get(PublishWorkflowPayload.__name__))
|
||||
@console_ns.response(200, "Workflow published successfully")
|
||||
@console_ns.response(400, "No draft workflow found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
"""Publish snippet workflow."""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
snippet_service = SnippetService()
|
||||
|
||||
with Session(db.engine) as session:
|
||||
snippet = session.merge(snippet)
|
||||
try:
|
||||
workflow = snippet_service.publish_workflow(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
account=current_user,
|
||||
)
|
||||
workflow_created_at = TimestampField().format(workflow.created_at)
|
||||
session.commit()
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"created_at": workflow_created_at,
|
||||
}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/default-workflow-block-configs")
|
||||
class SnippetDefaultBlockConfigsApi(Resource):
|
||||
@console_ns.doc("get_snippet_default_block_configs")
|
||||
@console_ns.response(200, "Default block configs retrieved successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""Get default block configurations for snippet workflow."""
|
||||
snippet_service = SnippetService()
|
||||
return snippet_service.get_default_block_configs()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs")
|
||||
class SnippetWorkflowRunsApi(Resource):
|
||||
@console_ns.doc("list_snippet_workflow_runs")
|
||||
@console_ns.response(200, "Workflow runs retrieved successfully", workflow_run_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_pagination_model)
|
||||
def get(self, snippet: CustomizedSnippet):
|
||||
"""List workflow runs for snippet."""
|
||||
query = WorkflowRunQuery.model_validate(
|
||||
{
|
||||
"last_id": request.args.get("last_id"),
|
||||
"limit": request.args.get("limit", type=int, default=20),
|
||||
}
|
||||
)
|
||||
args = {
|
||||
"last_id": query.last_id,
|
||||
"limit": query.limit,
|
||||
}
|
||||
|
||||
snippet_service = SnippetService()
|
||||
result = snippet_service.get_snippet_workflow_runs(snippet=snippet, args=args)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>")
|
||||
class SnippetWorkflowRunDetailApi(Resource):
|
||||
@console_ns.doc("get_snippet_workflow_run_detail")
|
||||
@console_ns.response(200, "Workflow run detail retrieved successfully", workflow_run_detail_model)
|
||||
@console_ns.response(404, "Workflow run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_detail_model)
|
||||
def get(self, snippet: CustomizedSnippet, run_id):
|
||||
"""Get workflow run detail for snippet."""
|
||||
run_id = str(run_id)
|
||||
|
||||
snippet_service = SnippetService()
|
||||
workflow_run = snippet_service.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
|
||||
|
||||
if not workflow_run:
|
||||
raise NotFound("Workflow run not found")
|
||||
|
||||
return workflow_run
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/<uuid:run_id>/node-executions")
|
||||
class SnippetWorkflowRunNodeExecutionsApi(Resource):
|
||||
@console_ns.doc("list_snippet_workflow_run_node_executions")
|
||||
@console_ns.response(200, "Node executions retrieved successfully", workflow_run_node_execution_list_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_node_execution_list_model)
|
||||
def get(self, snippet: CustomizedSnippet, run_id):
|
||||
"""List node executions for a workflow run."""
|
||||
run_id = str(run_id)
|
||||
|
||||
snippet_service = SnippetService()
|
||||
node_executions = snippet_service.get_snippet_workflow_run_node_executions(
|
||||
snippet=snippet,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
return {"data": node_executions}
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/run")
|
||||
class SnippetDraftNodeRunApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_node")
|
||||
@console_ns.doc(description="Run a single node in snippet draft workflow (single-step debugging)")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftNodeRunPayload.__name__))
|
||||
@console_ns.response(200, "Node run completed successfully", workflow_run_node_execution_model)
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a single node in snippet draft workflow.
|
||||
|
||||
Executes a specific node with provided inputs for single-step debugging.
|
||||
Returns the node execution result including status, outputs, and timing.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
payload = SnippetDraftNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
user_inputs = payload.inputs
|
||||
|
||||
# Get draft workflow for file parsing
|
||||
snippet_service = SnippetService()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not draft_workflow:
|
||||
raise NotFound("Draft workflow not found")
|
||||
|
||||
files = SnippetGenerateService.parse_files(draft_workflow, payload.files)
|
||||
|
||||
workflow_node_execution = SnippetGenerateService.run_draft_node(
|
||||
snippet=snippet,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
account=current_user,
|
||||
query=payload.query,
|
||||
files=files,
|
||||
)
|
||||
|
||||
return workflow_node_execution
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/nodes/<string:node_id>/last-run")
|
||||
class SnippetDraftNodeLastRunApi(Resource):
|
||||
@console_ns.doc("get_snippet_draft_node_last_run")
|
||||
@console_ns.doc(description="Get last run result for a node in snippet draft workflow")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.response(200, "Node last run retrieved successfully", workflow_run_node_execution_model)
|
||||
@console_ns.response(404, "Snippet, draft workflow, or node last run not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
def get(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Get the last run result for a specific node in snippet draft workflow.
|
||||
|
||||
Returns the most recent execution record for the given node,
|
||||
including status, inputs, outputs, and timing information.
|
||||
"""
|
||||
snippet_service = SnippetService()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not draft_workflow:
|
||||
raise NotFound("Draft workflow not found")
|
||||
|
||||
node_exec = snippet_service.get_snippet_node_last_run(
|
||||
snippet=snippet,
|
||||
workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
)
|
||||
if node_exec is None:
|
||||
raise NotFound("Node last run not found")
|
||||
|
||||
return node_exec
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/iteration/nodes/<string:node_id>/run")
|
||||
class SnippetDraftRunIterationNodeApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_iteration_node")
|
||||
@console_ns.doc(description="Run draft workflow iteration node for snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetIterationNodeRunPayload.__name__))
|
||||
@console_ns.response(200, "Iteration node run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a draft workflow iteration node for snippet.
|
||||
|
||||
Iteration nodes execute their internal sub-graph multiple times over an input list.
|
||||
Returns an SSE event stream with iteration progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = SnippetIterationNodeRunPayload.model_validate(console_ns.payload or {}).model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate_single_iteration(
|
||||
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/loop/nodes/<string:node_id>/run")
|
||||
class SnippetDraftRunLoopNodeApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_loop_node")
|
||||
@console_ns.doc(description="Run draft workflow loop node for snippet")
|
||||
@console_ns.doc(params={"snippet_id": "Snippet ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models.get(SnippetLoopNodeRunPayload.__name__))
|
||||
@console_ns.response(200, "Loop node run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, node_id: str):
|
||||
"""
|
||||
Run a draft workflow loop node for snippet.
|
||||
|
||||
Loop nodes execute their internal sub-graph repeatedly until a condition is met.
|
||||
Returns an SSE event stream with loop progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = SnippetLoopNodeRunPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate_single_loop(
|
||||
snippet=snippet, user=current_user, node_id=node_id, args=args, streaming=True
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflows/draft/run")
|
||||
class SnippetDraftWorkflowRunApi(Resource):
|
||||
@console_ns.doc("run_snippet_draft_workflow")
|
||||
@console_ns.expect(console_ns.models.get(SnippetDraftRunPayload.__name__))
|
||||
@console_ns.response(200, "Draft workflow run started successfully (SSE stream)")
|
||||
@console_ns.response(404, "Snippet or draft workflow not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet):
|
||||
"""
|
||||
Run draft workflow for snippet.
|
||||
|
||||
Executes the snippet's draft workflow with the provided inputs
|
||||
and returns an SSE event stream with execution progress and results.
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
payload = SnippetDraftRunPayload.model_validate(console_ns.payload or {})
|
||||
args = payload.model_dump(exclude_none=True)
|
||||
|
||||
try:
|
||||
response = SnippetGenerateService.generate(
|
||||
snippet=snippet,
|
||||
user=current_user,
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
streaming=True,
|
||||
)
|
||||
|
||||
return helper.compact_generate_response(response)
|
||||
except ValueError as e:
|
||||
raise e
|
||||
except Exception:
|
||||
logger.exception("internal server error.")
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
@console_ns.route("/snippets/<uuid:snippet_id>/workflow-runs/tasks/<string:task_id>/stop")
|
||||
class SnippetWorkflowTaskStopApi(Resource):
|
||||
@console_ns.doc("stop_snippet_workflow_task")
|
||||
@console_ns.response(200, "Task stopped successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_snippet
|
||||
@edit_permission_required
|
||||
def post(self, snippet: CustomizedSnippet, task_id: str):
|
||||
"""
|
||||
Stop a running snippet workflow task.
|
||||
|
||||
Uses both the legacy stop flag mechanism and the graph engine
|
||||
command channel for backward compatibility.
|
||||
"""
|
||||
# Stop using both mechanisms for backward compatibility
|
||||
# Legacy stop flag mechanism (without user check)
|
||||
AppQueueManager.set_stop_flag_no_user_check(task_id)
|
||||
|
||||
# New graph engine command channel mechanism
|
||||
GraphEngineManager.send_stop_command(task_id)
|
||||
|
||||
return {"result": "success"}
|
||||
202
api/controllers/console/workspace/snippets.py
Normal file
202
api/controllers/console/workspace/snippets.py
Normal file
@@ -0,0 +1,202 @@
|
||||
import logging
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, marshal, marshal_with
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import NotFound
|
||||
|
||||
from controllers.common.schema import register_schema_models
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.snippets.payloads import (
|
||||
CreateSnippetPayload,
|
||||
SnippetListQuery,
|
||||
UpdateSnippetPayload,
|
||||
)
|
||||
from controllers.console.wraps import (
|
||||
account_initialization_required,
|
||||
edit_permission_required,
|
||||
setup_required,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from fields.snippet_fields import snippet_fields, snippet_list_fields, snippet_pagination_fields
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.snippet import SnippetType
|
||||
from services.snippet_service import SnippetService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Register Pydantic models with Swagger
|
||||
register_schema_models(
|
||||
console_ns,
|
||||
SnippetListQuery,
|
||||
CreateSnippetPayload,
|
||||
UpdateSnippetPayload,
|
||||
)
|
||||
|
||||
# Create namespace models for marshaling
|
||||
snippet_model = console_ns.model("Snippet", snippet_fields)
|
||||
snippet_list_model = console_ns.model("SnippetList", snippet_list_fields)
|
||||
snippet_pagination_model = console_ns.model("SnippetPagination", snippet_pagination_fields)
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets")
|
||||
class CustomizedSnippetsApi(Resource):
|
||||
@console_ns.doc("list_customized_snippets")
|
||||
@console_ns.expect(console_ns.models.get(SnippetListQuery.__name__))
|
||||
@console_ns.response(200, "Snippets retrieved successfully", snippet_pagination_model)
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self):
|
||||
"""List customized snippets with pagination and search."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
query_params = request.args.to_dict()
|
||||
query = SnippetListQuery.model_validate(query_params)
|
||||
|
||||
snippets, total, has_more = SnippetService.get_snippets(
|
||||
tenant_id=current_tenant_id,
|
||||
page=query.page,
|
||||
limit=query.limit,
|
||||
keyword=query.keyword,
|
||||
)
|
||||
|
||||
return {
|
||||
"data": marshal(snippets, snippet_list_fields),
|
||||
"page": query.page,
|
||||
"limit": query.limit,
|
||||
"total": total,
|
||||
"has_more": has_more,
|
||||
}, 200
|
||||
|
||||
@console_ns.doc("create_customized_snippet")
|
||||
@console_ns.expect(console_ns.models.get(CreateSnippetPayload.__name__))
|
||||
@console_ns.response(201, "Snippet created successfully", snippet_model)
|
||||
@console_ns.response(400, "Invalid request or name already exists")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def post(self):
|
||||
"""Create a new customized snippet."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
payload = CreateSnippetPayload.model_validate(console_ns.payload or {})
|
||||
|
||||
try:
|
||||
snippet_type = SnippetType(payload.type)
|
||||
except ValueError:
|
||||
snippet_type = SnippetType.NODE
|
||||
|
||||
try:
|
||||
snippet = SnippetService.create_snippet(
|
||||
tenant_id=current_tenant_id,
|
||||
name=payload.name,
|
||||
description=payload.description,
|
||||
snippet_type=snippet_type,
|
||||
icon_info=payload.icon_info.model_dump() if payload.icon_info else None,
|
||||
graph=payload.graph,
|
||||
input_fields=[f.model_dump() for f in payload.input_fields] if payload.input_fields else None,
|
||||
account=current_user,
|
||||
)
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return marshal(snippet, snippet_fields), 201
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/customized-snippets/<uuid:snippet_id>")
|
||||
class CustomizedSnippetDetailApi(Resource):
|
||||
@console_ns.doc("get_customized_snippet")
|
||||
@console_ns.response(200, "Snippet retrieved successfully", snippet_model)
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, snippet_id: str):
|
||||
"""Get customized snippet details."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
return marshal(snippet, snippet_fields), 200
|
||||
|
||||
@console_ns.doc("update_customized_snippet")
|
||||
@console_ns.expect(console_ns.models.get(UpdateSnippetPayload.__name__))
|
||||
@console_ns.response(200, "Snippet updated successfully", snippet_model)
|
||||
@console_ns.response(400, "Invalid request or name already exists")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def patch(self, snippet_id: str):
|
||||
"""Update customized snippet."""
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
payload = UpdateSnippetPayload.model_validate(console_ns.payload or {})
|
||||
update_data = payload.model_dump(exclude_unset=True)
|
||||
|
||||
if "icon_info" in update_data and update_data["icon_info"] is not None:
|
||||
update_data["icon_info"] = payload.icon_info.model_dump() if payload.icon_info else None
|
||||
|
||||
if not update_data:
|
||||
return {"message": "No valid fields to update"}, 400
|
||||
|
||||
try:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
snippet = session.merge(snippet)
|
||||
snippet = SnippetService.update_snippet(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
account_id=current_user.id,
|
||||
data=update_data,
|
||||
)
|
||||
session.commit()
|
||||
except ValueError as e:
|
||||
return {"message": str(e)}, 400
|
||||
|
||||
return marshal(snippet, snippet_fields), 200
|
||||
|
||||
@console_ns.doc("delete_customized_snippet")
|
||||
@console_ns.response(204, "Snippet deleted successfully")
|
||||
@console_ns.response(404, "Snippet not found")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
def delete(self, snippet_id: str):
|
||||
"""Delete customized snippet."""
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
snippet = SnippetService.get_snippet_by_id(
|
||||
snippet_id=str(snippet_id),
|
||||
tenant_id=current_tenant_id,
|
||||
)
|
||||
|
||||
if not snippet:
|
||||
raise NotFound("Snippet not found")
|
||||
|
||||
with Session(db.engine) as session:
|
||||
snippet = session.merge(snippet)
|
||||
SnippetService.delete_snippet(
|
||||
session=session,
|
||||
snippet=snippet,
|
||||
)
|
||||
session.commit()
|
||||
|
||||
return "", 204
|
||||
45
api/fields/snippet_fields.py
Normal file
45
api/fields/snippet_fields.py
Normal file
@@ -0,0 +1,45 @@
|
||||
from flask_restx import fields
|
||||
|
||||
from fields.member_fields import simple_account_fields
|
||||
from libs.helper import TimestampField
|
||||
|
||||
# Snippet list item fields (lightweight for list display)
|
||||
snippet_list_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"type": fields.String,
|
||||
"version": fields.Integer,
|
||||
"use_count": fields.Integer,
|
||||
"is_published": fields.Boolean,
|
||||
"icon_info": fields.Raw,
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
# Full snippet fields (includes creator info and graph data)
|
||||
snippet_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"description": fields.String,
|
||||
"type": fields.String,
|
||||
"version": fields.Integer,
|
||||
"use_count": fields.Integer,
|
||||
"is_published": fields.Boolean,
|
||||
"icon_info": fields.Raw,
|
||||
"graph": fields.Raw(attribute="graph_dict"),
|
||||
"input_fields": fields.Raw(attribute="input_fields_list"),
|
||||
"created_by": fields.Nested(simple_account_fields, attribute="created_by_account", allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"updated_by": fields.Nested(simple_account_fields, attribute="updated_by_account", allow_null=True),
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
# Pagination response fields
|
||||
snippet_pagination_fields = {
|
||||
"data": fields.List(fields.Nested(snippet_list_fields)),
|
||||
"page": fields.Integer,
|
||||
"limit": fields.Integer,
|
||||
"total": fields.Integer,
|
||||
"has_more": fields.Boolean,
|
||||
}
|
||||
@@ -0,0 +1,83 @@
|
||||
"""add_customized_snippets_table
|
||||
|
||||
Revision ID: 1c05e80d2380
|
||||
Revises: 788d3099ae3a
|
||||
Create Date: 2026-01-29 12:00:00.000000
|
||||
|
||||
"""
|
||||
|
||||
import sqlalchemy as sa
|
||||
from alembic import op
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
import models as models
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "1c05e80d2380"
|
||||
down_revision = "788d3099ae3a"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table(
|
||||
"customized_snippets",
|
||||
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", sa.Text(), nullable=True),
|
||||
sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False),
|
||||
sa.Column("workflow_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False),
|
||||
sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False),
|
||||
sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
|
||||
sa.Column("icon_info", postgresql.JSONB(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("graph", sa.Text(), nullable=True),
|
||||
sa.Column("input_fields", sa.Text(), nullable=True),
|
||||
sa.Column("created_by", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
|
||||
)
|
||||
else:
|
||||
op.create_table(
|
||||
"customized_snippets",
|
||||
sa.Column("id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("name", sa.String(length=255), nullable=False),
|
||||
sa.Column("description", models.types.LongText(), nullable=True),
|
||||
sa.Column("type", sa.String(length=50), server_default=sa.text("'node'"), nullable=False),
|
||||
sa.Column("workflow_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("is_published", sa.Boolean(), server_default=sa.text("false"), nullable=False),
|
||||
sa.Column("version", sa.Integer(), server_default=sa.text("1"), nullable=False),
|
||||
sa.Column("use_count", sa.Integer(), server_default=sa.text("0"), nullable=False),
|
||||
sa.Column("icon_info", models.types.AdjustedJSON(astext_type=sa.Text()), nullable=True),
|
||||
sa.Column("graph", models.types.LongText(), nullable=True),
|
||||
sa.Column("input_fields", models.types.LongText(), nullable=True),
|
||||
sa.Column("created_by", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column("updated_by", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
|
||||
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
|
||||
)
|
||||
|
||||
with op.batch_alter_table("customized_snippets", schema=None) as batch_op:
|
||||
batch_op.create_index("customized_snippet_tenant_idx", ["tenant_id"], unique=False)
|
||||
|
||||
|
||||
def downgrade():
|
||||
with op.batch_alter_table("customized_snippets", schema=None) as batch_op:
|
||||
batch_op.drop_index("customized_snippet_tenant_idx")
|
||||
|
||||
op.drop_table("customized_snippets")
|
||||
@@ -81,6 +81,7 @@ from .provider import (
|
||||
TenantDefaultModel,
|
||||
TenantPreferredModelProvider,
|
||||
)
|
||||
from .snippet import CustomizedSnippet, SnippetType
|
||||
from .source import DataSourceApiKeyAuthBinding, DataSourceOauthBinding
|
||||
from .task import CeleryTask, CeleryTaskSet
|
||||
from .tools import (
|
||||
@@ -140,6 +141,7 @@ __all__ = [
|
||||
"Conversation",
|
||||
"ConversationVariable",
|
||||
"CreatorUserRole",
|
||||
"CustomizedSnippet",
|
||||
"DataSourceApiKeyAuthBinding",
|
||||
"DataSourceOauthBinding",
|
||||
"Dataset",
|
||||
@@ -184,6 +186,7 @@ __all__ = [
|
||||
"RecommendedApp",
|
||||
"SavedMessage",
|
||||
"Site",
|
||||
"SnippetType",
|
||||
"Tag",
|
||||
"TagBinding",
|
||||
"Tenant",
|
||||
|
||||
101
api/models/snippet.py
Normal file
101
api/models/snippet.py
Normal file
@@ -0,0 +1,101 @@
|
||||
import json
|
||||
from datetime import datetime
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy import DateTime, String, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from libs.uuid_utils import uuidv7
|
||||
|
||||
from .account import Account
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .types import AdjustedJSON, LongText, StringUUID
|
||||
|
||||
|
||||
class SnippetType(StrEnum):
|
||||
"""Snippet Type Enum"""
|
||||
|
||||
NODE = "node"
|
||||
GROUP = "group"
|
||||
|
||||
|
||||
class CustomizedSnippet(Base):
|
||||
"""
|
||||
Customized Snippet Model
|
||||
|
||||
Stores reusable workflow components (nodes or node groups) that can be
|
||||
shared across applications within a workspace.
|
||||
"""
|
||||
|
||||
__tablename__ = "customized_snippets"
|
||||
__table_args__ = (
|
||||
sa.PrimaryKeyConstraint("id", name="customized_snippet_pkey"),
|
||||
sa.Index("customized_snippet_tenant_idx", "tenant_id"),
|
||||
sa.UniqueConstraint("tenant_id", "name", name="customized_snippet_tenant_name_key"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuidv7()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
name: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
description: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
type: Mapped[str] = mapped_column(String(50), nullable=False, server_default=sa.text("'node'"))
|
||||
|
||||
# Workflow reference for published version
|
||||
workflow_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
|
||||
# State flags
|
||||
is_published: Mapped[bool] = mapped_column(sa.Boolean, nullable=False, server_default=sa.text("false"))
|
||||
version: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("1"))
|
||||
use_count: Mapped[int] = mapped_column(sa.Integer, nullable=False, server_default=sa.text("0"))
|
||||
|
||||
# Visual customization
|
||||
icon_info: Mapped[dict | None] = mapped_column(AdjustedJSON, nullable=True)
|
||||
|
||||
# Snippet configuration (stored as JSON text)
|
||||
input_fields: Mapped[str | None] = mapped_column(LongText, nullable=True)
|
||||
|
||||
# Audit fields
|
||||
created_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
created_at: Mapped[datetime] = mapped_column(DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_by: Mapped[str | None] = mapped_column(StringUUID, nullable=True)
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
|
||||
@property
|
||||
def graph_dict(self) -> dict[str, Any]:
|
||||
"""Get graph from associated workflow."""
|
||||
if self.workflow_id:
|
||||
from .workflow import Workflow
|
||||
|
||||
workflow = db.session.get(Workflow, self.workflow_id)
|
||||
if workflow:
|
||||
return json.loads(workflow.graph) if workflow.graph else {}
|
||||
return {}
|
||||
|
||||
@property
|
||||
def input_fields_list(self) -> list[dict[str, Any]]:
|
||||
"""Parse input_fields JSON to list."""
|
||||
return json.loads(self.input_fields) if self.input_fields else []
|
||||
|
||||
@property
|
||||
def created_by_account(self) -> Account | None:
|
||||
"""Get the account that created this snippet."""
|
||||
if self.created_by:
|
||||
return db.session.get(Account, self.created_by)
|
||||
return None
|
||||
|
||||
@property
|
||||
def updated_by_account(self) -> Account | None:
|
||||
"""Get the account that last updated this snippet."""
|
||||
if self.updated_by:
|
||||
return db.session.get(Account, self.updated_by)
|
||||
return None
|
||||
|
||||
@property
|
||||
def version_str(self) -> str:
|
||||
"""Get version as string for API response."""
|
||||
return str(self.version)
|
||||
@@ -67,6 +67,7 @@ class WorkflowType(StrEnum):
|
||||
WORKFLOW = "workflow"
|
||||
CHAT = "chat"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
SNIPPET = "snippet"
|
||||
|
||||
@classmethod
|
||||
def value_of(cls, value: str) -> "WorkflowType":
|
||||
|
||||
178
api/services/evaluation_service.py
Normal file
178
api/services/evaluation_service.py
Normal file
@@ -0,0 +1,178 @@
|
||||
import io
|
||||
import logging
|
||||
from typing import Union
|
||||
|
||||
from openpyxl import Workbook
|
||||
from openpyxl.styles import Alignment, Border, Font, PatternFill, Side
|
||||
from openpyxl.utils import get_column_letter
|
||||
|
||||
from models.model import App, AppMode
|
||||
from models.snippet import CustomizedSnippet
|
||||
from services.snippet_service import SnippetService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EvaluationService:
|
||||
"""
|
||||
Service for evaluation-related operations.
|
||||
|
||||
Provides functionality to generate evaluation dataset templates
|
||||
based on App or Snippet input parameters.
|
||||
"""
|
||||
|
||||
# Excluded app modes that don't support evaluation templates
|
||||
EXCLUDED_APP_MODES = {AppMode.RAG_PIPELINE}
|
||||
|
||||
@classmethod
|
||||
def generate_dataset_template(
|
||||
cls,
|
||||
target: Union[App, CustomizedSnippet],
|
||||
target_type: str,
|
||||
) -> tuple[bytes, str]:
|
||||
"""
|
||||
Generate evaluation dataset template as XLSX bytes.
|
||||
|
||||
Creates an XLSX file with headers based on the evaluation target's input parameters.
|
||||
The first column is index, followed by input parameter columns.
|
||||
|
||||
:param target: App or CustomizedSnippet instance
|
||||
:param target_type: Target type string ("app" or "snippet")
|
||||
:return: Tuple of (xlsx_content_bytes, filename)
|
||||
:raises ValueError: If target type is not supported or app mode is excluded
|
||||
"""
|
||||
# Validate target type
|
||||
if target_type == "app":
|
||||
if not isinstance(target, App):
|
||||
raise ValueError("Invalid target: expected App instance")
|
||||
if AppMode.value_of(target.mode) in cls.EXCLUDED_APP_MODES:
|
||||
raise ValueError(f"App mode '{target.mode}' does not support evaluation templates")
|
||||
input_fields = cls._get_app_input_fields(target)
|
||||
elif target_type == "snippet":
|
||||
if not isinstance(target, CustomizedSnippet):
|
||||
raise ValueError("Invalid target: expected CustomizedSnippet instance")
|
||||
input_fields = cls._get_snippet_input_fields(target)
|
||||
else:
|
||||
raise ValueError(f"Unsupported target type: {target_type}")
|
||||
|
||||
# Generate XLSX template
|
||||
xlsx_content = cls._generate_xlsx_template(input_fields, target.name)
|
||||
|
||||
# Build filename
|
||||
truncated_name = target.name[:10] + "..." if len(target.name) > 10 else target.name
|
||||
filename = f"{truncated_name}-evaluation-dataset.xlsx"
|
||||
|
||||
return xlsx_content, filename
|
||||
|
||||
@classmethod
|
||||
def _get_app_input_fields(cls, app: App) -> list[dict]:
|
||||
"""
|
||||
Get input fields from App's workflow.
|
||||
|
||||
:param app: App instance
|
||||
:return: List of input field definitions
|
||||
"""
|
||||
workflow_service = WorkflowService()
|
||||
workflow = workflow_service.get_published_workflow(app_model=app)
|
||||
if not workflow:
|
||||
workflow = workflow_service.get_draft_workflow(app_model=app)
|
||||
|
||||
if not workflow:
|
||||
return []
|
||||
|
||||
# Get user input form from workflow
|
||||
user_input_form = workflow.user_input_form()
|
||||
return user_input_form
|
||||
|
||||
@classmethod
|
||||
def _get_snippet_input_fields(cls, snippet: CustomizedSnippet) -> list[dict]:
|
||||
"""
|
||||
Get input fields from Snippet.
|
||||
|
||||
Tries to get from snippet's own input_fields first,
|
||||
then falls back to workflow's user_input_form.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:return: List of input field definitions
|
||||
"""
|
||||
# Try snippet's own input_fields first
|
||||
input_fields = snippet.input_fields_list
|
||||
if input_fields:
|
||||
return input_fields
|
||||
|
||||
# Fallback to workflow's user_input_form
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_published_workflow(snippet=snippet)
|
||||
if not workflow:
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
|
||||
if workflow:
|
||||
return workflow.user_input_form()
|
||||
|
||||
return []
|
||||
|
||||
@classmethod
|
||||
def _generate_xlsx_template(cls, input_fields: list[dict], target_name: str) -> bytes:
|
||||
"""
|
||||
Generate XLSX template file content.
|
||||
|
||||
Creates a workbook with:
|
||||
- First row as header row with "index" and input field names
|
||||
- Styled header with background color and borders
|
||||
- Empty data rows ready for user input
|
||||
|
||||
:param input_fields: List of input field definitions
|
||||
:param target_name: Name of the target (for sheet name)
|
||||
:return: XLSX file content as bytes
|
||||
"""
|
||||
wb = Workbook()
|
||||
ws = wb.active
|
||||
|
||||
sheet_name = "Evaluation Dataset"
|
||||
ws.title = sheet_name
|
||||
|
||||
header_font = Font(bold=True, color="FFFFFF")
|
||||
header_fill = PatternFill(start_color="4472C4", end_color="4472C4", fill_type="solid")
|
||||
header_alignment = Alignment(horizontal="center", vertical="center")
|
||||
thin_border = Border(
|
||||
left=Side(style="thin"),
|
||||
right=Side(style="thin"),
|
||||
top=Side(style="thin"),
|
||||
bottom=Side(style="thin"),
|
||||
)
|
||||
|
||||
# Build header row
|
||||
headers = ["index"]
|
||||
|
||||
for field in input_fields:
|
||||
field_label = field.get("label") or field.get("variable")
|
||||
headers.append(field_label)
|
||||
|
||||
# Write header row
|
||||
for col_idx, header in enumerate(headers, start=1):
|
||||
cell = ws.cell(row=1, column=col_idx, value=header)
|
||||
cell.font = header_font
|
||||
cell.fill = header_fill
|
||||
cell.alignment = header_alignment
|
||||
cell.border = thin_border
|
||||
|
||||
# Set column widths
|
||||
ws.column_dimensions["A"].width = 10 # index column
|
||||
for col_idx in range(2, len(headers) + 1):
|
||||
ws.column_dimensions[get_column_letter(col_idx)].width = 20
|
||||
|
||||
# Add one empty row with row number for user reference
|
||||
for col_idx in range(1, len(headers) + 1):
|
||||
cell = ws.cell(row=2, column=col_idx, value="")
|
||||
cell.border = thin_border
|
||||
if col_idx == 1:
|
||||
cell.value = 1
|
||||
cell.alignment = Alignment(horizontal="center")
|
||||
|
||||
# Save to bytes
|
||||
output = io.BytesIO()
|
||||
wb.save(output)
|
||||
output.seek(0)
|
||||
|
||||
return output.getvalue()
|
||||
374
api/services/snippet_generate_service.py
Normal file
374
api/services/snippet_generate_service.py
Normal file
@@ -0,0 +1,374 @@
|
||||
"""
|
||||
Service for generating snippet workflow executions.
|
||||
|
||||
Uses an adapter pattern to bridge CustomizedSnippet with the App-based
|
||||
WorkflowAppGenerator. The adapter (_SnippetAsApp) provides the minimal App-like
|
||||
interface needed by the generator, avoiding modifications to core workflow
|
||||
infrastructure.
|
||||
|
||||
Key invariants:
|
||||
- Snippets always run as WORKFLOW mode (not CHAT or ADVANCED_CHAT).
|
||||
- The adapter maps snippet.id to app_id in workflow execution records.
|
||||
- Snippet debugging has no rate limiting (max_active_requests = 0).
|
||||
|
||||
Supported execution modes:
|
||||
- Full workflow run (generate): Runs the entire draft workflow as SSE stream.
|
||||
- Single node run (run_draft_node): Synchronous single-step debugging for regular nodes.
|
||||
- Single iteration run (generate_single_iteration): SSE stream for iteration container nodes.
|
||||
- Single loop run (generate_single_loop): SSE stream for loop container nodes.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, Union
|
||||
|
||||
from sqlalchemy.orm import make_transient
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.file.models import File
|
||||
from factories import file_factory
|
||||
from models import Account
|
||||
from models.model import AppMode, EndUser
|
||||
from models.snippet import CustomizedSnippet
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionModel
|
||||
from services.snippet_service import SnippetService
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class _SnippetAsApp:
|
||||
"""
|
||||
Minimal adapter that wraps a CustomizedSnippet to satisfy the App-like
|
||||
interface required by WorkflowAppGenerator, WorkflowAppConfigManager,
|
||||
and WorkflowService.run_draft_workflow_node.
|
||||
|
||||
Used properties:
|
||||
- id: maps to snippet.id (stored as app_id in workflows table)
|
||||
- tenant_id: maps to snippet.tenant_id
|
||||
- mode: hardcoded to AppMode.WORKFLOW since snippets always run as workflows
|
||||
- max_active_requests: defaults to 0 (no limit) for snippet debugging
|
||||
- app_model_config_id: None (snippets don't have app model configs)
|
||||
"""
|
||||
|
||||
id: str
|
||||
tenant_id: str
|
||||
mode: str
|
||||
max_active_requests: int
|
||||
app_model_config_id: str | None
|
||||
|
||||
def __init__(self, snippet: CustomizedSnippet) -> None:
|
||||
self.id = snippet.id
|
||||
self.tenant_id = snippet.tenant_id
|
||||
self.mode = AppMode.WORKFLOW.value
|
||||
self.max_active_requests = 0
|
||||
self.app_model_config_id = None
|
||||
|
||||
|
||||
class SnippetGenerateService:
|
||||
"""
|
||||
Service for running snippet workflow executions.
|
||||
|
||||
Adapts CustomizedSnippet to work with the existing App-based
|
||||
WorkflowAppGenerator infrastructure, avoiding duplication of the
|
||||
complex workflow execution pipeline.
|
||||
"""
|
||||
|
||||
# Specific ID for the injected virtual Start node so it can be recognised
|
||||
_VIRTUAL_START_NODE_ID = "__snippet_virtual_start__"
|
||||
|
||||
@classmethod
|
||||
def generate(
|
||||
cls,
|
||||
snippet: CustomizedSnippet,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
"""
|
||||
Run a snippet's draft workflow.
|
||||
|
||||
Retrieves the draft workflow, adapts the snippet to an App-like proxy,
|
||||
then delegates execution to WorkflowAppGenerator.
|
||||
|
||||
If the workflow graph has no Start node, a virtual Start node is injected
|
||||
in-memory so that:
|
||||
1. Graph validation passes (root node must have execution_type=ROOT).
|
||||
2. User inputs are processed into the variable pool by the StartNode logic.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param user: Account or EndUser initiating the run
|
||||
:param args: Workflow inputs (must include "inputs" key)
|
||||
:param invoke_from: Source of invocation (typically DEBUGGER)
|
||||
:param streaming: Whether to stream the response
|
||||
:return: Blocking response mapping or SSE streaming generator
|
||||
:raises ValueError: If the snippet has no draft workflow
|
||||
"""
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
# Inject a virtual Start node when the graph doesn't have one.
|
||||
workflow = cls._ensure_start_node(workflow, snippet)
|
||||
|
||||
# Adapt snippet to App-like interface for WorkflowAppGenerator
|
||||
app_proxy = _SnippetAsApp(snippet)
|
||||
|
||||
return WorkflowAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().generate(
|
||||
app_model=app_proxy, # type: ignore[arg-type]
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=streaming,
|
||||
call_depth=0,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def _ensure_start_node(cls, workflow: Workflow, snippet: CustomizedSnippet) -> Workflow:
|
||||
"""
|
||||
Return *workflow* with a Start node.
|
||||
|
||||
If the graph already contains a Start node, the original workflow is
|
||||
returned unchanged. Otherwise a virtual Start node is injected and the
|
||||
workflow object is detached from the SQLAlchemy session so the in-memory
|
||||
change is never flushed to the database.
|
||||
"""
|
||||
graph_dict = workflow.graph_dict
|
||||
nodes: list[dict[str, Any]] = graph_dict.get("nodes", [])
|
||||
|
||||
has_start = any(node.get("data", {}).get("type") == "start" for node in nodes)
|
||||
if has_start:
|
||||
return workflow
|
||||
|
||||
modified_graph = cls._inject_virtual_start_node(
|
||||
graph_dict=graph_dict,
|
||||
input_fields=snippet.input_fields_list,
|
||||
)
|
||||
|
||||
# Detach from session to prevent accidental DB persistence of the
|
||||
# modified graph. All attributes remain accessible for read.
|
||||
make_transient(workflow)
|
||||
workflow.graph = json.dumps(modified_graph)
|
||||
return workflow
|
||||
|
||||
@classmethod
|
||||
def _inject_virtual_start_node(
|
||||
cls,
|
||||
graph_dict: Mapping[str, Any],
|
||||
input_fields: list[dict[str, Any]],
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Build a new graph dict with a virtual Start node prepended.
|
||||
|
||||
The virtual Start node is wired to every existing node that has no
|
||||
incoming edges (i.e. the current root candidates). This guarantees:
|
||||
|
||||
:param graph_dict: Original graph configuration.
|
||||
:param input_fields: Snippet input field definitions from
|
||||
``CustomizedSnippet.input_fields_list``.
|
||||
:return: New graph dict containing the virtual Start node and edges.
|
||||
"""
|
||||
nodes: list[dict[str, Any]] = list(graph_dict.get("nodes", []))
|
||||
edges: list[dict[str, Any]] = list(graph_dict.get("edges", []))
|
||||
|
||||
# Identify nodes with no incoming edges.
|
||||
nodes_with_incoming: set[str] = set()
|
||||
for edge in edges:
|
||||
target = edge.get("target")
|
||||
if isinstance(target, str):
|
||||
nodes_with_incoming.add(target)
|
||||
root_candidate_ids = [n["id"] for n in nodes if n["id"] not in nodes_with_incoming]
|
||||
|
||||
# Build Start node ``variables`` from snippet input fields.
|
||||
start_variables: list[dict[str, Any]] = []
|
||||
for field in input_fields:
|
||||
var: dict[str, Any] = {
|
||||
"variable": field.get("variable", ""),
|
||||
"label": field.get("label", field.get("variable", "")),
|
||||
"type": field.get("type", "text-input"),
|
||||
"required": field.get("required", False),
|
||||
"options": field.get("options", []),
|
||||
}
|
||||
if field.get("max_length") is not None:
|
||||
var["max_length"] = field["max_length"]
|
||||
start_variables.append(var)
|
||||
|
||||
virtual_start_node: dict[str, Any] = {
|
||||
"id": cls._VIRTUAL_START_NODE_ID,
|
||||
"data": {
|
||||
"type": "start",
|
||||
"title": "Start",
|
||||
"variables": start_variables,
|
||||
},
|
||||
}
|
||||
|
||||
# Create edges from virtual Start to each root candidate.
|
||||
new_edges: list[dict[str, Any]] = [
|
||||
{
|
||||
"source": cls._VIRTUAL_START_NODE_ID,
|
||||
"sourceHandle": "source",
|
||||
"target": root_id,
|
||||
"targetHandle": "target",
|
||||
}
|
||||
for root_id in root_candidate_ids
|
||||
]
|
||||
|
||||
return {
|
||||
**graph_dict,
|
||||
"nodes": [virtual_start_node, *nodes],
|
||||
"edges": [*edges, *new_edges],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def run_draft_node(
|
||||
cls,
|
||||
snippet: CustomizedSnippet,
|
||||
node_id: str,
|
||||
user_inputs: Mapping[str, Any],
|
||||
account: Account,
|
||||
query: str = "",
|
||||
files: Sequence[File] | None = None,
|
||||
) -> WorkflowNodeExecutionModel:
|
||||
"""
|
||||
Run a single node in a snippet's draft workflow (single-step debugging).
|
||||
|
||||
Retrieves the draft workflow, adapts the snippet to an App-like proxy,
|
||||
parses file inputs, then delegates to WorkflowService.run_draft_workflow_node.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param node_id: ID of the node to run
|
||||
:param user_inputs: User input values for the node
|
||||
:param account: Account initiating the run
|
||||
:param query: Optional query string
|
||||
:param files: Optional parsed file objects
|
||||
:return: WorkflowNodeExecutionModel with execution results
|
||||
:raises ValueError: If the snippet has no draft workflow
|
||||
"""
|
||||
snippet_service = SnippetService()
|
||||
draft_workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not draft_workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
app_proxy = _SnippetAsApp(snippet)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
return workflow_service.run_draft_workflow_node(
|
||||
app_model=app_proxy, # type: ignore[arg-type]
|
||||
draft_workflow=draft_workflow,
|
||||
node_id=node_id,
|
||||
user_inputs=user_inputs,
|
||||
account=account,
|
||||
query=query,
|
||||
files=files,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_single_iteration(
|
||||
cls,
|
||||
snippet: CustomizedSnippet,
|
||||
user: Union[Account, EndUser],
|
||||
node_id: str,
|
||||
args: Mapping[str, Any],
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
"""
|
||||
Run a single iteration node in a snippet's draft workflow.
|
||||
|
||||
Iteration nodes are container nodes that execute their sub-graph multiple
|
||||
times, producing many events. Therefore, this uses the full WorkflowAppGenerator
|
||||
pipeline with SSE streaming (unlike regular single-step node run).
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param user: Account or EndUser initiating the run
|
||||
:param node_id: ID of the iteration node to run
|
||||
:param args: Dict containing 'inputs' key with iteration input data
|
||||
:param streaming: Whether to stream the response (should be True)
|
||||
:return: SSE streaming generator
|
||||
:raises ValueError: If the snippet has no draft workflow
|
||||
"""
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
app_proxy = _SnippetAsApp(snippet)
|
||||
|
||||
return WorkflowAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_iteration_generate(
|
||||
app_model=app_proxy, # type: ignore[arg-type]
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user=user,
|
||||
args=args,
|
||||
streaming=streaming,
|
||||
)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def generate_single_loop(
|
||||
cls,
|
||||
snippet: CustomizedSnippet,
|
||||
user: Union[Account, EndUser],
|
||||
node_id: str,
|
||||
args: Any,
|
||||
streaming: bool = True,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
"""
|
||||
Run a single loop node in a snippet's draft workflow.
|
||||
|
||||
Loop nodes are container nodes that execute their sub-graph repeatedly,
|
||||
producing many events. Therefore, this uses the full WorkflowAppGenerator
|
||||
pipeline with SSE streaming (unlike regular single-step node run).
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param user: Account or EndUser initiating the run
|
||||
:param node_id: ID of the loop node to run
|
||||
:param args: Pydantic model with 'inputs' attribute containing loop input data
|
||||
:param streaming: Whether to stream the response (should be True)
|
||||
:return: SSE streaming generator
|
||||
:raises ValueError: If the snippet has no draft workflow
|
||||
"""
|
||||
snippet_service = SnippetService()
|
||||
workflow = snippet_service.get_draft_workflow(snippet=snippet)
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not initialized")
|
||||
|
||||
app_proxy = _SnippetAsApp(snippet)
|
||||
|
||||
return WorkflowAppGenerator.convert_to_event_stream(
|
||||
WorkflowAppGenerator().single_loop_generate(
|
||||
app_model=app_proxy, # type: ignore[arg-type]
|
||||
workflow=workflow,
|
||||
node_id=node_id,
|
||||
user=user,
|
||||
args=args, # type: ignore[arg-type]
|
||||
streaming=streaming,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def parse_files(workflow: Workflow, files: list[dict] | None = None) -> Sequence[File]:
|
||||
"""
|
||||
Parse file mappings into File objects based on workflow configuration.
|
||||
|
||||
:param workflow: Workflow instance for file upload config
|
||||
:param files: Raw file mapping dicts
|
||||
:return: Parsed File objects
|
||||
"""
|
||||
files = files or []
|
||||
file_extra_config = FileUploadConfigManager.convert(workflow.features_dict, is_vision=False)
|
||||
if file_extra_config is None:
|
||||
return []
|
||||
return file_factory.build_from_mappings(
|
||||
mappings=files,
|
||||
tenant_id=workflow.tenant_id,
|
||||
config=file_extra_config,
|
||||
)
|
||||
566
api/services/snippet_service.py
Normal file
566
api/services/snippet_service.py
Normal file
@@ -0,0 +1,566 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import UTC, datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import func, select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.variables.variables import VariableBase
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.snippet import CustomizedSnippet, SnippetType
|
||||
from models.workflow import (
|
||||
Workflow,
|
||||
WorkflowNodeExecutionModel,
|
||||
WorkflowRun,
|
||||
WorkflowType,
|
||||
)
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.errors.app import WorkflowHashNotEqualError
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class SnippetService:
|
||||
"""Service for managing customized snippets."""
|
||||
|
||||
def __init__(self, session_maker: sessionmaker | None = None):
|
||||
"""Initialize SnippetService with repository dependencies."""
|
||||
if session_maker is None:
|
||||
session_maker = sessionmaker(bind=db.engine, expire_on_commit=False)
|
||||
self._node_execution_service_repo = DifyAPIRepositoryFactory.create_api_workflow_node_execution_repository(
|
||||
session_maker
|
||||
)
|
||||
self._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
|
||||
# --- CRUD Operations ---
|
||||
|
||||
@staticmethod
|
||||
def get_snippets(
|
||||
*,
|
||||
tenant_id: str,
|
||||
page: int = 1,
|
||||
limit: int = 20,
|
||||
keyword: str | None = None,
|
||||
) -> tuple[Sequence[CustomizedSnippet], int, bool]:
|
||||
"""
|
||||
Get paginated list of snippets with optional search.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param page: Page number (1-indexed)
|
||||
:param limit: Number of items per page
|
||||
:param keyword: Optional search keyword for name/description
|
||||
:return: Tuple of (snippets list, total count, has_more flag)
|
||||
"""
|
||||
stmt = (
|
||||
select(CustomizedSnippet)
|
||||
.where(CustomizedSnippet.tenant_id == tenant_id)
|
||||
.order_by(CustomizedSnippet.created_at.desc())
|
||||
)
|
||||
|
||||
if keyword:
|
||||
stmt = stmt.where(
|
||||
CustomizedSnippet.name.ilike(f"%{keyword}%") | CustomizedSnippet.description.ilike(f"%{keyword}%")
|
||||
)
|
||||
|
||||
# Get total count
|
||||
count_stmt = select(func.count()).select_from(stmt.subquery())
|
||||
total = db.session.scalar(count_stmt) or 0
|
||||
|
||||
# Apply pagination
|
||||
stmt = stmt.limit(limit + 1).offset((page - 1) * limit)
|
||||
snippets = list(db.session.scalars(stmt).all())
|
||||
|
||||
has_more = len(snippets) > limit
|
||||
if has_more:
|
||||
snippets = snippets[:-1]
|
||||
|
||||
return snippets, total, has_more
|
||||
|
||||
@staticmethod
|
||||
def get_snippet_by_id(
|
||||
*,
|
||||
snippet_id: str,
|
||||
tenant_id: str,
|
||||
) -> CustomizedSnippet | None:
|
||||
"""
|
||||
Get snippet by ID with tenant isolation.
|
||||
|
||||
:param snippet_id: Snippet ID
|
||||
:param tenant_id: Tenant ID
|
||||
:return: CustomizedSnippet or None
|
||||
"""
|
||||
return (
|
||||
db.session.query(CustomizedSnippet)
|
||||
.where(
|
||||
CustomizedSnippet.id == snippet_id,
|
||||
CustomizedSnippet.tenant_id == tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def create_snippet(
|
||||
*,
|
||||
tenant_id: str,
|
||||
name: str,
|
||||
description: str | None,
|
||||
snippet_type: SnippetType,
|
||||
icon_info: dict | None,
|
||||
graph: dict | None,
|
||||
input_fields: list[dict] | None,
|
||||
account: Account,
|
||||
) -> CustomizedSnippet:
|
||||
"""
|
||||
Create a new snippet.
|
||||
|
||||
:param tenant_id: Tenant ID
|
||||
:param name: Snippet name (must be unique per tenant)
|
||||
:param description: Snippet description
|
||||
:param snippet_type: Type of snippet (node or group)
|
||||
:param icon_info: Icon information
|
||||
:param graph: Workflow graph structure
|
||||
:param input_fields: Input field definitions
|
||||
:param account: Creator account
|
||||
:return: Created CustomizedSnippet
|
||||
:raises ValueError: If name already exists
|
||||
"""
|
||||
# Check if name already exists for this tenant
|
||||
existing = (
|
||||
db.session.query(CustomizedSnippet)
|
||||
.where(
|
||||
CustomizedSnippet.tenant_id == tenant_id,
|
||||
CustomizedSnippet.name == name,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Snippet with name '{name}' already exists")
|
||||
|
||||
snippet = CustomizedSnippet(
|
||||
tenant_id=tenant_id,
|
||||
name=name,
|
||||
description=description or "",
|
||||
type=snippet_type.value,
|
||||
icon_info=icon_info,
|
||||
graph=json.dumps(graph) if graph else None,
|
||||
input_fields=json.dumps(input_fields) if input_fields else None,
|
||||
created_by=account.id,
|
||||
)
|
||||
|
||||
db.session.add(snippet)
|
||||
db.session.commit()
|
||||
|
||||
return snippet
|
||||
|
||||
@staticmethod
|
||||
def update_snippet(
|
||||
*,
|
||||
session: Session,
|
||||
snippet: CustomizedSnippet,
|
||||
account_id: str,
|
||||
data: dict,
|
||||
) -> CustomizedSnippet:
|
||||
"""
|
||||
Update snippet attributes.
|
||||
|
||||
:param session: Database session
|
||||
:param snippet: Snippet to update
|
||||
:param account_id: ID of account making the update
|
||||
:param data: Dictionary of fields to update
|
||||
:return: Updated CustomizedSnippet
|
||||
"""
|
||||
if "name" in data:
|
||||
# Check if new name already exists for this tenant
|
||||
existing = (
|
||||
session.query(CustomizedSnippet)
|
||||
.where(
|
||||
CustomizedSnippet.tenant_id == snippet.tenant_id,
|
||||
CustomizedSnippet.name == data["name"],
|
||||
CustomizedSnippet.id != snippet.id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if existing:
|
||||
raise ValueError(f"Snippet with name '{data['name']}' already exists")
|
||||
snippet.name = data["name"]
|
||||
|
||||
if "description" in data:
|
||||
snippet.description = data["description"]
|
||||
|
||||
if "icon_info" in data:
|
||||
snippet.icon_info = data["icon_info"]
|
||||
|
||||
snippet.updated_by = account_id
|
||||
snippet.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
session.add(snippet)
|
||||
return snippet
|
||||
|
||||
@staticmethod
|
||||
def delete_snippet(
|
||||
*,
|
||||
session: Session,
|
||||
snippet: CustomizedSnippet,
|
||||
) -> bool:
|
||||
"""
|
||||
Delete a snippet.
|
||||
|
||||
:param session: Database session
|
||||
:param snippet: Snippet to delete
|
||||
:return: True if deleted successfully
|
||||
"""
|
||||
session.delete(snippet)
|
||||
return True
|
||||
|
||||
# --- Workflow Operations ---
|
||||
|
||||
def get_draft_workflow(self, snippet: CustomizedSnippet) -> Workflow | None:
|
||||
"""
|
||||
Get draft workflow for snippet.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:return: Draft Workflow or None
|
||||
"""
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == snippet.tenant_id,
|
||||
Workflow.app_id == snippet.id,
|
||||
Workflow.type == WorkflowType.SNIPPET.value,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return workflow
|
||||
|
||||
def get_published_workflow(self, snippet: CustomizedSnippet) -> Workflow | None:
|
||||
"""
|
||||
Get published workflow for snippet.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:return: Published Workflow or None
|
||||
"""
|
||||
if not snippet.workflow_id:
|
||||
return None
|
||||
|
||||
workflow = (
|
||||
db.session.query(Workflow)
|
||||
.where(
|
||||
Workflow.tenant_id == snippet.tenant_id,
|
||||
Workflow.app_id == snippet.id,
|
||||
Workflow.type == WorkflowType.SNIPPET.value,
|
||||
Workflow.id == snippet.workflow_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
return workflow
|
||||
|
||||
def sync_draft_workflow(
|
||||
self,
|
||||
*,
|
||||
snippet: CustomizedSnippet,
|
||||
graph: dict,
|
||||
unique_hash: str | None,
|
||||
account: Account,
|
||||
environment_variables: Sequence[VariableBase],
|
||||
conversation_variables: Sequence[VariableBase],
|
||||
input_variables: list[dict] | None = None,
|
||||
) -> Workflow:
|
||||
"""
|
||||
Sync draft workflow for snippet.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param graph: Workflow graph configuration
|
||||
:param unique_hash: Hash for conflict detection
|
||||
:param account: Account making the change
|
||||
:param environment_variables: Environment variables
|
||||
:param conversation_variables: Conversation variables
|
||||
:param input_variables: Input variables for snippet
|
||||
:return: Synced Workflow
|
||||
:raises WorkflowHashNotEqualError: If hash mismatch
|
||||
"""
|
||||
workflow = self.get_draft_workflow(snippet=snippet)
|
||||
|
||||
if workflow and workflow.unique_hash != unique_hash:
|
||||
raise WorkflowHashNotEqualError()
|
||||
|
||||
# Create draft workflow if not found
|
||||
if not workflow:
|
||||
workflow = Workflow(
|
||||
tenant_id=snippet.tenant_id,
|
||||
app_id=snippet.id,
|
||||
features="{}",
|
||||
type=WorkflowType.SNIPPET.value,
|
||||
version="draft",
|
||||
graph=json.dumps(graph),
|
||||
created_by=account.id,
|
||||
environment_variables=environment_variables,
|
||||
conversation_variables=conversation_variables,
|
||||
)
|
||||
db.session.add(workflow)
|
||||
db.session.flush()
|
||||
else:
|
||||
# Update existing draft workflow
|
||||
workflow.graph = json.dumps(graph)
|
||||
workflow.updated_by = account.id
|
||||
workflow.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
workflow.environment_variables = environment_variables
|
||||
workflow.conversation_variables = conversation_variables
|
||||
|
||||
# Update snippet's input_fields if provided
|
||||
if input_variables is not None:
|
||||
snippet.input_fields = json.dumps(input_variables)
|
||||
snippet.updated_by = account.id
|
||||
snippet.updated_at = datetime.now(UTC).replace(tzinfo=None)
|
||||
|
||||
db.session.commit()
|
||||
return workflow
|
||||
|
||||
def publish_workflow(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
snippet: CustomizedSnippet,
|
||||
account: Account,
|
||||
) -> Workflow:
|
||||
"""
|
||||
Publish the draft workflow as a new version.
|
||||
|
||||
:param session: Database session
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param account: Account making the change
|
||||
:return: Published Workflow
|
||||
:raises ValueError: If no draft workflow exists
|
||||
"""
|
||||
draft_workflow_stmt = select(Workflow).where(
|
||||
Workflow.tenant_id == snippet.tenant_id,
|
||||
Workflow.app_id == snippet.id,
|
||||
Workflow.type == WorkflowType.SNIPPET.value,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
draft_workflow = session.scalar(draft_workflow_stmt)
|
||||
if not draft_workflow:
|
||||
raise ValueError("No valid workflow found.")
|
||||
|
||||
# Create new published workflow
|
||||
workflow = Workflow.new(
|
||||
tenant_id=snippet.tenant_id,
|
||||
app_id=snippet.id,
|
||||
type=draft_workflow.type,
|
||||
version=str(datetime.now(UTC).replace(tzinfo=None)),
|
||||
graph=draft_workflow.graph,
|
||||
features=draft_workflow.features,
|
||||
created_by=account.id,
|
||||
environment_variables=draft_workflow.environment_variables,
|
||||
conversation_variables=draft_workflow.conversation_variables,
|
||||
marked_name="",
|
||||
marked_comment="",
|
||||
)
|
||||
session.add(workflow)
|
||||
|
||||
# Update snippet version
|
||||
snippet.version += 1
|
||||
snippet.is_published = True
|
||||
snippet.workflow_id = workflow.id
|
||||
snippet.updated_by = account.id
|
||||
session.add(snippet)
|
||||
|
||||
return workflow
|
||||
|
||||
def get_all_published_workflows(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
snippet: CustomizedSnippet,
|
||||
page: int,
|
||||
limit: int,
|
||||
) -> tuple[Sequence[Workflow], bool]:
|
||||
"""
|
||||
Get all published workflow versions for snippet.
|
||||
|
||||
:param session: Database session
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param page: Page number
|
||||
:param limit: Items per page
|
||||
:return: Tuple of (workflows list, has_more flag)
|
||||
"""
|
||||
if not snippet.workflow_id:
|
||||
return [], False
|
||||
|
||||
stmt = (
|
||||
select(Workflow)
|
||||
.where(
|
||||
Workflow.app_id == snippet.id,
|
||||
Workflow.type == WorkflowType.SNIPPET.value,
|
||||
Workflow.version != "draft",
|
||||
)
|
||||
.order_by(Workflow.version.desc())
|
||||
.limit(limit + 1)
|
||||
.offset((page - 1) * limit)
|
||||
)
|
||||
|
||||
workflows = list(session.scalars(stmt).all())
|
||||
has_more = len(workflows) > limit
|
||||
if has_more:
|
||||
workflows = workflows[:-1]
|
||||
|
||||
return workflows, has_more
|
||||
|
||||
# --- Default Block Configs ---
|
||||
|
||||
def get_default_block_configs(self) -> list[dict]:
|
||||
"""
|
||||
Get default block configurations for all node types.
|
||||
|
||||
:return: List of default configurations
|
||||
"""
|
||||
default_block_configs: list[dict[str, Any]] = []
|
||||
for node_class_mapping in NODE_TYPE_CLASSES_MAPPING.values():
|
||||
node_class = node_class_mapping[LATEST_VERSION]
|
||||
default_config = node_class.get_default_config()
|
||||
if default_config:
|
||||
default_block_configs.append(dict(default_config))
|
||||
|
||||
return default_block_configs
|
||||
|
||||
def get_default_block_config(self, node_type: str, filters: dict | None = None) -> Mapping[str, object] | None:
|
||||
"""
|
||||
Get default config for specific node type.
|
||||
|
||||
:param node_type: Node type string
|
||||
:param filters: Optional filters
|
||||
:return: Default configuration or None
|
||||
"""
|
||||
node_type_enum = NodeType(node_type)
|
||||
|
||||
if node_type_enum not in NODE_TYPE_CLASSES_MAPPING:
|
||||
return None
|
||||
|
||||
node_class = NODE_TYPE_CLASSES_MAPPING[node_type_enum][LATEST_VERSION]
|
||||
default_config = node_class.get_default_config(filters=filters)
|
||||
if not default_config:
|
||||
return None
|
||||
|
||||
return default_config
|
||||
|
||||
# --- Workflow Run Operations ---
|
||||
|
||||
def get_snippet_workflow_runs(
|
||||
self,
|
||||
*,
|
||||
snippet: CustomizedSnippet,
|
||||
args: dict,
|
||||
) -> InfiniteScrollPagination:
|
||||
"""
|
||||
Get paginated workflow runs for snippet.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param args: Request arguments (last_id, limit)
|
||||
:return: InfiniteScrollPagination result
|
||||
"""
|
||||
limit = int(args.get("limit", 20))
|
||||
last_id = args.get("last_id")
|
||||
|
||||
triggered_from_values = [
|
||||
WorkflowRunTriggeredFrom.DEBUGGING,
|
||||
]
|
||||
|
||||
return self._workflow_run_repo.get_paginated_workflow_runs(
|
||||
tenant_id=snippet.tenant_id,
|
||||
app_id=snippet.id,
|
||||
triggered_from=triggered_from_values,
|
||||
limit=limit,
|
||||
last_id=last_id,
|
||||
)
|
||||
|
||||
def get_snippet_workflow_run(
|
||||
self,
|
||||
*,
|
||||
snippet: CustomizedSnippet,
|
||||
run_id: str,
|
||||
) -> WorkflowRun | None:
|
||||
"""
|
||||
Get workflow run details.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param run_id: Workflow run ID
|
||||
:return: WorkflowRun or None
|
||||
"""
|
||||
return self._workflow_run_repo.get_workflow_run_by_id(
|
||||
tenant_id=snippet.tenant_id,
|
||||
app_id=snippet.id,
|
||||
run_id=run_id,
|
||||
)
|
||||
|
||||
def get_snippet_workflow_run_node_executions(
|
||||
self,
|
||||
*,
|
||||
snippet: CustomizedSnippet,
|
||||
run_id: str,
|
||||
) -> Sequence[WorkflowNodeExecutionModel]:
|
||||
"""
|
||||
Get workflow run node execution list.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param run_id: Workflow run ID
|
||||
:return: List of WorkflowNodeExecutionModel
|
||||
"""
|
||||
workflow_run = self.get_snippet_workflow_run(snippet=snippet, run_id=run_id)
|
||||
if not workflow_run:
|
||||
return []
|
||||
|
||||
node_executions = self._node_execution_service_repo.get_executions_by_workflow_run(
|
||||
tenant_id=snippet.tenant_id,
|
||||
app_id=snippet.id,
|
||||
workflow_run_id=workflow_run.id,
|
||||
)
|
||||
|
||||
return node_executions
|
||||
|
||||
# --- Node Execution Operations ---
|
||||
|
||||
def get_snippet_node_last_run(
|
||||
self,
|
||||
*,
|
||||
snippet: CustomizedSnippet,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
) -> WorkflowNodeExecutionModel | None:
|
||||
"""
|
||||
Get the most recent execution for a specific node in a snippet workflow.
|
||||
|
||||
:param snippet: CustomizedSnippet instance
|
||||
:param workflow: Workflow instance
|
||||
:param node_id: Node identifier
|
||||
:return: WorkflowNodeExecutionModel or None
|
||||
"""
|
||||
return self._node_execution_service_repo.get_node_last_execution(
|
||||
tenant_id=snippet.tenant_id,
|
||||
app_id=snippet.id,
|
||||
workflow_id=workflow.id,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
# --- Use Count ---
|
||||
|
||||
@staticmethod
|
||||
def increment_use_count(
|
||||
*,
|
||||
session: Session,
|
||||
snippet: CustomizedSnippet,
|
||||
) -> None:
|
||||
"""
|
||||
Increment the use_count when snippet is used.
|
||||
|
||||
:param session: Database session
|
||||
:param snippet: CustomizedSnippet instance
|
||||
"""
|
||||
snippet.use_count += 1
|
||||
session.add(snippet)
|
||||
@@ -277,7 +277,10 @@ describe('App Card Operations Flow', () => {
|
||||
}
|
||||
})
|
||||
|
||||
// -- Basic rendering --
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
describe('Card Rendering', () => {
|
||||
it('should render app name and description', () => {
|
||||
renderAppCard({ name: 'My AI Bot', description: 'An intelligent assistant' })
|
||||
|
||||
@@ -187,7 +187,10 @@ describe('App List Browsing Flow', () => {
|
||||
mockShowTagManagementModal = false
|
||||
})
|
||||
|
||||
// -- Loading and Empty states --
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
describe('Loading and Empty States', () => {
|
||||
it('should show skeleton cards during initial loading', () => {
|
||||
mockIsLoading = true
|
||||
|
||||
@@ -237,7 +237,6 @@ describe('Create App Flow', () => {
|
||||
mockShowTagManagementModal = false
|
||||
})
|
||||
|
||||
// -- NewAppCard rendering --
|
||||
describe('NewAppCard Rendering', () => {
|
||||
it('should render the "Create App" card with all options', () => {
|
||||
renderList()
|
||||
|
||||
@@ -12,7 +12,6 @@ import { afterEach, beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import DevelopMain from '@/app/components/develop'
|
||||
import { AppModeEnum, Theme } from '@/types/app'
|
||||
|
||||
// ---------- fake timers ----------
|
||||
beforeEach(() => {
|
||||
vi.useFakeTimers({ shouldAdvanceTime: true })
|
||||
})
|
||||
@@ -28,8 +27,6 @@ async function flushUI() {
|
||||
})
|
||||
}
|
||||
|
||||
// ---------- store mock ----------
|
||||
|
||||
let storeAppDetail: unknown
|
||||
|
||||
vi.mock('@/app/components/app/store', () => ({
|
||||
@@ -38,8 +35,6 @@ vi.mock('@/app/components/app/store', () => ({
|
||||
},
|
||||
}))
|
||||
|
||||
// ---------- Doc dependencies ----------
|
||||
|
||||
vi.mock('@/context/i18n', () => ({
|
||||
useLocale: () => 'en-US',
|
||||
}))
|
||||
@@ -48,11 +43,12 @@ vi.mock('@/hooks/use-theme', () => ({
|
||||
default: () => ({ theme: Theme.light }),
|
||||
}))
|
||||
|
||||
vi.mock('@/i18n-config/language', () => ({
|
||||
LanguagesSupported: ['en-US', 'zh-Hans', 'zh-Hant', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP'],
|
||||
}))
|
||||
|
||||
// ---------- SecretKeyModal dependencies ----------
|
||||
vi.mock('@/i18n-config/language', async (importOriginal) => {
|
||||
const actual = await importOriginal<typeof import('@/i18n-config/language')>()
|
||||
return {
|
||||
...actual,
|
||||
}
|
||||
})
|
||||
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => ({
|
||||
|
||||
@@ -11,7 +11,7 @@ import { RETRIEVE_METHOD } from '@/types/app'
|
||||
import Item from './index'
|
||||
|
||||
vi.mock('../settings-modal', () => ({
|
||||
default: ({ onSave, onCancel, currentDataset }: any) => (
|
||||
default: ({ onSave, onCancel, currentDataset }: { currentDataset: DataSet, onCancel: () => void, onSave: (newDataset: DataSet) => void }) => (
|
||||
<div>
|
||||
<div>Mock settings modal</div>
|
||||
<button onClick={() => onSave({ ...currentDataset, name: 'Updated dataset' })}>Save changes</button>
|
||||
@@ -177,7 +177,7 @@ describe('dataset-config/card-item', () => {
|
||||
expect(screen.getByRole('dialog')).toBeVisible()
|
||||
})
|
||||
|
||||
await user.click(screen.getByText('Save changes'))
|
||||
fireEvent.click(screen.getByText('Save changes'))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(onSave).toHaveBeenCalledWith(expect.objectContaining({ name: 'Updated dataset' }))
|
||||
|
||||
@@ -53,6 +53,10 @@ vi.mock('@/hooks/use-theme', () => ({
|
||||
|
||||
vi.mock('@/i18n-config/language', () => ({
|
||||
LanguagesSupported: ['en-US', 'zh-Hans', 'zh-Hant', 'pt-BR', 'es-ES', 'fr-FR', 'de-DE', 'ja-JP'],
|
||||
getDocLanguage: (locale: string) => {
|
||||
const map: Record<string, string> = { 'zh-Hans': 'zh', 'ja-JP': 'ja' }
|
||||
return map[locale] || 'en'
|
||||
},
|
||||
}))
|
||||
|
||||
describe('Doc', () => {
|
||||
@@ -63,7 +67,7 @@ describe('Doc', () => {
|
||||
prompt_variables: variables,
|
||||
},
|
||||
},
|
||||
})
|
||||
}) as unknown as Parameters<typeof Doc>[0]['appDetail']
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
@@ -123,13 +127,13 @@ describe('Doc', () => {
|
||||
|
||||
describe('null/undefined appDetail', () => {
|
||||
it('should render nothing when appDetail has no mode', () => {
|
||||
render(<Doc appDetail={{}} />)
|
||||
render(<Doc appDetail={{} as unknown as Parameters<typeof Doc>[0]['appDetail']} />)
|
||||
expect(screen.queryByTestId('template-completion-en')).not.toBeInTheDocument()
|
||||
expect(screen.queryByTestId('template-chat-en')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render nothing when appDetail is null', () => {
|
||||
render(<Doc appDetail={null} />)
|
||||
render(<Doc appDetail={null as unknown as Parameters<typeof Doc>[0]['appDetail']} />)
|
||||
expect(screen.queryByTestId('template-completion-en')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
199
web/app/components/develop/__tests__/toc-panel.spec.tsx
Normal file
199
web/app/components/develop/__tests__/toc-panel.spec.tsx
Normal file
@@ -0,0 +1,199 @@
|
||||
import type { TocItem } from '../hooks/use-doc-toc'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import TocPanel from '../toc-panel'
|
||||
|
||||
/**
|
||||
* Unit tests for the TocPanel presentational component.
|
||||
* Covers collapsed/expanded states, item rendering, active section, and callbacks.
|
||||
*/
|
||||
describe('TocPanel', () => {
|
||||
const defaultProps = {
|
||||
toc: [] as TocItem[],
|
||||
activeSection: '',
|
||||
isTocExpanded: false,
|
||||
onToggle: vi.fn(),
|
||||
onItemClick: vi.fn(),
|
||||
}
|
||||
|
||||
const sampleToc: TocItem[] = [
|
||||
{ href: '#introduction', text: 'Introduction' },
|
||||
{ href: '#authentication', text: 'Authentication' },
|
||||
{ href: '#endpoints', text: 'Endpoints' },
|
||||
]
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
// Covers collapsed state rendering
|
||||
describe('collapsed state', () => {
|
||||
it('should render expand button when collapsed', () => {
|
||||
render(<TocPanel {...defaultProps} />)
|
||||
|
||||
expect(screen.getByLabelText('Open table of contents')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should not render nav or toc items when collapsed', () => {
|
||||
render(<TocPanel {...defaultProps} toc={sampleToc} />)
|
||||
|
||||
expect(screen.queryByRole('navigation')).not.toBeInTheDocument()
|
||||
expect(screen.queryByText('Introduction')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onToggle(true) when expand button is clicked', () => {
|
||||
const onToggle = vi.fn()
|
||||
render(<TocPanel {...defaultProps} onToggle={onToggle} />)
|
||||
|
||||
fireEvent.click(screen.getByLabelText('Open table of contents'))
|
||||
|
||||
expect(onToggle).toHaveBeenCalledWith(true)
|
||||
})
|
||||
})
|
||||
|
||||
// Covers expanded state with empty toc
|
||||
describe('expanded state - empty', () => {
|
||||
it('should render nav with empty message when toc is empty', () => {
|
||||
render(<TocPanel {...defaultProps} isTocExpanded />)
|
||||
|
||||
expect(screen.getByRole('navigation')).toBeInTheDocument()
|
||||
expect(screen.getByText('appApi.develop.noContent')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render TOC header with title', () => {
|
||||
render(<TocPanel {...defaultProps} isTocExpanded />)
|
||||
|
||||
expect(screen.getByText('appApi.develop.toc')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should call onToggle(false) when close button is clicked', () => {
|
||||
const onToggle = vi.fn()
|
||||
render(<TocPanel {...defaultProps} isTocExpanded onToggle={onToggle} />)
|
||||
|
||||
fireEvent.click(screen.getByLabelText('Close'))
|
||||
|
||||
expect(onToggle).toHaveBeenCalledWith(false)
|
||||
})
|
||||
})
|
||||
|
||||
// Covers expanded state with toc items
|
||||
describe('expanded state - with items', () => {
|
||||
it('should render all toc items as links', () => {
|
||||
render(<TocPanel {...defaultProps} isTocExpanded toc={sampleToc} />)
|
||||
|
||||
expect(screen.getByText('Introduction')).toBeInTheDocument()
|
||||
expect(screen.getByText('Authentication')).toBeInTheDocument()
|
||||
expect(screen.getByText('Endpoints')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render links with correct href attributes', () => {
|
||||
render(<TocPanel {...defaultProps} isTocExpanded toc={sampleToc} />)
|
||||
|
||||
const links = screen.getAllByRole('link')
|
||||
expect(links).toHaveLength(3)
|
||||
expect(links[0]).toHaveAttribute('href', '#introduction')
|
||||
expect(links[1]).toHaveAttribute('href', '#authentication')
|
||||
expect(links[2]).toHaveAttribute('href', '#endpoints')
|
||||
})
|
||||
|
||||
it('should not render empty message when toc has items', () => {
|
||||
render(<TocPanel {...defaultProps} isTocExpanded toc={sampleToc} />)
|
||||
|
||||
expect(screen.queryByText('appApi.develop.noContent')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// Covers active section highlighting
|
||||
describe('active section', () => {
|
||||
it('should apply active style to the matching toc item', () => {
|
||||
render(
|
||||
<TocPanel {...defaultProps} isTocExpanded toc={sampleToc} activeSection="authentication" />,
|
||||
)
|
||||
|
||||
const activeLink = screen.getByText('Authentication').closest('a')
|
||||
expect(activeLink?.className).toContain('font-medium')
|
||||
expect(activeLink?.className).toContain('text-text-primary')
|
||||
})
|
||||
|
||||
it('should apply inactive style to non-matching items', () => {
|
||||
render(
|
||||
<TocPanel {...defaultProps} isTocExpanded toc={sampleToc} activeSection="authentication" />,
|
||||
)
|
||||
|
||||
const inactiveLink = screen.getByText('Introduction').closest('a')
|
||||
expect(inactiveLink?.className).toContain('text-text-tertiary')
|
||||
expect(inactiveLink?.className).not.toContain('font-medium')
|
||||
})
|
||||
|
||||
it('should apply active indicator dot to active item', () => {
|
||||
render(
|
||||
<TocPanel {...defaultProps} isTocExpanded toc={sampleToc} activeSection="endpoints" />,
|
||||
)
|
||||
|
||||
const activeLink = screen.getByText('Endpoints').closest('a')
|
||||
const activeDot = activeLink?.querySelector('span:first-child')
|
||||
expect(activeDot?.className).toContain('bg-text-accent')
|
||||
})
|
||||
})
|
||||
|
||||
// Covers click event delegation
|
||||
describe('item click handling', () => {
|
||||
it('should call onItemClick with the event and item when a link is clicked', () => {
|
||||
const onItemClick = vi.fn()
|
||||
render(
|
||||
<TocPanel {...defaultProps} isTocExpanded toc={sampleToc} onItemClick={onItemClick} />,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('Authentication'))
|
||||
|
||||
expect(onItemClick).toHaveBeenCalledTimes(1)
|
||||
expect(onItemClick).toHaveBeenCalledWith(
|
||||
expect.any(Object),
|
||||
{ href: '#authentication', text: 'Authentication' },
|
||||
)
|
||||
})
|
||||
|
||||
it('should call onItemClick for each clicked item independently', () => {
|
||||
const onItemClick = vi.fn()
|
||||
render(
|
||||
<TocPanel {...defaultProps} isTocExpanded toc={sampleToc} onItemClick={onItemClick} />,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('Introduction'))
|
||||
fireEvent.click(screen.getByText('Endpoints'))
|
||||
|
||||
expect(onItemClick).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
|
||||
// Covers edge cases
|
||||
describe('edge cases', () => {
|
||||
it('should handle single item toc', () => {
|
||||
const singleItem = [{ href: '#only', text: 'Only Section' }]
|
||||
render(<TocPanel {...defaultProps} isTocExpanded toc={singleItem} activeSection="only" />)
|
||||
|
||||
expect(screen.getByText('Only Section')).toBeInTheDocument()
|
||||
expect(screen.getAllByRole('link')).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('should handle toc items with empty text', () => {
|
||||
const emptyTextItem = [{ href: '#empty', text: '' }]
|
||||
render(<TocPanel {...defaultProps} isTocExpanded toc={emptyTextItem} />)
|
||||
|
||||
expect(screen.getAllByRole('link')).toHaveLength(1)
|
||||
})
|
||||
|
||||
it('should handle active section that does not match any item', () => {
|
||||
render(
|
||||
<TocPanel {...defaultProps} isTocExpanded toc={sampleToc} activeSection="nonexistent" />,
|
||||
)
|
||||
|
||||
// All items should be in inactive style
|
||||
const links = screen.getAllByRole('link')
|
||||
links.forEach((link) => {
|
||||
expect(link.className).toContain('text-text-tertiary')
|
||||
expect(link.className).not.toContain('font-medium')
|
||||
})
|
||||
})
|
||||
})
|
||||
})
|
||||
425
web/app/components/develop/__tests__/use-doc-toc.spec.ts
Normal file
425
web/app/components/develop/__tests__/use-doc-toc.spec.ts
Normal file
@@ -0,0 +1,425 @@
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { useDocToc } from '../hooks/use-doc-toc'
|
||||
|
||||
/**
|
||||
* Unit tests for the useDocToc custom hook.
|
||||
* Covers TOC extraction, viewport-based expansion, scroll tracking, and click handling.
|
||||
*/
|
||||
describe('useDocToc', () => {
|
||||
const defaultOptions = { appDetail: { mode: 'chat' }, locale: 'en-US' }
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
vi.useRealTimers()
|
||||
|
||||
Object.defineProperty(window, 'matchMedia', {
|
||||
writable: true,
|
||||
value: vi.fn().mockReturnValue({ matches: false }),
|
||||
})
|
||||
})
|
||||
|
||||
// Covers initial state values based on viewport width
|
||||
describe('initial state', () => {
|
||||
it('should set isTocExpanded to false on narrow viewport', () => {
|
||||
const { result } = renderHook(() => useDocToc(defaultOptions))
|
||||
|
||||
expect(result.current.isTocExpanded).toBe(false)
|
||||
expect(result.current.toc).toEqual([])
|
||||
expect(result.current.activeSection).toBe('')
|
||||
})
|
||||
|
||||
it('should set isTocExpanded to true on wide viewport', () => {
|
||||
Object.defineProperty(window, 'matchMedia', {
|
||||
writable: true,
|
||||
value: vi.fn().mockReturnValue({ matches: true }),
|
||||
})
|
||||
|
||||
const { result } = renderHook(() => useDocToc(defaultOptions))
|
||||
|
||||
expect(result.current.isTocExpanded).toBe(true)
|
||||
})
|
||||
})
|
||||
|
||||
// Covers TOC extraction from DOM article headings
|
||||
describe('TOC extraction', () => {
|
||||
it('should extract toc items from article h2 anchors', async () => {
|
||||
vi.useFakeTimers()
|
||||
const article = document.createElement('article')
|
||||
const h2 = document.createElement('h2')
|
||||
const anchor = document.createElement('a')
|
||||
anchor.href = '#section-1'
|
||||
anchor.textContent = 'Section 1'
|
||||
h2.appendChild(anchor)
|
||||
article.appendChild(h2)
|
||||
document.body.appendChild(article)
|
||||
|
||||
const { result } = renderHook(() => useDocToc(defaultOptions))
|
||||
|
||||
await act(async () => {
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
expect(result.current.toc).toEqual([
|
||||
{ href: '#section-1', text: 'Section 1' },
|
||||
])
|
||||
expect(result.current.activeSection).toBe('section-1')
|
||||
|
||||
document.body.removeChild(article)
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should return empty toc when no article exists', async () => {
|
||||
vi.useFakeTimers()
|
||||
const { result } = renderHook(() => useDocToc(defaultOptions))
|
||||
|
||||
await act(async () => {
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
expect(result.current.toc).toEqual([])
|
||||
expect(result.current.activeSection).toBe('')
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should skip h2 headings without anchors', async () => {
|
||||
vi.useFakeTimers()
|
||||
const article = document.createElement('article')
|
||||
const h2NoAnchor = document.createElement('h2')
|
||||
h2NoAnchor.textContent = 'No Anchor'
|
||||
article.appendChild(h2NoAnchor)
|
||||
|
||||
const h2WithAnchor = document.createElement('h2')
|
||||
const anchor = document.createElement('a')
|
||||
anchor.href = '#valid'
|
||||
anchor.textContent = 'Valid'
|
||||
h2WithAnchor.appendChild(anchor)
|
||||
article.appendChild(h2WithAnchor)
|
||||
|
||||
document.body.appendChild(article)
|
||||
|
||||
const { result } = renderHook(() => useDocToc(defaultOptions))
|
||||
|
||||
await act(async () => {
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
expect(result.current.toc).toHaveLength(1)
|
||||
expect(result.current.toc[0]).toEqual({ href: '#valid', text: 'Valid' })
|
||||
|
||||
document.body.removeChild(article)
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should re-extract toc when appDetail changes', async () => {
|
||||
vi.useFakeTimers()
|
||||
const article = document.createElement('article')
|
||||
document.body.appendChild(article)
|
||||
|
||||
const { result, rerender } = renderHook(
|
||||
props => useDocToc(props),
|
||||
{ initialProps: defaultOptions },
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
expect(result.current.toc).toEqual([])
|
||||
|
||||
// Add a heading, then change appDetail to trigger re-extraction
|
||||
const h2 = document.createElement('h2')
|
||||
const anchor = document.createElement('a')
|
||||
anchor.href = '#new-section'
|
||||
anchor.textContent = 'New Section'
|
||||
h2.appendChild(anchor)
|
||||
article.appendChild(h2)
|
||||
|
||||
rerender({ appDetail: { mode: 'workflow' }, locale: 'en-US' })
|
||||
|
||||
await act(async () => {
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
expect(result.current.toc).toHaveLength(1)
|
||||
|
||||
document.body.removeChild(article)
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should re-extract toc when locale changes', async () => {
|
||||
vi.useFakeTimers()
|
||||
const article = document.createElement('article')
|
||||
const h2 = document.createElement('h2')
|
||||
const anchor = document.createElement('a')
|
||||
anchor.href = '#sec'
|
||||
anchor.textContent = 'Sec'
|
||||
h2.appendChild(anchor)
|
||||
article.appendChild(h2)
|
||||
document.body.appendChild(article)
|
||||
|
||||
const { result, rerender } = renderHook(
|
||||
props => useDocToc(props),
|
||||
{ initialProps: defaultOptions },
|
||||
)
|
||||
|
||||
await act(async () => {
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
expect(result.current.toc).toHaveLength(1)
|
||||
|
||||
rerender({ appDetail: defaultOptions.appDetail, locale: 'zh-Hans' })
|
||||
|
||||
await act(async () => {
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
// Should still have the toc item after re-extraction
|
||||
expect(result.current.toc).toHaveLength(1)
|
||||
|
||||
document.body.removeChild(article)
|
||||
vi.useRealTimers()
|
||||
})
|
||||
})
|
||||
|
||||
// Covers manual toggle via setIsTocExpanded
|
||||
describe('setIsTocExpanded', () => {
|
||||
it('should toggle isTocExpanded state', () => {
|
||||
const { result } = renderHook(() => useDocToc(defaultOptions))
|
||||
|
||||
expect(result.current.isTocExpanded).toBe(false)
|
||||
|
||||
act(() => {
|
||||
result.current.setIsTocExpanded(true)
|
||||
})
|
||||
|
||||
expect(result.current.isTocExpanded).toBe(true)
|
||||
|
||||
act(() => {
|
||||
result.current.setIsTocExpanded(false)
|
||||
})
|
||||
|
||||
expect(result.current.isTocExpanded).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
// Covers smooth-scroll click handler
|
||||
describe('handleTocClick', () => {
|
||||
it('should prevent default and scroll to target element', () => {
|
||||
const scrollContainer = document.createElement('div')
|
||||
scrollContainer.className = 'overflow-auto'
|
||||
scrollContainer.scrollTo = vi.fn()
|
||||
document.body.appendChild(scrollContainer)
|
||||
|
||||
const target = document.createElement('div')
|
||||
target.id = 'target-section'
|
||||
Object.defineProperty(target, 'offsetTop', { value: 500 })
|
||||
scrollContainer.appendChild(target)
|
||||
|
||||
const { result } = renderHook(() => useDocToc(defaultOptions))
|
||||
|
||||
const mockEvent = { preventDefault: vi.fn() } as unknown as React.MouseEvent<HTMLAnchorElement>
|
||||
act(() => {
|
||||
result.current.handleTocClick(mockEvent, { href: '#target-section', text: 'Target' })
|
||||
})
|
||||
|
||||
expect(mockEvent.preventDefault).toHaveBeenCalled()
|
||||
expect(scrollContainer.scrollTo).toHaveBeenCalledWith({
|
||||
top: 420, // 500 - 80 (HEADER_OFFSET)
|
||||
behavior: 'smooth',
|
||||
})
|
||||
|
||||
document.body.removeChild(scrollContainer)
|
||||
})
|
||||
|
||||
it('should do nothing when target element does not exist', () => {
|
||||
const { result } = renderHook(() => useDocToc(defaultOptions))
|
||||
|
||||
const mockEvent = { preventDefault: vi.fn() } as unknown as React.MouseEvent<HTMLAnchorElement>
|
||||
act(() => {
|
||||
result.current.handleTocClick(mockEvent, { href: '#nonexistent', text: 'Missing' })
|
||||
})
|
||||
|
||||
expect(mockEvent.preventDefault).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
// Covers scroll-based active section tracking
|
||||
describe('scroll tracking', () => {
|
||||
// Helper: set up DOM with scroll container, article headings, and matching target elements
|
||||
const setupScrollDOM = (sections: Array<{ id: string, text: string, top: number }>) => {
|
||||
const scrollContainer = document.createElement('div')
|
||||
scrollContainer.className = 'overflow-auto'
|
||||
document.body.appendChild(scrollContainer)
|
||||
|
||||
const article = document.createElement('article')
|
||||
sections.forEach(({ id, text, top }) => {
|
||||
// Heading with anchor for TOC extraction
|
||||
const h2 = document.createElement('h2')
|
||||
const anchor = document.createElement('a')
|
||||
anchor.href = `#${id}`
|
||||
anchor.textContent = text
|
||||
h2.appendChild(anchor)
|
||||
article.appendChild(h2)
|
||||
|
||||
// Target element for scroll tracking
|
||||
const target = document.createElement('div')
|
||||
target.id = id
|
||||
target.getBoundingClientRect = vi.fn().mockReturnValue({ top })
|
||||
scrollContainer.appendChild(target)
|
||||
})
|
||||
document.body.appendChild(article)
|
||||
|
||||
return {
|
||||
scrollContainer,
|
||||
article,
|
||||
cleanup: () => {
|
||||
document.body.removeChild(scrollContainer)
|
||||
document.body.removeChild(article)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
it('should register scroll listener when toc has items', async () => {
|
||||
vi.useFakeTimers()
|
||||
const { scrollContainer, cleanup } = setupScrollDOM([
|
||||
{ id: 'sec-a', text: 'Section A', top: 0 },
|
||||
])
|
||||
const addSpy = vi.spyOn(scrollContainer, 'addEventListener')
|
||||
const removeSpy = vi.spyOn(scrollContainer, 'removeEventListener')
|
||||
|
||||
const { unmount } = renderHook(() => useDocToc(defaultOptions))
|
||||
|
||||
await act(async () => {
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
expect(addSpy).toHaveBeenCalledWith('scroll', expect.any(Function))
|
||||
|
||||
unmount()
|
||||
|
||||
expect(removeSpy).toHaveBeenCalledWith('scroll', expect.any(Function))
|
||||
|
||||
cleanup()
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should update activeSection when scrolling past a section', async () => {
|
||||
vi.useFakeTimers()
|
||||
// innerHeight/2 = 384 in jsdom (default 768), so top <= 384 means "scrolled past"
|
||||
const { scrollContainer, cleanup } = setupScrollDOM([
|
||||
{ id: 'intro', text: 'Intro', top: 100 },
|
||||
{ id: 'details', text: 'Details', top: 600 },
|
||||
])
|
||||
|
||||
const { result } = renderHook(() => useDocToc(defaultOptions))
|
||||
|
||||
// Extract TOC items
|
||||
await act(async () => {
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
expect(result.current.toc).toHaveLength(2)
|
||||
expect(result.current.activeSection).toBe('intro')
|
||||
|
||||
// Fire scroll — 'intro' (top=100) is above midpoint, 'details' (top=600) is below
|
||||
await act(async () => {
|
||||
scrollContainer.dispatchEvent(new Event('scroll'))
|
||||
})
|
||||
|
||||
expect(result.current.activeSection).toBe('intro')
|
||||
|
||||
cleanup()
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should track the last section above the viewport midpoint', async () => {
|
||||
vi.useFakeTimers()
|
||||
const { scrollContainer, cleanup } = setupScrollDOM([
|
||||
{ id: 'sec-1', text: 'Section 1', top: 50 },
|
||||
{ id: 'sec-2', text: 'Section 2', top: 200 },
|
||||
{ id: 'sec-3', text: 'Section 3', top: 800 },
|
||||
])
|
||||
|
||||
const { result } = renderHook(() => useDocToc(defaultOptions))
|
||||
|
||||
await act(async () => {
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
// Fire scroll — sec-1 (top=50) and sec-2 (top=200) are above midpoint (384),
|
||||
// sec-3 (top=800) is below. The last one above midpoint wins.
|
||||
await act(async () => {
|
||||
scrollContainer.dispatchEvent(new Event('scroll'))
|
||||
})
|
||||
|
||||
expect(result.current.activeSection).toBe('sec-2')
|
||||
|
||||
cleanup()
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should not update activeSection when no section is above midpoint', async () => {
|
||||
vi.useFakeTimers()
|
||||
const { scrollContainer, cleanup } = setupScrollDOM([
|
||||
{ id: 'far-away', text: 'Far Away', top: 1000 },
|
||||
])
|
||||
|
||||
const { result } = renderHook(() => useDocToc(defaultOptions))
|
||||
|
||||
await act(async () => {
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
// Initial activeSection is set by extraction
|
||||
const initialSection = result.current.activeSection
|
||||
|
||||
await act(async () => {
|
||||
scrollContainer.dispatchEvent(new Event('scroll'))
|
||||
})
|
||||
|
||||
// Should not change since the element is below midpoint
|
||||
expect(result.current.activeSection).toBe(initialSection)
|
||||
|
||||
cleanup()
|
||||
vi.useRealTimers()
|
||||
})
|
||||
|
||||
it('should handle elements not found in DOM during scroll', async () => {
|
||||
vi.useFakeTimers()
|
||||
const scrollContainer = document.createElement('div')
|
||||
scrollContainer.className = 'overflow-auto'
|
||||
document.body.appendChild(scrollContainer)
|
||||
|
||||
// Article with heading but NO matching target element by id
|
||||
const article = document.createElement('article')
|
||||
const h2 = document.createElement('h2')
|
||||
const anchor = document.createElement('a')
|
||||
anchor.href = '#missing-target'
|
||||
anchor.textContent = 'Missing'
|
||||
h2.appendChild(anchor)
|
||||
article.appendChild(h2)
|
||||
document.body.appendChild(article)
|
||||
|
||||
const { result } = renderHook(() => useDocToc(defaultOptions))
|
||||
|
||||
await act(async () => {
|
||||
vi.runAllTimers()
|
||||
})
|
||||
|
||||
const initialSection = result.current.activeSection
|
||||
|
||||
// Scroll fires but getElementById returns null — no crash, no change
|
||||
await act(async () => {
|
||||
scrollContainer.dispatchEvent(new Event('scroll'))
|
||||
})
|
||||
|
||||
expect(result.current.activeSection).toBe(initialSection)
|
||||
|
||||
document.body.removeChild(scrollContainer)
|
||||
document.body.removeChild(article)
|
||||
vi.useRealTimers()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -1,12 +1,13 @@
|
||||
'use client'
|
||||
import { RiCloseLine, RiListUnordered } from '@remixicon/react'
|
||||
import { useEffect, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import type { ComponentType } from 'react'
|
||||
import type { App, AppSSO } from '@/types/app'
|
||||
import { useMemo } from 'react'
|
||||
import { useLocale } from '@/context/i18n'
|
||||
import useTheme from '@/hooks/use-theme'
|
||||
import { LanguagesSupported } from '@/i18n-config/language'
|
||||
import { getDocLanguage } from '@/i18n-config/language'
|
||||
import { AppModeEnum, Theme } from '@/types/app'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useDocToc } from './hooks/use-doc-toc'
|
||||
import TemplateEn from './template/template.en.mdx'
|
||||
import TemplateJa from './template/template.ja.mdx'
|
||||
import TemplateZh from './template/template.zh.mdx'
|
||||
@@ -19,225 +20,75 @@ import TemplateChatZh from './template/template_chat.zh.mdx'
|
||||
import TemplateWorkflowEn from './template/template_workflow.en.mdx'
|
||||
import TemplateWorkflowJa from './template/template_workflow.ja.mdx'
|
||||
import TemplateWorkflowZh from './template/template_workflow.zh.mdx'
|
||||
import TocPanel from './toc-panel'
|
||||
|
||||
type AppDetail = App & Partial<AppSSO>
|
||||
type PromptVariable = { key: string, name: string }
|
||||
|
||||
type IDocProps = {
|
||||
appDetail: any
|
||||
appDetail: AppDetail
|
||||
}
|
||||
|
||||
// Shared props shape for all MDX template components
|
||||
type TemplateProps = {
|
||||
appDetail: AppDetail
|
||||
variables: PromptVariable[]
|
||||
inputs: Record<string, string>
|
||||
}
|
||||
|
||||
// Lookup table: [appMode][docLanguage] → template component
|
||||
// MDX components accept arbitrary props at runtime but expose a narrow static type,
|
||||
// so we assert the map type to allow passing TemplateProps when rendering.
|
||||
const TEMPLATE_MAP = {
|
||||
[AppModeEnum.CHAT]: { zh: TemplateChatZh, ja: TemplateChatJa, en: TemplateChatEn },
|
||||
[AppModeEnum.AGENT_CHAT]: { zh: TemplateChatZh, ja: TemplateChatJa, en: TemplateChatEn },
|
||||
[AppModeEnum.ADVANCED_CHAT]: { zh: TemplateAdvancedChatZh, ja: TemplateAdvancedChatJa, en: TemplateAdvancedChatEn },
|
||||
[AppModeEnum.WORKFLOW]: { zh: TemplateWorkflowZh, ja: TemplateWorkflowJa, en: TemplateWorkflowEn },
|
||||
[AppModeEnum.COMPLETION]: { zh: TemplateZh, ja: TemplateJa, en: TemplateEn },
|
||||
} as Record<string, Record<string, ComponentType<TemplateProps>>>
|
||||
|
||||
const resolveTemplate = (mode: string | undefined, locale: string): ComponentType<TemplateProps> | null => {
|
||||
if (!mode)
|
||||
return null
|
||||
const langTemplates = TEMPLATE_MAP[mode]
|
||||
if (!langTemplates)
|
||||
return null
|
||||
const docLang = getDocLanguage(locale)
|
||||
return langTemplates[docLang] ?? langTemplates.en ?? null
|
||||
}
|
||||
|
||||
const Doc = ({ appDetail }: IDocProps) => {
|
||||
const locale = useLocale()
|
||||
const { t } = useTranslation()
|
||||
const [toc, setToc] = useState<Array<{ href: string, text: string }>>([])
|
||||
const [isTocExpanded, setIsTocExpanded] = useState(false)
|
||||
const [activeSection, setActiveSection] = useState<string>('')
|
||||
const { theme } = useTheme()
|
||||
const { toc, isTocExpanded, setIsTocExpanded, activeSection, handleTocClick } = useDocToc({ appDetail, locale })
|
||||
|
||||
const variables = appDetail?.model_config?.configs?.prompt_variables || []
|
||||
const inputs = variables.reduce((res: any, variable: any) => {
|
||||
// model_config.configs.prompt_variables exists in the raw API response but is not modeled in ModelConfig type
|
||||
const variables: PromptVariable[] = (
|
||||
appDetail?.model_config as unknown as Record<string, Record<string, PromptVariable[]>> | undefined
|
||||
)?.configs?.prompt_variables ?? []
|
||||
const inputs = variables.reduce<Record<string, string>>((res, variable) => {
|
||||
res[variable.key] = variable.name || ''
|
||||
return res
|
||||
}, {})
|
||||
|
||||
useEffect(() => {
|
||||
const mediaQuery = window.matchMedia('(min-width: 1280px)')
|
||||
setIsTocExpanded(mediaQuery.matches)
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
const extractTOC = () => {
|
||||
const article = document.querySelector('article')
|
||||
if (article) {
|
||||
const headings = article.querySelectorAll('h2')
|
||||
const tocItems = Array.from(headings).map((heading) => {
|
||||
const anchor = heading.querySelector('a')
|
||||
if (anchor) {
|
||||
return {
|
||||
href: anchor.getAttribute('href') || '',
|
||||
text: anchor.textContent || '',
|
||||
}
|
||||
}
|
||||
return null
|
||||
}).filter((item): item is { href: string, text: string } => item !== null)
|
||||
setToc(tocItems)
|
||||
if (tocItems.length > 0)
|
||||
setActiveSection(tocItems[0].href.replace('#', ''))
|
||||
}
|
||||
}
|
||||
|
||||
setTimeout(extractTOC, 0)
|
||||
}, [appDetail, locale])
|
||||
|
||||
useEffect(() => {
|
||||
const handleScroll = () => {
|
||||
const scrollContainer = document.querySelector('.overflow-auto')
|
||||
if (!scrollContainer || toc.length === 0)
|
||||
return
|
||||
|
||||
let currentSection = ''
|
||||
toc.forEach((item) => {
|
||||
const targetId = item.href.replace('#', '')
|
||||
const element = document.getElementById(targetId)
|
||||
if (element) {
|
||||
const rect = element.getBoundingClientRect()
|
||||
if (rect.top <= window.innerHeight / 2)
|
||||
currentSection = targetId
|
||||
}
|
||||
})
|
||||
|
||||
if (currentSection && currentSection !== activeSection)
|
||||
setActiveSection(currentSection)
|
||||
}
|
||||
|
||||
const scrollContainer = document.querySelector('.overflow-auto')
|
||||
if (scrollContainer) {
|
||||
scrollContainer.addEventListener('scroll', handleScroll)
|
||||
handleScroll()
|
||||
return () => scrollContainer.removeEventListener('scroll', handleScroll)
|
||||
}
|
||||
}, [toc, activeSection])
|
||||
|
||||
const handleTocClick = (e: React.MouseEvent<HTMLAnchorElement>, item: { href: string, text: string }) => {
|
||||
e.preventDefault()
|
||||
const targetId = item.href.replace('#', '')
|
||||
const element = document.getElementById(targetId)
|
||||
if (element) {
|
||||
const scrollContainer = document.querySelector('.overflow-auto')
|
||||
if (scrollContainer) {
|
||||
const headerOffset = 80
|
||||
const elementTop = element.offsetTop - headerOffset
|
||||
scrollContainer.scrollTo({
|
||||
top: elementTop,
|
||||
behavior: 'smooth',
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const Template = useMemo(() => {
|
||||
if (appDetail?.mode === AppModeEnum.CHAT || appDetail?.mode === AppModeEnum.AGENT_CHAT) {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateChatZh appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateChatJa appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
default:
|
||||
return <TemplateChatEn appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
}
|
||||
}
|
||||
if (appDetail?.mode === AppModeEnum.ADVANCED_CHAT) {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateAdvancedChatZh appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateAdvancedChatJa appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
default:
|
||||
return <TemplateAdvancedChatEn appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
}
|
||||
}
|
||||
if (appDetail?.mode === AppModeEnum.WORKFLOW) {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateWorkflowZh appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateWorkflowJa appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
default:
|
||||
return <TemplateWorkflowEn appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
}
|
||||
}
|
||||
if (appDetail?.mode === AppModeEnum.COMPLETION) {
|
||||
switch (locale) {
|
||||
case LanguagesSupported[1]:
|
||||
return <TemplateZh appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
case LanguagesSupported[7]:
|
||||
return <TemplateJa appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
default:
|
||||
return <TemplateEn appDetail={appDetail} variables={variables} inputs={inputs} />
|
||||
}
|
||||
}
|
||||
return null
|
||||
}, [appDetail, locale, variables, inputs])
|
||||
const TemplateComponent = useMemo(
|
||||
() => resolveTemplate(appDetail?.mode, locale),
|
||||
[appDetail?.mode, locale],
|
||||
)
|
||||
|
||||
return (
|
||||
<div className="flex">
|
||||
<div className={`fixed right-20 top-32 z-10 transition-all duration-150 ease-out ${isTocExpanded ? 'w-[280px]' : 'w-11'}`}>
|
||||
{isTocExpanded
|
||||
? (
|
||||
<nav className="toc flex max-h-[calc(100vh-150px)] w-full flex-col overflow-hidden rounded-xl border-[0.5px] border-components-panel-border bg-background-default-hover shadow-xl">
|
||||
<div className="relative z-10 flex items-center justify-between border-b border-components-panel-border-subtle bg-background-default-hover px-4 py-2.5">
|
||||
<span className="text-xs font-medium uppercase tracking-wide text-text-tertiary">
|
||||
{t('develop.toc', { ns: 'appApi' })}
|
||||
</span>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setIsTocExpanded(false)}
|
||||
className="group flex h-6 w-6 items-center justify-center rounded-md transition-colors hover:bg-state-base-hover"
|
||||
aria-label="Close"
|
||||
>
|
||||
<RiCloseLine className="h-3 w-3 text-text-quaternary transition-colors group-hover:text-text-secondary" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="from-components-panel-border-subtle/20 pointer-events-none absolute left-0 right-0 top-[41px] z-10 h-2 bg-gradient-to-b to-transparent"></div>
|
||||
<div className="pointer-events-none absolute left-0 right-0 top-[43px] z-10 h-3 bg-gradient-to-b from-background-default-hover to-transparent"></div>
|
||||
|
||||
<div className="relative flex-1 overflow-y-auto px-3 py-3 pt-1">
|
||||
{toc.length === 0
|
||||
? (
|
||||
<div className="px-2 py-8 text-center text-xs text-text-quaternary">
|
||||
{t('develop.noContent', { ns: 'appApi' })}
|
||||
</div>
|
||||
)
|
||||
: (
|
||||
<ul className="space-y-0.5">
|
||||
{toc.map((item, index) => {
|
||||
const isActive = activeSection === item.href.replace('#', '')
|
||||
return (
|
||||
<li key={index}>
|
||||
<a
|
||||
href={item.href}
|
||||
onClick={e => handleTocClick(e, item)}
|
||||
className={cn(
|
||||
'group relative flex items-center rounded-md px-3 py-2 text-[13px] transition-all duration-200',
|
||||
isActive
|
||||
? 'bg-state-base-hover font-medium text-text-primary'
|
||||
: 'text-text-tertiary hover:bg-state-base-hover hover:text-text-secondary',
|
||||
)}
|
||||
>
|
||||
<span
|
||||
className={cn(
|
||||
'mr-2 h-1.5 w-1.5 rounded-full transition-all duration-200',
|
||||
isActive
|
||||
? 'scale-100 bg-text-accent'
|
||||
: 'scale-75 bg-components-panel-border',
|
||||
)}
|
||||
/>
|
||||
<span className="flex-1 truncate">
|
||||
{item.text}
|
||||
</span>
|
||||
</a>
|
||||
</li>
|
||||
)
|
||||
})}
|
||||
</ul>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="pointer-events-none absolute bottom-0 left-0 right-0 z-10 h-4 rounded-b-xl bg-gradient-to-t from-background-default-hover to-transparent"></div>
|
||||
</nav>
|
||||
)
|
||||
: (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => setIsTocExpanded(true)}
|
||||
className="group flex h-11 w-11 items-center justify-center rounded-full border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg transition-all duration-150 hover:bg-background-default-hover hover:shadow-xl"
|
||||
aria-label="Open table of contents"
|
||||
>
|
||||
<RiListUnordered className="h-5 w-5 text-text-tertiary transition-colors group-hover:text-text-secondary" />
|
||||
</button>
|
||||
)}
|
||||
<TocPanel
|
||||
toc={toc}
|
||||
activeSection={activeSection}
|
||||
isTocExpanded={isTocExpanded}
|
||||
onToggle={setIsTocExpanded}
|
||||
onItemClick={handleTocClick}
|
||||
/>
|
||||
</div>
|
||||
<article className={cn('prose-xl prose', theme === Theme.dark && 'prose-invert')}>
|
||||
{Template}
|
||||
{TemplateComponent && <TemplateComponent appDetail={appDetail} variables={variables} inputs={inputs} />}
|
||||
</article>
|
||||
</div>
|
||||
)
|
||||
|
||||
115
web/app/components/develop/hooks/use-doc-toc.ts
Normal file
115
web/app/components/develop/hooks/use-doc-toc.ts
Normal file
@@ -0,0 +1,115 @@
|
||||
import { useCallback, useEffect, useState } from 'react'
|
||||
|
||||
export type TocItem = {
|
||||
href: string
|
||||
text: string
|
||||
}
|
||||
|
||||
type UseDocTocOptions = {
|
||||
appDetail: Record<string, unknown> | null
|
||||
locale: string
|
||||
}
|
||||
|
||||
const HEADER_OFFSET = 80
|
||||
const SCROLL_CONTAINER_SELECTOR = '.overflow-auto'
|
||||
|
||||
const getTargetId = (href: string) => href.replace('#', '')
|
||||
|
||||
/**
|
||||
* Extract heading anchors from the rendered <article> as TOC items.
|
||||
*/
|
||||
const extractTocFromArticle = (): TocItem[] => {
|
||||
const article = document.querySelector('article')
|
||||
if (!article)
|
||||
return []
|
||||
|
||||
return Array.from(article.querySelectorAll('h2'))
|
||||
.map((heading) => {
|
||||
const anchor = heading.querySelector('a')
|
||||
if (!anchor)
|
||||
return null
|
||||
return {
|
||||
href: anchor.getAttribute('href') || '',
|
||||
text: anchor.textContent || '',
|
||||
}
|
||||
})
|
||||
.filter((item): item is TocItem => item !== null)
|
||||
}
|
||||
|
||||
/**
|
||||
* Custom hook that manages table-of-contents state:
|
||||
* - Extracts TOC items from rendered headings
|
||||
* - Tracks the active section on scroll
|
||||
* - Auto-expands the panel on wide viewports
|
||||
*/
|
||||
export const useDocToc = ({ appDetail, locale }: UseDocTocOptions) => {
|
||||
const [toc, setToc] = useState<TocItem[]>([])
|
||||
const [isTocExpanded, setIsTocExpanded] = useState(() => {
|
||||
if (typeof window === 'undefined')
|
||||
return false
|
||||
return window.matchMedia('(min-width: 1280px)').matches
|
||||
})
|
||||
const [activeSection, setActiveSection] = useState<string>('')
|
||||
|
||||
// Re-extract TOC items whenever the doc content changes
|
||||
useEffect(() => {
|
||||
const timer = setTimeout(() => {
|
||||
const tocItems = extractTocFromArticle()
|
||||
setToc(tocItems)
|
||||
if (tocItems.length > 0)
|
||||
setActiveSection(getTargetId(tocItems[0].href))
|
||||
}, 0)
|
||||
return () => clearTimeout(timer)
|
||||
}, [appDetail, locale])
|
||||
|
||||
// Track active section based on scroll position
|
||||
useEffect(() => {
|
||||
const scrollContainer = document.querySelector(SCROLL_CONTAINER_SELECTOR)
|
||||
if (!scrollContainer || toc.length === 0)
|
||||
return
|
||||
|
||||
const handleScroll = () => {
|
||||
let currentSection = ''
|
||||
for (const item of toc) {
|
||||
const targetId = getTargetId(item.href)
|
||||
const element = document.getElementById(targetId)
|
||||
if (element) {
|
||||
const rect = element.getBoundingClientRect()
|
||||
if (rect.top <= window.innerHeight / 2)
|
||||
currentSection = targetId
|
||||
}
|
||||
}
|
||||
|
||||
if (currentSection && currentSection !== activeSection)
|
||||
setActiveSection(currentSection)
|
||||
}
|
||||
|
||||
scrollContainer.addEventListener('scroll', handleScroll)
|
||||
return () => scrollContainer.removeEventListener('scroll', handleScroll)
|
||||
}, [toc, activeSection])
|
||||
|
||||
// Smooth-scroll to a TOC target on click
|
||||
const handleTocClick = useCallback((e: React.MouseEvent<HTMLAnchorElement>, item: TocItem) => {
|
||||
e.preventDefault()
|
||||
const targetId = getTargetId(item.href)
|
||||
const element = document.getElementById(targetId)
|
||||
if (!element)
|
||||
return
|
||||
|
||||
const scrollContainer = document.querySelector(SCROLL_CONTAINER_SELECTOR)
|
||||
if (scrollContainer) {
|
||||
scrollContainer.scrollTo({
|
||||
top: element.offsetTop - HEADER_OFFSET,
|
||||
behavior: 'smooth',
|
||||
})
|
||||
}
|
||||
}, [])
|
||||
|
||||
return {
|
||||
toc,
|
||||
isTocExpanded,
|
||||
setIsTocExpanded,
|
||||
activeSection,
|
||||
handleTocClick,
|
||||
}
|
||||
}
|
||||
96
web/app/components/develop/toc-panel.tsx
Normal file
96
web/app/components/develop/toc-panel.tsx
Normal file
@@ -0,0 +1,96 @@
|
||||
'use client'
|
||||
import type { TocItem } from './hooks/use-doc-toc'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type TocPanelProps = {
|
||||
toc: TocItem[]
|
||||
activeSection: string
|
||||
isTocExpanded: boolean
|
||||
onToggle: (expanded: boolean) => void
|
||||
onItemClick: (e: React.MouseEvent<HTMLAnchorElement>, item: TocItem) => void
|
||||
}
|
||||
|
||||
const TocPanel = ({ toc, activeSection, isTocExpanded, onToggle, onItemClick }: TocPanelProps) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
if (!isTocExpanded) {
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => onToggle(true)}
|
||||
className="group flex h-11 w-11 items-center justify-center rounded-full border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg transition-all duration-150 hover:bg-background-default-hover hover:shadow-xl"
|
||||
aria-label="Open table of contents"
|
||||
>
|
||||
<span className="i-ri-list-unordered h-5 w-5 text-text-tertiary transition-colors group-hover:text-text-secondary" />
|
||||
</button>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<nav className="toc flex max-h-[calc(100vh-150px)] w-full flex-col overflow-hidden rounded-xl border-[0.5px] border-components-panel-border bg-background-default-hover shadow-xl">
|
||||
<div className="relative z-10 flex items-center justify-between border-b border-components-panel-border-subtle bg-background-default-hover px-4 py-2.5">
|
||||
<span className="text-xs font-medium uppercase tracking-wide text-text-tertiary">
|
||||
{t('develop.toc', { ns: 'appApi' })}
|
||||
</span>
|
||||
<button
|
||||
type="button"
|
||||
onClick={() => onToggle(false)}
|
||||
className="group flex h-6 w-6 items-center justify-center rounded-md transition-colors hover:bg-state-base-hover"
|
||||
aria-label="Close"
|
||||
>
|
||||
<span className="i-ri-close-line h-3 w-3 text-text-quaternary transition-colors group-hover:text-text-secondary" />
|
||||
</button>
|
||||
</div>
|
||||
|
||||
<div className="from-components-panel-border-subtle/20 pointer-events-none absolute left-0 right-0 top-[41px] z-10 h-2 bg-gradient-to-b to-transparent"></div>
|
||||
<div className="pointer-events-none absolute left-0 right-0 top-[43px] z-10 h-3 bg-gradient-to-b from-background-default-hover to-transparent"></div>
|
||||
|
||||
<div className="relative flex-1 overflow-y-auto px-3 py-3 pt-1">
|
||||
{toc.length === 0
|
||||
? (
|
||||
<div className="px-2 py-8 text-center text-xs text-text-quaternary">
|
||||
{t('develop.noContent', { ns: 'appApi' })}
|
||||
</div>
|
||||
)
|
||||
: (
|
||||
<ul className="space-y-0.5">
|
||||
{toc.map((item) => {
|
||||
const isActive = activeSection === item.href.replace('#', '')
|
||||
return (
|
||||
<li key={item.href}>
|
||||
<a
|
||||
href={item.href}
|
||||
onClick={e => onItemClick(e, item)}
|
||||
className={cn(
|
||||
'group relative flex items-center rounded-md px-3 py-2 text-[13px] transition-all duration-200',
|
||||
isActive
|
||||
? 'bg-state-base-hover font-medium text-text-primary'
|
||||
: 'text-text-tertiary hover:bg-state-base-hover hover:text-text-secondary',
|
||||
)}
|
||||
>
|
||||
<span
|
||||
className={cn(
|
||||
'mr-2 h-1.5 w-1.5 rounded-full transition-all duration-200',
|
||||
isActive
|
||||
? 'scale-100 bg-text-accent'
|
||||
: 'scale-75 bg-components-panel-border',
|
||||
)}
|
||||
/>
|
||||
<span className="flex-1 truncate">
|
||||
{item.text}
|
||||
</span>
|
||||
</a>
|
||||
</li>
|
||||
)
|
||||
})}
|
||||
</ul>
|
||||
)}
|
||||
</div>
|
||||
|
||||
<div className="pointer-events-none absolute bottom-0 left-0 right-0 z-10 h-4 rounded-b-xl bg-gradient-to-t from-background-default-hover to-transparent"></div>
|
||||
</nav>
|
||||
)
|
||||
}
|
||||
|
||||
export default TocPanel
|
||||
@@ -61,8 +61,8 @@ vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: () => ({}),
|
||||
}))
|
||||
|
||||
// Mock pluginInstallLimit
|
||||
vi.mock('../../../hooks/use-install-plugin-limit', () => ({
|
||||
// Mock pluginInstallLimit (imported by the useInstallMultiState hook via @/ path)
|
||||
vi.mock('@/app/components/plugins/install-plugin/hooks/use-install-plugin-limit', () => ({
|
||||
pluginInstallLimit: () => ({ canInstall: true }),
|
||||
}))
|
||||
|
||||
|
||||
@@ -0,0 +1,568 @@
|
||||
import type { Dependency, GitHubItemAndMarketPlaceDependency, PackageDependency, Plugin, VersionInfo } from '@/app/components/plugins/types'
|
||||
import { act, renderHook, waitFor } from '@testing-library/react'
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { PluginCategoryEnum } from '@/app/components/plugins/types'
|
||||
import { getPluginKey, useInstallMultiState } from '../use-install-multi-state'
|
||||
|
||||
let mockMarketplaceData: ReturnType<typeof createMarketplaceApiData> | null = null
|
||||
let mockMarketplaceError: Error | null = null
|
||||
let mockInstalledInfo: Record<string, VersionInfo> = {}
|
||||
let mockCanInstall = true
|
||||
|
||||
vi.mock('@/service/use-plugins', () => ({
|
||||
useFetchPluginsInMarketPlaceByInfo: () => ({
|
||||
isLoading: false,
|
||||
data: mockMarketplaceData,
|
||||
error: mockMarketplaceError,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/install-plugin/hooks/use-check-installed', () => ({
|
||||
default: () => ({
|
||||
installedInfo: mockInstalledInfo,
|
||||
}),
|
||||
}))
|
||||
|
||||
vi.mock('@/context/global-public-context', () => ({
|
||||
useGlobalPublicStore: () => ({}),
|
||||
}))
|
||||
|
||||
vi.mock('@/app/components/plugins/install-plugin/hooks/use-install-plugin-limit', () => ({
|
||||
pluginInstallLimit: () => ({ canInstall: mockCanInstall }),
|
||||
}))
|
||||
|
||||
const createMockPlugin = (overrides: Partial<Plugin> = {}): Plugin => ({
|
||||
type: 'plugin',
|
||||
org: 'test-org',
|
||||
name: 'Test Plugin',
|
||||
plugin_id: 'test-plugin-id',
|
||||
version: '1.0.0',
|
||||
latest_version: '1.0.0',
|
||||
latest_package_identifier: 'test-pkg-id',
|
||||
icon: 'icon.png',
|
||||
verified: true,
|
||||
label: { 'en-US': 'Test Plugin' },
|
||||
brief: { 'en-US': 'Brief' },
|
||||
description: { 'en-US': 'Description' },
|
||||
introduction: 'Intro',
|
||||
repository: 'https://github.com/test/plugin',
|
||||
category: PluginCategoryEnum.tool,
|
||||
install_count: 100,
|
||||
endpoint: { settings: [] },
|
||||
tags: [],
|
||||
badges: [],
|
||||
verification: { authorized_category: 'community' },
|
||||
from: 'marketplace',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createPackageDependency = (index: number) => ({
|
||||
type: 'package',
|
||||
value: {
|
||||
unique_identifier: `package-plugin-${index}-uid`,
|
||||
manifest: {
|
||||
plugin_unique_identifier: `package-plugin-${index}-uid`,
|
||||
version: '1.0.0',
|
||||
author: 'test-author',
|
||||
icon: 'icon.png',
|
||||
name: `Package Plugin ${index}`,
|
||||
category: PluginCategoryEnum.tool,
|
||||
label: { 'en-US': `Package Plugin ${index}` },
|
||||
description: { 'en-US': 'Test package plugin' },
|
||||
created_at: '2024-01-01',
|
||||
resource: {},
|
||||
plugins: [],
|
||||
verified: true,
|
||||
endpoint: { settings: [], endpoints: [] },
|
||||
model: null,
|
||||
tags: [],
|
||||
agent_strategy: null,
|
||||
meta: { version: '1.0.0' },
|
||||
trigger: {},
|
||||
},
|
||||
},
|
||||
} as unknown as PackageDependency)
|
||||
|
||||
const createMarketplaceDependency = (index: number): GitHubItemAndMarketPlaceDependency => ({
|
||||
type: 'marketplace',
|
||||
value: {
|
||||
marketplace_plugin_unique_identifier: `test-org/plugin-${index}:1.0.0`,
|
||||
plugin_unique_identifier: `plugin-${index}`,
|
||||
version: '1.0.0',
|
||||
},
|
||||
})
|
||||
|
||||
const createGitHubDependency = (index: number): GitHubItemAndMarketPlaceDependency => ({
|
||||
type: 'github',
|
||||
value: {
|
||||
repo: `test-org/plugin-${index}`,
|
||||
version: 'v1.0.0',
|
||||
package: `plugin-${index}.zip`,
|
||||
},
|
||||
})
|
||||
|
||||
const createMarketplaceApiData = (indexes: number[]) => ({
|
||||
data: {
|
||||
list: indexes.map(i => ({
|
||||
plugin: {
|
||||
plugin_id: `test-org/plugin-${i}`,
|
||||
org: 'test-org',
|
||||
name: `Test Plugin ${i}`,
|
||||
version: '1.0.0',
|
||||
latest_version: '1.0.0',
|
||||
},
|
||||
version: {
|
||||
unique_identifier: `plugin-${i}-uid`,
|
||||
},
|
||||
})),
|
||||
},
|
||||
})
|
||||
|
||||
const createDefaultParams = (overrides = {}) => ({
|
||||
allPlugins: [createPackageDependency(0)] as Dependency[],
|
||||
selectedPlugins: [] as Plugin[],
|
||||
onSelect: vi.fn(),
|
||||
onLoadedAllPlugin: vi.fn(),
|
||||
...overrides,
|
||||
})
|
||||
|
||||
// ==================== getPluginKey Tests ====================
|
||||
|
||||
describe('getPluginKey', () => {
|
||||
it('should return org/name when org is available', () => {
|
||||
const plugin = createMockPlugin({ org: 'my-org', name: 'my-plugin' })
|
||||
|
||||
expect(getPluginKey(plugin)).toBe('my-org/my-plugin')
|
||||
})
|
||||
|
||||
it('should fall back to author when org is not available', () => {
|
||||
const plugin = createMockPlugin({ org: undefined, author: 'my-author', name: 'my-plugin' })
|
||||
|
||||
expect(getPluginKey(plugin)).toBe('my-author/my-plugin')
|
||||
})
|
||||
|
||||
it('should prefer org over author when both exist', () => {
|
||||
const plugin = createMockPlugin({ org: 'my-org', author: 'my-author', name: 'my-plugin' })
|
||||
|
||||
expect(getPluginKey(plugin)).toBe('my-org/my-plugin')
|
||||
})
|
||||
|
||||
it('should handle undefined plugin', () => {
|
||||
expect(getPluginKey(undefined)).toBe('undefined/undefined')
|
||||
})
|
||||
})
|
||||
|
||||
// ==================== useInstallMultiState Tests ====================
|
||||
|
||||
describe('useInstallMultiState', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockMarketplaceData = null
|
||||
mockMarketplaceError = null
|
||||
mockInstalledInfo = {}
|
||||
mockCanInstall = true
|
||||
})
|
||||
|
||||
// ==================== Initial State ====================
|
||||
describe('Initial State', () => {
|
||||
it('should initialize plugins from package dependencies', () => {
|
||||
const params = createDefaultParams()
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
expect(result.current.plugins).toHaveLength(1)
|
||||
expect(result.current.plugins[0]).toBeDefined()
|
||||
expect(result.current.plugins[0]?.plugin_id).toBe('package-plugin-0-uid')
|
||||
})
|
||||
|
||||
it('should have slots for all dependencies even when no packages exist', () => {
|
||||
const params = createDefaultParams({
|
||||
allPlugins: [createGitHubDependency(0)] as Dependency[],
|
||||
})
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
// Array has slots for all dependencies, but unresolved ones are undefined
|
||||
expect(result.current.plugins).toHaveLength(1)
|
||||
expect(result.current.plugins[0]).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should return undefined for non-package items in mixed dependencies', () => {
|
||||
const params = createDefaultParams({
|
||||
allPlugins: [
|
||||
createPackageDependency(0),
|
||||
createGitHubDependency(1),
|
||||
] as Dependency[],
|
||||
})
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
expect(result.current.plugins).toHaveLength(2)
|
||||
expect(result.current.plugins[0]).toBeDefined()
|
||||
expect(result.current.plugins[1]).toBeUndefined()
|
||||
})
|
||||
|
||||
it('should start with empty errorIndexes', () => {
|
||||
const params = createDefaultParams()
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
expect(result.current.errorIndexes).toEqual([])
|
||||
})
|
||||
})
|
||||
|
||||
// ==================== Marketplace Data Sync ====================
|
||||
describe('Marketplace Data Sync', () => {
|
||||
it('should update plugins when marketplace data loads by ID', async () => {
|
||||
mockMarketplaceData = createMarketplaceApiData([0])
|
||||
|
||||
const params = createDefaultParams({
|
||||
allPlugins: [createMarketplaceDependency(0)] as Dependency[],
|
||||
})
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.plugins[0]).toBeDefined()
|
||||
expect(result.current.plugins[0]?.version).toBe('1.0.0')
|
||||
})
|
||||
})
|
||||
|
||||
it('should update plugins when marketplace data loads by meta', async () => {
|
||||
mockMarketplaceData = createMarketplaceApiData([0])
|
||||
|
||||
const params = createDefaultParams({
|
||||
allPlugins: [createMarketplaceDependency(0)] as Dependency[],
|
||||
})
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
// The "by meta" effect sets plugin_id from version.unique_identifier
|
||||
await waitFor(() => {
|
||||
expect(result.current.plugins[0]).toBeDefined()
|
||||
})
|
||||
})
|
||||
|
||||
it('should add to errorIndexes when marketplace item not found in response', async () => {
|
||||
mockMarketplaceData = { data: { list: [] } }
|
||||
|
||||
const params = createDefaultParams({
|
||||
allPlugins: [createMarketplaceDependency(0)] as Dependency[],
|
||||
})
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.errorIndexes).toContain(0)
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle multiple marketplace plugins', async () => {
|
||||
mockMarketplaceData = createMarketplaceApiData([0, 1])
|
||||
|
||||
const params = createDefaultParams({
|
||||
allPlugins: [
|
||||
createMarketplaceDependency(0),
|
||||
createMarketplaceDependency(1),
|
||||
] as Dependency[],
|
||||
})
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.plugins[0]).toBeDefined()
|
||||
expect(result.current.plugins[1]).toBeDefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ==================== Error Handling ====================
|
||||
describe('Error Handling', () => {
|
||||
it('should mark all marketplace indexes as errors on fetch failure', async () => {
|
||||
mockMarketplaceError = new Error('Fetch failed')
|
||||
|
||||
const params = createDefaultParams({
|
||||
allPlugins: [
|
||||
createMarketplaceDependency(0),
|
||||
createMarketplaceDependency(1),
|
||||
] as Dependency[],
|
||||
})
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.errorIndexes).toContain(0)
|
||||
expect(result.current.errorIndexes).toContain(1)
|
||||
})
|
||||
})
|
||||
|
||||
it('should not affect non-marketplace indexes on marketplace fetch error', async () => {
|
||||
mockMarketplaceError = new Error('Fetch failed')
|
||||
|
||||
const params = createDefaultParams({
|
||||
allPlugins: [
|
||||
createPackageDependency(0),
|
||||
createMarketplaceDependency(1),
|
||||
] as Dependency[],
|
||||
})
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(result.current.errorIndexes).toContain(1)
|
||||
expect(result.current.errorIndexes).not.toContain(0)
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ==================== Loaded All Data Notification ====================
|
||||
describe('Loaded All Data Notification', () => {
|
||||
it('should call onLoadedAllPlugin when all data loaded', async () => {
|
||||
const params = createDefaultParams()
|
||||
renderHook(() => useInstallMultiState(params))
|
||||
|
||||
await waitFor(() => {
|
||||
expect(params.onLoadedAllPlugin).toHaveBeenCalledWith(mockInstalledInfo)
|
||||
})
|
||||
})
|
||||
|
||||
it('should not call onLoadedAllPlugin when not all plugins resolved', () => {
|
||||
// GitHub plugin not fetched yet → isLoadedAllData = false
|
||||
const params = createDefaultParams({
|
||||
allPlugins: [
|
||||
createPackageDependency(0),
|
||||
createGitHubDependency(1),
|
||||
] as Dependency[],
|
||||
})
|
||||
renderHook(() => useInstallMultiState(params))
|
||||
|
||||
expect(params.onLoadedAllPlugin).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should call onLoadedAllPlugin after all errors are counted', async () => {
|
||||
mockMarketplaceError = new Error('Fetch failed')
|
||||
|
||||
const params = createDefaultParams({
|
||||
allPlugins: [createMarketplaceDependency(0)] as Dependency[],
|
||||
})
|
||||
renderHook(() => useInstallMultiState(params))
|
||||
|
||||
// Error fills errorIndexes → isLoadedAllData becomes true
|
||||
await waitFor(() => {
|
||||
expect(params.onLoadedAllPlugin).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// ==================== handleGitHubPluginFetched ====================
|
||||
describe('handleGitHubPluginFetched', () => {
|
||||
it('should update plugin at the specified index', async () => {
|
||||
const params = createDefaultParams({
|
||||
allPlugins: [createGitHubDependency(0)] as Dependency[],
|
||||
})
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
const mockPlugin = createMockPlugin({ plugin_id: 'github-plugin-0' })
|
||||
|
||||
await act(async () => {
|
||||
result.current.handleGitHubPluginFetched(0)(mockPlugin)
|
||||
})
|
||||
|
||||
expect(result.current.plugins[0]).toEqual(mockPlugin)
|
||||
})
|
||||
|
||||
it('should not affect other plugin slots', async () => {
|
||||
const params = createDefaultParams({
|
||||
allPlugins: [
|
||||
createPackageDependency(0),
|
||||
createGitHubDependency(1),
|
||||
] as Dependency[],
|
||||
})
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
const originalPlugin0 = result.current.plugins[0]
|
||||
const mockPlugin = createMockPlugin({ plugin_id: 'github-plugin-1' })
|
||||
|
||||
await act(async () => {
|
||||
result.current.handleGitHubPluginFetched(1)(mockPlugin)
|
||||
})
|
||||
|
||||
expect(result.current.plugins[0]).toEqual(originalPlugin0)
|
||||
expect(result.current.plugins[1]).toEqual(mockPlugin)
|
||||
})
|
||||
})
|
||||
|
||||
// ==================== handleGitHubPluginFetchError ====================
|
||||
describe('handleGitHubPluginFetchError', () => {
|
||||
it('should add index to errorIndexes', async () => {
|
||||
const params = createDefaultParams({
|
||||
allPlugins: [createGitHubDependency(0)] as Dependency[],
|
||||
})
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
await act(async () => {
|
||||
result.current.handleGitHubPluginFetchError(0)()
|
||||
})
|
||||
|
||||
expect(result.current.errorIndexes).toContain(0)
|
||||
})
|
||||
|
||||
it('should accumulate multiple error indexes without stale closure', async () => {
|
||||
const params = createDefaultParams({
|
||||
allPlugins: [
|
||||
createGitHubDependency(0),
|
||||
createGitHubDependency(1),
|
||||
] as Dependency[],
|
||||
})
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
await act(async () => {
|
||||
result.current.handleGitHubPluginFetchError(0)()
|
||||
})
|
||||
await act(async () => {
|
||||
result.current.handleGitHubPluginFetchError(1)()
|
||||
})
|
||||
|
||||
expect(result.current.errorIndexes).toContain(0)
|
||||
expect(result.current.errorIndexes).toContain(1)
|
||||
})
|
||||
})
|
||||
|
||||
// ==================== getVersionInfo ====================
|
||||
describe('getVersionInfo', () => {
|
||||
it('should return hasInstalled false when plugin not installed', () => {
|
||||
const params = createDefaultParams()
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
const info = result.current.getVersionInfo('unknown/plugin')
|
||||
|
||||
expect(info.hasInstalled).toBe(false)
|
||||
expect(info.installedVersion).toBeUndefined()
|
||||
expect(info.toInstallVersion).toBe('')
|
||||
})
|
||||
|
||||
it('should return hasInstalled true with version when installed', () => {
|
||||
mockInstalledInfo = {
|
||||
'test-author/Package Plugin 0': {
|
||||
installedId: 'installed-1',
|
||||
installedVersion: '0.9.0',
|
||||
uniqueIdentifier: 'uid-1',
|
||||
},
|
||||
}
|
||||
const params = createDefaultParams()
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
const info = result.current.getVersionInfo('test-author/Package Plugin 0')
|
||||
|
||||
expect(info.hasInstalled).toBe(true)
|
||||
expect(info.installedVersion).toBe('0.9.0')
|
||||
})
|
||||
})
|
||||
|
||||
// ==================== handleSelect ====================
|
||||
describe('handleSelect', () => {
|
||||
it('should call onSelect with plugin, index, and installable count', async () => {
|
||||
const params = createDefaultParams()
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
await act(async () => {
|
||||
result.current.handleSelect(0)()
|
||||
})
|
||||
|
||||
expect(params.onSelect).toHaveBeenCalledWith(
|
||||
result.current.plugins[0],
|
||||
0,
|
||||
expect.any(Number),
|
||||
)
|
||||
})
|
||||
|
||||
it('should filter installable plugins using pluginInstallLimit', async () => {
|
||||
const params = createDefaultParams({
|
||||
allPlugins: [
|
||||
createPackageDependency(0),
|
||||
createPackageDependency(1),
|
||||
] as Dependency[],
|
||||
})
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
await act(async () => {
|
||||
result.current.handleSelect(0)()
|
||||
})
|
||||
|
||||
// mockCanInstall is true, so all 2 plugins are installable
|
||||
expect(params.onSelect).toHaveBeenCalledWith(
|
||||
expect.anything(),
|
||||
0,
|
||||
2,
|
||||
)
|
||||
})
|
||||
})
|
||||
|
||||
// ==================== isPluginSelected ====================
|
||||
describe('isPluginSelected', () => {
|
||||
it('should return true when plugin is in selectedPlugins', () => {
|
||||
const selectedPlugin = createMockPlugin({ plugin_id: 'package-plugin-0-uid' })
|
||||
const params = createDefaultParams({
|
||||
selectedPlugins: [selectedPlugin],
|
||||
})
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
expect(result.current.isPluginSelected(0)).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false when plugin is not in selectedPlugins', () => {
|
||||
const params = createDefaultParams({ selectedPlugins: [] })
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
expect(result.current.isPluginSelected(0)).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false when plugin at index is undefined', () => {
|
||||
const params = createDefaultParams({
|
||||
allPlugins: [createGitHubDependency(0)] as Dependency[],
|
||||
selectedPlugins: [createMockPlugin()],
|
||||
})
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
// plugins[0] is undefined (GitHub not yet fetched)
|
||||
expect(result.current.isPluginSelected(0)).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
// ==================== getInstallablePlugins ====================
|
||||
describe('getInstallablePlugins', () => {
|
||||
it('should return all plugins when canInstall is true', () => {
|
||||
mockCanInstall = true
|
||||
const params = createDefaultParams({
|
||||
allPlugins: [
|
||||
createPackageDependency(0),
|
||||
createPackageDependency(1),
|
||||
] as Dependency[],
|
||||
})
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
const { installablePlugins, selectedIndexes } = result.current.getInstallablePlugins()
|
||||
|
||||
expect(installablePlugins).toHaveLength(2)
|
||||
expect(selectedIndexes).toEqual([0, 1])
|
||||
})
|
||||
|
||||
it('should return empty arrays when canInstall is false', () => {
|
||||
mockCanInstall = false
|
||||
const params = createDefaultParams({
|
||||
allPlugins: [createPackageDependency(0)] as Dependency[],
|
||||
})
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
const { installablePlugins, selectedIndexes } = result.current.getInstallablePlugins()
|
||||
|
||||
expect(installablePlugins).toHaveLength(0)
|
||||
expect(selectedIndexes).toEqual([])
|
||||
})
|
||||
|
||||
it('should skip unloaded (undefined) plugins', () => {
|
||||
mockCanInstall = true
|
||||
const params = createDefaultParams({
|
||||
allPlugins: [
|
||||
createPackageDependency(0),
|
||||
createGitHubDependency(1),
|
||||
] as Dependency[],
|
||||
})
|
||||
const { result } = renderHook(() => useInstallMultiState(params))
|
||||
|
||||
const { installablePlugins, selectedIndexes } = result.current.getInstallablePlugins()
|
||||
|
||||
// Only package plugin is loaded; GitHub not yet fetched
|
||||
expect(installablePlugins).toHaveLength(1)
|
||||
expect(selectedIndexes).toEqual([0])
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,230 @@
|
||||
'use client'
|
||||
|
||||
import type { Dependency, GitHubItemAndMarketPlaceDependency, PackageDependency, Plugin, VersionInfo } from '@/app/components/plugins/types'
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
import useCheckInstalled from '@/app/components/plugins/install-plugin/hooks/use-check-installed'
|
||||
import { pluginInstallLimit } from '@/app/components/plugins/install-plugin/hooks/use-install-plugin-limit'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useFetchPluginsInMarketPlaceByInfo } from '@/service/use-plugins'
|
||||
|
||||
type UseInstallMultiStateParams = {
|
||||
allPlugins: Dependency[]
|
||||
selectedPlugins: Plugin[]
|
||||
onSelect: (plugin: Plugin, selectedIndex: number, allCanInstallPluginsLength: number) => void
|
||||
onLoadedAllPlugin: (installedInfo: Record<string, VersionInfo>) => void
|
||||
}
|
||||
|
||||
export function getPluginKey(plugin: Plugin | undefined): string {
|
||||
return `${plugin?.org || plugin?.author}/${plugin?.name}`
|
||||
}
|
||||
|
||||
function parseMarketplaceIdentifier(identifier: string) {
|
||||
const [orgPart, nameAndVersionPart] = identifier.split('@')[0].split('/')
|
||||
const [name, version] = nameAndVersionPart.split(':')
|
||||
return { organization: orgPart, plugin: name, version }
|
||||
}
|
||||
|
||||
function initPluginsFromDependencies(allPlugins: Dependency[]): (Plugin | undefined)[] {
|
||||
if (!allPlugins.some(d => d.type === 'package'))
|
||||
return []
|
||||
|
||||
return allPlugins.map((d) => {
|
||||
if (d.type !== 'package')
|
||||
return undefined
|
||||
const { manifest, unique_identifier } = (d as PackageDependency).value
|
||||
return {
|
||||
...manifest,
|
||||
plugin_id: unique_identifier,
|
||||
} as unknown as Plugin
|
||||
})
|
||||
}
|
||||
|
||||
export function useInstallMultiState({
|
||||
allPlugins,
|
||||
selectedPlugins,
|
||||
onSelect,
|
||||
onLoadedAllPlugin,
|
||||
}: UseInstallMultiStateParams) {
|
||||
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
|
||||
|
||||
// Marketplace plugins filtering and index mapping
|
||||
const marketplacePlugins = useMemo(
|
||||
() => allPlugins.filter((d): d is GitHubItemAndMarketPlaceDependency => d.type === 'marketplace'),
|
||||
[allPlugins],
|
||||
)
|
||||
|
||||
const marketPlaceInDSLIndex = useMemo(() => {
|
||||
return allPlugins.reduce<number[]>((acc, d, index) => {
|
||||
if (d.type === 'marketplace')
|
||||
acc.push(index)
|
||||
return acc
|
||||
}, [])
|
||||
}, [allPlugins])
|
||||
|
||||
// Marketplace data fetching: by unique identifier and by meta info
|
||||
const {
|
||||
isLoading: isFetchingById,
|
||||
data: infoGetById,
|
||||
error: infoByIdError,
|
||||
} = useFetchPluginsInMarketPlaceByInfo(
|
||||
marketplacePlugins.map(d => parseMarketplaceIdentifier(d.value.marketplace_plugin_unique_identifier!)),
|
||||
)
|
||||
|
||||
const {
|
||||
isLoading: isFetchingByMeta,
|
||||
data: infoByMeta,
|
||||
error: infoByMetaError,
|
||||
} = useFetchPluginsInMarketPlaceByInfo(
|
||||
marketplacePlugins.map(d => d.value!),
|
||||
)
|
||||
|
||||
// Derive marketplace plugin data and errors from API responses
|
||||
const { marketplacePluginMap, marketplaceErrorIndexes } = useMemo(() => {
|
||||
const pluginMap = new Map<number, Plugin>()
|
||||
const errorSet = new Set<number>()
|
||||
|
||||
// Process "by ID" response
|
||||
if (!isFetchingById && infoGetById?.data.list) {
|
||||
const sortedList = marketplacePlugins.map((d) => {
|
||||
const id = d.value.marketplace_plugin_unique_identifier?.split(':')[0]
|
||||
const retPluginInfo = infoGetById.data.list.find(item => item.plugin.plugin_id === id)?.plugin
|
||||
return { ...retPluginInfo, from: d.type } as Plugin
|
||||
})
|
||||
marketPlaceInDSLIndex.forEach((index, i) => {
|
||||
if (sortedList[i]) {
|
||||
pluginMap.set(index, {
|
||||
...sortedList[i],
|
||||
version: sortedList[i]!.version || sortedList[i]!.latest_version,
|
||||
})
|
||||
}
|
||||
else { errorSet.add(index) }
|
||||
})
|
||||
}
|
||||
|
||||
// Process "by meta" response (may overwrite "by ID" results)
|
||||
if (!isFetchingByMeta && infoByMeta?.data.list) {
|
||||
const payloads = infoByMeta.data.list
|
||||
marketPlaceInDSLIndex.forEach((index, i) => {
|
||||
if (payloads[i]) {
|
||||
const item = payloads[i]
|
||||
pluginMap.set(index, {
|
||||
...item.plugin,
|
||||
plugin_id: item.version.unique_identifier,
|
||||
} as Plugin)
|
||||
}
|
||||
else { errorSet.add(index) }
|
||||
})
|
||||
}
|
||||
|
||||
// Mark all marketplace indexes as errors on fetch failure
|
||||
if (infoByMetaError || infoByIdError)
|
||||
marketPlaceInDSLIndex.forEach(index => errorSet.add(index))
|
||||
|
||||
return { marketplacePluginMap: pluginMap, marketplaceErrorIndexes: errorSet }
|
||||
}, [isFetchingById, isFetchingByMeta, infoGetById, infoByMeta, infoByMetaError, infoByIdError, marketPlaceInDSLIndex, marketplacePlugins])
|
||||
|
||||
// GitHub-fetched plugins and errors (imperative state from child callbacks)
|
||||
const [githubPluginMap, setGithubPluginMap] = useState<Map<number, Plugin>>(() => new Map())
|
||||
const [githubErrorIndexes, setGithubErrorIndexes] = useState<number[]>([])
|
||||
|
||||
// Merge all plugin sources into a single array
|
||||
const plugins = useMemo(() => {
|
||||
const initial = initPluginsFromDependencies(allPlugins)
|
||||
const result: (Plugin | undefined)[] = allPlugins.map((_, i) => initial[i])
|
||||
marketplacePluginMap.forEach((plugin, index) => {
|
||||
result[index] = plugin
|
||||
})
|
||||
githubPluginMap.forEach((plugin, index) => {
|
||||
result[index] = plugin
|
||||
})
|
||||
return result
|
||||
}, [allPlugins, marketplacePluginMap, githubPluginMap])
|
||||
|
||||
// Merge all error sources
|
||||
const errorIndexes = useMemo(() => {
|
||||
return [...marketplaceErrorIndexes, ...githubErrorIndexes]
|
||||
}, [marketplaceErrorIndexes, githubErrorIndexes])
|
||||
|
||||
// Check installed status after all data is loaded
|
||||
const isLoadedAllData = (plugins.filter(Boolean).length + errorIndexes.length) === allPlugins.length
|
||||
|
||||
const { installedInfo } = useCheckInstalled({
|
||||
pluginIds: plugins.filter(Boolean).map(d => getPluginKey(d)) || [],
|
||||
enabled: isLoadedAllData,
|
||||
})
|
||||
|
||||
// Notify parent when all plugin data and install info is ready
|
||||
useEffect(() => {
|
||||
if (isLoadedAllData && installedInfo)
|
||||
onLoadedAllPlugin(installedInfo!)
|
||||
// eslint-disable-next-line react-hooks/exhaustive-deps
|
||||
}, [isLoadedAllData, installedInfo])
|
||||
|
||||
// Callback: handle GitHub plugin fetch success
|
||||
const handleGitHubPluginFetched = useCallback((index: number) => {
|
||||
return (p: Plugin) => {
|
||||
setGithubPluginMap(prev => new Map(prev).set(index, p))
|
||||
}
|
||||
}, [])
|
||||
|
||||
// Callback: handle GitHub plugin fetch error
|
||||
const handleGitHubPluginFetchError = useCallback((index: number) => {
|
||||
return () => {
|
||||
setGithubErrorIndexes(prev => [...prev, index])
|
||||
}
|
||||
}, [])
|
||||
|
||||
// Callback: get version info for a plugin by its key
|
||||
const getVersionInfo = useCallback((pluginId: string) => {
|
||||
const pluginDetail = installedInfo?.[pluginId]
|
||||
return {
|
||||
hasInstalled: !!pluginDetail,
|
||||
installedVersion: pluginDetail?.installedVersion,
|
||||
toInstallVersion: '',
|
||||
}
|
||||
}, [installedInfo])
|
||||
|
||||
// Callback: handle plugin selection
|
||||
const handleSelect = useCallback((index: number) => {
|
||||
return () => {
|
||||
const canSelectPlugins = plugins.filter((p) => {
|
||||
const { canInstall } = pluginInstallLimit(p!, systemFeatures)
|
||||
return canInstall
|
||||
})
|
||||
onSelect(plugins[index]!, index, canSelectPlugins.length)
|
||||
}
|
||||
}, [onSelect, plugins, systemFeatures])
|
||||
|
||||
// Callback: check if a plugin at given index is selected
|
||||
const isPluginSelected = useCallback((index: number) => {
|
||||
return !!selectedPlugins.find(p => p.plugin_id === plugins[index]?.plugin_id)
|
||||
}, [selectedPlugins, plugins])
|
||||
|
||||
// Callback: get all installable plugins with their indexes
|
||||
const getInstallablePlugins = useCallback(() => {
|
||||
const selectedIndexes: number[] = []
|
||||
const installablePlugins: Plugin[] = []
|
||||
allPlugins.forEach((_d, index) => {
|
||||
const p = plugins[index]
|
||||
if (!p)
|
||||
return
|
||||
const { canInstall } = pluginInstallLimit(p, systemFeatures)
|
||||
if (canInstall) {
|
||||
selectedIndexes.push(index)
|
||||
installablePlugins.push(p)
|
||||
}
|
||||
})
|
||||
return { selectedIndexes, installablePlugins }
|
||||
}, [allPlugins, plugins, systemFeatures])
|
||||
|
||||
return {
|
||||
plugins,
|
||||
errorIndexes,
|
||||
handleGitHubPluginFetched,
|
||||
handleGitHubPluginFetchError,
|
||||
getVersionInfo,
|
||||
handleSelect,
|
||||
isPluginSelected,
|
||||
getInstallablePlugins,
|
||||
}
|
||||
}
|
||||
@@ -1,16 +1,12 @@
|
||||
'use client'
|
||||
import type { Dependency, GitHubItemAndMarketPlaceDependency, PackageDependency, Plugin, VersionInfo } from '../../../types'
|
||||
import { produce } from 'immer'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useEffect, useImperativeHandle, useMemo, useState } from 'react'
|
||||
import useCheckInstalled from '@/app/components/plugins/install-plugin/hooks/use-check-installed'
|
||||
import { useGlobalPublicStore } from '@/context/global-public-context'
|
||||
import { useFetchPluginsInMarketPlaceByInfo } from '@/service/use-plugins'
|
||||
import { useImperativeHandle } from 'react'
|
||||
import LoadingError from '../../base/loading-error'
|
||||
import { pluginInstallLimit } from '../../hooks/use-install-plugin-limit'
|
||||
import GithubItem from '../item/github-item'
|
||||
import MarketplaceItem from '../item/marketplace-item'
|
||||
import PackageItem from '../item/package-item'
|
||||
import { getPluginKey, useInstallMultiState } from './hooks/use-install-multi-state'
|
||||
|
||||
type Props = {
|
||||
allPlugins: Dependency[]
|
||||
@@ -38,206 +34,50 @@ const InstallByDSLList = ({
|
||||
isFromMarketPlace,
|
||||
ref,
|
||||
}: Props) => {
|
||||
const systemFeatures = useGlobalPublicStore(s => s.systemFeatures)
|
||||
// DSL has id, to get plugin info to show more info
|
||||
const { isLoading: isFetchingMarketplaceDataById, data: infoGetById, error: infoByIdError } = useFetchPluginsInMarketPlaceByInfo(allPlugins.filter(d => d.type === 'marketplace').map((d) => {
|
||||
const dependecy = (d as GitHubItemAndMarketPlaceDependency).value
|
||||
// split org, name, version by / and :
|
||||
// and remove @ and its suffix
|
||||
const [orgPart, nameAndVersionPart] = dependecy.marketplace_plugin_unique_identifier!.split('@')[0].split('/')
|
||||
const [name, version] = nameAndVersionPart.split(':')
|
||||
return {
|
||||
organization: orgPart,
|
||||
plugin: name,
|
||||
version,
|
||||
}
|
||||
}))
|
||||
// has meta(org,name,version), to get id
|
||||
const { isLoading: isFetchingDataByMeta, data: infoByMeta, error: infoByMetaError } = useFetchPluginsInMarketPlaceByInfo(allPlugins.filter(d => d.type === 'marketplace').map(d => (d as GitHubItemAndMarketPlaceDependency).value!))
|
||||
|
||||
const [plugins, doSetPlugins] = useState<(Plugin | undefined)[]>((() => {
|
||||
const hasLocalPackage = allPlugins.some(d => d.type === 'package')
|
||||
if (!hasLocalPackage)
|
||||
return []
|
||||
|
||||
const _plugins = allPlugins.map((d) => {
|
||||
if (d.type === 'package') {
|
||||
return {
|
||||
...(d as any).value.manifest,
|
||||
plugin_id: (d as any).value.unique_identifier,
|
||||
}
|
||||
}
|
||||
|
||||
return undefined
|
||||
})
|
||||
return _plugins
|
||||
})())
|
||||
|
||||
const pluginsRef = React.useRef<(Plugin | undefined)[]>(plugins)
|
||||
|
||||
const setPlugins = useCallback((p: (Plugin | undefined)[]) => {
|
||||
doSetPlugins(p)
|
||||
pluginsRef.current = p
|
||||
}, [])
|
||||
|
||||
const [errorIndexes, setErrorIndexes] = useState<number[]>([])
|
||||
|
||||
const handleGitHubPluginFetched = useCallback((index: number) => {
|
||||
return (p: Plugin) => {
|
||||
const nextPlugins = produce(pluginsRef.current, (draft) => {
|
||||
draft[index] = p
|
||||
})
|
||||
setPlugins(nextPlugins)
|
||||
}
|
||||
}, [setPlugins])
|
||||
|
||||
const handleGitHubPluginFetchError = useCallback((index: number) => {
|
||||
return () => {
|
||||
setErrorIndexes([...errorIndexes, index])
|
||||
}
|
||||
}, [errorIndexes])
|
||||
|
||||
const marketPlaceInDSLIndex = useMemo(() => {
|
||||
const res: number[] = []
|
||||
allPlugins.forEach((d, index) => {
|
||||
if (d.type === 'marketplace')
|
||||
res.push(index)
|
||||
})
|
||||
return res
|
||||
}, [allPlugins])
|
||||
|
||||
useEffect(() => {
|
||||
if (!isFetchingMarketplaceDataById && infoGetById?.data.list) {
|
||||
const sortedList = allPlugins.filter(d => d.type === 'marketplace').map((d) => {
|
||||
const p = d as GitHubItemAndMarketPlaceDependency
|
||||
const id = p.value.marketplace_plugin_unique_identifier?.split(':')[0]
|
||||
const retPluginInfo = infoGetById.data.list.find(item => item.plugin.plugin_id === id)?.plugin
|
||||
return { ...retPluginInfo, from: d.type } as Plugin
|
||||
})
|
||||
const payloads = sortedList
|
||||
const failedIndex: number[] = []
|
||||
const nextPlugins = produce(pluginsRef.current, (draft) => {
|
||||
marketPlaceInDSLIndex.forEach((index, i) => {
|
||||
if (payloads[i]) {
|
||||
draft[index] = {
|
||||
...payloads[i],
|
||||
version: payloads[i]!.version || payloads[i]!.latest_version,
|
||||
}
|
||||
}
|
||||
else { failedIndex.push(index) }
|
||||
})
|
||||
})
|
||||
setPlugins(nextPlugins)
|
||||
|
||||
if (failedIndex.length > 0)
|
||||
setErrorIndexes([...errorIndexes, ...failedIndex])
|
||||
}
|
||||
}, [isFetchingMarketplaceDataById])
|
||||
|
||||
useEffect(() => {
|
||||
if (!isFetchingDataByMeta && infoByMeta?.data.list) {
|
||||
const payloads = infoByMeta?.data.list
|
||||
const failedIndex: number[] = []
|
||||
const nextPlugins = produce(pluginsRef.current, (draft) => {
|
||||
marketPlaceInDSLIndex.forEach((index, i) => {
|
||||
if (payloads[i]) {
|
||||
const item = payloads[i]
|
||||
draft[index] = {
|
||||
...item.plugin,
|
||||
plugin_id: item.version.unique_identifier,
|
||||
}
|
||||
}
|
||||
else {
|
||||
failedIndex.push(index)
|
||||
}
|
||||
})
|
||||
})
|
||||
setPlugins(nextPlugins)
|
||||
if (failedIndex.length > 0)
|
||||
setErrorIndexes([...errorIndexes, ...failedIndex])
|
||||
}
|
||||
}, [isFetchingDataByMeta])
|
||||
|
||||
useEffect(() => {
|
||||
// get info all failed
|
||||
if (infoByMetaError || infoByIdError)
|
||||
setErrorIndexes([...errorIndexes, ...marketPlaceInDSLIndex])
|
||||
}, [infoByMetaError, infoByIdError])
|
||||
|
||||
const isLoadedAllData = (plugins.filter(p => !!p).length + errorIndexes.length) === allPlugins.length
|
||||
|
||||
const { installedInfo } = useCheckInstalled({
|
||||
pluginIds: plugins?.filter(p => !!p).map((d) => {
|
||||
return `${d?.org || d?.author}/${d?.name}`
|
||||
}) || [],
|
||||
enabled: isLoadedAllData,
|
||||
const {
|
||||
plugins,
|
||||
errorIndexes,
|
||||
handleGitHubPluginFetched,
|
||||
handleGitHubPluginFetchError,
|
||||
getVersionInfo,
|
||||
handleSelect,
|
||||
isPluginSelected,
|
||||
getInstallablePlugins,
|
||||
} = useInstallMultiState({
|
||||
allPlugins,
|
||||
selectedPlugins,
|
||||
onSelect,
|
||||
onLoadedAllPlugin,
|
||||
})
|
||||
|
||||
const getVersionInfo = useCallback((pluginId: string) => {
|
||||
const pluginDetail = installedInfo?.[pluginId]
|
||||
const hasInstalled = !!pluginDetail
|
||||
return {
|
||||
hasInstalled,
|
||||
installedVersion: pluginDetail?.installedVersion,
|
||||
toInstallVersion: '',
|
||||
}
|
||||
}, [installedInfo])
|
||||
|
||||
useEffect(() => {
|
||||
if (isLoadedAllData && installedInfo)
|
||||
onLoadedAllPlugin(installedInfo!)
|
||||
}, [isLoadedAllData, installedInfo])
|
||||
|
||||
const handleSelect = useCallback((index: number) => {
|
||||
return () => {
|
||||
const canSelectPlugins = plugins.filter((p) => {
|
||||
const { canInstall } = pluginInstallLimit(p!, systemFeatures)
|
||||
return canInstall
|
||||
})
|
||||
onSelect(plugins[index]!, index, canSelectPlugins.length)
|
||||
}
|
||||
}, [onSelect, plugins, systemFeatures])
|
||||
|
||||
useImperativeHandle(ref, () => ({
|
||||
selectAllPlugins: () => {
|
||||
const selectedIndexes: number[] = []
|
||||
const selectedPlugins: Plugin[] = []
|
||||
allPlugins.forEach((d, index) => {
|
||||
const p = plugins[index]
|
||||
if (!p)
|
||||
return
|
||||
const { canInstall } = pluginInstallLimit(p, systemFeatures)
|
||||
if (canInstall) {
|
||||
selectedIndexes.push(index)
|
||||
selectedPlugins.push(p)
|
||||
}
|
||||
})
|
||||
onSelectAll(selectedPlugins, selectedIndexes)
|
||||
},
|
||||
deSelectAllPlugins: () => {
|
||||
onDeSelectAll()
|
||||
const { installablePlugins, selectedIndexes } = getInstallablePlugins()
|
||||
onSelectAll(installablePlugins, selectedIndexes)
|
||||
},
|
||||
deSelectAllPlugins: onDeSelectAll,
|
||||
}))
|
||||
|
||||
return (
|
||||
<>
|
||||
{allPlugins.map((d, index) => {
|
||||
if (errorIndexes.includes(index)) {
|
||||
return (
|
||||
<LoadingError key={index} />
|
||||
)
|
||||
}
|
||||
if (errorIndexes.includes(index))
|
||||
return <LoadingError key={index} />
|
||||
|
||||
const plugin = plugins[index]
|
||||
const checked = isPluginSelected(index)
|
||||
const versionInfo = getVersionInfo(getPluginKey(plugin))
|
||||
|
||||
if (d.type === 'github') {
|
||||
return (
|
||||
<GithubItem
|
||||
key={index}
|
||||
checked={!!selectedPlugins.find(p => p.plugin_id === plugins[index]?.plugin_id)}
|
||||
checked={checked}
|
||||
onCheckedChange={handleSelect(index)}
|
||||
dependency={d as GitHubItemAndMarketPlaceDependency}
|
||||
onFetchedPayload={handleGitHubPluginFetched(index)}
|
||||
onFetchError={handleGitHubPluginFetchError(index)}
|
||||
versionInfo={getVersionInfo(`${plugin?.org || plugin?.author}/${plugin?.name}`)}
|
||||
versionInfo={versionInfo}
|
||||
/>
|
||||
)
|
||||
}
|
||||
@@ -246,24 +86,23 @@ const InstallByDSLList = ({
|
||||
return (
|
||||
<MarketplaceItem
|
||||
key={index}
|
||||
checked={!!selectedPlugins.find(p => p.plugin_id === plugins[index]?.plugin_id)}
|
||||
checked={checked}
|
||||
onCheckedChange={handleSelect(index)}
|
||||
payload={{ ...plugin, from: d.type } as Plugin}
|
||||
version={(d as GitHubItemAndMarketPlaceDependency).value.version! || plugin?.version || ''}
|
||||
versionInfo={getVersionInfo(`${plugin?.org || plugin?.author}/${plugin?.name}`)}
|
||||
versionInfo={versionInfo}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
// Local package
|
||||
return (
|
||||
<PackageItem
|
||||
key={index}
|
||||
checked={!!selectedPlugins.find(p => p.plugin_id === plugins[index]?.plugin_id)}
|
||||
checked={checked}
|
||||
onCheckedChange={handleSelect(index)}
|
||||
payload={d as PackageDependency}
|
||||
isFromMarketPlace={isFromMarketPlace}
|
||||
versionInfo={getVersionInfo(`${plugin?.org || plugin?.author}/${plugin?.name}`)}
|
||||
versionInfo={versionInfo}
|
||||
/>
|
||||
)
|
||||
})}
|
||||
|
||||
@@ -30,19 +30,21 @@ vi.mock('@/context/app-context', () => ({
|
||||
}))
|
||||
|
||||
// Mock API services - only mock external services
|
||||
const mockFetchWorkflowToolDetailByAppID = vi.fn()
|
||||
const mockCreateWorkflowToolProvider = vi.fn()
|
||||
const mockSaveWorkflowToolProvider = vi.fn()
|
||||
vi.mock('@/service/tools', () => ({
|
||||
fetchWorkflowToolDetailByAppID: (...args: unknown[]) => mockFetchWorkflowToolDetailByAppID(...args),
|
||||
createWorkflowToolProvider: (...args: unknown[]) => mockCreateWorkflowToolProvider(...args),
|
||||
saveWorkflowToolProvider: (...args: unknown[]) => mockSaveWorkflowToolProvider(...args),
|
||||
}))
|
||||
|
||||
// Mock invalidate workflow tools hook
|
||||
// Mock service hooks
|
||||
const mockInvalidateAllWorkflowTools = vi.fn()
|
||||
const mockInvalidateWorkflowToolDetailByAppID = vi.fn()
|
||||
const mockUseWorkflowToolDetailByAppID = vi.fn()
|
||||
vi.mock('@/service/use-tools', () => ({
|
||||
useInvalidateAllWorkflowTools: () => mockInvalidateAllWorkflowTools,
|
||||
useInvalidateWorkflowToolDetailByAppID: () => mockInvalidateWorkflowToolDetailByAppID,
|
||||
useWorkflowToolDetailByAppID: (...args: unknown[]) => mockUseWorkflowToolDetailByAppID(...args),
|
||||
}))
|
||||
|
||||
// Mock Toast - need to verify notification calls
|
||||
@@ -242,7 +244,10 @@ describe('WorkflowToolConfigureButton', () => {
|
||||
vi.clearAllMocks()
|
||||
mockPortalOpenState = false
|
||||
mockIsCurrentWorkspaceManager.mockReturnValue(true)
|
||||
mockFetchWorkflowToolDetailByAppID.mockResolvedValue(createMockWorkflowToolDetail())
|
||||
mockUseWorkflowToolDetailByAppID.mockImplementation((_appId: string, enabled: boolean) => ({
|
||||
data: enabled ? createMockWorkflowToolDetail() : undefined,
|
||||
isLoading: false,
|
||||
}))
|
||||
})
|
||||
|
||||
// Rendering Tests (REQUIRED)
|
||||
@@ -307,19 +312,17 @@ describe('WorkflowToolConfigureButton', () => {
|
||||
expect(screen.getByText('Please save the workflow first')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render loading state when published and fetching details', async () => {
|
||||
it('should render loading state when published and fetching details', () => {
|
||||
// Arrange
|
||||
mockFetchWorkflowToolDetailByAppID.mockImplementation(() => new Promise(() => { })) // Never resolves
|
||||
mockUseWorkflowToolDetailByAppID.mockReturnValue({ data: undefined, isLoading: true })
|
||||
const props = createDefaultConfigureButtonProps({ published: true })
|
||||
|
||||
// Act
|
||||
render(<WorkflowToolConfigureButton {...props} />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
const loadingElement = document.querySelector('.pt-2')
|
||||
expect(loadingElement).toBeInTheDocument()
|
||||
})
|
||||
const loadingElement = document.querySelector('.pt-2')
|
||||
expect(loadingElement).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should render configure and manage buttons when published', async () => {
|
||||
@@ -381,76 +384,10 @@ describe('WorkflowToolConfigureButton', () => {
|
||||
// Act & Assert
|
||||
expect(() => render(<WorkflowToolConfigureButton {...props} />)).not.toThrow()
|
||||
})
|
||||
|
||||
it('should call handlePublish when updating workflow tool', async () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup()
|
||||
const handlePublish = vi.fn().mockResolvedValue(undefined)
|
||||
mockSaveWorkflowToolProvider.mockResolvedValue({})
|
||||
const props = createDefaultConfigureButtonProps({ published: true, handlePublish })
|
||||
|
||||
// Act
|
||||
render(<WorkflowToolConfigureButton {...props} />)
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('workflow.common.configure')).toBeInTheDocument()
|
||||
})
|
||||
await user.click(screen.getByText('workflow.common.configure'))
|
||||
|
||||
// Fill required fields and save
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('drawer')).toBeInTheDocument()
|
||||
})
|
||||
const saveButton = screen.getByText('common.operation.save')
|
||||
await user.click(saveButton)
|
||||
|
||||
// Confirm in modal
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('tools.createTool.confirmTitle')).toBeInTheDocument()
|
||||
})
|
||||
await user.click(screen.getByText('common.operation.confirm'))
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(handlePublish).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// State Management Tests
|
||||
describe('State Management', () => {
|
||||
it('should fetch detail when published and mount', async () => {
|
||||
// Arrange
|
||||
const props = createDefaultConfigureButtonProps({ published: true })
|
||||
|
||||
// Act
|
||||
render(<WorkflowToolConfigureButton {...props} />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(mockFetchWorkflowToolDetailByAppID).toHaveBeenCalledWith('workflow-app-123')
|
||||
})
|
||||
})
|
||||
|
||||
it('should refetch detail when detailNeedUpdate changes to true', async () => {
|
||||
// Arrange
|
||||
const props = createDefaultConfigureButtonProps({ published: true, detailNeedUpdate: false })
|
||||
|
||||
// Act
|
||||
const { rerender } = render(<WorkflowToolConfigureButton {...props} />)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(mockFetchWorkflowToolDetailByAppID).toHaveBeenCalledTimes(1)
|
||||
})
|
||||
|
||||
// Rerender with detailNeedUpdate true
|
||||
rerender(<WorkflowToolConfigureButton {...props} detailNeedUpdate={true} />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(mockFetchWorkflowToolDetailByAppID).toHaveBeenCalledTimes(2)
|
||||
})
|
||||
})
|
||||
|
||||
// Modal behavior tests
|
||||
describe('Modal Behavior', () => {
|
||||
it('should toggle modal visibility', async () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup()
|
||||
@@ -513,85 +450,6 @@ describe('WorkflowToolConfigureButton', () => {
|
||||
})
|
||||
})
|
||||
|
||||
// Memoization Tests
|
||||
describe('Memoization - outdated detection', () => {
|
||||
it('should detect outdated when parameter count differs', async () => {
|
||||
// Arrange
|
||||
const detail = createMockWorkflowToolDetail()
|
||||
mockFetchWorkflowToolDetailByAppID.mockResolvedValue(detail)
|
||||
const props = createDefaultConfigureButtonProps({
|
||||
published: true,
|
||||
inputs: [
|
||||
createMockInputVar({ variable: 'test_var' }),
|
||||
createMockInputVar({ variable: 'extra_var' }),
|
||||
],
|
||||
})
|
||||
|
||||
// Act
|
||||
render(<WorkflowToolConfigureButton {...props} />)
|
||||
|
||||
// Assert - should show outdated warning
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('workflow.common.workflowAsToolTip')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should detect outdated when parameter not found', async () => {
|
||||
// Arrange
|
||||
const detail = createMockWorkflowToolDetail()
|
||||
mockFetchWorkflowToolDetailByAppID.mockResolvedValue(detail)
|
||||
const props = createDefaultConfigureButtonProps({
|
||||
published: true,
|
||||
inputs: [createMockInputVar({ variable: 'different_var' })],
|
||||
})
|
||||
|
||||
// Act
|
||||
render(<WorkflowToolConfigureButton {...props} />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('workflow.common.workflowAsToolTip')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should detect outdated when required property differs', async () => {
|
||||
// Arrange
|
||||
const detail = createMockWorkflowToolDetail()
|
||||
mockFetchWorkflowToolDetailByAppID.mockResolvedValue(detail)
|
||||
const props = createDefaultConfigureButtonProps({
|
||||
published: true,
|
||||
inputs: [createMockInputVar({ variable: 'test_var', required: false })], // Detail has required: true
|
||||
})
|
||||
|
||||
// Act
|
||||
render(<WorkflowToolConfigureButton {...props} />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('workflow.common.workflowAsToolTip')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should not show outdated when parameters match', async () => {
|
||||
// Arrange
|
||||
const detail = createMockWorkflowToolDetail()
|
||||
mockFetchWorkflowToolDetailByAppID.mockResolvedValue(detail)
|
||||
const props = createDefaultConfigureButtonProps({
|
||||
published: true,
|
||||
inputs: [createMockInputVar({ variable: 'test_var', required: true, type: InputVarType.textInput })],
|
||||
})
|
||||
|
||||
// Act
|
||||
render(<WorkflowToolConfigureButton {...props} />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('workflow.common.configure')).toBeInTheDocument()
|
||||
})
|
||||
expect(screen.queryByText('workflow.common.workflowAsToolTip')).not.toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
// User Interactions Tests
|
||||
describe('User Interactions', () => {
|
||||
it('should navigate to tools page when manage button clicked', async () => {
|
||||
@@ -611,174 +469,10 @@ describe('WorkflowToolConfigureButton', () => {
|
||||
// Assert
|
||||
expect(mockPush).toHaveBeenCalledWith('/tools?category=workflow')
|
||||
})
|
||||
|
||||
it('should create workflow tool provider on first publish', async () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup()
|
||||
mockCreateWorkflowToolProvider.mockResolvedValue({})
|
||||
const props = createDefaultConfigureButtonProps()
|
||||
|
||||
// Act
|
||||
render(<WorkflowToolConfigureButton {...props} />)
|
||||
|
||||
// Open modal
|
||||
const triggerArea = screen.getByText('workflow.common.workflowAsTool').closest('.flex')
|
||||
await user.click(triggerArea!)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('drawer')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Fill in required name field
|
||||
const nameInput = screen.getByPlaceholderText('tools.createTool.nameForToolCallPlaceHolder')
|
||||
await user.type(nameInput, 'my_tool')
|
||||
|
||||
// Click save
|
||||
await user.click(screen.getByText('common.operation.save'))
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(mockCreateWorkflowToolProvider).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should show success toast after creating workflow tool', async () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup()
|
||||
mockCreateWorkflowToolProvider.mockResolvedValue({})
|
||||
const props = createDefaultConfigureButtonProps()
|
||||
|
||||
// Act
|
||||
render(<WorkflowToolConfigureButton {...props} />)
|
||||
|
||||
const triggerArea = screen.getByText('workflow.common.workflowAsTool').closest('.flex')
|
||||
await user.click(triggerArea!)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('drawer')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const nameInput = screen.getByPlaceholderText('tools.createTool.nameForToolCallPlaceHolder')
|
||||
await user.type(nameInput, 'my_tool')
|
||||
|
||||
await user.click(screen.getByText('common.operation.save'))
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(mockToastNotify).toHaveBeenCalledWith({
|
||||
type: 'success',
|
||||
message: 'common.api.actionSuccess',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('should show error toast when create fails', async () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup()
|
||||
mockCreateWorkflowToolProvider.mockRejectedValue(new Error('Create failed'))
|
||||
const props = createDefaultConfigureButtonProps()
|
||||
|
||||
// Act
|
||||
render(<WorkflowToolConfigureButton {...props} />)
|
||||
|
||||
const triggerArea = screen.getByText('workflow.common.workflowAsTool').closest('.flex')
|
||||
await user.click(triggerArea!)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('drawer')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const nameInput = screen.getByPlaceholderText('tools.createTool.nameForToolCallPlaceHolder')
|
||||
await user.type(nameInput, 'my_tool')
|
||||
|
||||
await user.click(screen.getByText('common.operation.save'))
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(mockToastNotify).toHaveBeenCalledWith({
|
||||
type: 'error',
|
||||
message: 'Create failed',
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
it('should call onRefreshData after successful create', async () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup()
|
||||
const onRefreshData = vi.fn()
|
||||
mockCreateWorkflowToolProvider.mockResolvedValue({})
|
||||
const props = createDefaultConfigureButtonProps({ onRefreshData })
|
||||
|
||||
// Act
|
||||
render(<WorkflowToolConfigureButton {...props} />)
|
||||
|
||||
const triggerArea = screen.getByText('workflow.common.workflowAsTool').closest('.flex')
|
||||
await user.click(triggerArea!)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('drawer')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const nameInput = screen.getByPlaceholderText('tools.createTool.nameForToolCallPlaceHolder')
|
||||
await user.type(nameInput, 'my_tool')
|
||||
|
||||
await user.click(screen.getByText('common.operation.save'))
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(onRefreshData).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
it('should invalidate all workflow tools after successful create', async () => {
|
||||
// Arrange
|
||||
const user = userEvent.setup()
|
||||
mockCreateWorkflowToolProvider.mockResolvedValue({})
|
||||
const props = createDefaultConfigureButtonProps()
|
||||
|
||||
// Act
|
||||
render(<WorkflowToolConfigureButton {...props} />)
|
||||
|
||||
const triggerArea = screen.getByText('workflow.common.workflowAsTool').closest('.flex')
|
||||
await user.click(triggerArea!)
|
||||
|
||||
await waitFor(() => {
|
||||
expect(screen.getByTestId('drawer')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
const nameInput = screen.getByPlaceholderText('tools.createTool.nameForToolCallPlaceHolder')
|
||||
await user.type(nameInput, 'my_tool')
|
||||
|
||||
await user.click(screen.getByText('common.operation.save'))
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(mockInvalidateAllWorkflowTools).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
// Edge Cases (REQUIRED)
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle API returning undefined', async () => {
|
||||
// Arrange - API returns undefined (simulating empty response or handled error)
|
||||
mockFetchWorkflowToolDetailByAppID.mockResolvedValue(undefined)
|
||||
const props = createDefaultConfigureButtonProps({ published: true })
|
||||
|
||||
// Act
|
||||
render(<WorkflowToolConfigureButton {...props} />)
|
||||
|
||||
// Assert - should not crash and wait for API call
|
||||
await waitFor(() => {
|
||||
expect(mockFetchWorkflowToolDetailByAppID).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
// Component should still render without crashing
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('workflow.common.workflowAsTool')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle rapid publish/unpublish state changes', async () => {
|
||||
// Arrange
|
||||
const props = createDefaultConfigureButtonProps({ published: false })
|
||||
@@ -798,35 +492,7 @@ describe('WorkflowToolConfigureButton', () => {
|
||||
})
|
||||
|
||||
// Assert - should not crash
|
||||
expect(mockFetchWorkflowToolDetailByAppID).toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should handle detail with empty parameters', async () => {
|
||||
// Arrange
|
||||
const detail = createMockWorkflowToolDetail()
|
||||
detail.tool.parameters = []
|
||||
mockFetchWorkflowToolDetailByAppID.mockResolvedValue(detail)
|
||||
const props = createDefaultConfigureButtonProps({ published: true, inputs: [] })
|
||||
|
||||
// Act
|
||||
render(<WorkflowToolConfigureButton {...props} />)
|
||||
|
||||
// Assert
|
||||
await waitFor(() => {
|
||||
expect(screen.getByText('workflow.common.configure')).toBeInTheDocument()
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle detail with undefined output_schema', async () => {
|
||||
// Arrange
|
||||
const detail = createMockWorkflowToolDetail()
|
||||
// @ts-expect-error - testing undefined case
|
||||
detail.tool.output_schema = undefined
|
||||
mockFetchWorkflowToolDetailByAppID.mockResolvedValue(detail)
|
||||
const props = createDefaultConfigureButtonProps({ published: true })
|
||||
|
||||
// Act & Assert
|
||||
expect(() => render(<WorkflowToolConfigureButton {...props} />)).not.toThrow()
|
||||
expect(screen.getByText('workflow.common.workflowAsTool')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should handle paragraph type input conversion', async () => {
|
||||
@@ -1853,7 +1519,10 @@ describe('Integration Tests', () => {
|
||||
vi.clearAllMocks()
|
||||
mockPortalOpenState = false
|
||||
mockIsCurrentWorkspaceManager.mockReturnValue(true)
|
||||
mockFetchWorkflowToolDetailByAppID.mockResolvedValue(createMockWorkflowToolDetail())
|
||||
mockUseWorkflowToolDetailByAppID.mockImplementation((_appId: string, enabled: boolean) => ({
|
||||
data: enabled ? createMockWorkflowToolDetail() : undefined,
|
||||
isLoading: false,
|
||||
}))
|
||||
})
|
||||
|
||||
// Complete workflow: open modal -> fill form -> save
|
||||
|
||||
@@ -1,22 +1,16 @@
|
||||
'use client'
|
||||
import type { Emoji, WorkflowToolProviderOutputParameter, WorkflowToolProviderParameter, WorkflowToolProviderRequest, WorkflowToolProviderResponse } from '@/app/components/tools/types'
|
||||
import type { Emoji } from '@/app/components/tools/types'
|
||||
import type { InputVar, Variable } from '@/app/components/workflow/types'
|
||||
import type { PublishWorkflowParams } from '@/types/workflow'
|
||||
import { RiArrowRightUpLine, RiHammerLine } from '@remixicon/react'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import * as React from 'react'
|
||||
import { useCallback, useEffect, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Button from '@/app/components/base/button'
|
||||
import Loading from '@/app/components/base/loading'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import Indicator from '@/app/components/header/indicator'
|
||||
import WorkflowToolModal from '@/app/components/tools/workflow-tool'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { createWorkflowToolProvider, fetchWorkflowToolDetailByAppID, saveWorkflowToolProvider } from '@/service/tools'
|
||||
import { useInvalidateAllWorkflowTools } from '@/service/use-tools'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import Divider from '../../base/divider'
|
||||
import { useConfigureButton } from './hooks/use-configure-button'
|
||||
|
||||
type Props = {
|
||||
disabled: boolean
|
||||
@@ -48,153 +42,29 @@ const WorkflowToolConfigureButton = ({
|
||||
disabledReason,
|
||||
}: Props) => {
|
||||
const { t } = useTranslation()
|
||||
const router = useRouter()
|
||||
const [showModal, setShowModal] = useState(false)
|
||||
const [isLoading, setIsLoading] = useState(false)
|
||||
const [detail, setDetail] = useState<WorkflowToolProviderResponse>()
|
||||
const { isCurrentWorkspaceManager } = useAppContext()
|
||||
const invalidateAllWorkflowTools = useInvalidateAllWorkflowTools()
|
||||
|
||||
const outdated = useMemo(() => {
|
||||
if (!detail)
|
||||
return false
|
||||
if (detail.tool.parameters.length !== inputs?.length) {
|
||||
return true
|
||||
}
|
||||
else {
|
||||
for (const item of inputs || []) {
|
||||
const param = detail.tool.parameters.find(toolParam => toolParam.name === item.variable)
|
||||
if (!param) {
|
||||
return true
|
||||
}
|
||||
else if (param.required !== item.required) {
|
||||
return true
|
||||
}
|
||||
else {
|
||||
if (item.type === 'paragraph' && param.type !== 'string')
|
||||
return true
|
||||
if (item.type === 'text-input' && param.type !== 'string')
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}, [detail, inputs])
|
||||
|
||||
const payload = useMemo(() => {
|
||||
let parameters: WorkflowToolProviderParameter[] = []
|
||||
let outputParameters: WorkflowToolProviderOutputParameter[] = []
|
||||
|
||||
if (!published) {
|
||||
parameters = (inputs || []).map((item) => {
|
||||
return {
|
||||
name: item.variable,
|
||||
description: '',
|
||||
form: 'llm',
|
||||
required: item.required,
|
||||
type: item.type,
|
||||
}
|
||||
})
|
||||
outputParameters = (outputs || []).map((item) => {
|
||||
return {
|
||||
name: item.variable,
|
||||
description: '',
|
||||
type: item.value_type,
|
||||
}
|
||||
})
|
||||
}
|
||||
else if (detail && detail.tool) {
|
||||
parameters = (inputs || []).map((item) => {
|
||||
return {
|
||||
name: item.variable,
|
||||
required: item.required,
|
||||
type: item.type === 'paragraph' ? 'string' : item.type,
|
||||
description: detail.tool.parameters.find(param => param.name === item.variable)?.llm_description || '',
|
||||
form: detail.tool.parameters.find(param => param.name === item.variable)?.form || 'llm',
|
||||
}
|
||||
})
|
||||
outputParameters = (outputs || []).map((item) => {
|
||||
const found = detail.tool.output_schema?.properties?.[item.variable]
|
||||
return {
|
||||
name: item.variable,
|
||||
description: found ? found.description : '',
|
||||
type: item.value_type,
|
||||
}
|
||||
})
|
||||
}
|
||||
return {
|
||||
icon: detail?.icon || icon,
|
||||
label: detail?.label || name,
|
||||
name: detail?.name || '',
|
||||
description: detail?.description || description,
|
||||
parameters,
|
||||
outputParameters,
|
||||
labels: detail?.tool?.labels || [],
|
||||
privacy_policy: detail?.privacy_policy || '',
|
||||
...(published
|
||||
? {
|
||||
workflow_tool_id: detail?.workflow_tool_id,
|
||||
}
|
||||
: {
|
||||
workflow_app_id: workflowAppId,
|
||||
}),
|
||||
}
|
||||
}, [detail, published, workflowAppId, icon, name, description, inputs])
|
||||
|
||||
const getDetail = useCallback(async (workflowAppId: string) => {
|
||||
setIsLoading(true)
|
||||
const res = await fetchWorkflowToolDetailByAppID(workflowAppId)
|
||||
setDetail(res)
|
||||
setIsLoading(false)
|
||||
}, [])
|
||||
|
||||
useEffect(() => {
|
||||
if (published)
|
||||
getDetail(workflowAppId)
|
||||
}, [getDetail, published, workflowAppId])
|
||||
|
||||
useEffect(() => {
|
||||
if (detailNeedUpdate)
|
||||
getDetail(workflowAppId)
|
||||
}, [detailNeedUpdate, getDetail, workflowAppId])
|
||||
|
||||
const createHandle = async (data: WorkflowToolProviderRequest & { workflow_app_id: string }) => {
|
||||
try {
|
||||
await createWorkflowToolProvider(data)
|
||||
invalidateAllWorkflowTools()
|
||||
onRefreshData?.()
|
||||
getDetail(workflowAppId)
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('api.actionSuccess', { ns: 'common' }),
|
||||
})
|
||||
setShowModal(false)
|
||||
}
|
||||
catch (e) {
|
||||
Toast.notify({ type: 'error', message: (e as Error).message })
|
||||
}
|
||||
}
|
||||
|
||||
const updateWorkflowToolProvider = async (data: WorkflowToolProviderRequest & Partial<{
|
||||
workflow_app_id: string
|
||||
workflow_tool_id: string
|
||||
}>) => {
|
||||
try {
|
||||
await handlePublish()
|
||||
await saveWorkflowToolProvider(data)
|
||||
onRefreshData?.()
|
||||
invalidateAllWorkflowTools()
|
||||
getDetail(workflowAppId)
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('api.actionSuccess', { ns: 'common' }),
|
||||
})
|
||||
setShowModal(false)
|
||||
}
|
||||
catch (e) {
|
||||
Toast.notify({ type: 'error', message: (e as Error).message })
|
||||
}
|
||||
}
|
||||
const {
|
||||
showModal,
|
||||
isLoading,
|
||||
outdated,
|
||||
payload,
|
||||
isCurrentWorkspaceManager,
|
||||
openModal,
|
||||
closeModal,
|
||||
handleCreate,
|
||||
handleUpdate,
|
||||
navigateToTools,
|
||||
} = useConfigureButton({
|
||||
published,
|
||||
detailNeedUpdate,
|
||||
workflowAppId,
|
||||
icon,
|
||||
name,
|
||||
description,
|
||||
inputs,
|
||||
outputs,
|
||||
handlePublish,
|
||||
onRefreshData,
|
||||
})
|
||||
|
||||
return (
|
||||
<>
|
||||
@@ -210,17 +80,17 @@ const WorkflowToolConfigureButton = ({
|
||||
? (
|
||||
<div
|
||||
className="flex items-center justify-start gap-2 p-2 pl-2.5"
|
||||
onClick={() => !disabled && !published && setShowModal(true)}
|
||||
onClick={() => !disabled && !published && openModal()}
|
||||
>
|
||||
<RiHammerLine className={cn('relative h-4 w-4 text-text-secondary', !disabled && !published && 'group-hover:text-text-accent')} />
|
||||
<div
|
||||
title={t('common.workflowAsTool', { ns: 'workflow' }) || ''}
|
||||
className={cn('system-sm-medium shrink grow basis-0 truncate text-text-secondary', !disabled && !published && 'group-hover:text-text-accent')}
|
||||
className={cn('shrink grow basis-0 truncate text-text-secondary system-sm-medium', !disabled && !published && 'group-hover:text-text-accent')}
|
||||
>
|
||||
{t('common.workflowAsTool', { ns: 'workflow' })}
|
||||
</div>
|
||||
{!published && (
|
||||
<span className="system-2xs-medium-uppercase shrink-0 rounded-[5px] border border-divider-deep bg-components-badge-bg-dimm px-1 py-0.5 text-text-tertiary">
|
||||
<span className="shrink-0 rounded-[5px] border border-divider-deep bg-components-badge-bg-dimm px-1 py-0.5 text-text-tertiary system-2xs-medium-uppercase">
|
||||
{t('common.configureRequired', { ns: 'workflow' })}
|
||||
</span>
|
||||
)}
|
||||
@@ -233,7 +103,7 @@ const WorkflowToolConfigureButton = ({
|
||||
<RiHammerLine className="h-4 w-4 text-text-tertiary" />
|
||||
<div
|
||||
title={t('common.workflowAsTool', { ns: 'workflow' }) || ''}
|
||||
className="system-sm-medium shrink grow basis-0 truncate text-text-tertiary"
|
||||
className="shrink grow basis-0 truncate text-text-tertiary system-sm-medium"
|
||||
>
|
||||
{t('common.workflowAsTool', { ns: 'workflow' })}
|
||||
</div>
|
||||
@@ -250,7 +120,7 @@ const WorkflowToolConfigureButton = ({
|
||||
<Button
|
||||
size="small"
|
||||
className="w-[140px]"
|
||||
onClick={() => setShowModal(true)}
|
||||
onClick={openModal}
|
||||
disabled={!isCurrentWorkspaceManager || disabled}
|
||||
>
|
||||
{t('common.configure', { ns: 'workflow' })}
|
||||
@@ -259,7 +129,7 @@ const WorkflowToolConfigureButton = ({
|
||||
<Button
|
||||
size="small"
|
||||
className="w-[140px]"
|
||||
onClick={() => router.push('/tools?category=workflow')}
|
||||
onClick={navigateToTools}
|
||||
disabled={disabled}
|
||||
>
|
||||
{t('common.manageInTools', { ns: 'workflow' })}
|
||||
@@ -280,9 +150,9 @@ const WorkflowToolConfigureButton = ({
|
||||
<WorkflowToolModal
|
||||
isAdd={!published}
|
||||
payload={payload}
|
||||
onHide={() => setShowModal(false)}
|
||||
onCreate={createHandle}
|
||||
onSave={updateWorkflowToolProvider}
|
||||
onHide={closeModal}
|
||||
onCreate={handleCreate}
|
||||
onSave={handleUpdate}
|
||||
/>
|
||||
)}
|
||||
</>
|
||||
|
||||
@@ -0,0 +1,541 @@
|
||||
import type { WorkflowToolProviderRequest, WorkflowToolProviderResponse } from '@/app/components/tools/types'
|
||||
import type { InputVar, Variable } from '@/app/components/workflow/types'
|
||||
import { act, renderHook } from '@testing-library/react'
|
||||
import { InputVarType } from '@/app/components/workflow/types'
|
||||
import { isParametersOutdated, useConfigureButton } from '../use-configure-button'
|
||||
|
||||
const mockPush = vi.fn()
|
||||
vi.mock('next/navigation', () => ({
|
||||
useRouter: () => ({ push: mockPush }),
|
||||
}))
|
||||
|
||||
const mockIsCurrentWorkspaceManager = vi.fn(() => true)
|
||||
vi.mock('@/context/app-context', () => ({
|
||||
useAppContext: () => ({
|
||||
isCurrentWorkspaceManager: mockIsCurrentWorkspaceManager(),
|
||||
}),
|
||||
}))
|
||||
|
||||
const mockCreateWorkflowToolProvider = vi.fn()
|
||||
const mockSaveWorkflowToolProvider = vi.fn()
|
||||
vi.mock('@/service/tools', () => ({
|
||||
createWorkflowToolProvider: (...args: unknown[]) => mockCreateWorkflowToolProvider(...args),
|
||||
saveWorkflowToolProvider: (...args: unknown[]) => mockSaveWorkflowToolProvider(...args),
|
||||
}))
|
||||
|
||||
const mockInvalidateAllWorkflowTools = vi.fn()
|
||||
const mockInvalidateWorkflowToolDetailByAppID = vi.fn()
|
||||
const mockUseWorkflowToolDetailByAppID = vi.fn()
|
||||
vi.mock('@/service/use-tools', () => ({
|
||||
useInvalidateAllWorkflowTools: () => mockInvalidateAllWorkflowTools,
|
||||
useInvalidateWorkflowToolDetailByAppID: () => mockInvalidateWorkflowToolDetailByAppID,
|
||||
useWorkflowToolDetailByAppID: (...args: unknown[]) => mockUseWorkflowToolDetailByAppID(...args),
|
||||
}))
|
||||
|
||||
const mockToastNotify = vi.fn()
|
||||
vi.mock('@/app/components/base/toast', () => ({
|
||||
default: {
|
||||
notify: (options: { type: string, message: string }) => mockToastNotify(options),
|
||||
},
|
||||
}))
|
||||
|
||||
const createMockEmoji = () => ({ content: '🔧', background: '#ffffff' })
|
||||
|
||||
const createMockInputVar = (overrides: Partial<InputVar> = {}): InputVar => ({
|
||||
variable: 'test_var',
|
||||
label: 'Test Variable',
|
||||
type: InputVarType.textInput,
|
||||
required: true,
|
||||
max_length: 100,
|
||||
options: [],
|
||||
...overrides,
|
||||
} as InputVar)
|
||||
|
||||
const createMockVariable = (overrides: Partial<Variable> = {}): Variable => ({
|
||||
variable: 'output_var',
|
||||
value_type: 'string',
|
||||
...overrides,
|
||||
} as Variable)
|
||||
|
||||
const createMockDetail = (overrides: Partial<WorkflowToolProviderResponse> = {}): WorkflowToolProviderResponse => ({
|
||||
workflow_app_id: 'app-123',
|
||||
workflow_tool_id: 'tool-456',
|
||||
label: 'Test Tool',
|
||||
name: 'test_tool',
|
||||
icon: createMockEmoji(),
|
||||
description: 'A test workflow tool',
|
||||
synced: true,
|
||||
tool: {
|
||||
author: 'test-author',
|
||||
name: 'test_tool',
|
||||
label: { en_US: 'Test Tool', zh_Hans: '测试工具' },
|
||||
description: { en_US: 'Test description', zh_Hans: '测试描述' },
|
||||
labels: ['label1'],
|
||||
parameters: [
|
||||
{
|
||||
name: 'test_var',
|
||||
label: { en_US: 'Test Variable', zh_Hans: '测试变量' },
|
||||
human_description: { en_US: 'A test variable', zh_Hans: '测试变量' },
|
||||
type: 'string',
|
||||
form: 'llm',
|
||||
llm_description: 'Test variable description',
|
||||
required: true,
|
||||
default: '',
|
||||
},
|
||||
],
|
||||
output_schema: {
|
||||
type: 'object',
|
||||
properties: {
|
||||
output_var: { type: 'string', description: 'Output description' },
|
||||
},
|
||||
},
|
||||
},
|
||||
privacy_policy: 'https://example.com/privacy',
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createDefaultOptions = (overrides = {}) => ({
|
||||
published: false,
|
||||
detailNeedUpdate: false,
|
||||
workflowAppId: 'app-123',
|
||||
icon: createMockEmoji(),
|
||||
name: 'Test Workflow',
|
||||
description: 'Test workflow description',
|
||||
inputs: [createMockInputVar()],
|
||||
outputs: [createMockVariable()],
|
||||
handlePublish: vi.fn().mockResolvedValue(undefined),
|
||||
onRefreshData: vi.fn(),
|
||||
...overrides,
|
||||
})
|
||||
|
||||
const createMockRequest = (extra: Record<string, string> = {}): WorkflowToolProviderRequest & Record<string, unknown> => ({
|
||||
name: 'test_tool',
|
||||
description: 'desc',
|
||||
icon: createMockEmoji(),
|
||||
label: 'Test Tool',
|
||||
parameters: [{ name: 'test_var', description: '', form: 'llm' }],
|
||||
labels: [],
|
||||
privacy_policy: '',
|
||||
...extra,
|
||||
})
|
||||
|
||||
describe('isParametersOutdated', () => {
|
||||
it('should return false when detail is undefined', () => {
|
||||
expect(isParametersOutdated(undefined, [createMockInputVar()])).toBe(false)
|
||||
})
|
||||
|
||||
it('should return true when parameter count differs', () => {
|
||||
const detail = createMockDetail()
|
||||
const inputs = [
|
||||
createMockInputVar({ variable: 'test_var' }),
|
||||
createMockInputVar({ variable: 'extra_var' }),
|
||||
]
|
||||
expect(isParametersOutdated(detail, inputs)).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true when parameter is not found in detail', () => {
|
||||
const detail = createMockDetail()
|
||||
const inputs = [createMockInputVar({ variable: 'unknown_var' })]
|
||||
expect(isParametersOutdated(detail, inputs)).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true when required property differs', () => {
|
||||
const detail = createMockDetail()
|
||||
const inputs = [createMockInputVar({ variable: 'test_var', required: false })]
|
||||
expect(isParametersOutdated(detail, inputs)).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true when paragraph type does not match string', () => {
|
||||
const detail = createMockDetail()
|
||||
detail.tool.parameters[0].type = 'number'
|
||||
const inputs = [createMockInputVar({ variable: 'test_var', type: InputVarType.paragraph })]
|
||||
expect(isParametersOutdated(detail, inputs)).toBe(true)
|
||||
})
|
||||
|
||||
it('should return true when text-input type does not match string', () => {
|
||||
const detail = createMockDetail()
|
||||
detail.tool.parameters[0].type = 'number'
|
||||
const inputs = [createMockInputVar({ variable: 'test_var', type: InputVarType.textInput })]
|
||||
expect(isParametersOutdated(detail, inputs)).toBe(true)
|
||||
})
|
||||
|
||||
it('should return false when paragraph type matches string', () => {
|
||||
const detail = createMockDetail()
|
||||
const inputs = [createMockInputVar({ variable: 'test_var', type: InputVarType.paragraph })]
|
||||
expect(isParametersOutdated(detail, inputs)).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false when text-input type matches string', () => {
|
||||
const detail = createMockDetail()
|
||||
const inputs = [createMockInputVar({ variable: 'test_var', type: InputVarType.textInput })]
|
||||
expect(isParametersOutdated(detail, inputs)).toBe(false)
|
||||
})
|
||||
|
||||
it('should return false when all parameters match', () => {
|
||||
const detail = createMockDetail()
|
||||
const inputs = [createMockInputVar({ variable: 'test_var', required: true })]
|
||||
expect(isParametersOutdated(detail, inputs)).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle undefined inputs with empty detail parameters', () => {
|
||||
const detail = createMockDetail()
|
||||
detail.tool.parameters = []
|
||||
expect(isParametersOutdated(detail, undefined)).toBe(false)
|
||||
})
|
||||
|
||||
it('should return true when inputs undefined but detail has parameters', () => {
|
||||
const detail = createMockDetail()
|
||||
expect(isParametersOutdated(detail, undefined)).toBe(true)
|
||||
})
|
||||
|
||||
it('should handle empty inputs and empty detail parameters', () => {
|
||||
const detail = createMockDetail()
|
||||
detail.tool.parameters = []
|
||||
expect(isParametersOutdated(detail, [])).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('useConfigureButton', () => {
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
mockIsCurrentWorkspaceManager.mockReturnValue(true)
|
||||
mockUseWorkflowToolDetailByAppID.mockImplementation((_appId: string, enabled: boolean) => ({
|
||||
data: enabled ? createMockDetail() : undefined,
|
||||
isLoading: false,
|
||||
}))
|
||||
})
|
||||
|
||||
afterEach(() => {
|
||||
vi.restoreAllMocks()
|
||||
})
|
||||
|
||||
describe('Initialization', () => {
|
||||
it('should return showModal as false by default', () => {
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions()))
|
||||
expect(result.current.showModal).toBe(false)
|
||||
})
|
||||
|
||||
it('should forward isCurrentWorkspaceManager from context', () => {
|
||||
mockIsCurrentWorkspaceManager.mockReturnValue(false)
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions()))
|
||||
expect(result.current.isCurrentWorkspaceManager).toBe(false)
|
||||
})
|
||||
|
||||
it('should forward isLoading from query hook', () => {
|
||||
mockUseWorkflowToolDetailByAppID.mockReturnValue({ data: undefined, isLoading: true })
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({ published: true })))
|
||||
expect(result.current.isLoading).toBe(true)
|
||||
})
|
||||
|
||||
it('should call query hook with enabled=true when published', () => {
|
||||
renderHook(() => useConfigureButton(createDefaultOptions({ published: true })))
|
||||
expect(mockUseWorkflowToolDetailByAppID).toHaveBeenCalledWith('app-123', true)
|
||||
})
|
||||
|
||||
it('should call query hook with enabled=false when not published', () => {
|
||||
renderHook(() => useConfigureButton(createDefaultOptions({ published: false })))
|
||||
expect(mockUseWorkflowToolDetailByAppID).toHaveBeenCalledWith('app-123', false)
|
||||
})
|
||||
})
|
||||
|
||||
// Computed values
|
||||
describe('Computed - outdated', () => {
|
||||
it('should be false when not published (no detail)', () => {
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions()))
|
||||
expect(result.current.outdated).toBe(false)
|
||||
})
|
||||
|
||||
it('should be true when parameters differ', () => {
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({
|
||||
published: true,
|
||||
inputs: [
|
||||
createMockInputVar({ variable: 'test_var' }),
|
||||
createMockInputVar({ variable: 'extra_var' }),
|
||||
],
|
||||
})))
|
||||
expect(result.current.outdated).toBe(true)
|
||||
})
|
||||
|
||||
it('should be false when parameters match', () => {
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({
|
||||
published: true,
|
||||
inputs: [createMockInputVar({ variable: 'test_var', required: true })],
|
||||
})))
|
||||
expect(result.current.outdated).toBe(false)
|
||||
})
|
||||
})
|
||||
|
||||
describe('Computed - payload', () => {
|
||||
it('should use prop values when not published', () => {
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions()))
|
||||
|
||||
expect(result.current.payload).toMatchObject({
|
||||
icon: createMockEmoji(),
|
||||
label: 'Test Workflow',
|
||||
name: '',
|
||||
description: 'Test workflow description',
|
||||
workflow_app_id: 'app-123',
|
||||
})
|
||||
expect(result.current.payload.parameters).toHaveLength(1)
|
||||
expect(result.current.payload.parameters[0]).toMatchObject({
|
||||
name: 'test_var',
|
||||
form: 'llm',
|
||||
description: '',
|
||||
})
|
||||
})
|
||||
|
||||
it('should use detail values when published with detail', () => {
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({ published: true })))
|
||||
|
||||
expect(result.current.payload).toMatchObject({
|
||||
icon: createMockEmoji(),
|
||||
label: 'Test Tool',
|
||||
name: 'test_tool',
|
||||
description: 'A test workflow tool',
|
||||
workflow_tool_id: 'tool-456',
|
||||
privacy_policy: 'https://example.com/privacy',
|
||||
labels: ['label1'],
|
||||
})
|
||||
expect(result.current.payload.parameters[0]).toMatchObject({
|
||||
name: 'test_var',
|
||||
description: 'Test variable description',
|
||||
form: 'llm',
|
||||
})
|
||||
})
|
||||
|
||||
it('should return empty parameters when published without detail', () => {
|
||||
mockUseWorkflowToolDetailByAppID.mockReturnValue({ data: undefined, isLoading: false })
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({ published: true })))
|
||||
|
||||
expect(result.current.payload.parameters).toHaveLength(0)
|
||||
expect(result.current.payload.outputParameters).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('should build output parameters from detail output_schema', () => {
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({ published: true })))
|
||||
|
||||
expect(result.current.payload.outputParameters).toHaveLength(1)
|
||||
expect(result.current.payload.outputParameters[0]).toMatchObject({
|
||||
name: 'output_var',
|
||||
description: 'Output description',
|
||||
})
|
||||
})
|
||||
|
||||
it('should handle undefined output_schema in detail', () => {
|
||||
const detail = createMockDetail()
|
||||
// @ts-expect-error - testing undefined case
|
||||
detail.tool.output_schema = undefined
|
||||
mockUseWorkflowToolDetailByAppID.mockReturnValue({ data: detail, isLoading: false })
|
||||
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({ published: true })))
|
||||
|
||||
expect(result.current.payload.outputParameters[0]).toMatchObject({
|
||||
name: 'output_var',
|
||||
description: '',
|
||||
})
|
||||
})
|
||||
|
||||
it('should convert paragraph type to string in existing parameters', () => {
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({
|
||||
published: true,
|
||||
inputs: [createMockInputVar({ variable: 'test_var', type: InputVarType.paragraph })],
|
||||
})))
|
||||
|
||||
expect(result.current.payload.parameters[0].type).toBe('string')
|
||||
})
|
||||
})
|
||||
|
||||
// Modal controls
|
||||
describe('Modal Controls', () => {
|
||||
it('should open modal via openModal', () => {
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions()))
|
||||
act(() => {
|
||||
result.current.openModal()
|
||||
})
|
||||
expect(result.current.showModal).toBe(true)
|
||||
})
|
||||
|
||||
it('should close modal via closeModal', () => {
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions()))
|
||||
act(() => {
|
||||
result.current.openModal()
|
||||
})
|
||||
act(() => {
|
||||
result.current.closeModal()
|
||||
})
|
||||
expect(result.current.showModal).toBe(false)
|
||||
})
|
||||
|
||||
it('should navigate to tools page', () => {
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions()))
|
||||
act(() => {
|
||||
result.current.navigateToTools()
|
||||
})
|
||||
expect(mockPush).toHaveBeenCalledWith('/tools?category=workflow')
|
||||
})
|
||||
})
|
||||
|
||||
// Mutation handlers
|
||||
describe('handleCreate', () => {
|
||||
it('should create provider, invalidate caches, refresh, and close modal', async () => {
|
||||
mockCreateWorkflowToolProvider.mockResolvedValue({})
|
||||
const onRefreshData = vi.fn()
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({ onRefreshData })))
|
||||
|
||||
act(() => {
|
||||
result.current.openModal()
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleCreate(createMockRequest({ workflow_app_id: 'app-123' }) as WorkflowToolProviderRequest & { workflow_app_id: string })
|
||||
})
|
||||
|
||||
expect(mockCreateWorkflowToolProvider).toHaveBeenCalled()
|
||||
expect(mockInvalidateAllWorkflowTools).toHaveBeenCalled()
|
||||
expect(onRefreshData).toHaveBeenCalled()
|
||||
expect(mockInvalidateWorkflowToolDetailByAppID).toHaveBeenCalledWith('app-123')
|
||||
expect(mockToastNotify).toHaveBeenCalledWith({ type: 'success', message: expect.any(String) })
|
||||
expect(result.current.showModal).toBe(false)
|
||||
})
|
||||
|
||||
it('should show error toast on failure', async () => {
|
||||
mockCreateWorkflowToolProvider.mockRejectedValue(new Error('Create failed'))
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions()))
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleCreate(createMockRequest({ workflow_app_id: 'app-123' }) as WorkflowToolProviderRequest & { workflow_app_id: string })
|
||||
})
|
||||
|
||||
expect(mockToastNotify).toHaveBeenCalledWith({ type: 'error', message: 'Create failed' })
|
||||
})
|
||||
})
|
||||
|
||||
describe('handleUpdate', () => {
|
||||
it('should publish, save, invalidate caches, and close modal', async () => {
|
||||
mockSaveWorkflowToolProvider.mockResolvedValue({})
|
||||
const handlePublish = vi.fn().mockResolvedValue(undefined)
|
||||
const onRefreshData = vi.fn()
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({
|
||||
published: true,
|
||||
handlePublish,
|
||||
onRefreshData,
|
||||
})))
|
||||
|
||||
act(() => {
|
||||
result.current.openModal()
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate(createMockRequest({ workflow_tool_id: 'tool-456' }) as WorkflowToolProviderRequest & Partial<{ workflow_app_id: string, workflow_tool_id: string }>)
|
||||
})
|
||||
|
||||
expect(handlePublish).toHaveBeenCalled()
|
||||
expect(mockSaveWorkflowToolProvider).toHaveBeenCalled()
|
||||
expect(onRefreshData).toHaveBeenCalled()
|
||||
expect(mockInvalidateAllWorkflowTools).toHaveBeenCalled()
|
||||
expect(mockInvalidateWorkflowToolDetailByAppID).toHaveBeenCalledWith('app-123')
|
||||
expect(mockToastNotify).toHaveBeenCalledWith({ type: 'success', message: expect.any(String) })
|
||||
expect(result.current.showModal).toBe(false)
|
||||
})
|
||||
|
||||
it('should show error toast when publish fails', async () => {
|
||||
const handlePublish = vi.fn().mockRejectedValue(new Error('Publish failed'))
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({
|
||||
published: true,
|
||||
handlePublish,
|
||||
})))
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate(createMockRequest() as WorkflowToolProviderRequest & Partial<{ workflow_app_id: string, workflow_tool_id: string }>)
|
||||
})
|
||||
|
||||
expect(mockToastNotify).toHaveBeenCalledWith({ type: 'error', message: 'Publish failed' })
|
||||
})
|
||||
|
||||
it('should show error toast when save fails', async () => {
|
||||
mockSaveWorkflowToolProvider.mockRejectedValue(new Error('Save failed'))
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({ published: true })))
|
||||
|
||||
await act(async () => {
|
||||
await result.current.handleUpdate(createMockRequest() as WorkflowToolProviderRequest & Partial<{ workflow_app_id: string, workflow_tool_id: string }>)
|
||||
})
|
||||
|
||||
expect(mockToastNotify).toHaveBeenCalledWith({ type: 'error', message: 'Save failed' })
|
||||
})
|
||||
})
|
||||
|
||||
// Effects
|
||||
describe('Effects', () => {
|
||||
it('should invalidate detail when detailNeedUpdate becomes true', () => {
|
||||
const options = createDefaultOptions({ published: true, detailNeedUpdate: false })
|
||||
const { rerender } = renderHook(
|
||||
(props: ReturnType<typeof createDefaultOptions>) => useConfigureButton(props),
|
||||
{ initialProps: options },
|
||||
)
|
||||
|
||||
rerender({ ...options, detailNeedUpdate: true })
|
||||
|
||||
expect(mockInvalidateWorkflowToolDetailByAppID).toHaveBeenCalledWith('app-123')
|
||||
})
|
||||
|
||||
it('should not invalidate when detailNeedUpdate stays false', () => {
|
||||
const options = createDefaultOptions({ published: true, detailNeedUpdate: false })
|
||||
const { rerender } = renderHook(
|
||||
(props: ReturnType<typeof createDefaultOptions>) => useConfigureButton(props),
|
||||
{ initialProps: options },
|
||||
)
|
||||
|
||||
rerender({ ...options })
|
||||
|
||||
expect(mockInvalidateWorkflowToolDetailByAppID).not.toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
|
||||
// Edge cases
|
||||
describe('Edge Cases', () => {
|
||||
it('should handle undefined detail from query gracefully', () => {
|
||||
mockUseWorkflowToolDetailByAppID.mockReturnValue({ data: undefined, isLoading: false })
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({ published: true })))
|
||||
|
||||
expect(result.current.outdated).toBe(false)
|
||||
expect(result.current.payload.parameters).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('should handle detail with empty parameters', () => {
|
||||
const detail = createMockDetail()
|
||||
detail.tool.parameters = []
|
||||
mockUseWorkflowToolDetailByAppID.mockReturnValue({ data: detail, isLoading: false })
|
||||
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({
|
||||
published: true,
|
||||
inputs: [],
|
||||
})))
|
||||
|
||||
expect(result.current.outdated).toBe(false)
|
||||
})
|
||||
|
||||
it('should handle undefined inputs and outputs', () => {
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({
|
||||
inputs: undefined,
|
||||
outputs: undefined,
|
||||
})))
|
||||
|
||||
expect(result.current.payload.parameters).toHaveLength(0)
|
||||
expect(result.current.payload.outputParameters).toHaveLength(0)
|
||||
})
|
||||
|
||||
it('should handle missing onRefreshData callback in create', async () => {
|
||||
mockCreateWorkflowToolProvider.mockResolvedValue({})
|
||||
const { result } = renderHook(() => useConfigureButton(createDefaultOptions({
|
||||
onRefreshData: undefined,
|
||||
})))
|
||||
|
||||
// Should not throw
|
||||
await act(async () => {
|
||||
await result.current.handleCreate(createMockRequest({ workflow_app_id: 'app-123' }) as WorkflowToolProviderRequest & { workflow_app_id: string })
|
||||
})
|
||||
|
||||
expect(mockCreateWorkflowToolProvider).toHaveBeenCalled()
|
||||
})
|
||||
})
|
||||
})
|
||||
@@ -0,0 +1,235 @@
|
||||
import type { Emoji, WorkflowToolProviderOutputParameter, WorkflowToolProviderParameter, WorkflowToolProviderRequest, WorkflowToolProviderResponse } from '@/app/components/tools/types'
|
||||
import type { InputVar, Variable } from '@/app/components/workflow/types'
|
||||
import type { PublishWorkflowParams } from '@/types/workflow'
|
||||
import { useRouter } from 'next/navigation'
|
||||
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Toast from '@/app/components/base/toast'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { createWorkflowToolProvider, saveWorkflowToolProvider } from '@/service/tools'
|
||||
import { useInvalidateAllWorkflowTools, useInvalidateWorkflowToolDetailByAppID, useWorkflowToolDetailByAppID } from '@/service/use-tools'
|
||||
|
||||
// region Pure helpers
|
||||
|
||||
/**
|
||||
* Check if workflow tool parameters are outdated compared to current inputs.
|
||||
* Uses flat early-return style to reduce cyclomatic complexity.
|
||||
*/
|
||||
export function isParametersOutdated(
|
||||
detail: WorkflowToolProviderResponse | undefined,
|
||||
inputs: InputVar[] | undefined,
|
||||
): boolean {
|
||||
if (!detail)
|
||||
return false
|
||||
if (detail.tool.parameters.length !== (inputs?.length ?? 0))
|
||||
return true
|
||||
|
||||
for (const item of inputs || []) {
|
||||
const param = detail.tool.parameters.find(p => p.name === item.variable)
|
||||
if (!param)
|
||||
return true
|
||||
if (param.required !== item.required)
|
||||
return true
|
||||
const needsStringType = item.type === 'paragraph' || item.type === 'text-input'
|
||||
if (needsStringType && param.type !== 'string')
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
function buildNewParameters(inputs?: InputVar[]): WorkflowToolProviderParameter[] {
|
||||
return (inputs || []).map(item => ({
|
||||
name: item.variable,
|
||||
description: '',
|
||||
form: 'llm',
|
||||
required: item.required,
|
||||
type: item.type,
|
||||
}))
|
||||
}
|
||||
|
||||
function buildExistingParameters(
|
||||
inputs: InputVar[] | undefined,
|
||||
detail: WorkflowToolProviderResponse,
|
||||
): WorkflowToolProviderParameter[] {
|
||||
return (inputs || []).map((item) => {
|
||||
const matched = detail.tool.parameters.find(p => p.name === item.variable)
|
||||
return {
|
||||
name: item.variable,
|
||||
required: item.required,
|
||||
type: item.type === 'paragraph' ? 'string' : item.type,
|
||||
description: matched?.llm_description || '',
|
||||
form: matched?.form || 'llm',
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
function buildNewOutputParameters(outputs?: Variable[]): WorkflowToolProviderOutputParameter[] {
|
||||
return (outputs || []).map(item => ({
|
||||
name: item.variable,
|
||||
description: '',
|
||||
type: item.value_type,
|
||||
}))
|
||||
}
|
||||
|
||||
function buildExistingOutputParameters(
|
||||
outputs: Variable[] | undefined,
|
||||
detail: WorkflowToolProviderResponse,
|
||||
): WorkflowToolProviderOutputParameter[] {
|
||||
return (outputs || []).map((item) => {
|
||||
const found = detail.tool.output_schema?.properties?.[item.variable]
|
||||
return {
|
||||
name: item.variable,
|
||||
description: found ? found.description : '',
|
||||
type: item.value_type,
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
// endregion
|
||||
|
||||
type UseConfigureButtonOptions = {
|
||||
published: boolean
|
||||
detailNeedUpdate: boolean
|
||||
workflowAppId: string
|
||||
icon: Emoji
|
||||
name: string
|
||||
description: string
|
||||
inputs?: InputVar[]
|
||||
outputs?: Variable[]
|
||||
handlePublish: (params?: PublishWorkflowParams) => Promise<void>
|
||||
onRefreshData?: () => void
|
||||
}
|
||||
|
||||
export function useConfigureButton(options: UseConfigureButtonOptions) {
|
||||
const {
|
||||
published,
|
||||
detailNeedUpdate,
|
||||
workflowAppId,
|
||||
icon,
|
||||
name,
|
||||
description,
|
||||
inputs,
|
||||
outputs,
|
||||
handlePublish,
|
||||
onRefreshData,
|
||||
} = options
|
||||
|
||||
const { t } = useTranslation()
|
||||
const router = useRouter()
|
||||
const { isCurrentWorkspaceManager } = useAppContext()
|
||||
|
||||
const [showModal, setShowModal] = useState(false)
|
||||
|
||||
// Data fetching via React Query
|
||||
const { data: detail, isLoading } = useWorkflowToolDetailByAppID(workflowAppId, published)
|
||||
|
||||
// Invalidation functions (store in ref for stable effect dependency)
|
||||
const invalidateDetail = useInvalidateWorkflowToolDetailByAppID()
|
||||
const invalidateAllWorkflowTools = useInvalidateAllWorkflowTools()
|
||||
|
||||
const invalidateDetailRef = useRef(invalidateDetail)
|
||||
invalidateDetailRef.current = invalidateDetail
|
||||
|
||||
// Refetch when detailNeedUpdate becomes true
|
||||
useEffect(() => {
|
||||
if (detailNeedUpdate)
|
||||
invalidateDetailRef.current(workflowAppId)
|
||||
}, [detailNeedUpdate, workflowAppId])
|
||||
|
||||
// Computed values
|
||||
const outdated = useMemo(
|
||||
() => isParametersOutdated(detail, inputs),
|
||||
[detail, inputs],
|
||||
)
|
||||
|
||||
const payload = useMemo(() => {
|
||||
const hasPublishedDetail = published && detail?.tool
|
||||
|
||||
const parameters = !published
|
||||
? buildNewParameters(inputs)
|
||||
: hasPublishedDetail
|
||||
? buildExistingParameters(inputs, detail)
|
||||
: []
|
||||
|
||||
const outputParameters = !published
|
||||
? buildNewOutputParameters(outputs)
|
||||
: hasPublishedDetail
|
||||
? buildExistingOutputParameters(outputs, detail)
|
||||
: []
|
||||
|
||||
return {
|
||||
icon: detail?.icon || icon,
|
||||
label: detail?.label || name,
|
||||
name: detail?.name || '',
|
||||
description: detail?.description || description,
|
||||
parameters,
|
||||
outputParameters,
|
||||
labels: detail?.tool?.labels || [],
|
||||
privacy_policy: detail?.privacy_policy || '',
|
||||
...(published
|
||||
? { workflow_tool_id: detail?.workflow_tool_id }
|
||||
: { workflow_app_id: workflowAppId }),
|
||||
}
|
||||
}, [detail, published, workflowAppId, icon, name, description, inputs, outputs])
|
||||
|
||||
// Modal controls (stable callbacks)
|
||||
const openModal = useCallback(() => setShowModal(true), [])
|
||||
const closeModal = useCallback(() => setShowModal(false), [])
|
||||
const navigateToTools = useCallback(
|
||||
() => router.push('/tools?category=workflow'),
|
||||
[router],
|
||||
)
|
||||
|
||||
// Mutation handlers (not memoized — only used in conditionally-rendered modal)
|
||||
const handleCreate = async (data: WorkflowToolProviderRequest & { workflow_app_id: string }) => {
|
||||
try {
|
||||
await createWorkflowToolProvider(data)
|
||||
invalidateAllWorkflowTools()
|
||||
onRefreshData?.()
|
||||
invalidateDetail(workflowAppId)
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('api.actionSuccess', { ns: 'common' }),
|
||||
})
|
||||
setShowModal(false)
|
||||
}
|
||||
catch (e) {
|
||||
Toast.notify({ type: 'error', message: (e as Error).message })
|
||||
}
|
||||
}
|
||||
|
||||
const handleUpdate = async (data: WorkflowToolProviderRequest & Partial<{
|
||||
workflow_app_id: string
|
||||
workflow_tool_id: string
|
||||
}>) => {
|
||||
try {
|
||||
await handlePublish()
|
||||
await saveWorkflowToolProvider(data)
|
||||
onRefreshData?.()
|
||||
invalidateAllWorkflowTools()
|
||||
invalidateDetail(workflowAppId)
|
||||
Toast.notify({
|
||||
type: 'success',
|
||||
message: t('api.actionSuccess', { ns: 'common' }),
|
||||
})
|
||||
setShowModal(false)
|
||||
}
|
||||
catch (e) {
|
||||
Toast.notify({ type: 'error', message: (e as Error).message })
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
showModal,
|
||||
isLoading,
|
||||
outdated,
|
||||
payload,
|
||||
isCurrentWorkspaceManager,
|
||||
openModal,
|
||||
closeModal,
|
||||
handleCreate,
|
||||
handleUpdate,
|
||||
navigateToTools,
|
||||
}
|
||||
}
|
||||
@@ -784,11 +784,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/app/configuration/dataset-config/card-item/index.spec.tsx": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/app/configuration/dataset-config/card-item/index.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 1
|
||||
@@ -1231,11 +1226,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/apps/app-card.spec.tsx": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 22
|
||||
}
|
||||
},
|
||||
"app/components/apps/app-card.tsx": {
|
||||
"react-hooks-extra/no-direct-set-state-in-use-effect": {
|
||||
"count": 1
|
||||
@@ -1260,21 +1250,11 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/apps/list.spec.tsx": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 5
|
||||
}
|
||||
},
|
||||
"app/components/apps/list.tsx": {
|
||||
"unused-imports/no-unused-vars": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/apps/new-app-card.spec.tsx": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 4
|
||||
}
|
||||
},
|
||||
"app/components/apps/new-app-card.tsx": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 1
|
||||
@@ -3042,11 +3022,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/custom/custom-web-app-brand/index.spec.tsx": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 7
|
||||
}
|
||||
},
|
||||
"app/components/custom/custom-web-app-brand/index.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 12
|
||||
@@ -4073,14 +4048,6 @@
|
||||
"count": 9
|
||||
}
|
||||
},
|
||||
"app/components/develop/doc.tsx": {
|
||||
"react-hooks-extra/no-direct-set-state-in-use-effect": {
|
||||
"count": 2
|
||||
},
|
||||
"ts/no-explicit-any": {
|
||||
"count": 3
|
||||
}
|
||||
},
|
||||
"app/components/develop/md.tsx": {
|
||||
"ts/no-empty-object-type": {
|
||||
"count": 1
|
||||
@@ -4735,14 +4702,6 @@
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/plugins/install-plugin/install-bundle/steps/install-multi.tsx": {
|
||||
"react-hooks-extra/no-direct-set-state-in-use-effect": {
|
||||
"count": 5
|
||||
},
|
||||
"ts/no-explicit-any": {
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"app/components/plugins/install-plugin/install-bundle/steps/install.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 2
|
||||
@@ -5766,11 +5725,6 @@
|
||||
"count": 4
|
||||
}
|
||||
},
|
||||
"app/components/tools/workflow-tool/configure-button.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 3
|
||||
}
|
||||
},
|
||||
"app/components/tools/workflow-tool/index.tsx": {
|
||||
"tailwindcss/enforce-consistent-class-order": {
|
||||
"count": 7
|
||||
@@ -5807,11 +5761,6 @@
|
||||
"count": 2
|
||||
}
|
||||
},
|
||||
"app/components/workflow-app/components/workflow-onboarding-modal/start-node-selection-panel.spec.tsx": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 1
|
||||
}
|
||||
},
|
||||
"app/components/workflow-app/hooks/use-DSL.ts": {
|
||||
"ts/no-explicit-any": {
|
||||
"count": 1
|
||||
|
||||
@@ -3,6 +3,7 @@ import type {
|
||||
Collection,
|
||||
MCPServerDetail,
|
||||
Tool,
|
||||
WorkflowToolProviderResponse,
|
||||
} from '@/app/components/tools/types'
|
||||
import type { RAGRecommendedPlugins, ToolWithProvider } from '@/app/components/workflow/types'
|
||||
import type { AppIconType } from '@/types/app'
|
||||
@@ -402,3 +403,22 @@ export const useUpdateTriggerStatus = () => {
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
const workflowToolDetailByAppIDKey = (appId: string) => [NAME_SPACE, 'workflowToolDetailByAppID', appId]
|
||||
|
||||
export const useWorkflowToolDetailByAppID = (appId: string, enabled = true) => {
|
||||
return useQuery<WorkflowToolProviderResponse>({
|
||||
queryKey: workflowToolDetailByAppIDKey(appId),
|
||||
queryFn: () => get<WorkflowToolProviderResponse>(`/workspaces/current/tool-provider/workflow/get?workflow_app_id=${appId}`),
|
||||
enabled: enabled && !!appId,
|
||||
})
|
||||
}
|
||||
|
||||
export const useInvalidateWorkflowToolDetailByAppID = () => {
|
||||
const queryClient = useQueryClient()
|
||||
return (appId: string) => {
|
||||
queryClient.invalidateQueries({
|
||||
queryKey: workflowToolDetailByAppIDKey(appId),
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user