Compare commits

...

30 Commits

Author SHA1 Message Date
yyh
aec6bcd92c fix test 2026-01-02 21:54:54 +08:00
yyh
89b29bd836 Merge remote-tracking branch 'origin/main' into feat/vibe-wf 2026-01-02 19:59:54 +08:00
yyh
32ebf2a1c6 make lint 2026-01-01 19:13:29 +08:00
qiuqiua
7d14b27447 feat: add vibe workflow (#30258)
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
2025-12-31 10:23:20 +08:00
GuanMu
d3223c6b59 fix: modal style (#30308)
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: WTW0313 <twwu@dify.ai>
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
2025-12-30 09:21:33 +08:00
GuanMu
df1715a152 fix: picker style error (#30302)
Co-authored-by: crazywoola <427733928@qq.com>
Co-authored-by: WTW0313 <twwu@dify.ai>
Co-authored-by: yyh <yuanyouhuilyz@gmail.com>
Co-authored-by: yyh <92089059+lyzno1@users.noreply.github.com>
2025-12-29 13:22:43 +08:00
Wu Tianwei
28cb1792cf fix: fix generated workflow preview and version control (#30300) 2025-12-29 12:59:20 +08:00
yyh
33e2f96647 refactor: Extract Go to Anything search logic into a new useSearch hook. 2025-12-29 10:42:25 +08:00
yyh
2ac47114f7 feat: improve search UX by adding keepPreviousData to useQuery and clearing results for empty queries. 2025-12-29 10:42:24 +08:00
yyh
1367497deb refactor goto-anything scopes and registry 2025-12-29 10:42:24 +08:00
yyh
e5a98b10a9 refactor: Replace goto-anything actions with a new scope registry system, simplifying command management and registration. 2025-12-29 10:42:24 +08:00
yyh
b7b263d54e chore: remove duplicate tests 2025-12-29 10:42:24 +08:00
yyh
0f40afafe2 fix: sync vibe apply immediately 2025-12-29 10:42:24 +08:00
yyh
de7d39d388 chore: test for vibe-panel 2025-12-29 10:42:24 +08:00
yyh
aa37e8fa4c fix: prevent stale vibe preview on invalid flow 2025-12-29 10:42:23 +08:00
yyh
cfb54a0e7d fix: enhance version management and validation in workflow hooks
- Updated `useVibeFlowData` to prevent adding empty graphs and ensure the current version is correctly derived from available versions.
- Improved error handling in `applyFlowchartToWorkflow` to notify users when the current flow graph is invalid.
- Added checks to only add valid workflow graphs to the versions list, enhancing data integrity.
2025-12-29 10:42:23 +08:00
yyh
75b7d269e1 refactor: centralize action keys and internationalization mappings
- Introduced a new `constants.ts` file to centralize action keys and i18n mappings for slash commands and scope actions.
- Updated `command-selector` and various action files to utilize the new constants for improved maintainability and readability.
- Removed hardcoded strings in favor of the new mappings, ensuring consistency across the application.
2025-12-29 10:42:23 +08:00
yyh
68c220d25e feat: implement banana command with search and registration functionality
- Added `bananaCommand` to handle vibe-related actions, including search and command registration.
- Updated `command-selector` to utilize new internationalization maps for slash commands and scope actions.
- Removed deprecated banana action from the actions index and adjusted related filtering logic.
- Added unit tests for the banana command to ensure correct behavior in various scenarios.
- Deleted obsolete banana action tests and files.
2025-12-29 10:42:23 +08:00
yyh
a4efb3acbf fix: correct header assignment in SSRF proxy request handling 2025-12-29 10:42:23 +08:00
yyh
8294759200 feat: add highPriority prop to VibePanel for improved rendering 2025-12-29 10:42:23 +08:00
WTW0313
99e5669a66 fix: ensure storageKey defaults to an empty string in Vibe panel and workflow hooks 2025-12-29 10:42:23 +08:00
WTW0313
4879795cb9 feat: update Vibe panel to use new event handling and versioning for flowcharts 2025-12-29 10:42:22 +08:00
yyh
bcef6e8216 chore: some tests 2025-12-29 10:42:22 +08:00
crazywoola
f2363fc458 feat: use new styles 2025-12-29 10:42:22 +08:00
WTW0313
0f69e2f6ab feat: implement Vibe panel for workflow with regeneration and acceptance features 2025-12-29 10:42:22 +08:00
crazywoola
b7be5c8c82 feat: add MCP tools 2025-12-29 10:42:22 +08:00
crazywoola
9d496ed3dc feat: add MCP tools 2025-12-29 10:42:22 +08:00
crazywoola
336769deb1 feat: use @banana 2025-12-29 10:42:21 +08:00
crazywoola
5705fa898f feat: v1 2025-12-29 10:42:21 +08:00
crazywoola
a5c6c8638e feat: v1 2025-12-29 10:42:21 +08:00
88 changed files with 11168 additions and 1131 deletions

View File

@@ -1,9 +1,13 @@
import logging
from collections.abc import Sequence
from typing import Any
from flask_restx import Resource
from pydantic import BaseModel, Field
logger = logging.getLogger(__name__)
from controllers.console import console_ns
from controllers.console.app.error import (
CompletionRequestError,
@@ -18,6 +22,7 @@ from core.helper.code_executor.javascript.javascript_code_provider import Javasc
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from core.workflow.generator import WorkflowGenerator
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import App
@@ -55,6 +60,30 @@ class InstructionTemplatePayload(BaseModel):
type: str = Field(..., description="Instruction template type")
class PreviousWorkflow(BaseModel):
"""Previous workflow attempt for regeneration context."""
nodes: list[dict[str, Any]] = Field(default_factory=list, description="Previously generated nodes")
edges: list[dict[str, Any]] = Field(default_factory=list, description="Previously generated edges")
warnings: list[str] = Field(default_factory=list, description="Warnings from previous generation")
class FlowchartGeneratePayload(BaseModel):
instruction: str = Field(..., description="Workflow flowchart generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
available_nodes: list[dict[str, Any]] = Field(default_factory=list, description="Available node types")
existing_nodes: list[dict[str, Any]] = Field(default_factory=list, description="Existing workflow nodes")
existing_edges: list[dict[str, Any]] = Field(default_factory=list, description="Existing workflow edges")
available_tools: list[dict[str, Any]] = Field(default_factory=list, description="Available tools")
selected_node_ids: list[str] = Field(default_factory=list, description="IDs of selected nodes for context")
previous_workflow: PreviousWorkflow | None = Field(default=None, description="Previous workflow for regeneration")
regenerate_mode: bool = Field(default=False, description="Whether this is a regeneration request")
# Language preference for generated content (node titles, descriptions)
language: str | None = Field(default=None, description="Preferred language for generated content")
# Available models that user has configured (for LLM/question-classifier nodes)
available_models: list[dict[str, Any]] = Field(default_factory=list, description="User's configured models")
def reg(cls: type[BaseModel]):
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
@@ -64,6 +93,7 @@ reg(RuleCodeGeneratePayload)
reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload)
reg(FlowchartGeneratePayload)
@console_ns.route("/rule-generate")
@@ -255,6 +285,52 @@ class InstructionGenerateApi(Resource):
raise CompletionRequestError(e.description)
@console_ns.route("/flowchart-generate")
class FlowchartGenerateApi(Resource):
@console_ns.doc("generate_workflow_flowchart")
@console_ns.doc(description="Generate workflow flowchart using LLM with intent classification")
@console_ns.expect(console_ns.models[FlowchartGeneratePayload.__name__])
@console_ns.response(200, "Flowchart generated successfully")
@console_ns.response(400, "Invalid request parameters")
@console_ns.response(402, "Provider quota exceeded")
@setup_required
@login_required
@account_initialization_required
def post(self):
args = FlowchartGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
try:
# Convert PreviousWorkflow to dict if present
previous_workflow_dict = args.previous_workflow.model_dump() if args.previous_workflow else None
result = WorkflowGenerator.generate_workflow_flowchart(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
available_nodes=args.available_nodes,
existing_nodes=args.existing_nodes,
existing_edges=args.existing_edges,
available_tools=args.available_tools,
selected_node_ids=args.selected_node_ids,
previous_workflow=previous_workflow_dict,
regenerate_mode=args.regenerate_mode,
preferred_language=args.language,
available_models=args.available_models,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
raise ProviderQuotaExceededError()
except ModelCurrentlyNotSupportError:
raise ProviderModelCurrentlyNotSupportError()
except InvokeError as e:
raise CompletionRequestError(e.description)
return result
@console_ns.route("/instruction-generate/template")
class InstructionGenerationTemplateApi(Resource):
@console_ns.doc("get_instruction_template")

View File

@@ -106,6 +106,9 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
verify_option = kwargs.pop("ssl_verify", dify_config.HTTP_REQUEST_NODE_SSL_VERIFY)
client = _get_ssrf_client(verify_option)
# Extract follow_redirects for client.send() - it's not a build_request parameter
follow_redirects = kwargs.pop("follow_redirects", True)
# Preserve user-provided Host header
# When using a forward proxy, httpx may override the Host header based on the URL.
# We extract and preserve any explicitly set Host header to support virtual hosting.
@@ -120,9 +123,10 @@ def make_request(method, url, max_retries=SSRF_DEFAULT_MAX_RETRIES, **kwargs):
# the request API to explicitly set headers before sending
headers = {k: v for k, v in headers.items() if k.lower() != "host"}
if user_provided_host is not None:
headers["host"] = user_provided_host
kwargs["headers"] = headers
response = client.request(method=method, url=url, **kwargs)
headers["Host"] = user_provided_host
request = client.build_request(method, url, headers=headers, **kwargs)
response = client.send(request, follow_redirects=follow_redirects)
# Check for SSRF protection by Squid proxy
if response.status_code in (401, 403):

View File

@@ -1,6 +1,5 @@
import json
import logging
import re
from collections.abc import Sequence
from typing import Protocol, cast
@@ -12,8 +11,6 @@ from core.llm_generator.prompts import (
CONVERSATION_TITLE_PROMPT,
GENERATOR_QA_PROMPT,
JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE,
LLM_MODIFY_CODE_SYSTEM,
LLM_MODIFY_PROMPT_SYSTEM,
PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE,
SUGGESTED_QUESTIONS_MAX_TOKENS,
SUGGESTED_QUESTIONS_TEMPERATURE,
@@ -30,6 +27,7 @@ from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from core.workflow.generator import WorkflowGenerator
from extensions.ext_database import db
from extensions.ext_storage import storage
from models import App, Message, WorkflowNodeExecutionModel
@@ -285,6 +283,35 @@ class LLMGenerator:
return rule_config
@classmethod
def generate_workflow_flowchart(
cls,
tenant_id: str,
instruction: str,
model_config: dict,
available_nodes: Sequence[dict[str, object]] | None = None,
existing_nodes: Sequence[dict[str, object]] | None = None,
available_tools: Sequence[dict[str, object]] | None = None,
selected_node_ids: Sequence[str] | None = None,
previous_workflow: dict[str, object] | None = None,
regenerate_mode: bool = False,
preferred_language: str | None = None,
available_models: Sequence[dict[str, object]] | None = None,
):
return WorkflowGenerator.generate_workflow_flowchart(
tenant_id=tenant_id,
instruction=instruction,
model_config=model_config,
available_nodes=available_nodes,
existing_nodes=existing_nodes,
available_tools=available_tools,
selected_node_ids=selected_node_ids,
previous_workflow=previous_workflow,
regenerate_mode=regenerate_mode,
preferred_language=preferred_language,
available_models=available_models,
)
@classmethod
def generate_code(cls, tenant_id: str, instruction: str, model_config: dict, code_language: str = "javascript"):
if code_language == "python":

View File

@@ -143,6 +143,50 @@ Based on task description, please create a well-structured prompt template that
Please generate the full prompt template with at least 300 words and output only the prompt template.
""" # noqa: E501
WORKFLOW_FLOWCHART_PROMPT_TEMPLATE = """
You are an expert workflow designer. Generate a Mermaid flowchart based on the user's request.
Constraints:
- Detect the language of the user's request. Generate all node titles in the same language as the user's input.
- If the input language cannot be determined, use {{PREFERRED_LANGUAGE}} as the fallback language.
- Use only node types listed in <available_nodes>.
- Use only tools listed in <available_tools>. When using a tool node, set type=tool and tool=<tool_key>.
- Tools may include MCP providers (provider_type=mcp). Tool selection still uses tool_key.
- Prefer reusing node titles from <existing_nodes> when possible.
- Output must be valid Mermaid flowchart syntax, no markdown, no extra text.
- First line must be: flowchart LR
- Every node must be declared on its own line using:
<id>["type=<type>|title=<title>|tool=<tool_key>"]
- type is required and must match a type in <available_nodes>.
- title is required for non-tool nodes.
- tool is required only when type=tool, otherwise omit tool.
- Declare all node lines before any edges.
- Edges must use:
<id> --> <id>
<id> -->|true| <id>
<id> -->|false| <id>
- Keep node ids unique and simple (N1, N2, ...).
- For complex orchestration:
- Break the request into stages (ingest, transform, decision, action, output).
- Use IfElse for branching and label edges true/false only.
- Fan-in branches by connecting multiple nodes into a shared downstream node.
- Avoid cycles unless explicitly requested.
- Keep each branch complete with a clear downstream target.
<user_request>
{{TASK_DESCRIPTION}}
</user_request>
<available_nodes>
{{AVAILABLE_NODES}}
</available_nodes>
<existing_nodes>
{{EXISTING_NODES}}
</existing_nodes>
<available_tools>
{{AVAILABLE_TOOLS}}
</available_tools>
"""
RULE_CONFIG_PROMPT_GENERATE_TEMPLATE = """
Here is a task description for which I would like you to create a high-quality prompt template for:
<task_description>

View File

@@ -0,0 +1 @@
from .runner import WorkflowGenerator

View File

@@ -0,0 +1,29 @@
"""
Vibe Workflow Generator Configuration Module.
This module centralizes configuration for the Vibe workflow generation feature,
including node schemas, fallback rules, and response templates.
"""
from core.workflow.generator.config.node_schemas import (
BUILTIN_NODE_SCHEMAS,
FALLBACK_RULES,
FIELD_NAME_CORRECTIONS,
NODE_TYPE_ALIASES,
get_builtin_node_schemas,
get_corrected_field_name,
validate_node_schemas,
)
from core.workflow.generator.config.responses import DEFAULT_SUGGESTIONS, OFF_TOPIC_RESPONSES
__all__ = [
"BUILTIN_NODE_SCHEMAS",
"DEFAULT_SUGGESTIONS",
"FALLBACK_RULES",
"FIELD_NAME_CORRECTIONS",
"NODE_TYPE_ALIASES",
"OFF_TOPIC_RESPONSES",
"get_builtin_node_schemas",
"get_corrected_field_name",
"validate_node_schemas",
]

View File

@@ -0,0 +1,501 @@
"""
Unified Node Configuration for Vibe Workflow Generation.
This module centralizes all node-related configuration:
- Node schemas (parameter definitions)
- Fallback rules (keyword-based node type inference)
- Node type aliases (natural language to canonical type mapping)
- Field name corrections (LLM output normalization)
- Validation utilities
Note: These definitions are the single source of truth.
Frontend has a mirrored copy at web/app/components/workflow/hooks/use-workflow-vibe-config.ts
"""
from typing import Any
# =============================================================================
# NODE SCHEMAS
# =============================================================================
# Built-in node schemas with parameter definitions
# These help the model understand what config each node type requires
_HARDCODED_SCHEMAS: dict[str, dict[str, Any]] = {
"http-request": {
"description": "Send HTTP requests to external APIs or fetch web content",
"required": ["url", "method"],
"parameters": {
"url": {
"type": "string",
"description": "Full URL including protocol (https://...)",
"example": "{{#start.url#}} or https://api.example.com/data",
},
"method": {
"type": "enum",
"options": ["GET", "POST", "PUT", "DELETE", "PATCH", "HEAD"],
"description": "HTTP method",
},
"headers": {
"type": "string",
"description": "HTTP headers as newline-separated 'Key: Value' pairs",
"example": "Content-Type: application/json\nAuthorization: Bearer {{#start.api_key#}}",
},
"params": {
"type": "string",
"description": "URL query parameters as newline-separated 'key: value' pairs",
},
"body": {
"type": "object",
"description": "Request body with type field required",
"example": {"type": "none", "data": []},
},
"authorization": {
"type": "object",
"description": "Authorization config",
"example": {"type": "no-auth"},
},
"timeout": {
"type": "number",
"description": "Request timeout in seconds",
"default": 60,
},
},
"outputs": ["body (response content)", "status_code", "headers"],
},
"code": {
"description": "Execute Python or JavaScript code for custom logic",
"required": ["code", "language"],
"parameters": {
"code": {
"type": "string",
"description": "Code to execute. Must define a main() function that returns a dict.",
},
"language": {
"type": "enum",
"options": ["python3", "javascript"],
},
"variables": {
"type": "array",
"description": "Input variables passed to the code",
"item_schema": {"variable": "string", "value_selector": "array"},
},
"outputs": {
"type": "object",
"description": "Output variable definitions",
},
},
"outputs": ["Variables defined in outputs schema"],
},
"llm": {
"description": "Call a large language model for text generation/processing",
"required": ["prompt_template"],
"parameters": {
"model": {
"type": "object",
"description": "Model configuration (provider, name, mode)",
},
"prompt_template": {
"type": "array",
"description": "Messages for the LLM",
"item_schema": {
"role": "enum: system, user, assistant",
"text": "string - message content, can include {{#node_id.field#}} references",
},
},
"context": {
"type": "object",
"description": "Optional context settings",
},
"memory": {
"type": "object",
"description": "Optional memory/conversation settings",
},
},
"outputs": ["text (generated response)"],
},
"if-else": {
"description": "Conditional branching based on conditions",
"required": ["cases"],
"parameters": {
"cases": {
"type": "array",
"description": "List of condition cases. Each case defines when 'true' branch is taken.",
"item_schema": {
"case_id": "string - unique case identifier (e.g., 'case_1')",
"logical_operator": "enum: and, or - how multiple conditions combine",
"conditions": {
"type": "array",
"item_schema": {
"variable_selector": "array of strings - path to variable, e.g. ['node_id', 'field']",
"comparison_operator": (
"enum: =, ≠, >, <, ≥, ≤, contains, not contains, is, is not, empty, not empty"
),
"value": "string or number - value to compare against",
},
},
},
},
},
"outputs": ["Branches: true (first case conditions met), false (else/no case matched)"],
},
"knowledge-retrieval": {
"description": "Query knowledge base for relevant content",
"required": ["query_variable_selector", "dataset_ids"],
"parameters": {
"query_variable_selector": {
"type": "array",
"description": "Path to query variable, e.g. ['start', 'query']",
},
"dataset_ids": {
"type": "array",
"description": "List of knowledge base IDs to search",
},
"retrieval_mode": {
"type": "enum",
"options": ["single", "multiple"],
},
},
"outputs": ["result (retrieved documents)"],
},
"template-transform": {
"description": "Transform data using Jinja2 templates",
"required": ["template", "variables"],
"parameters": {
"template": {
"type": "string",
"description": "Jinja2 template string. Use {{ variable_name }} to reference variables.",
},
"variables": {
"type": "array",
"description": "Input variables defined for the template",
"item_schema": {
"variable": "string - variable name to use in template",
"value_selector": "array - path to source value, e.g. ['start', 'user_input']",
},
},
},
"outputs": ["output (transformed string)"],
},
"variable-aggregator": {
"description": "Aggregate variables from multiple branches",
"required": ["variables"],
"parameters": {
"variables": {
"type": "array",
"description": "List of variable selectors to aggregate",
"item_schema": "array of strings - path to source variable, e.g. ['node_id', 'field']",
},
},
"outputs": ["output (aggregated value)"],
},
"iteration": {
"description": "Loop over array items",
"required": ["iterator_selector"],
"parameters": {
"iterator_selector": {
"type": "array",
"description": "Path to array variable to iterate",
},
},
"outputs": ["item (current iteration item)", "index (current index)"],
},
"parameter-extractor": {
"description": "Extract structured parameters from user input using LLM",
"required": ["query", "parameters"],
"parameters": {
"model": {
"type": "object",
"description": "Model configuration (provider, name, mode)",
},
"query": {
"type": "array",
"description": "Path to input text to extract parameters from, e.g. ['start', 'user_input']",
},
"parameters": {
"type": "array",
"description": "Parameters to extract from the input",
"item_schema": {
"name": "string - parameter name (required)",
"type": (
"enum: string, number, boolean, array[string], array[number], array[object], array[boolean]"
),
"description": "string - description of what to extract (required)",
"required": "boolean - whether this parameter is required (MUST be specified)",
"options": "array of strings (optional) - for enum-like selection",
},
},
"instruction": {
"type": "string",
"description": "Additional instructions for extraction",
},
"reasoning_mode": {
"type": "enum",
"options": ["function_call", "prompt"],
"description": "How to perform extraction (defaults to function_call)",
},
},
"outputs": ["Extracted parameters as defined in parameters array", "__is_success", "__reason"],
},
"question-classifier": {
"description": "Classify user input into predefined categories using LLM",
"required": ["query", "classes"],
"parameters": {
"model": {
"type": "object",
"description": "Model configuration (provider, name, mode)",
},
"query": {
"type": "array",
"description": "Path to input text to classify, e.g. ['start', 'user_input']",
},
"classes": {
"type": "array",
"description": "Classification categories",
"item_schema": {
"id": "string - unique class identifier",
"name": "string - class name/label",
},
},
"instruction": {
"type": "string",
"description": "Additional instructions for classification",
},
},
"outputs": ["class_name (selected class)"],
},
}
def _get_dynamic_schemas() -> dict[str, dict[str, Any]]:
"""
Dynamically load schemas from node classes.
Uses lazy import to avoid circular dependency.
"""
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
schemas = {}
for node_type, version_map in NODE_TYPE_CLASSES_MAPPING.items():
# Get the latest version class
node_cls = version_map.get(LATEST_VERSION)
if not node_cls:
continue
# Get schema from the class
schema = node_cls.get_default_config_schema()
if schema:
schemas[node_type.value] = schema
return schemas
# Cache for built-in schemas (populated on first access)
_builtin_schemas_cache: dict[str, dict[str, Any]] | None = None
def get_builtin_node_schemas() -> dict[str, dict[str, Any]]:
"""
Get the complete set of built-in node schemas.
Combines hardcoded schemas with dynamically loaded ones.
Results are cached after first call.
"""
global _builtin_schemas_cache
if _builtin_schemas_cache is None:
_builtin_schemas_cache = {**_HARDCODED_SCHEMAS, **_get_dynamic_schemas()}
return _builtin_schemas_cache
# For backward compatibility - but use get_builtin_node_schemas() for lazy loading
BUILTIN_NODE_SCHEMAS: dict[str, dict[str, Any]] = _HARDCODED_SCHEMAS.copy()
# =============================================================================
# FALLBACK RULES
# =============================================================================
# Keyword rules for smart fallback detection
# Maps node type to keywords that suggest using that node type as a fallback
FALLBACK_RULES: dict[str, list[str]] = {
"http-request": [
"http",
"url",
"web",
"scrape",
"scraper",
"fetch",
"api",
"request",
"download",
"upload",
"webhook",
"endpoint",
"rest",
"get",
"post",
],
"code": [
"code",
"script",
"calculate",
"compute",
"process",
"transform",
"parse",
"convert",
"format",
"filter",
"sort",
"math",
"logic",
],
"llm": [
"analyze",
"summarize",
"summary",
"extract",
"classify",
"translate",
"generate",
"write",
"rewrite",
"explain",
"answer",
"chat",
],
}
# =============================================================================
# NODE TYPE ALIASES
# =============================================================================
# Node type aliases for inference from natural language
# Maps common terms to canonical node type names
NODE_TYPE_ALIASES: dict[str, str] = {
# Start node aliases
"start": "start",
"begin": "start",
"input": "start",
# End node aliases
"end": "end",
"finish": "end",
"output": "end",
# LLM node aliases
"llm": "llm",
"ai": "llm",
"gpt": "llm",
"model": "llm",
"chat": "llm",
# Code node aliases
"code": "code",
"script": "code",
"python": "code",
"javascript": "code",
# HTTP request node aliases
"http-request": "http-request",
"http": "http-request",
"request": "http-request",
"api": "http-request",
"fetch": "http-request",
"webhook": "http-request",
# Conditional node aliases
"if-else": "if-else",
"condition": "if-else",
"branch": "if-else",
"switch": "if-else",
# Loop node aliases
"iteration": "iteration",
"loop": "loop",
"foreach": "iteration",
# Tool node alias
"tool": "tool",
}
# =============================================================================
# FIELD NAME CORRECTIONS
# =============================================================================
# Field name corrections for LLM-generated node configs
# Maps incorrect field names to correct ones for specific node types
FIELD_NAME_CORRECTIONS: dict[str, dict[str, str]] = {
"http-request": {
"text": "body", # LLM might use "text" instead of "body"
"content": "body",
"response": "body",
},
"code": {
"text": "result", # LLM might use "text" instead of "result"
"output": "result",
},
"llm": {
"response": "text",
"answer": "text",
},
}
def get_corrected_field_name(node_type: str, field: str) -> str:
"""
Get the corrected field name for a node type.
Args:
node_type: The type of the node (e.g., "http-request", "code")
field: The field name to correct
Returns:
The corrected field name, or the original if no correction needed
"""
corrections = FIELD_NAME_CORRECTIONS.get(node_type, {})
return corrections.get(field, field)
# =============================================================================
# VALIDATION UTILITIES
# =============================================================================
# Node types that are internal and don't need schemas for LLM generation
_INTERNAL_NODE_TYPES: set[str] = {
# Internal workflow nodes
"answer", # Internal to chatflow
"loop", # Uses iteration internally
"assigner", # Variable assignment utility
"variable-assigner", # Variable assignment utility
"agent", # Agent node (complex, handled separately)
"document-extractor", # Internal document processing
"list-operator", # Internal list operations
# Iteration internal nodes
"iteration-start", # Internal to iteration loop
"loop-start", # Internal to loop
"loop-end", # Internal to loop
# Trigger nodes (not user-creatable via LLM)
"trigger-plugin", # Plugin trigger
"trigger-schedule", # Scheduled trigger
"trigger-webhook", # Webhook trigger
# Other internal nodes
"datasource", # Data source configuration
"human-input", # Human-in-the-loop node
"knowledge-index", # Knowledge indexing node
}
def validate_node_schemas() -> list[str]:
"""
Validate that all registered node types have corresponding schemas.
This function checks if BUILTIN_NODE_SCHEMAS covers all node types
registered in NODE_TYPE_CLASSES_MAPPING, excluding internal node types.
Returns:
List of warning messages for missing schemas (empty if all valid)
"""
from core.workflow.nodes.node_mapping import NODE_TYPE_CLASSES_MAPPING
schemas = get_builtin_node_schemas()
warnings = []
for node_type in NODE_TYPE_CLASSES_MAPPING:
type_value = node_type.value
if type_value in _INTERNAL_NODE_TYPES:
continue
if type_value not in schemas:
warnings.append(f"Missing schema for node type: {type_value}")
return warnings

View File

@@ -0,0 +1,72 @@
"""
Response Templates for Vibe Workflow Generation.
This module defines templates for off-topic responses and default suggestions
to guide users back to workflow-related requests.
"""
# Off-topic response templates for different categories
# Each category has messages in multiple languages
OFF_TOPIC_RESPONSES: dict[str, dict[str, str]] = {
"weather": {
"en": (
"I'm the workflow design assistant - I can't check the weather, "
"but I can help you build AI workflows! For example, I could help you "
"create a workflow that fetches weather data from an API."
),
"zh": "我是工作流设计助手无法查询天气。但我可以帮你创建一个从API获取天气数据的工作流",
},
"math": {
"en": (
"I focus on workflow design rather than calculations. However, "
"if you need calculations in a workflow, I can help you add a Code node "
"that handles math operations!"
),
"zh": "我专注于工作流设计而非计算。但如果您需要在工作流中进行计算,我可以帮您添加一个处理数学运算的代码节点!",
},
"joke": {
"en": (
"While I'd love to share a laugh, I'm specialized in workflow design. "
"How about we create something fun instead - like a workflow that generates jokes using AI?"
),
"zh": "虽然我很想讲笑话但我专门从事工作流设计。不如我们创建一个有趣的东西——比如使用AI生成笑话的工作流",
},
"translation": {
"en": (
"I can't translate directly, but I can help you build a translation workflow! "
"Would you like to create one using an LLM node?"
),
"zh": "我不能直接翻译但我可以帮你构建一个翻译工作流要创建一个使用LLM节点的翻译流程吗",
},
"general_coding": {
"en": (
"I'm specialized in Dify workflow design rather than general coding help. "
"But if you want to add code logic to your workflow, I can help you configure a Code node!"
),
"zh": (
"我专注于Dify工作流设计而非通用编程帮助。但如果您想在工作流中添加代码逻辑我可以帮您配置一个代码节点"
),
},
"default": {
"en": (
"I'm the Dify workflow design assistant. I help create AI automation workflows, "
"but I can't help with general questions. Would you like to create a workflow instead?"
),
"zh": "我是Dify工作流设计助手。我帮助创建AI自动化工作流但无法回答一般性问题。您想创建一个工作流吗",
},
}
# Default suggestions for off-topic requests
# These help guide users towards valid workflow requests
DEFAULT_SUGGESTIONS: dict[str, list[str]] = {
"en": [
"Create a chatbot workflow",
"Build a document summarization pipeline",
"Add email notification to workflow",
],
"zh": [
"创建一个聊天机器人工作流",
"构建文档摘要处理流程",
"添加邮件通知到工作流",
],
}

View File

@@ -0,0 +1,457 @@
BUILDER_SYSTEM_PROMPT = """<role>
You are a Workflow Configuration Engineer.
Your goal is to implement the Architect's plan by generating a precise, runnable Dify Workflow JSON configuration.
</role>
<language_rules>
- Detect the language of the user's request automatically (e.g., English, Chinese, Japanese, etc.).
- Generate ALL node titles, descriptions, and user-facing text in the SAME language as the user's input.
- If the input language is ambiguous or cannot be determined (e.g. code-only input),
use {preferred_language} as the target language.
</language_rules>
<inputs>
<plan>
{plan_context}
</plan>
<tool_schemas>
{tool_schemas}
</tool_schemas>
<node_specs>
{builtin_node_specs}
</node_specs>
<available_models>
{available_models}
</available_models>
<workflow_context>
<existing_nodes>
{existing_nodes_context}
</existing_nodes>
<existing_edges>
{existing_edges_context}
</existing_edges>
<selected_nodes>
{selected_nodes_context}
</selected_nodes>
</workflow_context>
</inputs>
<rules>
1. **Configuration**:
- You MUST fill ALL required parameters for every node.
- Use `{{{{#node_id.field#}}}}` syntax to reference outputs from previous nodes in text fields.
- For 'start' node, define all necessary user inputs.
2. **Variable References**:
- For text fields (like prompts, queries): use string format `{{{{#node_id.field#}}}}`
- For 'end' node outputs: use `value_selector` array format `["node_id", "field"]`
- Example: to reference 'llm' node's 'text' output in end node, use `["llm", "text"]`
3. **Tools**:
- ONLY use the tools listed in `<tool_schemas>`.
- If a planned tool is missing from schemas, fallback to `http-request` or `code`.
4. **Model Selection** (CRITICAL):
- For LLM, question-classifier, and parameter-extractor nodes, you MUST include a "model" config.
- You MUST use ONLY models from the `<available_models>` section above.
- Copy the EXACT provider and name values from available_models.
- NEVER use openai/gpt-4o, gpt-3.5-turbo, gpt-4, or any other models unless they appear in available_models.
- If available_models is empty or shows "No models configured", omit the model config entirely.
5. **Node Specifics**:
- For `if-else` comparison_operator, use literal symbols: `≥`, `≤`, `=`, `≠` (NOT `>=` or `==`).
6. **Modification Mode**:
- If `<existing_nodes>` contains nodes, you are MODIFYING an existing workflow.
- Keep nodes that are NOT mentioned in the user's instruction UNCHANGED.
- Only modify/add/remove nodes that the user explicitly requested.
- Preserve node IDs for unchanged nodes to maintain connections.
- If user says "add X", append new nodes to existing workflow.
- If user says "change Y to Z", only modify that specific node.
- If user says "remove X", exclude that node from output.
**Edge Modification**:
- Use `<existing_edges>` to understand current node connections.
- If user mentions "fix edge", "connect", "link", or "add connection",
review existing_edges and correct missing/wrong connections.
- For multi-branch nodes (if-else, question-classifier),
ensure EACH branch has proper sourceHandle (e.g., "true"/"false") and target.
- Common edge issues to fix:
* Missing edge: Two nodes should connect but don't - add the edge
* Wrong target: Edge points to wrong node - update the target
* Missing sourceHandle: if-else/classifier branches lack sourceHandle - add "true"/"false"
* Disconnected nodes: Node has no incoming or outgoing edges - connect it properly
- When modifying edges, ensure logical flow makes sense (start → middle → end).
- ALWAYS output complete edges array, even if only modifying one edge.
**Validation Feedback** (Automatic Retry):
- If `<validation_feedback>` is present, you are RETRYING after validation errors.
- Focus ONLY on fixing the specific validation issues mentioned.
- Keep everything else from the previous attempt UNCHANGED (preserve node IDs, edges, etc).
- Common validation issues and fixes:
* "Missing required connection" → Add the missing edge
* "Invalid node configuration" → Fix the specific node's config section
* "Type mismatch in variable reference" → Correct the variable selector path
* "Unknown variable" → Update variable reference to existing output
- When fixing, make MINIMAL changes to address each specific error.
7. **Output**:
- Return ONLY the JSON object with `nodes` and `edges`.
- Do NOT generate Mermaid diagrams.
- Do NOT generate explanations.
</rules>
<edge_rules priority="critical">
**EDGES ARE CRITICAL** - Every node except 'end' MUST have at least one outgoing edge.
1. **Linear Flow**: Simple source -> target connection
```
{{"source": "node_a", "target": "node_b"}}
```
2. **question-classifier Branching**: Each class MUST have a separate edge with `sourceHandle` = class `id`
- If you define classes: [{{"id": "cls_refund", "name": "Refund"}}, {{"id": "cls_inquiry", "name": "Inquiry"}}]
- You MUST create edges:
- {{"source": "classifier", "sourceHandle": "cls_refund", "target": "refund_handler"}}
- {{"source": "classifier", "sourceHandle": "cls_inquiry", "target": "inquiry_handler"}}
3. **if-else Branching**: MUST have exactly TWO edges with sourceHandle "true" and "false"
- {{"source": "condition", "sourceHandle": "true", "target": "true_branch"}}
- {{"source": "condition", "sourceHandle": "false", "target": "false_branch"}}
4. **Branch Convergence**: Multiple branches can connect to same downstream node
- Both true_branch and false_branch can connect to the same 'end' node
5. **NEVER leave orphan nodes**: Every node must be connected in the graph
</edge_rules>
<examples>
<example name="simple_linear">
```json
{{
"nodes": [
{{
"id": "start",
"type": "start",
"title": "Start",
"config": {{
"variables": [{{"variable": "query", "label": "Query", "type": "text-input"}}]
}}
}},
{{
"id": "llm",
"type": "llm",
"title": "Generate Response",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Answer: {{{{#start.query#}}}}"}}]
}}
}},
{{
"id": "end",
"type": "end",
"title": "End",
"config": {{
"outputs": [
{{"variable": "result", "value_selector": ["llm", "text"]}}
]
}}
}}
],
"edges": [
{{"source": "start", "target": "llm"}},
{{"source": "llm", "target": "end"}}
]
}}
```
</example>
<example name="question_classifier_branching" description="Customer service with intent classification">
```json
{{
"nodes": [
{{
"id": "start",
"type": "start",
"title": "Start",
"config": {{
"variables": [{{"variable": "user_input", "label": "User Message", "type": "text-input", "required": true}}]
}}
}},
{{
"id": "classifier",
"type": "question-classifier",
"title": "Classify Intent",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"query_variable_selector": ["start", "user_input"],
"classes": [
{{"id": "cls_refund", "name": "Refund Request"}},
{{"id": "cls_inquiry", "name": "Product Inquiry"}},
{{"id": "cls_complaint", "name": "Complaint"}},
{{"id": "cls_other", "name": "Other"}}
],
"instruction": "Classify the user's intent"
}}
}},
{{
"id": "handle_refund",
"type": "llm",
"title": "Handle Refund",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Extract order number and respond: {{{{#start.user_input#}}}}"}}]
}}
}},
{{
"id": "handle_inquiry",
"type": "llm",
"title": "Handle Inquiry",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Answer product question: {{{{#start.user_input#}}}}"}}]
}}
}},
{{
"id": "handle_complaint",
"type": "llm",
"title": "Handle Complaint",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Respond with empathy: {{{{#start.user_input#}}}}"}}]
}}
}},
{{
"id": "handle_other",
"type": "llm",
"title": "Handle Other",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Provide general response: {{{{#start.user_input#}}}}"}}]
}}
}},
{{
"id": "end",
"type": "end",
"title": "End",
"config": {{
"outputs": [{{"variable": "response", "value_selector": ["handle_refund", "text"]}}]
}}
}}
],
"edges": [
{{"source": "start", "target": "classifier"}},
{{"source": "classifier", "sourceHandle": "cls_refund", "target": "handle_refund"}},
{{"source": "classifier", "sourceHandle": "cls_inquiry", "target": "handle_inquiry"}},
{{"source": "classifier", "sourceHandle": "cls_complaint", "target": "handle_complaint"}},
{{"source": "classifier", "sourceHandle": "cls_other", "target": "handle_other"}},
{{"source": "handle_refund", "target": "end"}},
{{"source": "handle_inquiry", "target": "end"}},
{{"source": "handle_complaint", "target": "end"}},
{{"source": "handle_other", "target": "end"}}
]
}}
```
CRITICAL: Notice that each class id (cls_refund, cls_inquiry, etc.) becomes a sourceHandle in the edges!
</example>
<example name="if_else_branching" description="Conditional logic with if-else">
```json
{{
"nodes": [
{{
"id": "start",
"type": "start",
"title": "Start",
"config": {{
"variables": [{{"variable": "years", "label": "Years of Experience", "type": "number", "required": true}}]
}}
}},
{{
"id": "check_experience",
"type": "if-else",
"title": "Check Experience",
"config": {{
"cases": [
{{
"case_id": "case_1",
"logical_operator": "and",
"conditions": [
{{
"variable_selector": ["start", "years"],
"comparison_operator": "",
"value": "3"
}}
]
}}
]
}}
}},
{{
"id": "qualified",
"type": "llm",
"title": "Qualified Response",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Generate qualified candidate response"}}]
}}
}},
{{
"id": "not_qualified",
"type": "llm",
"title": "Not Qualified Response",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Generate rejection response"}}]
}}
}},
{{
"id": "end",
"type": "end",
"title": "End",
"config": {{
"outputs": [{{"variable": "result", "value_selector": ["qualified", "text"]}}]
}}
}}
],
"edges": [
{{"source": "start", "target": "check_experience"}},
{{"source": "check_experience", "sourceHandle": "true", "target": "qualified"}},
{{"source": "check_experience", "sourceHandle": "false", "target": "not_qualified"}},
{{"source": "qualified", "target": "end"}},
{{"source": "not_qualified", "target": "end"}}
]
}}
```
CRITICAL: if-else MUST have exactly two edges with sourceHandle "true" and "false"!
</example>
<example name="parameter_extractor" description="Extract structured data from text">
```json
{{
"nodes": [
{{
"id": "start",
"type": "start",
"title": "Start",
"config": {{
"variables": [{{"variable": "resume", "label": "Resume Text", "type": "paragraph", "required": true}}]
}}
}},
{{
"id": "extract",
"type": "parameter-extractor",
"title": "Extract Info",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"query": ["start", "resume"],
"parameters": [
{{"name": "name", "type": "string", "description": "Candidate name", "required": true}},
{{"name": "years", "type": "number", "description": "Years of experience", "required": true}},
{{"name": "skills", "type": "array[string]", "description": "List of skills", "required": true}}
],
"instruction": "Extract candidate information from resume"
}}
}},
{{
"id": "process",
"type": "llm",
"title": "Process Data",
"config": {{
"model": {{"provider": "openai", "name": "gpt-4o", "mode": "chat"}},
"prompt_template": [{{"role": "user", "text": "Name: {{{{#extract.name#}}}}, Years: {{{{#extract.years#}}}}"}}]
}}
}},
{{
"id": "end",
"type": "end",
"title": "End",
"config": {{
"outputs": [{{"variable": "result", "value_selector": ["process", "text"]}}]
}}
}}
],
"edges": [
{{"source": "start", "target": "extract"}},
{{"source": "extract", "target": "process"}},
{{"source": "process", "target": "end"}}
]
}}
```
</example>
</examples>
<edge_checklist>
Before finalizing, verify:
1. [ ] Every node (except 'end') has at least one outgoing edge
2. [ ] 'start' node has exactly one outgoing edge
3. [ ] 'question-classifier' has one edge per class, each with sourceHandle = class id
4. [ ] 'if-else' has exactly two edges: sourceHandle "true" and sourceHandle "false"
5. [ ] All branches eventually connect to 'end' (directly or through other nodes)
6. [ ] No orphan nodes exist (every node is reachable from 'start')
</edge_checklist>
"""
BUILDER_USER_PROMPT = """<instruction>
{instruction}
</instruction>
Generate the full workflow configuration now. Pay special attention to:
1. Creating edges for ALL branches of question-classifier and if-else nodes
2. Using correct sourceHandle values for branching nodes
3. Ensuring every node is connected in the graph
"""
def format_existing_nodes(nodes: list[dict] | None) -> str:
"""Format existing workflow nodes for context."""
if not nodes:
return "No existing nodes in workflow (creating from scratch)."
lines = []
for node in nodes:
node_id = node.get("id", "unknown")
node_type = node.get("type", "unknown")
title = node.get("title", "Untitled")
lines.append(f"- [{node_id}] {title} ({node_type})")
return "\n".join(lines)
def format_selected_nodes(
selected_ids: list[str] | None,
existing_nodes: list[dict] | None,
) -> str:
"""Format selected nodes for modification context."""
if not selected_ids:
return "No nodes selected (generating new workflow)."
node_map = {n.get("id"): n for n in (existing_nodes or [])}
lines = []
for node_id in selected_ids:
if node_id in node_map:
node = node_map[node_id]
lines.append(f"- [{node_id}] {node.get('title', 'Untitled')} ({node.get('type', 'unknown')})")
else:
lines.append(f"- [{node_id}] (not found in current workflow)")
return "\n".join(lines)
def format_existing_edges(edges: list[dict] | None) -> str:
"""Format existing workflow edges to show connections."""
if not edges:
return "No existing edges (creating new workflow)."
lines = []
for edge in edges:
source = edge.get("source", "unknown")
target = edge.get("target", "unknown")
source_handle = edge.get("sourceHandle", "")
if source_handle:
lines.append(f"- {source} ({source_handle}) -> {target}")
else:
lines.append(f"- {source} -> {target}")
return "\n".join(lines)

View File

@@ -0,0 +1,75 @@
PLANNER_SYSTEM_PROMPT = """<role>
You are an expert Workflow Architect.
Your job is to analyze user requests and plan a high-level automation workflow.
</role>
<task>
1. **Classify Intent**:
- Is the user asking to create an automation/workflow? -> Intent: "generate"
- Is it general chat/weather/jokes? -> Intent: "off_topic"
2. **Plan Steps** (if intent is "generate"):
- Break down the user's goal into logical steps.
- For each step, identify if a specific capability/tool is needed.
- Select the MOST RELEVANT tools from the available_tools list.
- DO NOT configure parameters yet. Just identify the tool.
3. **Output Format**:
Return a JSON object.
</task>
<available_tools>
{tools_summary}
</available_tools>
<response_format>
If intent is "generate":
```json
{{
"intent": "generate",
"plan_thought": "Brief explanation of the plan...",
"steps": [
{{ "step": 1, "description": "Fetch data from URL", "tool": "http-request" }},
{{ "step": 2, "description": "Summarize content", "tool": "llm" }},
{{ "step": 3, "description": "Search for info", "tool": "google_search" }}
],
"required_tool_keys": ["google_search"]
}}
```
(Note: 'http-request', 'llm', 'code' are built-in, you don't need to list them in required_tool_keys,
only external tools)
If intent is "off_topic":
```json
{{
"intent": "off_topic",
"message": "I can only help you build workflows. Try asking me to 'Create a workflow that...'",
"suggestions": ["Scrape a website", "Summarize a PDF"]
}}
```
</response_format>
"""
PLANNER_USER_PROMPT = """<user_request>
{instruction}
</user_request>
"""
def format_tools_for_planner(tools: list[dict]) -> str:
"""Format tools list for planner (Lightweight: Name + Description only)."""
if not tools:
return "No external tools available."
lines = []
for t in tools:
key = t.get("tool_key") or t.get("tool_name")
provider = t.get("provider_id") or t.get("provider", "")
desc = t.get("tool_description") or t.get("description", "")
label = t.get("tool_label") or key
# Format: - [provider/key] Label: Description
full_key = f"{provider}/{key}" if provider else key
lines.append(f"- [{full_key}] {label}: {desc}")
return "\n".join(lines)

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,287 @@
import json
import logging
import re
from collections.abc import Sequence
import json_repair
from core.model_manager import ModelManager
from core.model_runtime.entities.message_entities import SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.workflow.generator.prompts.builder_prompts import (
BUILDER_SYSTEM_PROMPT,
BUILDER_USER_PROMPT,
format_existing_edges,
format_existing_nodes,
format_selected_nodes,
)
from core.workflow.generator.prompts.planner_prompts import (
PLANNER_SYSTEM_PROMPT,
PLANNER_USER_PROMPT,
format_tools_for_planner,
)
from core.workflow.generator.prompts.vibe_prompts import (
format_available_models,
format_available_nodes,
format_available_tools,
parse_vibe_response,
)
from core.workflow.generator.utils.mermaid_generator import generate_mermaid
from core.workflow.generator.utils.workflow_validator import ValidationHint, WorkflowValidator
logger = logging.getLogger(__name__)
class WorkflowGenerator:
"""
Refactored Vibe Workflow Generator (Planner-Builder Architecture).
Extracts Vibe logic from the monolithic LLMGenerator.
"""
@classmethod
def generate_workflow_flowchart(
cls,
tenant_id: str,
instruction: str,
model_config: dict,
available_nodes: Sequence[dict[str, object]] | None = None,
existing_nodes: Sequence[dict[str, object]] | None = None,
existing_edges: Sequence[dict[str, object]] | None = None,
available_tools: Sequence[dict[str, object]] | None = None,
selected_node_ids: Sequence[str] | None = None,
previous_workflow: dict[str, object] | None = None,
regenerate_mode: bool = False,
preferred_language: str | None = None,
available_models: Sequence[dict[str, object]] | None = None,
):
"""
Generates a Dify Workflow Flowchart from natural language instruction.
Pipeline:
1. Planner: Analyze intent & select tools.
2. Context Filter: Filter relevant tools (reduce tokens).
3. Builder: Generate node configurations.
4. Repair: Fix common node/edge issues (NodeRepair, EdgeRepair).
5. Validator: Check for errors & generate friendly hints.
6. Renderer: Deterministic Mermaid generation.
"""
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
model_parameters = model_config.get("completion_params", {})
available_tools_list = list(available_tools) if available_tools else []
# Check if this is modification mode (user is refining existing workflow)
has_existing_nodes = existing_nodes and len(list(existing_nodes)) > 0
# --- STEP 1: PLANNER (Skip in modification mode) ---
if has_existing_nodes:
# In modification mode, skip Planner:
# - User intent is clear: modify the existing workflow
# - Tools are already in use (from existing nodes)
# - No need for intent classification or tool selection
plan_data = {"intent": "generate", "steps": [], "required_tool_keys": []}
filtered_tools = available_tools_list # Use all available tools
else:
# In creation mode, run Planner to validate intent and select tools
planner_tools_context = format_tools_for_planner(available_tools_list)
planner_system = PLANNER_SYSTEM_PROMPT.format(tools_summary=planner_tools_context)
planner_user = PLANNER_USER_PROMPT.format(instruction=instruction)
try:
response = model_instance.invoke_llm(
prompt_messages=[
SystemPromptMessage(content=planner_system),
UserPromptMessage(content=planner_user),
],
model_parameters=model_parameters,
stream=False,
)
plan_content = response.message.content
# Reuse parse_vibe_response logic or simple load
plan_data = parse_vibe_response(plan_content)
except Exception as e:
logger.exception("Planner failed")
return {"intent": "error", "error": f"Planning failed: {str(e)}"}
if plan_data.get("intent") == "off_topic":
return {
"intent": "off_topic",
"message": plan_data.get("message", "I can only help with workflow creation."),
"suggestions": plan_data.get("suggestions", []),
}
# --- STEP 2: CONTEXT FILTERING ---
required_tools = plan_data.get("required_tool_keys", [])
filtered_tools = []
if required_tools:
# Simple linear search (optimized version would use a map)
for tool in available_tools_list:
t_key = tool.get("tool_key") or tool.get("tool_name")
provider = tool.get("provider_id") or tool.get("provider")
full_key = f"{provider}/{t_key}" if provider else t_key
# Check if this tool is in required list (match either full key or short name)
if t_key in required_tools or full_key in required_tools:
filtered_tools.append(tool)
else:
# If logic only, no tools needed
filtered_tools = []
# --- STEP 3: BUILDER (with retry loop) ---
MAX_GLOBAL_RETRIES = 2 # Total attempts: 1 initial + 1 retry
workflow_data = None
mermaid_code = None
all_warnings = []
all_fixes = []
retry_count = 0
validation_hints = []
for attempt in range(MAX_GLOBAL_RETRIES):
retry_count = attempt
logger.info("Generation attempt %s/%s", attempt + 1, MAX_GLOBAL_RETRIES)
# Prepare context
tool_schemas = format_available_tools(filtered_tools)
node_specs = format_available_nodes(list(available_nodes) if available_nodes else [])
existing_nodes_context = format_existing_nodes(list(existing_nodes) if existing_nodes else None)
existing_edges_context = format_existing_edges(list(existing_edges) if existing_edges else None)
selected_nodes_context = format_selected_nodes(
list(selected_node_ids) if selected_node_ids else None, list(existing_nodes) if existing_nodes else None
)
# Build retry context
retry_context = ""
# NOTE: Manual regeneration/refinement mode removed
# Only handle automatic retry (validation errors)
# For automatic retry (validation errors)
if attempt > 0 and validation_hints:
severe_issues = [h for h in validation_hints if h.severity == "error"]
if severe_issues:
retry_context = "\n<validation_feedback>\n"
retry_context += "The previous generation had validation errors:\n"
for idx, hint in enumerate(severe_issues[:5], 1):
retry_context += f"{idx}. {hint.message}\n"
retry_context += "\nPlease fix these specific issues while keeping everything else UNCHANGED.\n"
retry_context += "</validation_feedback>\n"
builder_system = BUILDER_SYSTEM_PROMPT.format(
plan_context=json.dumps(plan_data.get("steps", []), indent=2),
tool_schemas=tool_schemas,
builtin_node_specs=node_specs,
available_models=format_available_models(list(available_models or [])),
preferred_language=preferred_language or "English",
existing_nodes_context=existing_nodes_context,
existing_edges_context=existing_edges_context,
selected_nodes_context=selected_nodes_context,
)
builder_user = BUILDER_USER_PROMPT.format(instruction=instruction) + retry_context
try:
build_res = model_instance.invoke_llm(
prompt_messages=[
SystemPromptMessage(content=builder_system),
UserPromptMessage(content=builder_user),
],
model_parameters=model_parameters,
stream=False,
)
# Builder output is raw JSON nodes/edges
build_content = build_res.message.content
match = re.search(r"```(?:json)?\s*([\s\S]+?)```", build_content)
if match:
build_content = match.group(1)
workflow_data = json_repair.loads(build_content)
if "nodes" not in workflow_data:
workflow_data["nodes"] = []
if "edges" not in workflow_data:
workflow_data["edges"] = []
except Exception as e:
logger.exception("Builder failed on attempt %d", attempt + 1)
if attempt == MAX_GLOBAL_RETRIES - 1:
return {"intent": "error", "error": f"Building failed: {str(e)}"}
continue # Try again
# NOTE: NodeRepair and EdgeRepair have been removed.
# Validation will detect structural issues, and LLM will fix them on retry.
# This is more accurate because LLM understands the workflow context.
# --- STEP 4: RENDERER (Generate Mermaid early for validation) ---
mermaid_code = generate_mermaid(workflow_data)
# --- STEP 5: VALIDATOR ---
is_valid, validation_hints = WorkflowValidator.validate(workflow_data, available_tools_list)
# --- STEP 6: GRAPH VALIDATION (structural checks using graph algorithms) ---
if attempt < MAX_GLOBAL_RETRIES - 1:
try:
from core.workflow.generator.utils.graph_validator import GraphValidator
graph_result = GraphValidator.validate(workflow_data)
if not graph_result.success:
# Convert graph errors to validation hints
for graph_error in graph_result.errors:
validation_hints.append(
ValidationHint(
node_id=graph_error.node_id,
field="edges",
message=f"[Graph] {graph_error.message}",
severity="error",
)
)
# Also add warnings (dead ends) as hints
for graph_warning in graph_result.warnings:
validation_hints.append(
ValidationHint(
node_id=graph_warning.node_id,
field="edges",
message=f"[Graph] {graph_warning.message}",
severity="warning",
)
)
except Exception as e:
logger.warning("Graph validation error: %s", e)
# Collect all validation warnings
all_warnings = [h.message for h in validation_hints]
# Check if we should retry
severe_issues = [h for h in validation_hints if h.severity == "error"]
if not severe_issues or attempt == MAX_GLOBAL_RETRIES - 1:
break
# Has severe errors and retries remaining - continue to next attempt
# Collect all validation warnings
all_warnings = [h.message for h in validation_hints]
# Add stability warning (as requested by user)
stability_warning = "The generated workflow may require debugging."
if preferred_language and preferred_language.startswith("zh"):
stability_warning = "生成的 Workflow 可能需要调试。"
all_warnings.append(stability_warning)
return {
"intent": "generate",
"flowchart": mermaid_code,
"nodes": workflow_data["nodes"],
"edges": workflow_data["edges"],
"message": plan_data.get("plan_thought", "Generated workflow based on your request."),
"warnings": all_warnings,
"tool_recommendations": [], # Legacy field
"error": "",
"fixed_issues": all_fixes, # Track what was auto-fixed
"retry_count": retry_count, # Track how many retries were needed
}

View File

@@ -0,0 +1,217 @@
"""
Type definitions for Vibe Workflow Generator.
This module provides:
- TypedDict classes for lightweight type hints (no runtime overhead)
- Pydantic models for runtime validation where needed
Usage:
# For type hints only (no runtime validation):
from core.workflow.generator.types import WorkflowNodeDict, WorkflowEdgeDict
# For runtime validation:
from core.workflow.generator.types import WorkflowNode, WorkflowEdge
"""
from typing import Any, TypedDict
from pydantic import BaseModel, Field
# ============================================================
# TypedDict definitions (lightweight, for type hints only)
# ============================================================
class WorkflowNodeDict(TypedDict, total=False):
"""
Workflow node structure (TypedDict for hints).
Attributes:
id: Unique node identifier
type: Node type (e.g., "start", "end", "llm", "if-else", "http-request")
title: Human-readable node title
config: Node-specific configuration
data: Additional node data
"""
id: str
type: str
title: str
config: dict[str, Any]
data: dict[str, Any]
class WorkflowEdgeDict(TypedDict, total=False):
"""
Workflow edge structure (TypedDict for hints).
Attributes:
source: Source node ID
target: Target node ID
sourceHandle: Branch handle for if-else/question-classifier nodes
"""
source: str
target: str
sourceHandle: str
class AvailableModelDict(TypedDict):
"""
Available model structure.
Attributes:
provider: Model provider (e.g., "openai", "anthropic")
model: Model name (e.g., "gpt-4", "claude-3")
"""
provider: str
model: str
class ToolParameterDict(TypedDict, total=False):
"""
Tool parameter structure.
Attributes:
name: Parameter name
type: Parameter type (e.g., "string", "number", "boolean")
required: Whether parameter is required
human_description: Human-readable description
llm_description: LLM-oriented description
options: Available options for enum-type parameters
"""
name: str
type: str
required: bool
human_description: str | dict[str, str]
llm_description: str
options: list[Any]
class AvailableToolDict(TypedDict, total=False):
"""
Available tool structure.
Attributes:
provider_id: Tool provider ID
provider: Tool provider name (alternative to provider_id)
tool_key: Unique tool key
tool_name: Tool name (alternative to tool_key)
tool_description: Tool description
description: Alternative description field
is_team_authorization: Whether tool is configured/authorized
parameters: List of tool parameters
"""
provider_id: str
provider: str
tool_key: str
tool_name: str
tool_description: str
description: str
is_team_authorization: bool
parameters: list[ToolParameterDict]
class WorkflowDataDict(TypedDict, total=False):
"""
Complete workflow data structure.
Attributes:
nodes: List of workflow nodes
edges: List of workflow edges
warnings: List of warning messages
"""
nodes: list[WorkflowNodeDict]
edges: list[WorkflowEdgeDict]
warnings: list[str]
# ============================================================
# Pydantic models (for runtime validation)
# ============================================================
class WorkflowNode(BaseModel):
"""
Workflow node with runtime validation.
Use this model when you need to validate node data at runtime.
For lightweight type hints without validation, use WorkflowNodeDict.
"""
id: str
type: str
title: str = ""
config: dict[str, Any] = Field(default_factory=dict)
data: dict[str, Any] = Field(default_factory=dict)
class WorkflowEdge(BaseModel):
"""
Workflow edge with runtime validation.
Use this model when you need to validate edge data at runtime.
For lightweight type hints without validation, use WorkflowEdgeDict.
"""
source: str
target: str
sourceHandle: str | None = None
class AvailableModel(BaseModel):
"""
Available model with runtime validation.
Use this model when you need to validate model data at runtime.
For lightweight type hints without validation, use AvailableModelDict.
"""
provider: str
model: str
class ToolParameter(BaseModel):
"""Tool parameter with runtime validation."""
name: str = ""
type: str = "string"
required: bool = False
human_description: str | dict[str, str] = ""
llm_description: str = ""
options: list[Any] = Field(default_factory=list)
class AvailableTool(BaseModel):
"""
Available tool with runtime validation.
Use this model when you need to validate tool data at runtime.
For lightweight type hints without validation, use AvailableToolDict.
"""
provider_id: str = ""
provider: str = ""
tool_key: str = ""
tool_name: str = ""
tool_description: str = ""
description: str = ""
is_team_authorization: bool = False
parameters: list[ToolParameter] = Field(default_factory=list)
class WorkflowData(BaseModel):
"""
Complete workflow data with runtime validation.
Use this model when you need to validate workflow data at runtime.
For lightweight type hints without validation, use WorkflowDataDict.
"""
nodes: list[WorkflowNode] = Field(default_factory=list)
edges: list[WorkflowEdge] = Field(default_factory=list)
warnings: list[str] = Field(default_factory=list)

View File

@@ -0,0 +1,384 @@
"""
Edge Repair Utility for Vibe Workflow Generation.
This module provides intelligent edge repair capabilities for generated workflows.
It can detect and fix common edge issues:
- Missing edges between sequential nodes
- Incomplete branches for question-classifier and if-else nodes
- Orphaned nodes without connections
The repair logic is deterministic and doesn't require LLM calls.
"""
import logging
from dataclasses import dataclass, field
from core.workflow.generator.types import WorkflowDataDict, WorkflowEdgeDict, WorkflowNodeDict
logger = logging.getLogger(__name__)
@dataclass
class RepairResult:
"""Result of edge repair operation."""
nodes: list[WorkflowNodeDict]
edges: list[WorkflowEdgeDict]
repairs_made: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
@property
def was_repaired(self) -> bool:
"""Check if any repairs were made."""
return len(self.repairs_made) > 0
class EdgeRepair:
"""
Intelligent edge repair for workflow graphs.
Repairs are applied in order:
1. Infer linear connections from node order (if no edges exist)
2. Add missing branch edges for question-classifier
3. Add missing branch edges for if-else
4. Connect orphaned nodes
"""
@classmethod
def repair(cls, workflow_data: WorkflowDataDict) -> RepairResult:
"""
Repair edges in the workflow data.
Args:
workflow_data: Dict containing 'nodes' and 'edges'
Returns:
RepairResult with repaired nodes, edges, and repair logs
"""
nodes = list(workflow_data.get("nodes", []))
edges = list(workflow_data.get("edges", []))
repairs: list[str] = []
warnings: list[str] = []
logger.info("[EDGE REPAIR] Starting repair process for %s nodes, %s edges", len(nodes), len(edges))
# Build node lookup
# Build node lookup
node_map = {n.get("id"): n for n in nodes if n.get("id")}
node_ids = set(node_map.keys())
# 1. If no edges at all, infer linear chain
if not edges and len(nodes) > 1:
edges, inferred_repairs = cls._infer_linear_chain(nodes)
repairs.extend(inferred_repairs)
# 2. Build edge index for analysis
outgoing_edges: dict[str, list[WorkflowEdgeDict]] = {}
incoming_edges: dict[str, list[WorkflowEdgeDict]] = {}
for edge in edges:
src = edge.get("source")
tgt = edge.get("target")
if src:
outgoing_edges.setdefault(src, []).append(edge)
if tgt:
incoming_edges.setdefault(tgt, []).append(edge)
# 3. Repair question-classifier branches
for node in nodes:
if node.get("type") == "question-classifier":
new_edges, branch_repairs, branch_warnings = cls._repair_classifier_branches(
node, edges, outgoing_edges, node_ids
)
edges.extend(new_edges)
repairs.extend(branch_repairs)
warnings.extend(branch_warnings)
# Update outgoing index
for edge in new_edges:
outgoing_edges.setdefault(edge.get("source"), []).append(edge)
# 4. Repair if-else branches
for node in nodes:
if node.get("type") == "if-else":
new_edges, branch_repairs, branch_warnings = cls._repair_if_else_branches(
node, edges, outgoing_edges, node_ids
)
edges.extend(new_edges)
repairs.extend(branch_repairs)
warnings.extend(branch_warnings)
# Update outgoing index
for edge in new_edges:
outgoing_edges.setdefault(edge.get("source"), []).append(edge)
# 5. Connect orphaned nodes (nodes with no incoming edge, except start)
new_edges, orphan_repairs = cls._connect_orphaned_nodes(nodes, edges, outgoing_edges, incoming_edges)
edges.extend(new_edges)
repairs.extend(orphan_repairs)
# 6. Connect nodes with no outgoing edge to 'end' (except end nodes)
new_edges, terminal_repairs = cls._connect_terminal_nodes(nodes, edges, outgoing_edges)
edges.extend(new_edges)
repairs.extend(terminal_repairs)
if repairs:
logger.info("[EDGE REPAIR] Completed with %s repairs:", len(repairs))
for i, repair in enumerate(repairs, 1):
logger.info("[EDGE REPAIR] %s. %s", i, repair)
else:
logger.info("[EDGE REPAIR] Completed - no repairs needed")
return RepairResult(
nodes=nodes,
edges=edges,
repairs_made=repairs,
warnings=warnings,
)
@classmethod
def _infer_linear_chain(cls, nodes: list[WorkflowNodeDict]) -> tuple[list[WorkflowEdgeDict], list[str]]:
"""
Infer a linear chain of edges from node order.
This is used when no edges are provided at all.
"""
edges: list[WorkflowEdgeDict] = []
repairs: list[str] = []
# Filter to get ordered node IDs
node_ids = [n.get("id") for n in nodes if n.get("id")]
if len(node_ids) < 2:
return edges, repairs
# Create edges between consecutive nodes
for i in range(len(node_ids) - 1):
src = node_ids[i]
tgt = node_ids[i + 1]
edges.append({"source": src, "target": tgt})
repairs.append(f"Inferred edge: {src} -> {tgt}")
return edges, repairs
@classmethod
def _repair_classifier_branches(
cls,
node: WorkflowNodeDict,
edges: list[WorkflowEdgeDict],
outgoing_edges: dict[str, list[WorkflowEdgeDict]],
valid_node_ids: set[str],
) -> tuple[list[WorkflowEdgeDict], list[str], list[str]]:
"""
Repair missing branches for question-classifier nodes.
For each class that doesn't have an edge, create one pointing to 'end'.
"""
new_edges: list[WorkflowEdgeDict] = []
repairs: list[str] = []
warnings: list[str] = []
node_id = node.get("id")
if not node_id:
return new_edges, repairs, warnings
config = node.get("config", {})
classes = config.get("classes", [])
if not classes:
return new_edges, repairs, warnings
# Get existing sourceHandles for this node
existing_handles = set()
for edge in outgoing_edges.get(node_id, []):
handle = edge.get("sourceHandle")
if handle:
existing_handles.add(handle)
# Find 'end' node as default target
end_node_id = "end"
if "end" not in valid_node_ids:
# Try to find an end node
for nid in valid_node_ids:
if "end" in nid.lower():
end_node_id = nid
break
# Add missing branches
for cls_def in classes:
if not isinstance(cls_def, dict):
continue
cls_id = cls_def.get("id")
cls_name = cls_def.get("name", cls_id)
if cls_id and cls_id not in existing_handles:
new_edge = {
"source": node_id,
"sourceHandle": cls_id,
"target": end_node_id,
}
new_edges.append(new_edge)
repairs.append(f"Added missing branch edge for class '{cls_name}' -> {end_node_id}")
warnings.append(
f"Auto-connected question-classifier branch '{cls_name}' to '{end_node_id}'. "
"You may want to redirect this to a specific handler node."
)
return new_edges, repairs, warnings
@classmethod
def _repair_if_else_branches(
cls,
node: WorkflowNodeDict,
edges: list[WorkflowEdgeDict],
outgoing_edges: dict[str, list[WorkflowEdgeDict]],
valid_node_ids: set[str],
) -> tuple[list[WorkflowEdgeDict], list[str], list[str]]:
"""
Repair missing branches for if-else nodes.
If-else in Dify uses case_id as sourceHandle for each condition,
plus 'false' for the else branch.
"""
new_edges: list[WorkflowEdgeDict] = []
repairs: list[str] = []
warnings: list[str] = []
node_id = node.get("id")
if not node_id:
return new_edges, repairs, warnings
# Get existing sourceHandles
existing_handles = set()
for edge in outgoing_edges.get(node_id, []):
handle = edge.get("sourceHandle")
if handle:
existing_handles.add(handle)
# Find 'end' node as default target
end_node_id = "end"
if "end" not in valid_node_ids:
for nid in valid_node_ids:
if "end" in nid.lower():
end_node_id = nid
break
# Get required branches from config
config = node.get("config", {})
cases = config.get("cases", [])
# Build required handles: each case_id + 'false' for else
required_branches = set()
for case in cases:
case_id = case.get("case_id")
if case_id:
required_branches.add(case_id)
required_branches.add("false") # else branch
# Add missing branches
for branch in required_branches:
if branch not in existing_handles:
new_edge = {
"source": node_id,
"sourceHandle": branch,
"target": end_node_id,
}
new_edges.append(new_edge)
repairs.append(f"Added missing if-else branch '{branch}' -> {end_node_id}")
warnings.append(
f"Auto-connected if-else branch '{branch}' to '{end_node_id}'. "
"You may want to redirect this to a specific handler node."
)
return new_edges, repairs, warnings
@classmethod
def _connect_orphaned_nodes(
cls,
nodes: list[WorkflowNodeDict],
edges: list[WorkflowEdgeDict],
outgoing_edges: dict[str, list[WorkflowEdgeDict]],
incoming_edges: dict[str, list[WorkflowEdgeDict]],
) -> tuple[list[WorkflowEdgeDict], list[str]]:
"""
Connect orphaned nodes to the previous node in sequence.
An orphaned node has no incoming edges and is not a 'start' node.
"""
new_edges: list[WorkflowEdgeDict] = []
repairs: list[str] = []
node_ids = [n.get("id") for n in nodes if n.get("id")]
node_types = {n.get("id"): n.get("type") for n in nodes}
for i, node_id in enumerate(node_ids):
node_type = node_types.get(node_id)
# Skip start nodes - they don't need incoming edges
if node_type == "start":
continue
# Check if node has incoming edges
if node_id not in incoming_edges or not incoming_edges[node_id]:
# Find previous node to connect from
if i > 0:
prev_node_id = node_ids[i - 1]
new_edge = {"source": prev_node_id, "target": node_id}
new_edges.append(new_edge)
repairs.append(f"Connected orphaned node: {prev_node_id} -> {node_id}")
# Update incoming_edges for subsequent checks
incoming_edges.setdefault(node_id, []).append(new_edge)
return new_edges, repairs
@classmethod
def _connect_terminal_nodes(
cls,
nodes: list[WorkflowNodeDict],
edges: list[WorkflowEdgeDict],
outgoing_edges: dict[str, list[WorkflowEdgeDict]],
) -> tuple[list[WorkflowEdgeDict], list[str]]:
"""
Connect terminal nodes (no outgoing edges) to 'end'.
A terminal node has no outgoing edges and is not an 'end' node.
This ensures all branches eventually reach 'end'.
"""
new_edges: list[WorkflowEdgeDict] = []
repairs: list[str] = []
# Find end node
end_node_id = None
node_ids = set()
for n in nodes:
nid = n.get("id")
ntype = n.get("type")
if nid:
node_ids.add(nid)
if ntype == "end":
end_node_id = nid
if not end_node_id:
# No end node found, can't connect
return new_edges, repairs
for node in nodes:
node_id = node.get("id")
node_type = node.get("type")
# Skip end nodes
if node_type == "end":
continue
# Skip nodes that already have outgoing edges
if outgoing_edges.get(node_id):
continue
# Connect to end
new_edge = {"source": node_id, "target": end_node_id}
new_edges.append(new_edge)
repairs.append(f"Connected terminal node to end: {node_id} -> {end_node_id}")
# Update for subsequent checks
outgoing_edges.setdefault(node_id, []).append(new_edge)
return new_edges, repairs

View File

@@ -0,0 +1,280 @@
"""
Graph Validator for Workflow Generation
Validates workflow graph structure using graph algorithms:
- Reachability from start node (BFS)
- Reachability to end node (reverse BFS)
- Branch edge validation for if-else and classifier nodes
"""
import time
from collections import deque
from dataclasses import dataclass, field
@dataclass
class GraphError:
"""Represents a structural error in the workflow graph."""
node_id: str
node_type: str
error_type: str # "unreachable", "dead_end", "cycle", "missing_start", "missing_end"
message: str
@dataclass
class GraphValidationResult:
"""Result of graph validation."""
success: bool
errors: list[GraphError] = field(default_factory=list)
warnings: list[GraphError] = field(default_factory=list)
execution_time: float = 0.0
stats: dict = field(default_factory=dict)
class GraphValidator:
"""
Validates workflow graph structure using proper graph algorithms.
Performs:
1. Forward reachability analysis (BFS from start)
2. Backward reachability analysis (reverse BFS from end)
3. Branch edge validation for if-else and classifier nodes
"""
@staticmethod
def _build_adjacency(
nodes: dict[str, dict], edges: list[dict]
) -> tuple[dict[str, list[str]], dict[str, list[str]]]:
"""Build forward and reverse adjacency lists from edges."""
outgoing: dict[str, list[str]] = {node_id: [] for node_id in nodes}
incoming: dict[str, list[str]] = {node_id: [] for node_id in nodes}
for edge in edges:
source = edge.get("source")
target = edge.get("target")
if source in outgoing and target in incoming:
outgoing[source].append(target)
incoming[target].append(source)
return outgoing, incoming
@staticmethod
def _bfs_reachable(start: str, adjacency: dict[str, list[str]]) -> set[str]:
"""BFS to find all nodes reachable from start node."""
if start not in adjacency:
return set()
visited = set()
queue = deque([start])
visited.add(start)
while queue:
current = queue.popleft()
for neighbor in adjacency.get(current, []):
if neighbor not in visited:
visited.add(neighbor)
queue.append(neighbor)
return visited
@staticmethod
def validate(workflow_data: dict) -> GraphValidationResult:
"""Validate workflow graph structure."""
start_time = time.time()
errors: list[GraphError] = []
warnings: list[GraphError] = []
nodes_list = workflow_data.get("nodes", [])
edges_list = workflow_data.get("edges", [])
nodes = {n["id"]: n for n in nodes_list if n.get("id")}
# Find start and end nodes
start_node_id = None
end_node_ids = []
for node_id, node in nodes.items():
node_type = node.get("type")
if node_type == "start":
start_node_id = node_id
elif node_type == "end":
end_node_ids.append(node_id)
# Check start node exists
if not start_node_id:
errors.append(
GraphError(
node_id="workflow",
node_type="workflow",
error_type="missing_start",
message="Workflow has no start node",
)
)
# Check end node exists
if not end_node_ids:
errors.append(
GraphError(
node_id="workflow",
node_type="workflow",
error_type="missing_end",
message="Workflow has no end node",
)
)
# If missing start or end, can't do reachability analysis
if not start_node_id or not end_node_ids:
execution_time = time.time() - start_time
return GraphValidationResult(
success=False,
errors=errors,
warnings=warnings,
execution_time=execution_time,
stats={"nodes": len(nodes), "edges": len(edges_list)},
)
# Build adjacency lists
outgoing, incoming = GraphValidator._build_adjacency(nodes, edges_list)
# --- FORWARD REACHABILITY: BFS from start ---
reachable_from_start = GraphValidator._bfs_reachable(start_node_id, outgoing)
# Find unreachable nodes
unreachable_nodes = set(nodes.keys()) - reachable_from_start
for node_id in unreachable_nodes:
node = nodes[node_id]
errors.append(
GraphError(
node_id=node_id,
node_type=node.get("type", "unknown"),
error_type="unreachable",
message=f"Node '{node_id}' is not reachable from start node",
)
)
# --- BACKWARD REACHABILITY: Reverse BFS from end nodes ---
can_reach_end: set[str] = set()
for end_id in end_node_ids:
can_reach_end.update(GraphValidator._bfs_reachable(end_id, incoming))
# Find dead-end nodes (can't reach any end node)
dead_end_nodes = set(nodes.keys()) - can_reach_end
for node_id in dead_end_nodes:
if node_id in unreachable_nodes:
continue
node = nodes[node_id]
warnings.append(
GraphError(
node_id=node_id,
node_type=node.get("type", "unknown"),
error_type="dead_end",
message=f"Node '{node_id}' cannot reach any end node (dead end)",
)
)
# --- Start node has outgoing edges? ---
if not outgoing.get(start_node_id):
errors.append(
GraphError(
node_id=start_node_id,
node_type="start",
error_type="disconnected",
message="Start node has no outgoing connections",
)
)
# --- End nodes have incoming edges? ---
for end_id in end_node_ids:
if not incoming.get(end_id):
errors.append(
GraphError(
node_id=end_id,
node_type="end",
error_type="disconnected",
message="End node has no incoming connections",
)
)
# --- BRANCH EDGE VALIDATION ---
edge_handles: dict[str, set[str]] = {}
for edge in edges_list:
source = edge.get("source")
handle = edge.get("sourceHandle", "")
if source:
if source not in edge_handles:
edge_handles[source] = set()
edge_handles[source].add(handle)
# Check if-else and question-classifier nodes
for node_id, node in nodes.items():
node_type = node.get("type")
if node_type == "if-else":
handles = edge_handles.get(node_id, set())
config = node.get("config", {})
cases = config.get("cases", [])
required_handles = set()
for case in cases:
case_id = case.get("case_id")
if case_id:
required_handles.add(case_id)
required_handles.add("false")
missing = required_handles - handles
for handle in missing:
errors.append(
GraphError(
node_id=node_id,
node_type=node_type,
error_type="missing_branch",
message=f"If-else node '{node_id}' missing edge for branch '{handle}'",
)
)
elif node_type == "question-classifier":
handles = edge_handles.get(node_id, set())
config = node.get("config", {})
classes = config.get("classes", [])
required_handles = set()
for cls in classes:
if isinstance(cls, dict):
cls_id = cls.get("id")
if cls_id:
required_handles.add(cls_id)
missing = required_handles - handles
for handle in missing:
cls_name = handle
for cls in classes:
if isinstance(cls, dict) and cls.get("id") == handle:
cls_name = cls.get("name", handle)
break
errors.append(
GraphError(
node_id=node_id,
node_type=node_type,
error_type="missing_branch",
message=f"Classifier '{node_id}' missing edge for class '{cls_name}'",
)
)
execution_time = time.time() - start_time
success = len(errors) == 0
return GraphValidationResult(
success=success,
errors=errors,
warnings=warnings,
execution_time=execution_time,
stats={
"nodes": len(nodes),
"edges": len(edges_list),
"reachable_from_start": len(reachable_from_start),
"can_reach_end": len(can_reach_end),
"unreachable": len(unreachable_nodes),
"dead_ends": len(dead_end_nodes - unreachable_nodes),
},
)

View File

@@ -0,0 +1,113 @@
import logging
from core.workflow.generator.types import WorkflowDataDict
logger = logging.getLogger(__name__)
def generate_mermaid(workflow_data: WorkflowDataDict) -> str:
"""
Generate a Mermaid flowchart from workflow data consisting of nodes and edges.
Args:
workflow_data: Dict containing 'nodes' (list) and 'edges' (list)
Returns:
String containing the Mermaid flowchart syntax
"""
nodes = workflow_data.get("nodes", [])
edges = workflow_data.get("edges", [])
lines = ["flowchart TD"]
# 1. Define Nodes
# Format: node_id["title<br/>type"] or similar
# We will use the Vibe Workflow standard format: id["type=TYPE|title=TITLE"]
# Or specifically for tool nodes: id["type=tool|title=TITLE|tool=TOOL_KEY"]
# Map of original IDs to safe Mermaid IDs
id_map = {}
def get_safe_id(original_id: str) -> str:
if original_id == "end":
return "end_node"
if original_id == "subgraph":
return "subgraph_node"
# Mermaid IDs should be alphanumeric.
# If the ID has special chars, we might need to escape or hash, but Vibe usually generates simple IDs.
# We'll trust standard IDs but handle the reserved keyword 'end'.
return original_id
for node in nodes:
node_id = node.get("id")
if not node_id:
continue
safe_id = get_safe_id(node_id)
id_map[node_id] = safe_id
node_type = node.get("type", "unknown")
title = node.get("title", "Untitled")
# Escape quotes in title
safe_title = title.replace('"', "'")
if node_type == "tool":
config = node.get("config", {})
# Try multiple fields for tool reference
tool_ref = (
config.get("tool_key")
or config.get("tool")
or config.get("tool_name")
or node.get("tool_name")
or "unknown"
)
node_def = f'{safe_id}["type={node_type}|title={safe_title}|tool={tool_ref}"]'
else:
node_def = f'{safe_id}["type={node_type}|title={safe_title}"]'
lines.append(f" {node_def}")
# 2. Define Edges
# Format: source --> target
# Track defined nodes to avoid edge errors
defined_node_ids = {n.get("id") for n in nodes if n.get("id")}
for edge in edges:
source = edge.get("source")
target = edge.get("target")
# Skip invalid edges
if not source or not target:
continue
if source not in defined_node_ids or target not in defined_node_ids:
continue
safe_source = id_map.get(source, source)
safe_target = id_map.get(target, target)
# Handle conditional branches (true/false) if present
# In Dify workflow, sourceHandle is often used for this
source_handle = edge.get("sourceHandle")
label = ""
if source_handle == "true":
label = "|true|"
elif source_handle == "false":
label = "|false|"
elif source_handle and source_handle != "source":
# For question-classifier or other multi-path nodes
# Clean up handle for display if needed
safe_handle = str(source_handle).replace('"', "'")
label = f"|{safe_handle}|"
edge_line = f" {safe_source} -->{label} {safe_target}"
lines.append(edge_line)
# Start/End nodes are implicitly handled if they are in the 'nodes' list
# If not, we might need to add them, but usually the Builder should produce them.
result = "\n".join(lines)
return result

View File

@@ -0,0 +1,304 @@
"""
Node Repair Utility for Vibe Workflow Generation.
This module provides intelligent node configuration repair capabilities.
It can detect and fix common node configuration issues:
- Invalid comparison operators in if-else nodes (e.g. '>=' -> '')
"""
import copy
import logging
import uuid
from dataclasses import dataclass, field
from core.workflow.generator.types import WorkflowNodeDict
logger = logging.getLogger(__name__)
@dataclass
class NodeRepairResult:
"""Result of node repair operation."""
nodes: list[WorkflowNodeDict]
repairs_made: list[str] = field(default_factory=list)
warnings: list[str] = field(default_factory=list)
@property
def was_repaired(self) -> bool:
"""Check if any repairs were made."""
return len(self.repairs_made) > 0
class NodeRepair:
"""
Intelligent node configuration repair.
"""
OPERATOR_MAP = {
">=": "",
"<=": "",
"!=": "",
"==": "=",
}
TYPE_MAPPING = {
"json": "object",
"dict": "object",
"dictionary": "object",
"float": "number",
"int": "number",
"integer": "number",
"double": "number",
"str": "string",
"text": "string",
"bool": "boolean",
"list": "array[object]",
"array": "array[object]",
}
_REPAIR_HANDLERS = {
"if-else": "_repair_if_else_operators",
"variable-aggregator": "_repair_variable_aggregator_variables",
"code": "_repair_code_node_config",
}
@classmethod
def repair(
cls,
nodes: list[WorkflowNodeDict],
llm_callback=None,
) -> NodeRepairResult:
"""
Repair node configurations.
Args:
nodes: List of node dictionaries
llm_callback: Optional callback(node, issue_desc) -> fixed_config_part
Returns:
NodeRepairResult with repaired nodes and logs
"""
# Deep copy to avoid mutating original
nodes = copy.deepcopy(nodes)
repairs: list[str] = []
warnings: list[str] = []
logger.info("[NODE REPAIR] Starting repair process for %s nodes", len(nodes))
for node in nodes:
node_type = node.get("type")
# 1. Rule-based repairs
handler_name = cls._REPAIR_HANDLERS.get(node_type)
if handler_name:
handler = getattr(cls, handler_name)
# Check if handler accepts llm_callback (inspect signature or just pass generic kwargs?)
# Simplest for now: handlers signature: (node, repairs, llm_callback=None)
try:
handler(node, repairs, llm_callback=llm_callback)
except TypeError:
# Fallback for handlers that don't accept llm_callback yet
handler(node, repairs)
# Add other node type repairs here as needed
if repairs:
logger.info("[NODE REPAIR] Completed with %s repairs:", len(repairs))
for i, repair in enumerate(repairs, 1):
logger.info("[NODE REPAIR] %s. %s", i, repair)
else:
logger.info("[NODE REPAIR] Completed - no repairs needed")
return NodeRepairResult(
nodes=nodes,
repairs_made=repairs,
warnings=warnings,
)
@classmethod
def _repair_if_else_operators(cls, node: WorkflowNodeDict, repairs: list[str], **kwargs):
"""
Normalize comparison operators in if-else nodes.
And ensure 'id' field exists for cases and conditions (frontend requirement).
"""
node_id = node.get("id", "unknown")
config = node.get("config", {})
cases = config.get("cases", [])
for case in cases:
# Ensure case_id
if "case_id" not in case:
case["case_id"] = str(uuid.uuid4())
repairs.append(f"Generated missing case_id for case in node '{node_id}'")
conditions = case.get("conditions", [])
for condition in conditions:
# Ensure condition id
if "id" not in condition:
condition["id"] = str(uuid.uuid4())
# Not logging this repair to avoid clutter, as it's a structural fix
# Ensure value type (LLM might return int/float, but we need str/bool/list)
val = condition.get("value")
if isinstance(val, (int, float)) and not isinstance(val, bool):
condition["value"] = str(val)
repairs.append(f"Coerced numeric value to string in node '{node_id}'")
op = condition.get("comparison_operator")
if op in cls.OPERATOR_MAP:
new_op = cls.OPERATOR_MAP[op]
condition["comparison_operator"] = new_op
repairs.append(f"Normalized operator '{op}' to '{new_op}' in node '{node_id}'")
@classmethod
def _repair_variable_aggregator_variables(cls, node: WorkflowNodeDict, repairs: list[str]):
"""
Repair variable-aggregator variables format.
Converts dict format to list[list[str]] format.
Expected: [["node_id", "field"], ["node_id2", "field2"]]
May receive: [{"name": "...", "value_selector": ["node_id", "field"]}, ...]
"""
node_id = node.get("id", "unknown")
config = node.get("config", {})
variables = config.get("variables", [])
if not variables:
return
repaired = False
repaired_variables = []
for var in variables:
if isinstance(var, dict):
# Convert dict format to array format
value_selector = var.get("value_selector") or var.get("selector") or var.get("path")
if isinstance(value_selector, list) and len(value_selector) > 0:
repaired_variables.append(value_selector)
repaired = True
else:
# Try to extract from name field - LLM may generate {"name": "node_id.field"}
name = var.get("name")
if isinstance(name, str) and "." in name:
# Try to parse "node_id.field" format
parts = name.split(".", 1)
if len(parts) == 2:
repaired_variables.append([parts[0], parts[1]])
repaired = True
else:
logger.warning(
"Variable aggregator node '%s' has invalid variable format: %s",
node_id,
var,
)
repaired_variables.append([]) # Empty array as fallback
else:
# If no valid selector or name, skip this variable
logger.warning(
"Variable aggregator node '%s' has invalid variable format: %s",
node_id,
var,
)
# Don't add empty array - skip invalid variables
elif isinstance(var, list):
# Already in correct format
repaired_variables.append(var)
else:
# Unknown format, skip
logger.warning("Variable aggregator node '%s' has unknown variable format: %s", node_id, var)
# Don't add empty array - skip invalid variables
if repaired:
config["variables"] = repaired_variables
repairs.append(f"Repaired variable-aggregator variables format in node '{node_id}'")
@classmethod
def _repair_code_node_config(cls, node: WorkflowNodeDict, repairs: list[str], llm_callback=None):
"""
Repair code node configuration (outputs and variables).
1. Outputs: Converts list format to dict format AND normalizes types.
2. Variables: Ensures value_selector exists.
"""
node_id = node.get("id", "unknown")
config = node.get("config", {})
if "variables" not in config:
config["variables"] = []
# --- Repair Variables ---
variables = config.get("variables")
if isinstance(variables, list):
for var in variables:
if isinstance(var, dict):
# Ensure value_selector exists (frontend crashes if missing)
if "value_selector" not in var:
var["value_selector"] = []
# Not logging trivial repairs
# --- Repair Outputs ---
outputs = config.get("outputs")
if not outputs:
return
# Helper to normalize type
def normalize_type(t: str) -> str:
t_lower = str(t).lower()
return cls.TYPE_MAPPING.get(t_lower, t)
# 1. Handle Dict format (Standard) - Check for invalid types
if isinstance(outputs, dict):
changed = False
for var_name, var_config in outputs.items():
if isinstance(var_config, dict):
original_type = var_config.get("type")
if original_type:
new_type = normalize_type(original_type)
if new_type != original_type:
var_config["type"] = new_type
changed = True
repairs.append(
f"Normalized type '{original_type}' to '{new_type}' "
f"for var '{var_name}' in node '{node_id}'"
)
return
# 2. Handle List format (Repair needed)
if isinstance(outputs, list):
new_outputs = {}
for item in outputs:
if isinstance(item, dict):
var_name = item.get("variable") or item.get("name")
var_type = item.get("type")
if var_name and var_type:
norm_type = normalize_type(var_type)
new_outputs[var_name] = {"type": norm_type}
if norm_type != var_type:
repairs.append(
f"Normalized type '{var_type}' to '{norm_type}' "
f"during list conversion in node '{node_id}'"
)
if new_outputs:
config["outputs"] = new_outputs
repairs.append(f"Repaired code node outputs format in node '{node_id}'")
else:
# Fallback: Try LLM if available
if llm_callback:
try:
# Attempt to fix using LLM
fixed_outputs = llm_callback(
node,
"outputs must be a dictionary like {'var_name': {'type': 'string'}}, "
"but got a list or valid conversion failed.",
)
if isinstance(fixed_outputs, dict) and fixed_outputs:
config["outputs"] = fixed_outputs
repairs.append(f"Repaired code node outputs format using LLM in node '{node_id}'")
return
except Exception as e:
logger.warning("LLM fallback repair failed for node '%s': %s", node_id, e)
# If conversion/LLM failed, set to empty dict
config["outputs"] = {}
repairs.append(f"Reset invalid code node outputs to empty dict in node '{node_id}'")

View File

@@ -0,0 +1,101 @@
from dataclasses import dataclass
from core.workflow.generator.types import AvailableModelDict, AvailableToolDict, WorkflowDataDict
from core.workflow.generator.validation.context import ValidationContext
from core.workflow.generator.validation.engine import ValidationEngine
from core.workflow.generator.validation.rules import Severity
@dataclass
class ValidationHint:
"""Legacy compatibility class for validation hints."""
node_id: str
field: str
message: str
severity: str # 'error', 'warning'
suggestion: str = None
node_type: str = None # Added for test compatibility
# Alias for potential old code using 'type' instead of 'severity'
@property
def type(self) -> str:
return self.severity
@property
def element_id(self) -> str:
return self.node_id
FriendlyHint = ValidationHint # Alias for backward compatibility
class WorkflowValidator:
"""
Validates the generated workflow configuration (nodes and edges).
Wraps the new ValidationEngine for backward compatibility.
"""
@classmethod
def validate(
cls,
workflow_data: WorkflowDataDict,
available_tools: list[AvailableToolDict],
available_models: list[AvailableModelDict] | None = None,
) -> tuple[bool, list[ValidationHint]]:
"""
Validate workflow data and return validity status and hints.
Args:
workflow_data: Dict containing 'nodes' and 'edges'
available_tools: List of available tool configurations
available_models: List of available models (added for Vibe compat)
Returns:
Tuple(max_severity_is_not_error, list_of_hints)
"""
nodes = workflow_data.get("nodes", [])
edges = workflow_data.get("edges", [])
# Create context
context = ValidationContext(
nodes=nodes,
edges=edges,
available_models=available_models or [],
available_tools=available_tools or [],
)
# Run validation engine
engine = ValidationEngine()
result = engine.validate(context)
# Convert engine errors to legacy hints
hints: list[ValidationHint] = []
error_count = 0
warning_count = 0
for error in result.all_errors:
# Map severity
severity = "error" if error.severity == Severity.ERROR else "warning"
if severity == "error":
error_count += 1
else:
warning_count += 1
# Map field from message or details if possible (heuristic)
field_name = error.details.get("field", "unknown")
hints.append(
ValidationHint(
node_id=error.node_id,
field=field_name,
message=error.message,
severity=severity,
suggestion=error.fix_hint,
node_type=error.node_type,
)
)
return result.is_valid, hints

View File

@@ -0,0 +1,42 @@
"""
Validation Rule Engine for Vibe Workflow Generation.
This module provides a declarative, schema-based validation system for
generated workflow nodes. It classifies errors into fixable (LLM can auto-fix)
and user-required (needs manual intervention) categories.
Usage:
from core.workflow.generator.validation import ValidationEngine, ValidationContext
context = ValidationContext(
available_models=[...],
available_tools=[...],
nodes=[...],
edges=[...],
)
engine = ValidationEngine()
result = engine.validate(context)
# Access classified errors
fixable_errors = result.fixable_errors
user_required_errors = result.user_required_errors
"""
from core.workflow.generator.validation.context import ValidationContext
from core.workflow.generator.validation.engine import ValidationEngine, ValidationResult
from core.workflow.generator.validation.rules import (
RuleCategory,
Severity,
ValidationError,
ValidationRule,
)
__all__ = [
"RuleCategory",
"Severity",
"ValidationContext",
"ValidationEngine",
"ValidationError",
"ValidationResult",
"ValidationRule",
]

View File

@@ -0,0 +1,115 @@
"""
Validation Context for the Rule Engine.
The ValidationContext holds all the data needed for validation:
- Generated nodes and edges
- Available models, tools, and datasets
- Node output schemas for variable reference validation
"""
from dataclasses import dataclass, field
from core.workflow.generator.types import (
AvailableModelDict,
AvailableToolDict,
WorkflowEdgeDict,
WorkflowNodeDict,
)
@dataclass
class ValidationContext:
"""
Context object containing all data needed for validation.
This is passed to each validation rule, providing access to:
- The nodes being validated
- Edge connections between nodes
- Available external resources (models, tools)
"""
# Generated workflow data
nodes: list[WorkflowNodeDict] = field(default_factory=list)
edges: list[WorkflowEdgeDict] = field(default_factory=list)
# Available external resources
available_models: list[AvailableModelDict] = field(default_factory=list)
available_tools: list[AvailableToolDict] = field(default_factory=list)
# Cached lookups (populated lazily)
_node_map: dict[str, WorkflowNodeDict] | None = field(default=None, repr=False)
_model_set: set[tuple[str, str]] | None = field(default=None, repr=False)
_tool_set: set[str] | None = field(default=None, repr=False)
_configured_tool_set: set[str] | None = field(default=None, repr=False)
@property
def node_map(self) -> dict[str, WorkflowNodeDict]:
"""Get a map of node_id -> node for quick lookup."""
if self._node_map is None:
self._node_map = {node.get("id", ""): node for node in self.nodes}
return self._node_map
@property
def model_set(self) -> set[tuple[str, str]]:
"""Get a set of (provider, model_name) tuples for quick lookup."""
if self._model_set is None:
self._model_set = {(m.get("provider", ""), m.get("model", "")) for m in self.available_models}
return self._model_set
@property
def tool_set(self) -> set[str]:
"""Get a set of all tool keys (both configured and unconfigured)."""
if self._tool_set is None:
self._tool_set = set()
for tool in self.available_tools:
provider = tool.get("provider_id") or tool.get("provider", "")
tool_key = tool.get("tool_key") or tool.get("tool_name", "")
if provider and tool_key:
self._tool_set.add(f"{provider}/{tool_key}")
if tool_key:
self._tool_set.add(tool_key)
return self._tool_set
@property
def configured_tool_set(self) -> set[str]:
"""Get a set of configured (authorized) tool keys."""
if self._configured_tool_set is None:
self._configured_tool_set = set()
for tool in self.available_tools:
if not tool.get("is_team_authorization", False):
continue
provider = tool.get("provider_id") or tool.get("provider", "")
tool_key = tool.get("tool_key") or tool.get("tool_name", "")
if provider and tool_key:
self._configured_tool_set.add(f"{provider}/{tool_key}")
if tool_key:
self._configured_tool_set.add(tool_key)
return self._configured_tool_set
def has_model(self, provider: str, model_name: str) -> bool:
"""Check if a model is available."""
return (provider, model_name) in self.model_set
def has_tool(self, tool_key: str) -> bool:
"""Check if a tool exists (configured or not)."""
return tool_key in self.tool_set
def is_tool_configured(self, tool_key: str) -> bool:
"""Check if a tool is configured and ready to use."""
return tool_key in self.configured_tool_set
def get_node(self, node_id: str) -> WorkflowNodeDict | None:
"""Get a node by its ID."""
return self.node_map.get(node_id)
def get_node_ids(self) -> set[str]:
"""Get all node IDs in the workflow."""
return set(self.node_map.keys())
def get_upstream_nodes(self, node_id: str) -> list[str]:
"""Get IDs of nodes that connect to this node (upstream)."""
return [edge.get("source", "") for edge in self.edges if edge.get("target") == node_id]
def get_downstream_nodes(self, node_id: str) -> list[str]:
"""Get IDs of nodes that this node connects to (downstream)."""
return [edge.get("target", "") for edge in self.edges if edge.get("source") == node_id]

View File

@@ -0,0 +1,260 @@
"""
Validation Engine - Core validation logic.
The ValidationEngine orchestrates rule execution and aggregates results.
It provides a clean interface for validating workflow nodes.
"""
import logging
from dataclasses import dataclass, field
from typing import Any
from core.workflow.generator.types import (
AvailableModelDict,
AvailableToolDict,
WorkflowEdgeDict,
WorkflowNodeDict,
)
from core.workflow.generator.validation.context import ValidationContext
from core.workflow.generator.validation.rules import (
RuleCategory,
Severity,
ValidationError,
get_registry,
)
logger = logging.getLogger(__name__)
@dataclass
class ValidationResult:
"""
Result of validation containing all errors classified by fixability.
Attributes:
all_errors: All validation errors found
fixable_errors: Errors that LLM can automatically fix
user_required_errors: Errors that require user intervention
warnings: Non-blocking warnings
stats: Validation statistics
"""
all_errors: list[ValidationError] = field(default_factory=list)
fixable_errors: list[ValidationError] = field(default_factory=list)
user_required_errors: list[ValidationError] = field(default_factory=list)
warnings: list[ValidationError] = field(default_factory=list)
stats: dict[str, int] = field(default_factory=dict)
@property
def has_errors(self) -> bool:
"""Check if there are any errors (excluding warnings)."""
return len(self.fixable_errors) > 0 or len(self.user_required_errors) > 0
@property
def has_fixable_errors(self) -> bool:
"""Check if there are fixable errors."""
return len(self.fixable_errors) > 0
@property
def is_valid(self) -> bool:
"""Check if validation passed (no errors, warnings are OK)."""
return not self.has_errors
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for API response."""
return {
"fixable": [e.to_dict() for e in self.fixable_errors],
"user_required": [e.to_dict() for e in self.user_required_errors],
"warnings": [e.to_dict() for e in self.warnings],
"all_warnings": [e.message for e in self.all_errors],
"stats": self.stats,
}
def get_error_messages(self) -> list[str]:
"""Get all error messages as strings."""
return [e.message for e in self.all_errors]
def get_fixable_by_node(self) -> dict[str, list[ValidationError]]:
"""Group fixable errors by node ID."""
result: dict[str, list[ValidationError]] = {}
for error in self.fixable_errors:
if error.node_id not in result:
result[error.node_id] = []
result[error.node_id].append(error)
return result
class ValidationEngine:
"""
The main validation engine.
Usage:
engine = ValidationEngine()
context = ValidationContext(nodes=[...], available_models=[...])
result = engine.validate(context)
"""
def __init__(self):
self._registry = get_registry()
def validate(self, context: ValidationContext) -> ValidationResult:
"""
Validate all nodes in the context.
Args:
context: ValidationContext with nodes, edges, and available resources
Returns:
ValidationResult with classified errors
"""
result = ValidationResult()
stats = {
"total_nodes": len(context.nodes),
"total_rules_checked": 0,
"total_errors": 0,
"fixable_count": 0,
"user_required_count": 0,
"warning_count": 0,
}
# Validate each node
for node in context.nodes:
node_type = node.get("type", "unknown")
node_id = node.get("id", "unknown")
# Get applicable rules for this node type
rules = self._registry.get_rules_for_node(node_type)
for rule in rules:
stats["total_rules_checked"] += 1
try:
errors = rule.check(node, context)
for error in errors:
result.all_errors.append(error)
stats["total_errors"] += 1
# Classify by severity and fixability
if error.severity == Severity.WARNING:
result.warnings.append(error)
stats["warning_count"] += 1
elif error.is_fixable:
result.fixable_errors.append(error)
stats["fixable_count"] += 1
else:
result.user_required_errors.append(error)
stats["user_required_count"] += 1
except Exception:
logger.exception(
"Rule '%s' failed for node '%s'",
rule.id,
node_id,
)
# Don't let a rule failure break the entire validation
continue
# Validate edges separately
edge_errors = self._validate_edges(context)
for error in edge_errors:
result.all_errors.append(error)
stats["total_errors"] += 1
if error.is_fixable:
result.fixable_errors.append(error)
stats["fixable_count"] += 1
else:
result.user_required_errors.append(error)
stats["user_required_count"] += 1
result.stats = stats
return result
def _validate_edges(self, context: ValidationContext) -> list[ValidationError]:
"""Validate edge connections."""
errors: list[ValidationError] = []
valid_node_ids = context.get_node_ids()
for edge in context.edges:
source = edge.get("source", "")
target = edge.get("target", "")
if source and source not in valid_node_ids:
errors.append(
ValidationError(
rule_id="edge.source.invalid",
node_id=source,
node_type="edge",
category=RuleCategory.SEMANTIC,
severity=Severity.ERROR,
is_fixable=True,
message=f"Edge source '{source}' does not exist",
fix_hint="Update edge to reference existing node",
)
)
if target and target not in valid_node_ids:
errors.append(
ValidationError(
rule_id="edge.target.invalid",
node_id=target,
node_type="edge",
category=RuleCategory.SEMANTIC,
severity=Severity.ERROR,
is_fixable=True,
message=f"Edge target '{target}' does not exist",
fix_hint="Update edge to reference existing node",
)
)
return errors
def validate_single_node(
self,
node: WorkflowNodeDict,
context: ValidationContext,
) -> list[ValidationError]:
"""
Validate a single node.
Useful for incremental validation when a node is added/modified.
"""
node_type = node.get("type", "unknown")
rules = self._registry.get_rules_for_node(node_type)
errors: list[ValidationError] = []
for rule in rules:
try:
errors.extend(rule.check(node, context))
except Exception:
logger.exception("Rule '%s' failed", rule.id)
return errors
def validate_nodes(
nodes: list[WorkflowNodeDict],
edges: list[WorkflowEdgeDict] | None = None,
available_models: list[AvailableModelDict] | None = None,
available_tools: list[AvailableToolDict] | None = None,
) -> ValidationResult:
"""
Convenience function to validate nodes without creating engine/context manually.
Args:
nodes: List of workflow nodes to validate
edges: Optional list of edges
available_models: Optional list of available models
available_tools: Optional list of available tools
Returns:
ValidationResult with classified errors
"""
context = ValidationContext(
nodes=nodes,
edges=edges or [],
available_models=available_models or [],
available_tools=available_tools or [],
)
engine = ValidationEngine()
return engine.validate(context)

View File

@@ -0,0 +1,947 @@
"""
Validation Rules Definition and Registry.
This module defines:
- ValidationRule: The rule structure
- RuleCategory: Categories of validation rules
- Severity: Error severity levels
- ValidationError: Error output structure
- All built-in validation rules
"""
import re
from collections.abc import Callable
from dataclasses import dataclass, field
from enum import Enum
from typing import TYPE_CHECKING, Any
from core.workflow.generator.types import WorkflowNodeDict
if TYPE_CHECKING:
from core.workflow.generator.validation.context import ValidationContext
class RuleCategory(Enum):
"""Categories of validation rules."""
STRUCTURE = "structure" # Field existence, types, formats
SEMANTIC = "semantic" # Variable references, edge connections
REFERENCE = "reference" # External resources (models, tools, datasets)
class Severity(Enum):
"""Severity levels for validation errors."""
ERROR = "error" # Must be fixed
WARNING = "warning" # Should be fixed but not blocking
@dataclass
class ValidationError:
"""
Represents a validation error found during rule execution.
Attributes:
rule_id: The ID of the rule that generated this error
node_id: The ID of the node with the error
node_type: The type of the node
category: The rule category
severity: Error severity
is_fixable: Whether LLM can auto-fix this error
message: Human-readable error message
fix_hint: Hint for LLM to fix the error
details: Additional error details
"""
rule_id: str
node_id: str
node_type: str
category: RuleCategory
severity: Severity
is_fixable: bool
message: str
fix_hint: str = ""
details: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> dict[str, Any]:
"""Convert to dictionary for API response."""
return {
"rule_id": self.rule_id,
"node_id": self.node_id,
"node_type": self.node_type,
"category": self.category.value,
"severity": self.severity.value,
"is_fixable": self.is_fixable,
"message": self.message,
"fix_hint": self.fix_hint,
"details": self.details,
}
# Type alias for rule check functions
RuleCheckFn = Callable[
[WorkflowNodeDict, "ValidationContext"],
list[ValidationError],
]
@dataclass
class ValidationRule:
"""
A validation rule definition.
Attributes:
id: Unique rule identifier (e.g., "llm.model.required")
node_types: List of node types this rule applies to, or ["*"] for all
category: The rule category
severity: Default severity for errors from this rule
is_fixable: Whether errors from this rule can be auto-fixed by LLM
check: The validation function
description: Human-readable description of what this rule checks
fix_hint: Default hint for fixing errors from this rule
"""
id: str
node_types: list[str]
category: RuleCategory
severity: Severity
is_fixable: bool
check: RuleCheckFn
description: str = ""
fix_hint: str = ""
def applies_to(self, node_type: str) -> bool:
"""Check if this rule applies to a given node type."""
return "*" in self.node_types or node_type in self.node_types
# =============================================================================
# Rule Registry
# =============================================================================
class RuleRegistry:
"""
Registry for validation rules.
Rules are registered here and can be retrieved by category or node type.
"""
def __init__(self):
self._rules: list[ValidationRule] = []
def register(self, rule: ValidationRule) -> None:
"""Register a validation rule."""
self._rules.append(rule)
def get_rules_for_node(self, node_type: str) -> list[ValidationRule]:
"""Get all rules that apply to a given node type."""
return [r for r in self._rules if r.applies_to(node_type)]
def get_rules_by_category(self, category: RuleCategory) -> list[ValidationRule]:
"""Get all rules in a given category."""
return [r for r in self._rules if r.category == category]
def get_all_rules(self) -> list[ValidationRule]:
"""Get all registered rules."""
return list(self._rules)
# Global rule registry instance
_registry = RuleRegistry()
def register_rule(rule: ValidationRule) -> ValidationRule:
"""Decorator/function to register a rule with the global registry."""
_registry.register(rule)
return rule
def get_registry() -> RuleRegistry:
"""Get the global rule registry."""
return _registry
# =============================================================================
# Helper Functions for Rule Implementations
# =============================================================================
# Explicit placeholder value defined in prompt contract
# See: api/core/workflow/generator/prompts/vibe_prompts.py
PLACEHOLDER_VALUE = "__PLACEHOLDER__"
# Variable reference pattern: {{#node_id.field#}}
VARIABLE_REF_PATTERN = re.compile(r"\{\{#([^.#]+)\.([^#]+)#\}\}")
def is_placeholder(value: Any) -> bool:
"""Check if a value appears to be a placeholder."""
if not isinstance(value, str):
return False
return value == PLACEHOLDER_VALUE or PLACEHOLDER_VALUE in value
def extract_variable_refs(text: str) -> list[tuple[str, str]]:
"""
Extract variable references from text.
Returns list of (node_id, field_name) tuples.
"""
return VARIABLE_REF_PATTERN.findall(text)
def check_required_field(
config: dict[str, Any],
field_name: str,
node_id: str,
node_type: str,
rule_id: str,
fix_hint: str = "",
) -> ValidationError | None:
"""Helper to check if a required field exists and is non-empty."""
value = config.get(field_name)
if value is None or value == "" or (isinstance(value, list) and len(value) == 0):
return ValidationError(
rule_id=rule_id,
node_id=node_id,
node_type=node_type,
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}': missing required field '{field_name}'",
fix_hint=fix_hint or f"Add '{field_name}' to the node config",
)
return None
# =============================================================================
# Structure Rules - Field existence, types, formats
# =============================================================================
def _check_llm_prompt_template(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that LLM node has prompt_template."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
config = node.get("config", {})
err = check_required_field(
config,
"prompt_template",
node_id,
"llm",
"llm.prompt_template.required",
"Add prompt_template with system and user messages",
)
if err:
errors.append(err)
return errors
def _check_http_request_url(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that http-request node has url and method."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
config = node.get("config", {})
# Check url
url = config.get("url", "")
if not url:
errors.append(
ValidationError(
rule_id="http.url.required",
node_id=node_id,
node_type="http-request",
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}': http-request missing required 'url'",
fix_hint="Add url - use {{#start.url#}} or a concrete URL",
)
)
elif is_placeholder(url):
errors.append(
ValidationError(
rule_id="http.url.placeholder",
node_id=node_id,
node_type="http-request",
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}': url contains placeholder value",
fix_hint="Replace placeholder with actual URL or variable reference",
)
)
# Check method
method = config.get("method", "")
if not method:
errors.append(
ValidationError(
rule_id="http.method.required",
node_id=node_id,
node_type="http-request",
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}': http-request missing 'method'",
fix_hint="Add method: GET, POST, PUT, DELETE, or PATCH",
)
)
return errors
def _check_code_node(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that code node has code and language."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
config = node.get("config", {})
err = check_required_field(
config,
"code",
node_id,
"code",
"code.code.required",
"Add code with a main() function that returns a dict",
)
if err:
errors.append(err)
err = check_required_field(
config,
"language",
node_id,
"code",
"code.language.required",
"Add language: python3 or javascript",
)
if err:
errors.append(err)
return errors
def _check_question_classifier(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that question-classifier has classes."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
config = node.get("config", {})
err = check_required_field(
config,
"classes",
node_id,
"question-classifier",
"classifier.classes.required",
"Add classes array with id and name for each classification",
)
if err:
errors.append(err)
return errors
def _check_parameter_extractor(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that parameter-extractor has parameters and instruction."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
config = node.get("config", {})
err = check_required_field(
config,
"parameters",
node_id,
"parameter-extractor",
"extractor.parameters.required",
"Add parameters array with name, type, description fields",
)
if err:
errors.append(err)
else:
# Check individual parameters for required fields
parameters = config.get("parameters", [])
if isinstance(parameters, list):
for i, param in enumerate(parameters):
if isinstance(param, dict):
# Check for 'required' field (boolean)
if "required" not in param:
errors.append(
ValidationError(
rule_id="extractor.param.required_field.missing",
node_id=node_id,
node_type="parameter-extractor",
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}': parameter[{i}] missing 'required' field",
fix_hint=f"Add 'required': True to parameter '{param.get('name', 'unknown')}'",
details={"param_index": i, "param_name": param.get("name")},
)
)
# instruction is recommended but not strictly required
if not config.get("instruction"):
errors.append(
ValidationError(
rule_id="extractor.instruction.recommended",
node_id=node_id,
node_type="parameter-extractor",
category=RuleCategory.STRUCTURE,
severity=Severity.WARNING,
is_fixable=True,
message=f"Node '{node_id}': parameter-extractor should have 'instruction'",
fix_hint="Add instruction describing what to extract",
)
)
return errors
def _check_knowledge_retrieval(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that knowledge-retrieval has dataset_ids."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
config = node.get("config", {})
dataset_ids = config.get("dataset_ids", [])
if not dataset_ids:
errors.append(
ValidationError(
rule_id="knowledge.dataset.required",
node_id=node_id,
node_type="knowledge-retrieval",
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=False, # User must select knowledge base
message=f"Node '{node_id}': knowledge-retrieval missing 'dataset_ids'",
fix_hint="User must select knowledge bases in the UI",
)
)
else:
# Check for placeholder values
for ds_id in dataset_ids:
if is_placeholder(ds_id):
errors.append(
ValidationError(
rule_id="knowledge.dataset.placeholder",
node_id=node_id,
node_type="knowledge-retrieval",
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=False,
message=f"Node '{node_id}': dataset_ids contains placeholder",
fix_hint="User must replace placeholder with actual knowledge base ID",
details={"placeholder_value": ds_id},
)
)
break
return errors
def _check_end_node(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that end node has outputs defined."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
config = node.get("config", {})
outputs = config.get("outputs", [])
if not outputs:
errors.append(
ValidationError(
rule_id="end.outputs.recommended",
node_id=node_id,
node_type="end",
category=RuleCategory.STRUCTURE,
severity=Severity.WARNING,
is_fixable=True,
message="End node should define output variables",
fix_hint="Add outputs array with variable and value_selector",
)
)
return errors
# =============================================================================
# Semantic Rules - Variable references, edge connections
# =============================================================================
def _check_variable_references(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that variable references point to valid nodes."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
node_type = node.get("type", "unknown")
config = node.get("config", {})
# Get all valid node IDs (including 'start' which is always valid)
valid_node_ids = ctx.get_node_ids()
valid_node_ids.add("start")
valid_node_ids.add("sys") # System variables
def check_text_for_refs(text: str, field_path: str) -> None:
if not isinstance(text, str):
return
refs = extract_variable_refs(text)
for ref_node_id, ref_field in refs:
if ref_node_id not in valid_node_ids:
errors.append(
ValidationError(
rule_id="variable.ref.invalid_node",
node_id=node_id,
node_type=node_type,
category=RuleCategory.SEMANTIC,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}': references non-existent node '{ref_node_id}'",
fix_hint=f"Change {{{{#{ref_node_id}.{ref_field}#}}}} to reference a valid node",
details={"field_path": field_path, "invalid_ref": ref_node_id},
)
)
# Check prompt_template for LLM nodes
prompt_template = config.get("prompt_template", [])
if isinstance(prompt_template, list):
for i, msg in enumerate(prompt_template):
if isinstance(msg, dict):
text = msg.get("text", "")
check_text_for_refs(text, f"prompt_template[{i}].text")
# Check instruction field
instruction = config.get("instruction", "")
check_text_for_refs(instruction, "instruction")
# Check url for http-request
url = config.get("url", "")
check_text_for_refs(url, "url")
return errors
# NOTE: _check_node_has_outgoing_edge removed - handled by GraphValidator
# NOTE: _check_node_has_incoming_edge removed - handled by GraphValidator
# NOTE: _check_question_classifier_branches removed - handled by EdgeRepair
# NOTE: _check_if_else_branches removed - handled by EdgeRepair
def _check_if_else_operators(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that if-else comparison operators are valid."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
node_type = node.get("type", "unknown")
if node_type != "if-else":
return errors
valid_operators = {
"contains",
"not contains",
"start with",
"end with",
"is",
"is not",
"empty",
"not empty",
"in",
"not in",
"all of",
"=",
"",
">",
"<",
"",
"",
"null",
"not null",
"exists",
"not exists",
}
config = node.get("config", {})
cases = config.get("cases", [])
for case in cases:
conditions = case.get("conditions", [])
for condition in conditions:
op = condition.get("comparison_operator")
if op and op not in valid_operators:
errors.append(
ValidationError(
rule_id="ifelse.operator.invalid",
node_id=node_id,
node_type=node_type,
category=RuleCategory.SEMANTIC,
severity=Severity.ERROR,
is_fixable=True,
message=f"Invalid operator '{op}' in if-else node",
fix_hint=f"Use one of: {', '.join(sorted(valid_operators))}",
details={"invalid_operator": op, "field": "config.cases.conditions.comparison_operator"},
)
)
return errors
def _check_edge_targets_exist(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that edge targets reference existing nodes."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
node_type = node.get("type", "unknown")
valid_node_ids = ctx.get_node_ids()
# Check all outgoing edges from this node
for edge in ctx.edges:
if edge.get("source") == node_id:
target = edge.get("target")
if target and target not in valid_node_ids:
errors.append(
ValidationError(
rule_id="edge.target.invalid",
node_id=node_id,
node_type=node_type,
category=RuleCategory.SEMANTIC,
severity=Severity.ERROR,
is_fixable=True,
message=f"Edge from '{node_id}' targets non-existent node '{target}'",
fix_hint=f"Change edge target from '{target}' to an existing node",
details={"invalid_target": target, "field": "edges"},
)
)
return errors
# =============================================================================
# Reference Rules - External resources (models, tools, datasets)
# =============================================================================
# Node types that require model configuration
MODEL_REQUIRED_NODE_TYPES = {"llm", "question-classifier", "parameter-extractor"}
def _check_model_config(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that model configuration is valid."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
node_type = node.get("type", "unknown")
config = node.get("config", {})
if node_type not in MODEL_REQUIRED_NODE_TYPES:
return errors
model = config.get("model")
# Check if model config exists
if not model:
if ctx.available_models:
errors.append(
ValidationError(
rule_id="model.required",
node_id=node_id,
node_type=node_type,
category=RuleCategory.REFERENCE,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}' ({node_type}): missing required 'model' configuration",
fix_hint="Add model config using one of the available models",
)
)
else:
errors.append(
ValidationError(
rule_id="model.no_available",
node_id=node_id,
node_type=node_type,
category=RuleCategory.REFERENCE,
severity=Severity.ERROR,
is_fixable=False,
message=f"Node '{node_id}' ({node_type}): needs model but no models available",
fix_hint="User must configure a model provider first",
)
)
return errors
# Check if model config is valid
if isinstance(model, dict):
provider = model.get("provider", "")
name = model.get("name", "")
# Check for placeholder values
if is_placeholder(provider) or is_placeholder(name):
if ctx.available_models:
errors.append(
ValidationError(
rule_id="model.placeholder",
node_id=node_id,
node_type=node_type,
category=RuleCategory.REFERENCE,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}': model config contains placeholder",
fix_hint="Replace placeholder with actual model from available_models",
)
)
return errors
# Check if model exists in available_models
if ctx.available_models and provider and name:
if not ctx.has_model(provider, name):
errors.append(
ValidationError(
rule_id="model.not_found",
node_id=node_id,
node_type=node_type,
category=RuleCategory.REFERENCE,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}': model '{provider}/{name}' not in available models",
fix_hint="Replace with a model from available_models",
details={"provider": provider, "model": name},
)
)
return errors
def _check_tool_reference(node: WorkflowNodeDict, ctx: "ValidationContext") -> list[ValidationError]:
"""Check that tool references are valid and configured."""
errors: list[ValidationError] = []
node_id = node.get("id", "unknown")
node_type = node.get("type", "unknown")
if node_type != "tool":
return errors
config = node.get("config", {})
tool_ref = (
config.get("tool_key")
or config.get("tool_name")
or config.get("provider_id", "") + "/" + config.get("tool_name", "")
)
if not tool_ref:
errors.append(
ValidationError(
rule_id="tool.key.required",
node_id=node_id,
node_type=node_type,
category=RuleCategory.REFERENCE,
severity=Severity.ERROR,
is_fixable=True,
message=f"Node '{node_id}': tool node missing tool_key",
fix_hint="Add tool_key from available_tools",
)
)
return errors
# Check if tool exists
if not ctx.has_tool(tool_ref):
errors.append(
ValidationError(
rule_id="tool.not_found",
node_id=node_id,
node_type=node_type,
category=RuleCategory.REFERENCE,
severity=Severity.ERROR,
is_fixable=True, # Can be replaced with http-request fallback
message=f"Node '{node_id}': tool '{tool_ref}' not found",
fix_hint="Use http-request or code node as fallback",
details={"tool_ref": tool_ref},
)
)
elif not ctx.is_tool_configured(tool_ref):
errors.append(
ValidationError(
rule_id="tool.not_configured",
node_id=node_id,
node_type=node_type,
category=RuleCategory.REFERENCE,
severity=Severity.WARNING,
is_fixable=False, # User needs to configure
message=f"Node '{node_id}': tool '{tool_ref}' requires configuration",
fix_hint="Configure the tool in Tools settings",
details={"tool_ref": tool_ref},
)
)
return errors
# =============================================================================
# Register All Rules
# =============================================================================
# Structure Rules
register_rule(
ValidationRule(
id="llm.prompt_template.required",
node_types=["llm"],
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
check=_check_llm_prompt_template,
description="LLM node must have prompt_template",
fix_hint="Add prompt_template with system and user messages",
)
)
register_rule(
ValidationRule(
id="http.config.required",
node_types=["http-request"],
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
check=_check_http_request_url,
description="HTTP request node must have url and method",
fix_hint="Add url and method to config",
)
)
register_rule(
ValidationRule(
id="code.config.required",
node_types=["code"],
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
check=_check_code_node,
description="Code node must have code and language",
fix_hint="Add code with main() function and language",
)
)
register_rule(
ValidationRule(
id="classifier.classes.required",
node_types=["question-classifier"],
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
check=_check_question_classifier,
description="Question classifier must have classes",
fix_hint="Add classes array with classification options",
)
)
register_rule(
ValidationRule(
id="extractor.config.required",
node_types=["parameter-extractor"],
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=True,
check=_check_parameter_extractor,
description="Parameter extractor must have parameters",
fix_hint="Add parameters array",
)
)
register_rule(
ValidationRule(
id="knowledge.config.required",
node_types=["knowledge-retrieval"],
category=RuleCategory.STRUCTURE,
severity=Severity.ERROR,
is_fixable=False,
check=_check_knowledge_retrieval,
description="Knowledge retrieval must have dataset_ids",
fix_hint="User must select knowledge base",
)
)
register_rule(
ValidationRule(
id="end.outputs.check",
node_types=["end"],
category=RuleCategory.STRUCTURE,
severity=Severity.WARNING,
is_fixable=True,
check=_check_end_node,
description="End node should have outputs",
fix_hint="Add outputs array",
)
)
# Semantic Rules
register_rule(
ValidationRule(
id="variable.references.valid",
node_types=["*"],
category=RuleCategory.SEMANTIC,
severity=Severity.ERROR,
is_fixable=True,
check=_check_variable_references,
description="Variable references must point to valid nodes",
fix_hint="Fix variable reference to use valid node ID",
)
)
# Edge Validation Rules
# NOTE: Edge connectivity and branch completeness are now handled by:
# - GraphValidator (BFS-based reachability analysis)
# - EdgeRepair (automatic branch edge repair)
register_rule(
ValidationRule(
id="edge.targets.valid",
node_types=["*"],
category=RuleCategory.SEMANTIC,
severity=Severity.ERROR,
is_fixable=True,
check=_check_edge_targets_exist,
description="Edge targets must reference existing nodes",
fix_hint="Change edge target to an existing node ID",
)
)
# Reference Rules
register_rule(
ValidationRule(
id="model.config.valid",
node_types=["llm", "question-classifier", "parameter-extractor"],
category=RuleCategory.REFERENCE,
severity=Severity.ERROR,
is_fixable=True,
check=_check_model_config,
description="Model configuration must be valid",
fix_hint="Add valid model from available_models",
)
)
register_rule(
ValidationRule(
id="tool.reference.valid",
node_types=["tool"],
category=RuleCategory.REFERENCE,
severity=Severity.ERROR,
is_fixable=True,
check=_check_tool_reference,
description="Tool reference must be valid and configured",
fix_hint="Use valid tool or fallback node",
)
)
register_rule(
ValidationRule(
id="ifelse.operator.valid",
node_types=["if-else"],
category=RuleCategory.SEMANTIC,
severity=Severity.ERROR,
is_fixable=True,
check=_check_if_else_operators,
description="If-else operators must be valid",
fix_hint="Use standard operators like ≥, ≤, =, ≠",
)
)

View File

@@ -197,6 +197,14 @@ class Node(Generic[NodeDataT]):
return None
@classmethod
def get_default_config_schema(cls) -> dict[str, Any] | None:
"""
Get the default configuration schema for the node.
Used for LLM generation.
"""
return None
# Global registry populated via __init_subclass__
_registry: ClassVar[dict["NodeType", dict[str, type["Node"]]]] = {}

View File

@@ -1,3 +1,5 @@
from typing import Any
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
@@ -9,6 +11,24 @@ class EndNode(Node[EndNodeData]):
node_type = NodeType.END
execution_type = NodeExecutionType.RESPONSE
@classmethod
def get_default_config_schema(cls) -> dict[str, Any] | None:
return {
"description": "Workflow exit point - defines output variables",
"required": ["outputs"],
"parameters": {
"outputs": {
"type": "array",
"description": "Output variables to return",
"item_schema": {
"variable": "string - output variable name",
"type": "enum: string, number, object, array",
"value_selector": "array - path to source value, e.g. ['node_id', 'field']",
},
},
},
}
@classmethod
def version(cls) -> str:
return "1"

View File

@@ -15,6 +15,27 @@ class StartNode(Node[StartNodeData]):
node_type = NodeType.START
execution_type = NodeExecutionType.ROOT
@classmethod
def get_default_config_schema(cls) -> dict[str, Any] | None:
return {
"description": "Workflow entry point - defines input variables",
"required": [],
"parameters": {
"variables": {
"type": "array",
"description": "Input variables for the workflow",
"item_schema": {
"variable": "string - variable name",
"label": "string - display label",
"type": "enum: text-input, paragraph, number, select, file, file-list",
"required": "boolean",
"max_length": "number (optional)",
},
},
},
"outputs": ["All defined variables are available as {{#start.variable_name#}}"],
}
@classmethod
def version(cls) -> str:
return "1"

View File

@@ -50,6 +50,19 @@ class ToolNode(Node[ToolNodeData]):
def version(cls) -> str:
return "1"
@classmethod
def get_default_config_schema(cls) -> dict[str, Any] | None:
return {
"description": "Execute an external tool",
"required": ["provider_id", "tool_id", "tool_parameters"],
"parameters": {
"provider_id": {"type": "string"},
"provider_type": {"type": "string"},
"tool_id": {"type": "string"},
"tool_parameters": {"type": "object"},
},
}
def _run(self) -> Generator[NodeEventBase, None, None]:
"""
Run the tool node

View File

@@ -107,3 +107,24 @@ def test_host_header_preservation_with_user_header(mock_get_client):
response = make_request("GET", "http://example.com", headers={"Host": custom_host})
assert response.status_code == 200
# Verify build_request was called
mock_client.build_request.assert_called_once()
# Verify the Host header was set on the request object
assert mock_request.headers.get("Host") == custom_host
mock_client.send.assert_called_once_with(mock_request, follow_redirects=True)
@patch("core.helper.ssrf_proxy._get_ssrf_client")
@pytest.mark.parametrize("host_key", ["host", "HOST"])
def test_host_header_preservation_case_insensitive(mock_get_client, host_key):
"""Test that Host header is preserved regardless of case."""
mock_client = MagicMock()
mock_request = MagicMock()
mock_request.headers = {}
mock_response = MagicMock()
mock_response.status_code = 200
mock_client.send.return_value = mock_response
mock_client.build_request.return_value = mock_request
mock_get_client.return_value = mock_client
response = make_request("GET", "http://example.com", headers={host_key: "api.example.com"})
assert mock_request.headers.get("Host") == "api.example.com"

View File

@@ -0,0 +1,287 @@
"""
Unit tests for the Mermaid Generator.
Tests cover:
- Basic workflow rendering
- Reserved word handling ('end''end_node')
- Question classifier multi-branch edges
- If-else branch labels
- Edge validation and skipping
- Tool node formatting
"""
from core.workflow.generator.utils.mermaid_generator import generate_mermaid
class TestBasicWorkflow:
"""Tests for basic workflow Mermaid generation."""
def test_simple_start_end_workflow(self):
"""Test simple Start → End workflow."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "title": "Start"},
{"id": "end", "type": "end", "title": "End"},
],
"edges": [{"source": "start", "target": "end"}],
}
result = generate_mermaid(workflow_data)
assert "flowchart TD" in result
assert 'start["type=start|title=Start"]' in result
assert 'end_node["type=end|title=End"]' in result
assert "start --> end_node" in result
def test_start_llm_end_workflow(self):
"""Test Start → LLM → End workflow."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "title": "Start"},
{"id": "llm", "type": "llm", "title": "Generate"},
{"id": "end", "type": "end", "title": "End"},
],
"edges": [
{"source": "start", "target": "llm"},
{"source": "llm", "target": "end"},
],
}
result = generate_mermaid(workflow_data)
assert 'llm["type=llm|title=Generate"]' in result
assert "start --> llm" in result
assert "llm --> end_node" in result
def test_empty_workflow(self):
"""Test empty workflow returns minimal output."""
workflow_data = {"nodes": [], "edges": []}
result = generate_mermaid(workflow_data)
assert result == "flowchart TD"
def test_missing_keys_handled(self):
"""Test workflow with missing keys doesn't crash."""
workflow_data = {}
result = generate_mermaid(workflow_data)
assert "flowchart TD" in result
class TestReservedWords:
"""Tests for reserved word handling in node IDs."""
def test_end_node_id_is_replaced(self):
"""Test 'end' node ID is replaced with 'end_node'."""
workflow_data = {
"nodes": [{"id": "end", "type": "end", "title": "End"}],
"edges": [],
}
result = generate_mermaid(workflow_data)
# Should use end_node instead of end
assert "end_node[" in result
assert '"type=end|title=End"' in result
def test_subgraph_node_id_is_replaced(self):
"""Test 'subgraph' node ID is replaced with 'subgraph_node'."""
workflow_data = {
"nodes": [{"id": "subgraph", "type": "code", "title": "Process"}],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert "subgraph_node[" in result
def test_edge_uses_safe_ids(self):
"""Test edges correctly reference safe IDs after replacement."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "title": "Start"},
{"id": "end", "type": "end", "title": "End"},
],
"edges": [{"source": "start", "target": "end"}],
}
result = generate_mermaid(workflow_data)
# Edge should use end_node, not end
assert "start --> end_node" in result
assert "start --> end\n" not in result
class TestBranchEdges:
"""Tests for branching node edge labels."""
def test_question_classifier_source_handles(self):
"""Test question-classifier edges with sourceHandle labels."""
workflow_data = {
"nodes": [
{"id": "classifier", "type": "question-classifier", "title": "Classify"},
{"id": "refund", "type": "llm", "title": "Handle Refund"},
{"id": "inquiry", "type": "llm", "title": "Handle Inquiry"},
],
"edges": [
{"source": "classifier", "target": "refund", "sourceHandle": "refund"},
{"source": "classifier", "target": "inquiry", "sourceHandle": "inquiry"},
],
}
result = generate_mermaid(workflow_data)
assert "classifier -->|refund| refund" in result
assert "classifier -->|inquiry| inquiry" in result
def test_if_else_true_false_handles(self):
"""Test if-else edges with true/false labels."""
workflow_data = {
"nodes": [
{"id": "ifelse", "type": "if-else", "title": "Check"},
{"id": "yes_branch", "type": "llm", "title": "Yes"},
{"id": "no_branch", "type": "llm", "title": "No"},
],
"edges": [
{"source": "ifelse", "target": "yes_branch", "sourceHandle": "true"},
{"source": "ifelse", "target": "no_branch", "sourceHandle": "false"},
],
}
result = generate_mermaid(workflow_data)
assert "ifelse -->|true| yes_branch" in result
assert "ifelse -->|false| no_branch" in result
def test_source_handle_source_is_ignored(self):
"""Test sourceHandle='source' doesn't add label."""
workflow_data = {
"nodes": [
{"id": "llm1", "type": "llm", "title": "LLM 1"},
{"id": "llm2", "type": "llm", "title": "LLM 2"},
],
"edges": [{"source": "llm1", "target": "llm2", "sourceHandle": "source"}],
}
result = generate_mermaid(workflow_data)
# Should be plain arrow without label
assert "llm1 --> llm2" in result
assert "llm1 -->|source|" not in result
class TestEdgeValidation:
"""Tests for edge validation and error handling."""
def test_edge_with_missing_source_is_skipped(self):
"""Test edge with non-existent source node is skipped."""
workflow_data = {
"nodes": [{"id": "end", "type": "end", "title": "End"}],
"edges": [{"source": "nonexistent", "target": "end"}],
}
result = generate_mermaid(workflow_data)
# Should not contain the invalid edge
assert "nonexistent" not in result
assert "-->" not in result or "nonexistent" not in result
def test_edge_with_missing_target_is_skipped(self):
"""Test edge with non-existent target node is skipped."""
workflow_data = {
"nodes": [{"id": "start", "type": "start", "title": "Start"}],
"edges": [{"source": "start", "target": "nonexistent"}],
}
result = generate_mermaid(workflow_data)
# Edge should be skipped
assert "start --> nonexistent" not in result
def test_edge_without_source_or_target_is_skipped(self):
"""Test edge missing source or target is skipped."""
workflow_data = {
"nodes": [{"id": "start", "type": "start", "title": "Start"}],
"edges": [{"source": "start"}, {"target": "start"}, {}],
}
result = generate_mermaid(workflow_data)
# No edges should be rendered
assert result.count("-->") == 0
class TestToolNodes:
"""Tests for tool node formatting."""
def test_tool_node_includes_tool_key(self):
"""Test tool node includes tool_key in label."""
workflow_data = {
"nodes": [
{
"id": "search",
"type": "tool",
"title": "Search",
"config": {"tool_key": "google/search"},
}
],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert 'search["type=tool|title=Search|tool=google/search"]' in result
def test_tool_node_with_tool_name_fallback(self):
"""Test tool node uses tool_name as fallback."""
workflow_data = {
"nodes": [
{
"id": "tool1",
"type": "tool",
"title": "My Tool",
"config": {"tool_name": "my_tool"},
}
],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert "tool=my_tool" in result
def test_tool_node_missing_tool_key_shows_unknown(self):
"""Test tool node without tool_key shows 'unknown'."""
workflow_data = {
"nodes": [{"id": "tool1", "type": "tool", "title": "Tool", "config": {}}],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert "tool=unknown" in result
class TestNodeFormatting:
"""Tests for node label formatting."""
def test_quotes_in_title_are_escaped(self):
"""Test double quotes in title are replaced with single quotes."""
workflow_data = {
"nodes": [{"id": "llm", "type": "llm", "title": 'Say "Hello"'}],
"edges": [],
}
result = generate_mermaid(workflow_data)
# Double quotes should be replaced
assert "Say 'Hello'" in result
assert 'Say "Hello"' not in result
def test_node_without_id_is_skipped(self):
"""Test node without id is skipped."""
workflow_data = {
"nodes": [{"type": "llm", "title": "No ID"}],
"edges": [],
}
result = generate_mermaid(workflow_data)
# Should only have flowchart header
lines = [line for line in result.split("\n") if line.strip()]
assert len(lines) == 1
def test_node_default_values(self):
"""Test node with missing type/title uses defaults."""
workflow_data = {
"nodes": [{"id": "node1"}],
"edges": [],
}
result = generate_mermaid(workflow_data)
assert "type=unknown" in result
assert "title=Untitled" in result

View File

@@ -0,0 +1,81 @@
from core.workflow.generator.utils.node_repair import NodeRepair
class TestNodeRepair:
"""Tests for NodeRepair utility."""
def test_repair_if_else_valid_operators(self):
"""Test that valid operators remain unchanged."""
nodes = [
{
"id": "node1",
"type": "if-else",
"config": {
"cases": [
{
"conditions": [
{"comparison_operator": "", "value": "1"},
{"comparison_operator": "=", "value": "2"},
]
}
]
},
}
]
result = NodeRepair.repair(nodes)
assert result.was_repaired is False
assert result.nodes == nodes
def test_repair_if_else_invalid_operators(self):
"""Test that invalid operators are normalized."""
nodes = [
{
"id": "node1",
"type": "if-else",
"config": {
"cases": [
{
"conditions": [
{"comparison_operator": ">=", "value": "1"},
{"comparison_operator": "<=", "value": "2"},
{"comparison_operator": "!=", "value": "3"},
{"comparison_operator": "==", "value": "4"},
]
}
]
},
}
]
result = NodeRepair.repair(nodes)
assert result.was_repaired is True
assert len(result.repairs_made) == 4
conditions = result.nodes[0]["config"]["cases"][0]["conditions"]
assert conditions[0]["comparison_operator"] == ""
assert conditions[1]["comparison_operator"] == ""
assert conditions[2]["comparison_operator"] == ""
assert conditions[3]["comparison_operator"] == "="
def test_repair_ignores_other_nodes(self):
"""Test that other node types are ignored."""
nodes = [{"id": "node1", "type": "llm", "config": {"some_field": ">="}}]
result = NodeRepair.repair(nodes)
assert result.was_repaired is False
assert result.nodes[0]["config"]["some_field"] == ">="
def test_repair_handles_missing_config(self):
"""Test robustness against missing fields."""
nodes = [
{
"id": "node1",
"type": "if-else",
# Missing config
},
{
"id": "node2",
"type": "if-else",
"config": {}, # Missing cases
},
]
result = NodeRepair.repair(nodes)
assert result.was_repaired is False

View File

@@ -0,0 +1,99 @@
"""
Tests for node schemas validation.
Ensures that the node configuration stays in sync with registered node types.
"""
from core.workflow.generator.config.node_schemas import (
get_builtin_node_schemas,
validate_node_schemas,
)
class TestNodeSchemasValidation:
"""Tests for node schema validation utilities."""
def test_validate_node_schemas_returns_no_warnings(self):
"""Ensure all registered node types have corresponding schemas."""
warnings = validate_node_schemas()
# If this test fails, it means a new node type was added but
# no schema was defined for it in node_schemas.py
assert len(warnings) == 0, (
f"Missing schemas for node types: {warnings}. "
"Please add schemas for these node types in node_schemas.py "
"or add them to _INTERNAL_NODE_TYPES if they don't need schemas."
)
def test_builtin_node_schemas_not_empty(self):
"""Ensure BUILTIN_NODE_SCHEMAS contains expected node types."""
# get_builtin_node_schemas() includes dynamic schemas
all_schemas = get_builtin_node_schemas()
assert len(all_schemas) > 0
# Core node types should always be present
expected_types = ["llm", "code", "http-request", "if-else"]
for node_type in expected_types:
assert node_type in all_schemas, f"Missing schema for core node type: {node_type}"
def test_schema_structure(self):
"""Ensure each schema has required fields."""
all_schemas = get_builtin_node_schemas()
for node_type, schema in all_schemas.items():
assert "description" in schema, f"Missing 'description' in schema for {node_type}"
# 'parameters' is optional but if present should be a dict
if "parameters" in schema:
assert isinstance(schema["parameters"], dict), (
f"'parameters' in schema for {node_type} should be a dict"
)
class TestNodeSchemasMerged:
"""Tests to verify the merged configuration works correctly."""
def test_fallback_rules_available(self):
"""Ensure FALLBACK_RULES is available from node_schemas."""
from core.workflow.generator.config.node_schemas import FALLBACK_RULES
assert len(FALLBACK_RULES) > 0
assert "http-request" in FALLBACK_RULES
assert "code" in FALLBACK_RULES
assert "llm" in FALLBACK_RULES
def test_node_type_aliases_available(self):
"""Ensure NODE_TYPE_ALIASES is available from node_schemas."""
from core.workflow.generator.config.node_schemas import NODE_TYPE_ALIASES
assert len(NODE_TYPE_ALIASES) > 0
assert NODE_TYPE_ALIASES.get("gpt") == "llm"
assert NODE_TYPE_ALIASES.get("api") == "http-request"
def test_field_name_corrections_available(self):
"""Ensure FIELD_NAME_CORRECTIONS is available from node_schemas."""
from core.workflow.generator.config.node_schemas import (
FIELD_NAME_CORRECTIONS,
get_corrected_field_name,
)
assert len(FIELD_NAME_CORRECTIONS) > 0
# Test the helper function
assert get_corrected_field_name("http-request", "text") == "body"
assert get_corrected_field_name("llm", "response") == "text"
assert get_corrected_field_name("code", "unknown") == "unknown"
def test_config_init_exports(self):
"""Ensure config __init__.py exports all needed symbols."""
from core.workflow.generator.config import (
BUILTIN_NODE_SCHEMAS,
FALLBACK_RULES,
FIELD_NAME_CORRECTIONS,
NODE_TYPE_ALIASES,
get_corrected_field_name,
validate_node_schemas,
)
# Just verify imports work
assert BUILTIN_NODE_SCHEMAS is not None
assert FALLBACK_RULES is not None
assert FIELD_NAME_CORRECTIONS is not None
assert NODE_TYPE_ALIASES is not None
assert callable(get_corrected_field_name)
assert callable(validate_node_schemas)

View File

@@ -0,0 +1,172 @@
"""
Unit tests for the Planner Prompts.
Tests cover:
- Tool formatting for planner context
- Edge cases with missing fields
- Empty tool lists
"""
from core.workflow.generator.prompts.planner_prompts import format_tools_for_planner
class TestFormatToolsForPlanner:
"""Tests for format_tools_for_planner function."""
def test_empty_tools_returns_default_message(self):
"""Test empty tools list returns default message."""
result = format_tools_for_planner([])
assert result == "No external tools available."
def test_none_tools_returns_default_message(self):
"""Test None tools list returns default message."""
result = format_tools_for_planner(None)
assert result == "No external tools available."
def test_single_tool_formatting(self):
"""Test single tool is formatted correctly."""
tools = [
{
"provider_id": "google",
"tool_key": "search",
"tool_label": "Google Search",
"tool_description": "Search the web using Google",
}
]
result = format_tools_for_planner(tools)
assert "[google/search]" in result
assert "Google Search" in result
assert "Search the web using Google" in result
def test_multiple_tools_formatting(self):
"""Test multiple tools are formatted correctly."""
tools = [
{
"provider_id": "google",
"tool_key": "search",
"tool_label": "Search",
"tool_description": "Web search",
},
{
"provider_id": "slack",
"tool_key": "send_message",
"tool_label": "Send Message",
"tool_description": "Send a Slack message",
},
]
result = format_tools_for_planner(tools)
lines = result.strip().split("\n")
assert len(lines) == 2
assert "[google/search]" in result
assert "[slack/send_message]" in result
def test_tool_without_provider_uses_key_only(self):
"""Test tool without provider_id uses tool_key only."""
tools = [
{
"tool_key": "my_tool",
"tool_label": "My Tool",
"tool_description": "A custom tool",
}
]
result = format_tools_for_planner(tools)
# Should format as [my_tool] without provider prefix
assert "[my_tool]" in result
assert "My Tool" in result
def test_tool_with_tool_name_fallback(self):
"""Test tool uses tool_name when tool_key is missing."""
tools = [
{
"tool_name": "fallback_tool",
"description": "Fallback description",
}
]
result = format_tools_for_planner(tools)
assert "fallback_tool" in result
assert "Fallback description" in result
def test_tool_with_missing_description(self):
"""Test tool with missing description doesn't crash."""
tools = [
{
"provider_id": "test",
"tool_key": "tool1",
"tool_label": "Tool 1",
}
]
result = format_tools_for_planner(tools)
assert "[test/tool1]" in result
assert "Tool 1" in result
def test_tool_with_all_missing_fields(self):
"""Test tool with all fields missing uses defaults."""
tools = [{}]
result = format_tools_for_planner(tools)
# Should not crash, may produce minimal output
assert isinstance(result, str)
def test_tool_uses_provider_fallback(self):
"""Test tool uses 'provider' when 'provider_id' is missing."""
tools = [
{
"provider": "openai",
"tool_key": "dalle",
"tool_label": "DALL-E",
"tool_description": "Generate images",
}
]
result = format_tools_for_planner(tools)
assert "[openai/dalle]" in result
def test_tool_label_fallback_to_key(self):
"""Test tool_label falls back to tool_key when missing."""
tools = [
{
"provider_id": "test",
"tool_key": "my_key",
"tool_description": "Description here",
}
]
result = format_tools_for_planner(tools)
# Label should fallback to key
assert "my_key" in result
assert "Description here" in result
class TestPlannerPromptConstants:
"""Tests for planner prompt constant availability."""
def test_planner_system_prompt_exists(self):
"""Test PLANNER_SYSTEM_PROMPT is defined."""
from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT
assert PLANNER_SYSTEM_PROMPT is not None
assert len(PLANNER_SYSTEM_PROMPT) > 0
assert "{tools_summary}" in PLANNER_SYSTEM_PROMPT
def test_planner_user_prompt_exists(self):
"""Test PLANNER_USER_PROMPT is defined."""
from core.workflow.generator.prompts.planner_prompts import PLANNER_USER_PROMPT
assert PLANNER_USER_PROMPT is not None
assert "{instruction}" in PLANNER_USER_PROMPT
def test_planner_system_prompt_has_required_sections(self):
"""Test PLANNER_SYSTEM_PROMPT has required XML sections."""
from core.workflow.generator.prompts.planner_prompts import PLANNER_SYSTEM_PROMPT
assert "<role>" in PLANNER_SYSTEM_PROMPT
assert "<task>" in PLANNER_SYSTEM_PROMPT
assert "<available_tools>" in PLANNER_SYSTEM_PROMPT
assert "<response_format>" in PLANNER_SYSTEM_PROMPT

View File

@@ -0,0 +1,510 @@
"""
Unit tests for the Validation Rule Engine.
Tests cover:
- Structure rules (required fields, types, formats)
- Semantic rules (variable references, edge connections)
- Reference rules (model exists, tool configured, dataset valid)
- ValidationEngine integration
"""
from core.workflow.generator.validation import (
ValidationContext,
ValidationEngine,
)
from core.workflow.generator.validation.rules import (
extract_variable_refs,
is_placeholder,
)
class TestPlaceholderDetection:
"""Tests for placeholder detection utility."""
def test_detects_please_select(self):
assert is_placeholder("PLEASE_SELECT_YOUR_MODEL") is True
def test_detects_your_prefix(self):
assert is_placeholder("YOUR_API_KEY") is True
def test_detects_todo(self):
assert is_placeholder("TODO: fill this in") is True
def test_detects_placeholder(self):
assert is_placeholder("PLACEHOLDER_VALUE") is True
def test_detects_example_prefix(self):
assert is_placeholder("EXAMPLE_URL") is True
def test_detects_replace_prefix(self):
assert is_placeholder("REPLACE_WITH_ACTUAL") is True
def test_case_insensitive(self):
assert is_placeholder("please_select") is True
assert is_placeholder("Please_Select") is True
def test_valid_values_not_detected(self):
assert is_placeholder("https://api.example.com") is False
assert is_placeholder("gpt-4") is False
assert is_placeholder("my_variable") is False
def test_non_string_returns_false(self):
assert is_placeholder(123) is False
assert is_placeholder(None) is False
assert is_placeholder(["list"]) is False
class TestVariableRefExtraction:
"""Tests for variable reference extraction."""
def test_extracts_simple_ref(self):
refs = extract_variable_refs("Hello {{#start.query#}}")
assert refs == [("start", "query")]
def test_extracts_multiple_refs(self):
refs = extract_variable_refs("{{#node1.output#}} and {{#node2.text#}}")
assert refs == [("node1", "output"), ("node2", "text")]
def test_extracts_nested_field(self):
refs = extract_variable_refs("{{#http_request.body#}}")
assert refs == [("http_request", "body")]
def test_no_refs_returns_empty(self):
refs = extract_variable_refs("No references here")
assert refs == []
def test_handles_malformed_refs(self):
refs = extract_variable_refs("{{#invalid}} and {{incomplete#}}")
assert refs == []
class TestValidationContext:
"""Tests for ValidationContext."""
def test_node_map_lookup(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start"},
{"id": "llm_1", "type": "llm"},
]
)
assert ctx.get_node("start") == {"id": "start", "type": "start"}
assert ctx.get_node("nonexistent") is None
def test_model_set(self):
ctx = ValidationContext(
available_models=[
{"provider": "openai", "model": "gpt-4"},
{"provider": "anthropic", "model": "claude-3"},
]
)
assert ctx.has_model("openai", "gpt-4") is True
assert ctx.has_model("anthropic", "claude-3") is True
assert ctx.has_model("openai", "gpt-3.5") is False
def test_tool_set(self):
ctx = ValidationContext(
available_tools=[
{"provider_id": "google", "tool_key": "search", "is_team_authorization": True},
{"provider_id": "slack", "tool_key": "send_message", "is_team_authorization": False},
]
)
assert ctx.has_tool("google/search") is True
assert ctx.has_tool("search") is True
assert ctx.is_tool_configured("google/search") is True
assert ctx.is_tool_configured("slack/send_message") is False
def test_upstream_downstream_nodes(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start"},
{"id": "llm", "type": "llm"},
{"id": "end", "type": "end"},
],
edges=[
{"source": "start", "target": "llm"},
{"source": "llm", "target": "end"},
],
)
assert ctx.get_upstream_nodes("llm") == ["start"]
assert ctx.get_downstream_nodes("llm") == ["end"]
class TestStructureRules:
"""Tests for structure validation rules."""
def test_llm_missing_prompt_template(self):
ctx = ValidationContext(nodes=[{"id": "llm_1", "type": "llm", "config": {}}])
engine = ValidationEngine()
result = engine.validate(ctx)
assert result.has_errors
errors = [e for e in result.all_errors if e.rule_id == "llm.prompt_template.required"]
assert len(errors) == 1
assert errors[0].is_fixable is True
def test_llm_with_prompt_template_passes(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {
"prompt_template": [
{"role": "system", "text": "You are helpful"},
{"role": "user", "text": "Hello"},
]
},
}
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
# No prompt_template errors
errors = [e for e in result.all_errors if "prompt_template" in e.rule_id]
assert len(errors) == 0
def test_http_request_missing_url(self):
ctx = ValidationContext(nodes=[{"id": "http_1", "type": "http-request", "config": {}}])
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "http.url" in e.rule_id]
assert len(errors) == 1
assert errors[0].is_fixable is True
def test_http_request_placeholder_url(self):
ctx = ValidationContext(
nodes=[
{
"id": "http_1",
"type": "http-request",
"config": {"url": "PLEASE_SELECT_YOUR_URL", "method": "GET"},
}
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "placeholder" in e.rule_id]
assert len(errors) == 1
def test_code_node_missing_fields(self):
ctx = ValidationContext(nodes=[{"id": "code_1", "type": "code", "config": {}}])
engine = ValidationEngine()
result = engine.validate(ctx)
error_rules = {e.rule_id for e in result.all_errors}
assert "code.code.required" in error_rules
assert "code.language.required" in error_rules
def test_knowledge_retrieval_missing_dataset(self):
ctx = ValidationContext(nodes=[{"id": "kb_1", "type": "knowledge-retrieval", "config": {}}])
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "knowledge.dataset" in e.rule_id]
assert len(errors) == 1
assert errors[0].is_fixable is False # User must configure
class TestSemanticRules:
"""Tests for semantic validation rules."""
def test_valid_variable_reference(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start", "config": {}},
{
"id": "llm_1",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Process: {{#start.query#}}"}]},
},
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
# No variable reference errors
errors = [e for e in result.all_errors if "variable.ref" in e.rule_id]
assert len(errors) == 0
def test_invalid_variable_reference(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start", "config": {}},
{
"id": "llm_1",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Process: {{#nonexistent.field#}}"}]},
},
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "variable.ref" in e.rule_id]
assert len(errors) == 1
assert "nonexistent" in errors[0].message
def test_edge_validation(self):
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start", "config": {}},
{"id": "end", "type": "end", "config": {}},
],
edges=[
{"source": "start", "target": "end"},
{"source": "nonexistent", "target": "end"},
],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "edge" in e.rule_id]
assert len(errors) == 1
assert "nonexistent" in errors[0].message
class TestReferenceRules:
"""Tests for reference validation rules (models, tools)."""
def test_llm_missing_model_with_available(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
}
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "model.required"]
assert len(errors) == 1
assert errors[0].is_fixable is True
def test_llm_missing_model_no_available(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
}
],
available_models=[], # No models available
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "model.no_available"]
assert len(errors) == 1
assert errors[0].is_fixable is False
def test_llm_with_valid_model(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {
"prompt_template": [{"role": "user", "text": "Hi"}],
"model": {"provider": "openai", "name": "gpt-4"},
},
}
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if "model" in e.rule_id]
assert len(errors) == 0
def test_llm_with_invalid_model(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {
"prompt_template": [{"role": "user", "text": "Hi"}],
"model": {"provider": "openai", "name": "gpt-99"},
},
}
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "model.not_found"]
assert len(errors) == 1
assert errors[0].is_fixable is True
def test_tool_node_not_found(self):
ctx = ValidationContext(
nodes=[
{
"id": "tool_1",
"type": "tool",
"config": {"tool_key": "nonexistent/tool"},
}
],
available_tools=[],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "tool.not_found"]
assert len(errors) == 1
def test_tool_node_not_configured(self):
ctx = ValidationContext(
nodes=[
{
"id": "tool_1",
"type": "tool",
"config": {"tool_key": "google/search"},
}
],
available_tools=[{"provider_id": "google", "tool_key": "search", "is_team_authorization": False}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
errors = [e for e in result.all_errors if e.rule_id == "tool.not_configured"]
assert len(errors) == 1
assert errors[0].is_fixable is False
class TestValidationResult:
"""Tests for ValidationResult classification."""
def test_has_errors(self):
ctx = ValidationContext(nodes=[{"id": "llm_1", "type": "llm", "config": {}}])
engine = ValidationEngine()
result = engine.validate(ctx)
assert result.has_errors is True
assert result.is_valid is False
def test_has_fixable_errors(self):
ctx = ValidationContext(
nodes=[
{
"id": "llm_1",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Hi"}]},
}
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
assert result.has_fixable_errors is True
assert len(result.fixable_errors) > 0
def test_get_fixable_by_node(self):
ctx = ValidationContext(
nodes=[
{"id": "llm_1", "type": "llm", "config": {}},
{"id": "http_1", "type": "http-request", "config": {}},
]
)
engine = ValidationEngine()
result = engine.validate(ctx)
by_node = result.get_fixable_by_node()
assert "llm_1" in by_node
assert "http_1" in by_node
def test_to_dict(self):
ctx = ValidationContext(nodes=[{"id": "llm_1", "type": "llm", "config": {}}])
engine = ValidationEngine()
result = engine.validate(ctx)
d = result.to_dict()
assert "fixable" in d
assert "user_required" in d
assert "warnings" in d
assert "all_warnings" in d
assert "stats" in d
class TestIntegration:
"""Integration tests for the full validation pipeline."""
def test_complete_workflow_validation(self):
"""Test validation of a complete workflow."""
ctx = ValidationContext(
nodes=[
{
"id": "start",
"type": "start",
"config": {"variables": [{"variable": "query", "type": "text-input"}]},
},
{
"id": "llm_1",
"type": "llm",
"config": {
"model": {"provider": "openai", "name": "gpt-4"},
"prompt_template": [{"role": "user", "text": "{{#start.query#}}"}],
},
},
{
"id": "end",
"type": "end",
"config": {"outputs": [{"variable": "result", "value_selector": ["llm_1", "text"]}]},
},
],
edges=[
{"source": "start", "target": "llm_1"},
{"source": "llm_1", "target": "end"},
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
# Should have no errors
assert result.is_valid is True
assert len(result.fixable_errors) == 0
assert len(result.user_required_errors) == 0
def test_workflow_with_multiple_errors(self):
"""Test workflow with multiple types of errors."""
ctx = ValidationContext(
nodes=[
{"id": "start", "type": "start", "config": {}},
{
"id": "llm_1",
"type": "llm",
"config": {}, # Missing prompt_template and model
},
{
"id": "kb_1",
"type": "knowledge-retrieval",
"config": {"dataset_ids": ["PLEASE_SELECT_YOUR_DATASET"]},
},
{"id": "end", "type": "end", "config": {}},
],
available_models=[{"provider": "openai", "model": "gpt-4"}],
)
engine = ValidationEngine()
result = engine.validate(ctx)
# Should have multiple errors
assert result.has_errors is True
assert len(result.fixable_errors) >= 2 # model, prompt_template
assert len(result.user_required_errors) >= 1 # dataset placeholder
# Check stats
assert result.stats["total_nodes"] == 4
assert result.stats["total_errors"] >= 3

View File

@@ -0,0 +1,434 @@
"""
Unit tests for the Vibe Workflow Validator.
Tests cover:
- Basic validation function
- User-friendly validation hints
- Edge cases and error handling
"""
from core.workflow.generator.utils.workflow_validator import ValidationHint, WorkflowValidator
class TestValidationHint:
"""Tests for ValidationHint dataclass."""
def test_hint_creation(self):
"""Test creating a validation hint."""
hint = ValidationHint(
node_id="llm_1",
field="model",
message="Model is not configured",
severity="error",
)
assert hint.node_id == "llm_1"
assert hint.field == "model"
assert hint.message == "Model is not configured"
assert hint.severity == "error"
def test_hint_with_suggestion(self):
"""Test hint with suggestion."""
hint = ValidationHint(
node_id="http_1",
field="url",
message="URL is required",
severity="error",
suggestion="Add a valid URL like https://api.example.com",
)
assert hint.suggestion is not None
class TestWorkflowValidatorBasic:
"""Tests for basic validation scenarios."""
def test_empty_workflow_is_valid(self):
"""Test empty workflow passes validation."""
workflow_data = {"nodes": [], "edges": []}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
# Empty but valid structure
assert is_valid is True
assert len(hints) == 0
def test_minimal_valid_workflow(self):
"""Test minimal Start → End workflow."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {}},
{"id": "end", "type": "end", "config": {}},
],
"edges": [{"source": "start", "target": "end"}],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
assert is_valid is True
def test_complete_workflow_with_llm(self):
"""Test complete workflow with LLM node."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {"variables": []}},
{
"id": "llm",
"type": "llm",
"config": {
"model": {"provider": "openai", "name": "gpt-4"},
"prompt_template": [{"role": "user", "text": "Hello"}],
},
},
{"id": "end", "type": "end", "config": {"outputs": []}},
],
"edges": [
{"source": "start", "target": "llm"},
{"source": "llm", "target": "end"},
],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
# Should pass with no critical errors
errors = [h for h in hints if h.severity == "error"]
assert len(errors) == 0
class TestVariableReferenceValidation:
"""Tests for variable reference validation."""
def test_valid_variable_reference(self):
"""Test valid variable reference passes."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {}},
{
"id": "llm",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "Query: {{#start.query#}}"}]},
},
],
"edges": [{"source": "start", "target": "llm"}],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
ref_errors = [h for h in hints if "reference" in h.message.lower()]
assert len(ref_errors) == 0
def test_invalid_variable_reference(self):
"""Test invalid variable reference generates hint."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {}},
{
"id": "llm",
"type": "llm",
"config": {"prompt_template": [{"role": "user", "text": "{{#nonexistent.field#}}"}]},
},
],
"edges": [{"source": "start", "target": "llm"}],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
# Should have a hint about invalid reference
ref_hints = [h for h in hints if "nonexistent" in h.message or "reference" in h.message.lower()]
assert len(ref_hints) >= 1
class TestEdgeValidation:
"""Tests for edge validation."""
def test_edge_with_invalid_source(self):
"""Test edge with non-existent source generates hint."""
workflow_data = {
"nodes": [{"id": "end", "type": "end", "config": {}}],
"edges": [{"source": "nonexistent", "target": "end"}],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
# Should have hint about invalid edge
edge_hints = [h for h in hints if "edge" in h.message.lower() or "source" in h.message.lower()]
assert len(edge_hints) >= 1
def test_edge_with_invalid_target(self):
"""Test edge with non-existent target generates hint."""
workflow_data = {
"nodes": [{"id": "start", "type": "start", "config": {}}],
"edges": [{"source": "start", "target": "nonexistent"}],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
edge_hints = [h for h in hints if "edge" in h.message.lower() or "target" in h.message.lower()]
assert len(edge_hints) >= 1
class TestToolValidation:
"""Tests for tool node validation."""
def test_tool_node_found_in_available(self):
"""Test tool node that exists in available tools."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {}},
{
"id": "tool1",
"type": "tool",
"config": {"tool_key": "google/search"},
},
{"id": "end", "type": "end", "config": {}},
],
"edges": [{"source": "start", "target": "tool1"}, {"source": "tool1", "target": "end"}],
}
available_tools = [{"provider_id": "google", "tool_key": "search", "is_team_authorization": True}]
is_valid, hints = WorkflowValidator.validate(workflow_data, available_tools)
tool_errors = [h for h in hints if h.severity == "error" and "tool" in h.message.lower()]
assert len(tool_errors) == 0
def test_tool_node_not_found(self):
"""Test tool node not in available tools generates hint."""
workflow_data = {
"nodes": [
{
"id": "tool1",
"type": "tool",
"config": {"tool_key": "unknown/tool"},
}
],
"edges": [],
}
available_tools = []
is_valid, hints = WorkflowValidator.validate(workflow_data, available_tools)
tool_hints = [h for h in hints if "tool" in h.message.lower()]
assert len(tool_hints) >= 1
class TestQuestionClassifierValidation:
"""Tests for question-classifier node validation."""
def test_question_classifier_with_classes(self):
"""Test question-classifier with valid classes."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {}},
{
"id": "classifier",
"type": "question-classifier",
"config": {
"classes": [
{"id": "class1", "name": "Class 1"},
{"id": "class2", "name": "Class 2"},
],
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat"},
},
},
{"id": "h1", "type": "llm", "config": {}},
{"id": "h2", "type": "llm", "config": {}},
{"id": "end", "type": "end", "config": {}},
],
"edges": [
{"source": "start", "target": "classifier"},
{"source": "classifier", "sourceHandle": "class1", "target": "h1"},
{"source": "classifier", "sourceHandle": "class2", "target": "h2"},
{"source": "h1", "target": "end"},
{"source": "h2", "target": "end"},
],
}
available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}]
is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models)
class_errors = [h for h in hints if "class" in h.message.lower() and h.severity == "error"]
assert len(class_errors) == 0
def test_question_classifier_missing_classes(self):
"""Test question-classifier without classes generates hint."""
workflow_data = {
"nodes": [
{
"id": "classifier",
"type": "question-classifier",
"config": {"model": {"provider": "openai", "name": "gpt-4", "mode": "chat"}},
}
],
"edges": [],
}
available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}]
is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models)
# Should have hint about missing classes
class_hints = [h for h in hints if "class" in h.message.lower()]
assert len(class_hints) >= 1
class TestHttpRequestValidation:
"""Tests for HTTP request node validation."""
def test_http_request_with_url(self):
"""Test HTTP request with valid URL."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {}},
{
"id": "http",
"type": "http-request",
"config": {"url": "https://api.example.com", "method": "GET"},
},
{"id": "end", "type": "end", "config": {}},
],
"edges": [{"source": "start", "target": "http"}, {"source": "http", "target": "end"}],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
url_errors = [h for h in hints if "url" in h.message.lower() and h.severity == "error"]
assert len(url_errors) == 0
def test_http_request_missing_url(self):
"""Test HTTP request without URL generates hint."""
workflow_data = {
"nodes": [
{
"id": "http",
"type": "http-request",
"config": {"method": "GET"},
}
],
"edges": [],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
url_hints = [h for h in hints if "url" in h.message.lower()]
assert len(url_hints) >= 1
class TestParameterExtractorValidation:
"""Tests for parameter-extractor node validation."""
def test_parameter_extractor_valid_params(self):
"""Test parameter-extractor with valid parameters."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {}},
{
"id": "extractor",
"type": "parameter-extractor",
"config": {
"instruction": "Extract info",
"parameters": [
{
"name": "name",
"type": "string",
"description": "Name",
"required": True,
}
],
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat"},
},
},
{"id": "end", "type": "end", "config": {}},
],
"edges": [{"source": "start", "target": "extractor"}, {"source": "extractor", "target": "end"}],
}
available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}]
is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models)
errors = [h for h in hints if h.severity == "error"]
assert len(errors) == 0
def test_parameter_extractor_missing_required_field(self):
"""Test parameter-extractor missing 'required' field in parameter item."""
workflow_data = {
"nodes": [
{
"id": "extractor",
"type": "parameter-extractor",
"config": {
"instruction": "Extract info",
"parameters": [
{
"name": "name",
"type": "string",
"description": "Name",
# Missing 'required'
}
],
"model": {"provider": "openai", "name": "gpt-4", "mode": "chat"},
},
}
],
"edges": [],
}
available_models = [{"provider": "openai", "model": "gpt-4", "mode": "chat"}]
is_valid, hints = WorkflowValidator.validate(workflow_data, [], available_models=available_models)
errors = [h for h in hints if "required" in h.message and h.severity == "error"]
assert len(errors) >= 1
assert "parameter-extractor" in errors[0].node_type
class TestIfElseValidation:
"""Tests for if-else node validation."""
def test_if_else_valid_operators(self):
"""Test if-else with valid operators."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {}},
{
"id": "ifelse",
"type": "if-else",
"config": {
"cases": [{"case_id": "c1", "conditions": [{"comparison_operator": "", "value": "1"}]}]
},
},
{"id": "t", "type": "llm", "config": {}},
{"id": "f", "type": "llm", "config": {}},
{"id": "end", "type": "end", "config": {}},
],
"edges": [
{"source": "start", "target": "ifelse"},
{"source": "ifelse", "sourceHandle": "true", "target": "t"},
{"source": "ifelse", "sourceHandle": "false", "target": "f"},
{"source": "t", "target": "end"},
{"source": "f", "target": "end"},
],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
errors = [h for h in hints if h.severity == "error"]
# Filter out LLM model errors if any (available tools/models check might trigger)
# (actually available_models empty list might trigger model error?
# No, model config validation skips if model field not present? No, LLM has model config.
# But logic skips check if key missing? Let's check logic.
# _check_model_config checks if provider/name match available. If available is empty, it fails.
# But wait, validate default available_models is None?
# I should provide mock available_models or ignore model errors.
# Actually LLM node "config": {} implies missing model config. Rules check if config structure is valid?
# Let's filter specifically for operator errors.
operator_errors = [h for h in errors if "operator" in h.message]
assert len(operator_errors) == 0
def test_if_else_invalid_operators(self):
"""Test if-else with invalid operators."""
workflow_data = {
"nodes": [
{"id": "start", "type": "start", "config": {}},
{
"id": "ifelse",
"type": "if-else",
"config": {
"cases": [{"case_id": "c1", "conditions": [{"comparison_operator": ">=", "value": "1"}]}]
},
},
{"id": "t", "type": "llm", "config": {}},
{"id": "f", "type": "llm", "config": {}},
{"id": "end", "type": "end", "config": {}},
],
"edges": [
{"source": "start", "target": "ifelse"},
{"source": "ifelse", "sourceHandle": "true", "target": "t"},
{"source": "ifelse", "sourceHandle": "false", "target": "f"},
{"source": "t", "target": "end"},
{"source": "f", "target": "end"},
],
}
is_valid, hints = WorkflowValidator.validate(workflow_data, [])
operator_errors = [h for h in hints if "operator" in h.message and h.severity == "error"]
assert len(operator_errors) > 0
assert "" in operator_errors[0].suggestion

View File

@@ -1,326 +0,0 @@
import type { ActionItem } from '../../app/components/goto-anything/actions/types'
import { fireEvent, render, screen } from '@testing-library/react'
import * as React from 'react'
import CommandSelector from '../../app/components/goto-anything/command-selector'
vi.mock('cmdk', () => ({
Command: {
Group: ({ children, className }: any) => <div className={className}>{children}</div>,
Item: ({ children, onSelect, value, className }: any) => (
<div
className={className}
onClick={() => onSelect?.()}
data-value={value}
data-testid={`command-item-${value}`}
>
{children}
</div>
),
},
}))
describe('CommandSelector', () => {
const mockActions: Record<string, ActionItem> = {
app: {
key: '@app',
shortcut: '@app',
title: 'Search Applications',
description: 'Search apps',
search: vi.fn(),
},
knowledge: {
key: '@knowledge',
shortcut: '@kb',
title: 'Search Knowledge',
description: 'Search knowledge bases',
search: vi.fn(),
},
plugin: {
key: '@plugin',
shortcut: '@plugin',
title: 'Search Plugins',
description: 'Search plugins',
search: vi.fn(),
},
node: {
key: '@node',
shortcut: '@node',
title: 'Search Nodes',
description: 'Search workflow nodes',
search: vi.fn(),
},
}
const mockOnCommandSelect = vi.fn()
const mockOnCommandValueChange = vi.fn()
beforeEach(() => {
vi.clearAllMocks()
})
describe('Basic Rendering', () => {
it('should render all actions when no filter is provided', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
/>,
)
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
})
it('should render empty filter as showing all actions', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter=""
/>,
)
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
})
})
describe('Filtering Functionality', () => {
it('should filter actions based on searchFilter - single match', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="k"
/>,
)
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument()
expect(screen.queryByTestId('command-item-@plugin')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument()
})
it('should filter actions with multiple matches', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="p"
/>,
)
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
expect(screen.queryByTestId('command-item-@kb')).not.toBeInTheDocument()
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument()
})
it('should be case-insensitive when filtering', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="APP"
/>,
)
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
expect(screen.queryByTestId('command-item-@kb')).not.toBeInTheDocument()
})
it('should match partial strings', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="od"
/>,
)
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@kb')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@plugin')).not.toBeInTheDocument()
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
})
})
describe('Empty State', () => {
it('should show empty state when no matches found', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="xyz"
/>,
)
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@kb')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@plugin')).not.toBeInTheDocument()
expect(screen.queryByTestId('command-item-@node')).not.toBeInTheDocument()
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
expect(screen.getByText('app.gotoAnything.tryDifferentSearch')).toBeInTheDocument()
})
it('should not show empty state when filter is empty', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter=""
/>,
)
expect(screen.queryByText('app.gotoAnything.noMatchingCommands')).not.toBeInTheDocument()
})
})
describe('Selection and Highlight Management', () => {
it('should call onCommandValueChange when filter changes and first item differs', () => {
const { rerender } = render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter=""
commandValue="@app"
onCommandValueChange={mockOnCommandValueChange}
/>,
)
rerender(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="k"
commandValue="@app"
onCommandValueChange={mockOnCommandValueChange}
/>,
)
expect(mockOnCommandValueChange).toHaveBeenCalledWith('@kb')
})
it('should not call onCommandValueChange if current value still exists', () => {
const { rerender } = render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter=""
commandValue="@app"
onCommandValueChange={mockOnCommandValueChange}
/>,
)
rerender(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="a"
commandValue="@app"
onCommandValueChange={mockOnCommandValueChange}
/>,
)
expect(mockOnCommandValueChange).not.toHaveBeenCalled()
})
it('should handle onCommandSelect callback correctly', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="k"
/>,
)
const knowledgeItem = screen.getByTestId('command-item-@kb')
fireEvent.click(knowledgeItem)
expect(mockOnCommandSelect).toHaveBeenCalledWith('@kb')
})
})
describe('Edge Cases', () => {
it('should handle empty actions object', () => {
render(
<CommandSelector
actions={{}}
onCommandSelect={mockOnCommandSelect}
searchFilter=""
/>,
)
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
})
it('should handle special characters in filter', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="@"
/>,
)
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
})
it('should handle undefined onCommandValueChange gracefully', () => {
const { rerender } = render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter=""
/>,
)
expect(() => {
rerender(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="k"
/>,
)
}).not.toThrow()
})
})
describe('Backward Compatibility', () => {
it('should work without searchFilter prop (backward compatible)', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
/>,
)
expect(screen.getByTestId('command-item-@app')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@plugin')).toBeInTheDocument()
expect(screen.getByTestId('command-item-@node')).toBeInTheDocument()
})
it('should work without commandValue and onCommandValueChange props', () => {
render(
<CommandSelector
actions={mockActions}
onCommandSelect={mockOnCommandSelect}
searchFilter="k"
/>,
)
expect(screen.getByTestId('command-item-@kb')).toBeInTheDocument()
expect(screen.queryByTestId('command-item-@app')).not.toBeInTheDocument()
})
})
})

View File

@@ -1,236 +0,0 @@
import type { Mock } from 'vitest'
import type { ActionItem } from '../../app/components/goto-anything/actions/types'
// Import after mocking to get mocked version
import { matchAction } from '../../app/components/goto-anything/actions'
import { slashCommandRegistry } from '../../app/components/goto-anything/actions/commands/registry'
// Mock the entire actions module to avoid import issues
vi.mock('../../app/components/goto-anything/actions', () => ({
matchAction: vi.fn(),
}))
vi.mock('../../app/components/goto-anything/actions/commands/registry')
// Implement the actual matchAction logic for testing
const actualMatchAction = (query: string, actions: Record<string, ActionItem>) => {
const result = Object.values(actions).find((action) => {
// Special handling for slash commands
if (action.key === '/') {
// Get all registered commands from the registry
const allCommands = slashCommandRegistry.getAllCommands()
// Check if query matches any registered command
return allCommands.some((cmd) => {
const cmdPattern = `/${cmd.name}`
// For direct mode commands, don't match (keep in command selector)
if (cmd.mode === 'direct')
return false
// For submenu mode commands, match when complete command is entered
return query === cmdPattern || query.startsWith(`${cmdPattern} `)
})
}
const reg = new RegExp(`^(${action.key}|${action.shortcut})(?:\\s|$)`)
return reg.test(query)
})
return result
}
// Replace mock with actual implementation
;(matchAction as Mock).mockImplementation(actualMatchAction)
describe('matchAction Logic', () => {
const mockActions: Record<string, ActionItem> = {
app: {
key: '@app',
shortcut: '@a',
title: 'Search Applications',
description: 'Search apps',
search: vi.fn(),
},
knowledge: {
key: '@knowledge',
shortcut: '@kb',
title: 'Search Knowledge',
description: 'Search knowledge bases',
search: vi.fn(),
},
slash: {
key: '/',
shortcut: '/',
title: 'Commands',
description: 'Execute commands',
search: vi.fn(),
},
}
beforeEach(() => {
vi.clearAllMocks()
;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([
{ name: 'docs', mode: 'direct' },
{ name: 'community', mode: 'direct' },
{ name: 'feedback', mode: 'direct' },
{ name: 'account', mode: 'direct' },
{ name: 'theme', mode: 'submenu' },
{ name: 'language', mode: 'submenu' },
])
})
describe('@ Actions Matching', () => {
it('should match @app with key', () => {
const result = matchAction('@app', mockActions)
expect(result).toBe(mockActions.app)
})
it('should match @app with shortcut', () => {
const result = matchAction('@a', mockActions)
expect(result).toBe(mockActions.app)
})
it('should match @knowledge with key', () => {
const result = matchAction('@knowledge', mockActions)
expect(result).toBe(mockActions.knowledge)
})
it('should match @knowledge with shortcut @kb', () => {
const result = matchAction('@kb', mockActions)
expect(result).toBe(mockActions.knowledge)
})
it('should match with text after action', () => {
const result = matchAction('@app search term', mockActions)
expect(result).toBe(mockActions.app)
})
it('should not match partial @ actions', () => {
const result = matchAction('@ap', mockActions)
expect(result).toBeUndefined()
})
})
describe('Slash Commands Matching', () => {
describe('Direct Mode Commands', () => {
it('should not match direct mode commands', () => {
const result = matchAction('/docs', mockActions)
expect(result).toBeUndefined()
})
it('should not match direct mode with arguments', () => {
const result = matchAction('/docs something', mockActions)
expect(result).toBeUndefined()
})
it('should not match any direct mode command', () => {
expect(matchAction('/community', mockActions)).toBeUndefined()
expect(matchAction('/feedback', mockActions)).toBeUndefined()
expect(matchAction('/account', mockActions)).toBeUndefined()
})
})
describe('Submenu Mode Commands', () => {
it('should match submenu mode commands exactly', () => {
const result = matchAction('/theme', mockActions)
expect(result).toBe(mockActions.slash)
})
it('should match submenu mode with arguments', () => {
const result = matchAction('/theme dark', mockActions)
expect(result).toBe(mockActions.slash)
})
it('should match all submenu commands', () => {
expect(matchAction('/language', mockActions)).toBe(mockActions.slash)
expect(matchAction('/language en', mockActions)).toBe(mockActions.slash)
})
})
describe('Slash Without Command', () => {
it('should not match single slash', () => {
const result = matchAction('/', mockActions)
expect(result).toBeUndefined()
})
it('should not match unregistered commands', () => {
const result = matchAction('/unknown', mockActions)
expect(result).toBeUndefined()
})
})
})
describe('Edge Cases', () => {
it('should handle empty query', () => {
const result = matchAction('', mockActions)
expect(result).toBeUndefined()
})
it('should handle whitespace only', () => {
const result = matchAction(' ', mockActions)
expect(result).toBeUndefined()
})
it('should handle regular text without actions', () => {
const result = matchAction('search something', mockActions)
expect(result).toBeUndefined()
})
it('should handle special characters', () => {
const result = matchAction('#tag', mockActions)
expect(result).toBeUndefined()
})
it('should handle multiple @ or /', () => {
expect(matchAction('@@app', mockActions)).toBeUndefined()
expect(matchAction('//theme', mockActions)).toBeUndefined()
})
})
describe('Mode-based Filtering', () => {
it('should filter direct mode commands from matching', () => {
;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([
{ name: 'test', mode: 'direct' },
])
const result = matchAction('/test', mockActions)
expect(result).toBeUndefined()
})
it('should allow submenu mode commands to match', () => {
;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([
{ name: 'test', mode: 'submenu' },
])
const result = matchAction('/test', mockActions)
expect(result).toBe(mockActions.slash)
})
it('should treat undefined mode as submenu', () => {
;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([
{ name: 'test' }, // No mode specified
])
const result = matchAction('/test', mockActions)
expect(result).toBe(mockActions.slash)
})
})
describe('Registry Integration', () => {
it('should call getAllCommands when matching slash', () => {
matchAction('/theme', mockActions)
expect(slashCommandRegistry.getAllCommands).toHaveBeenCalled()
})
it('should not call getAllCommands for @ actions', () => {
matchAction('@app', mockActions)
expect(slashCommandRegistry.getAllCommands).not.toHaveBeenCalled()
})
it('should handle empty command list', () => {
;(slashCommandRegistry.getAllCommands as Mock).mockReturnValue([])
const result = matchAction('/anything', mockActions)
expect(result).toBeUndefined()
})
})
})

View File

@@ -9,7 +9,7 @@ import type { MockedFunction } from 'vitest'
* 4. Ensure errors don't propagate to UI layer causing "search failed"
*/
import { Actions, searchAnything } from '@/app/components/goto-anything/actions'
import { appScope, knowledgeScope, pluginScope, searchAnything } from '@/app/components/goto-anything/actions'
import { fetchAppList } from '@/service/apps'
import { postMarketplace } from '@/service/base'
import { fetchDatasets } from '@/service/datasets'
@@ -30,6 +30,7 @@ vi.mock('@/service/datasets', () => ({
const mockPostMarketplace = postMarketplace as MockedFunction<typeof postMarketplace>
const mockFetchAppList = fetchAppList as MockedFunction<typeof fetchAppList>
const mockFetchDatasets = fetchDatasets as MockedFunction<typeof fetchDatasets>
const searchScopes = [appScope, knowledgeScope, pluginScope]
describe('GotoAnything Search Error Handling', () => {
beforeEach(() => {
@@ -49,10 +50,7 @@ describe('GotoAnything Search Error Handling', () => {
// Mock marketplace API failure (403 permission denied)
mockPostMarketplace.mockRejectedValue(new Error('HTTP 403: Forbidden'))
const pluginAction = Actions.plugin
// Directly call plugin action's search method
const result = await pluginAction.search('@plugin', 'test', 'en')
const result = await pluginScope.search('@plugin', 'test', 'en')
// Should return empty array instead of throwing error
expect(result).toEqual([])
@@ -72,8 +70,7 @@ describe('GotoAnything Search Error Handling', () => {
data: { plugins: [] },
})
const pluginAction = Actions.plugin
const result = await pluginAction.search('@plugin', '', 'en')
const result = await pluginScope.search('@plugin', '', 'en')
expect(result).toEqual([])
})
@@ -84,8 +81,7 @@ describe('GotoAnything Search Error Handling', () => {
data: null,
})
const pluginAction = Actions.plugin
const result = await pluginAction.search('@plugin', 'test', 'en')
const result = await pluginScope.search('@plugin', 'test', 'en')
expect(result).toEqual([])
})
@@ -96,8 +92,7 @@ describe('GotoAnything Search Error Handling', () => {
// Mock app API failure
mockFetchAppList.mockRejectedValue(new Error('API Error'))
const appAction = Actions.app
const result = await appAction.search('@app', 'test', 'en')
const result = await appScope.search('@app', 'test', 'en')
expect(result).toEqual([])
})
@@ -106,8 +101,7 @@ describe('GotoAnything Search Error Handling', () => {
// Mock knowledge API failure
mockFetchDatasets.mockRejectedValue(new Error('API Error'))
const knowledgeAction = Actions.knowledge
const result = await knowledgeAction.search('@knowledge', 'test', 'en')
const result = await knowledgeScope.search('@knowledge', 'test', 'en')
expect(result).toEqual([])
})
@@ -120,7 +114,7 @@ describe('GotoAnything Search Error Handling', () => {
mockFetchDatasets.mockResolvedValue({ data: [], has_more: false, limit: 10, page: 1, total: 0 })
mockPostMarketplace.mockRejectedValue(new Error('Plugin API failed'))
const result = await searchAnything('en', 'test')
const result = await searchAnything('en', 'test', undefined, searchScopes)
// Should return successful results even if plugin search fails
expect(result).toEqual([])
@@ -131,8 +125,7 @@ describe('GotoAnything Search Error Handling', () => {
// Mock plugin API failure
mockPostMarketplace.mockRejectedValue(new Error('Plugin service unavailable'))
const pluginAction = Actions.plugin
const result = await searchAnything('en', '@plugin test', pluginAction)
const result = await searchAnything('en', '@plugin test', pluginScope, searchScopes)
// Should return empty array instead of throwing error
expect(result).toEqual([])
@@ -142,8 +135,7 @@ describe('GotoAnything Search Error Handling', () => {
// Mock app API failure
mockFetchAppList.mockRejectedValue(new Error('App service unavailable'))
const appAction = Actions.app
const result = await searchAnything('en', '@app test', appAction)
const result = await searchAnything('en', '@app test', appScope, searchScopes)
expect(result).toEqual([])
})
@@ -157,9 +149,9 @@ describe('GotoAnything Search Error Handling', () => {
mockFetchDatasets.mockRejectedValue(new Error('Dataset API failed'))
const actions = [
{ name: '@plugin', action: Actions.plugin },
{ name: '@app', action: Actions.app },
{ name: '@knowledge', action: Actions.knowledge },
{ name: '@plugin', action: pluginScope },
{ name: '@app', action: appScope },
{ name: '@knowledge', action: knowledgeScope },
]
for (const { name, action } of actions) {
@@ -173,7 +165,7 @@ describe('GotoAnything Search Error Handling', () => {
it('empty search term should be handled properly', async () => {
mockPostMarketplace.mockResolvedValue({ data: { plugins: [] } })
const result = await searchAnything('en', '@plugin ', Actions.plugin)
const result = await searchAnything('en', '@plugin ', pluginScope, searchScopes)
expect(result).toEqual([])
})
@@ -183,7 +175,7 @@ describe('GotoAnything Search Error Handling', () => {
mockPostMarketplace.mockRejectedValue(timeoutError)
const result = await searchAnything('en', '@plugin test', Actions.plugin)
const result = await searchAnything('en', '@plugin test', pluginScope, searchScopes)
expect(result).toEqual([])
})
@@ -191,7 +183,7 @@ describe('GotoAnything Search Error Handling', () => {
const parseError = new SyntaxError('Unexpected token in JSON')
mockPostMarketplace.mockRejectedValue(parseError)
const result = await searchAnything('en', '@plugin test', Actions.plugin)
const result = await searchAnything('en', '@plugin test', pluginScope, searchScopes)
expect(result).toEqual([])
})
})

View File

@@ -10,9 +10,15 @@ type VersionSelectorProps = {
versionLen: number
value: number
onChange: (index: number) => void
contentClassName?: string
}
const VersionSelector: React.FC<VersionSelectorProps> = ({ versionLen, value, onChange }) => {
const VersionSelector: React.FC<VersionSelectorProps> = ({
versionLen,
value,
onChange,
contentClassName,
}) => {
const { t } = useTranslation()
const [isOpen, {
setFalse: handleOpenFalse,
@@ -64,6 +70,7 @@ const VersionSelector: React.FC<VersionSelectorProps> = ({ versionLen, value, on
</PortalToFollowElemTrigger>
<PortalToFollowElemContent className={cn(
'z-[99]',
contentClassName,
)}
>
<div

View File

@@ -1,9 +1,10 @@
import type { ActionItem, AppSearchResult } from './types'
import type { AppSearchResult, ScopeDescriptor } from './types'
import type { App } from '@/types/app'
import { fetchAppList } from '@/service/apps'
import { getRedirectionPath } from '@/utils/app-redirection'
import { AppTypeIcon } from '../../app/type-selector'
import AppIcon from '../../base/app-icon'
import { ACTION_KEYS } from '../constants'
const parser = (apps: App[]): AppSearchResult[] => {
return apps.map(app => ({
@@ -35,9 +36,9 @@ const parser = (apps: App[]): AppSearchResult[] => {
}))
}
export const appAction: ActionItem = {
key: '@app',
shortcut: '@app',
export const appScope: ScopeDescriptor = {
id: 'app',
shortcut: ACTION_KEYS.APP,
title: 'Search Applications',
description: 'Search and navigate to your applications',
// action,

View File

@@ -0,0 +1,189 @@
import { isInWorkflowPage, VIBE_COMMAND_EVENT } from '@/app/components/workflow/constants'
import i18n from '@/i18n-config/i18next-config'
import { bananaCommand } from './banana'
import { registerCommands, unregisterCommands } from './command-bus'
vi.mock('@/i18n-config/i18next-config', () => ({
default: {
t: vi.fn((key: string, options?: Record<string, unknown>) => {
if (!options)
return key
return `${key}:${JSON.stringify(options)}`
}),
},
}))
vi.mock('@/app/components/workflow/constants', async () => {
const actual = await vi.importActual<typeof import('@/app/components/workflow/constants')>(
'@/app/components/workflow/constants',
)
return {
...actual,
isInWorkflowPage: vi.fn(),
}
})
vi.mock('./command-bus', () => ({
registerCommands: vi.fn(),
unregisterCommands: vi.fn(),
}))
const mockedIsInWorkflowPage = vi.mocked(isInWorkflowPage)
const mockedRegisterCommands = vi.mocked(registerCommands)
const mockedUnregisterCommands = vi.mocked(unregisterCommands)
const mockedT = vi.mocked(i18n.t)
type CommandArgs = { dsl?: string }
type CommandMap = Record<string, (args?: CommandArgs) => void | Promise<void>>
beforeEach(() => {
vi.clearAllMocks()
})
// Command availability, search, and registration behavior for banana command.
describe('bananaCommand', () => {
// Command metadata mirrors the static definition.
describe('metadata', () => {
it('should expose name, mode, and description', () => {
// Assert
expect(bananaCommand.name).toBe('banana')
expect(bananaCommand.mode).toBe('submenu')
expect(bananaCommand.description).toContain('gotoAnything.actions.vibeDesc')
})
})
// Availability mirrors workflow page detection.
describe('availability', () => {
it('should return true when on workflow page', () => {
// Arrange
mockedIsInWorkflowPage.mockReturnValue(true)
// Act
const available = bananaCommand.isAvailable?.()
// Assert
expect(available).toBe(true)
expect(mockedIsInWorkflowPage).toHaveBeenCalledTimes(1)
})
it('should return false when not on workflow page', () => {
// Arrange
mockedIsInWorkflowPage.mockReturnValue(false)
// Act
const available = bananaCommand.isAvailable?.()
// Assert
expect(available).toBe(false)
expect(mockedIsInWorkflowPage).toHaveBeenCalledTimes(1)
})
})
// Search results depend on provided arguments.
describe('search', () => {
it('should return hint description when args are empty', async () => {
// Arrange
mockedIsInWorkflowPage.mockReturnValue(true)
// Act
const result = await bananaCommand.search(' ')
// Assert
expect(result).toHaveLength(1)
const [item] = result
expect(item.description).toContain('gotoAnything.actions.vibeHint')
expect(item.data?.args?.dsl).toBe('')
expect(item.data?.command).toBe('workflow.vibe')
expect(mockedT).toHaveBeenCalledWith(
'gotoAnything.actions.vibeTitle',
expect.objectContaining({ lng: 'en', ns: 'app' }),
)
expect(mockedT).toHaveBeenCalledWith(
'gotoAnything.actions.vibeHint',
expect.objectContaining({ prompt: expect.any(String), lng: 'en', ns: 'app' }),
)
})
it('should return default description when args are provided', async () => {
// Arrange
mockedIsInWorkflowPage.mockReturnValue(true)
// Act
const result = await bananaCommand.search(' make a flow ', 'fr')
// Assert
expect(result).toHaveLength(1)
const [item] = result
expect(item.description).toContain('gotoAnything.actions.vibeDesc')
expect(item.data?.args?.dsl).toBe('make a flow')
expect(item.data?.command).toBe('workflow.vibe')
expect(mockedT).toHaveBeenCalledWith(
'gotoAnything.actions.vibeTitle',
expect.objectContaining({ lng: 'fr', ns: 'app' }),
)
expect(mockedT).toHaveBeenCalledWith(
'gotoAnything.actions.vibeDesc',
expect.objectContaining({ lng: 'fr', ns: 'app' }),
)
})
it('should fall back to Banana when title translation is empty', async () => {
// Arrange
mockedIsInWorkflowPage.mockReturnValue(true)
mockedT.mockImplementationOnce(() => '')
// Act
const result = await bananaCommand.search('make a plan')
// Assert
expect(result).toHaveLength(1)
expect(result[0]?.title).toBe('Banana')
})
})
// Command registration and event dispatching.
describe('registration', () => {
it('should register the workflow vibe command', () => {
// Act
expect(bananaCommand.register).toBeDefined()
bananaCommand.register?.({})
// Assert
expect(mockedRegisterCommands).toHaveBeenCalledTimes(1)
const commands = mockedRegisterCommands.mock.calls[0]?.[0] as CommandMap
expect(commands['workflow.vibe']).toEqual(expect.any(Function))
})
it('should dispatch vibe event when command handler runs', async () => {
// Arrange
const dispatchSpy = vi.spyOn(document, 'dispatchEvent')
expect(bananaCommand.register).toBeDefined()
bananaCommand.register?.({})
expect(mockedRegisterCommands).toHaveBeenCalledTimes(1)
const commands = mockedRegisterCommands.mock.calls[0]?.[0] as CommandMap
try {
// Act
await commands['workflow.vibe']?.({ dsl: 'hello' })
// Assert
expect(dispatchSpy).toHaveBeenCalledTimes(1)
const event = dispatchSpy.mock.calls[0][0] as CustomEvent
expect(event.type).toBe(VIBE_COMMAND_EVENT)
expect(event.detail).toEqual({ dsl: 'hello' })
}
finally {
dispatchSpy.mockRestore()
}
})
it('should unregister workflow vibe command', () => {
// Act
expect(bananaCommand.unregister).toBeDefined()
bananaCommand.unregister?.()
// Assert
expect(mockedUnregisterCommands).toHaveBeenCalledWith(['workflow.vibe'])
})
})
})

View File

@@ -0,0 +1,59 @@
import type { SlashCommandHandler } from './types'
import { RiSparklingFill } from '@remixicon/react'
import * as React from 'react'
import { isInWorkflowPage, VIBE_COMMAND_EVENT } from '@/app/components/workflow/constants'
import i18n from '@/i18n-config/i18next-config'
import { registerCommands, unregisterCommands } from './command-bus'
type BananaDeps = Record<string, never>
const BANANA_PROMPT_EXAMPLE = 'Summarize a document, classify sentiment, then notify Slack'
const dispatchVibeCommand = (input?: string) => {
if (typeof document === 'undefined')
return
document.dispatchEvent(new CustomEvent(VIBE_COMMAND_EVENT, { detail: { dsl: input } }))
}
export const bananaCommand: SlashCommandHandler<BananaDeps> = {
name: 'banana',
description: i18n.t('gotoAnything.actions.vibeDesc', { ns: 'app' }),
mode: 'submenu',
isAvailable: () => isInWorkflowPage(),
async search(args: string, locale: string = 'en') {
const trimmed = args.trim()
const hasInput = !!trimmed
return [{
id: 'banana-vibe',
title: i18n.t('gotoAnything.actions.vibeTitle', { ns: 'app', lng: locale }) || 'Banana',
description: hasInput
? i18n.t('gotoAnything.actions.vibeDesc', { ns: 'app', lng: locale })
: i18n.t('gotoAnything.actions.vibeHint', { ns: 'app', lng: locale, prompt: BANANA_PROMPT_EXAMPLE }),
type: 'command' as const,
icon: (
<div className="flex h-6 w-6 items-center justify-center rounded-md border-[0.5px] border-divider-regular bg-components-panel-bg">
<RiSparklingFill className="h-4 w-4 text-text-tertiary" />
</div>
),
data: {
command: 'workflow.vibe',
args: { dsl: trimmed },
},
}]
},
register(_deps: BananaDeps) {
registerCommands({
'workflow.vibe': async (args) => {
dispatchVibeCommand(args?.dsl)
},
})
},
unregister() {
unregisterCommands(['workflow.vibe'])
},
}

View File

@@ -9,7 +9,7 @@ export {
export { slashCommandRegistry, SlashCommandRegistry } from './registry'
// Command system exports
export { slashAction } from './slash'
export { slashScope } from './slash'
export { registerSlashCommands, SlashCommandProvider, unregisterSlashCommands } from './slash'
export type { SlashCommandHandler } from './types'

View File

@@ -1,12 +1,13 @@
import type { CommandSearchResult } from '../types'
import type { SlashCommandHandler } from './types'
import type { Locale } from '@/i18n-config/language'
import i18n from '@/i18n-config/i18next-config'
import { languages } from '@/i18n-config/language'
import { registerCommands, unregisterCommands } from './command-bus'
// Language dependency types
type LanguageDeps = {
setLocale?: (locale: string) => Promise<void>
setLocale?: (locale: Locale, reloadPage?: boolean) => Promise<void>
}
const buildLanguageCommands = (query: string): CommandSearchResult[] => {

View File

@@ -6,20 +6,21 @@ import type { SlashCommandHandler } from './types'
* Responsible for managing registration, lookup, and search of all slash commands
*/
export class SlashCommandRegistry {
private commands = new Map<string, SlashCommandHandler>()
private commandDeps = new Map<string, any>()
private commands = new Map<string, SlashCommandHandler<unknown>>()
private commandDeps = new Map<string, unknown>()
/**
* Register command handler
*/
register<TDeps = any>(handler: SlashCommandHandler<TDeps>, deps?: TDeps) {
register<TDeps = unknown>(handler: SlashCommandHandler<TDeps>, deps?: TDeps) {
// Register main command name
this.commands.set(handler.name, handler)
// Cast to unknown first, then to SlashCommandHandler<unknown> to handle generic type variance
this.commands.set(handler.name, handler as SlashCommandHandler<unknown>)
// Register aliases
if (handler.aliases) {
handler.aliases.forEach((alias) => {
this.commands.set(alias, handler)
this.commands.set(alias, handler as SlashCommandHandler<unknown>)
})
}
@@ -57,7 +58,7 @@ export class SlashCommandRegistry {
/**
* Find command handler
*/
findCommand(commandName: string): SlashCommandHandler | undefined {
findCommand(commandName: string): SlashCommandHandler<unknown> | undefined {
return this.commands.get(commandName)
}
@@ -65,7 +66,7 @@ export class SlashCommandRegistry {
* Smart partial command matching
* Prioritize alias matching, then match command name prefix
*/
private findBestPartialMatch(partialName: string): SlashCommandHandler | undefined {
private findBestPartialMatch(partialName: string): SlashCommandHandler<unknown> | undefined {
const lowerPartial = partialName.toLowerCase()
// First check if any alias starts with this
@@ -81,7 +82,7 @@ export class SlashCommandRegistry {
/**
* Find handler by alias prefix
*/
private findHandlerByAliasPrefix(prefix: string): SlashCommandHandler | undefined {
private findHandlerByAliasPrefix(prefix: string): SlashCommandHandler<unknown> | undefined {
for (const handler of this.getAllCommands()) {
if (handler.aliases?.some(alias => alias.toLowerCase().startsWith(prefix)))
return handler
@@ -92,7 +93,7 @@ export class SlashCommandRegistry {
/**
* Find handler by name prefix
*/
private findHandlerByNamePrefix(prefix: string): SlashCommandHandler | undefined {
private findHandlerByNamePrefix(prefix: string): SlashCommandHandler<unknown> | undefined {
return this.getAllCommands().find(handler =>
handler.name.toLowerCase().startsWith(prefix),
)
@@ -101,8 +102,8 @@ export class SlashCommandRegistry {
/**
* Get all registered commands (deduplicated)
*/
getAllCommands(): SlashCommandHandler[] {
const uniqueCommands = new Map<string, SlashCommandHandler>()
getAllCommands(): SlashCommandHandler<unknown>[] {
const uniqueCommands = new Map<string, SlashCommandHandler<unknown>>()
this.commands.forEach((handler) => {
uniqueCommands.set(handler.name, handler)
})
@@ -113,7 +114,7 @@ export class SlashCommandRegistry {
* Get all available commands in current context (deduplicated and filtered)
* Commands without isAvailable method are considered always available
*/
getAvailableCommands(): SlashCommandHandler[] {
getAvailableCommands(): SlashCommandHandler<unknown>[] {
return this.getAllCommands().filter(handler => this.isCommandAvailable(handler))
}
@@ -228,7 +229,7 @@ export class SlashCommandRegistry {
/**
* Get command dependencies
*/
getCommandDependencies(commandName: string): any {
getCommandDependencies(commandName: string): unknown {
return this.commandDeps.get(commandName)
}
@@ -236,7 +237,7 @@ export class SlashCommandRegistry {
* Determine if a command is available in the current context.
* Defaults to true when a handler does not implement the guard.
*/
private isCommandAvailable(handler: SlashCommandHandler) {
private isCommandAvailable(handler: SlashCommandHandler<unknown>) {
return handler.isAvailable?.() ?? true
}
}

View File

@@ -1,11 +1,13 @@
'use client'
import type { ActionItem } from '../types'
import type { ScopeDescriptor } from '../types'
import type { SlashCommandDependencies } from './types'
import { useTheme } from 'next-themes'
import { useEffect } from 'react'
import { setLocaleOnClient } from '@/i18n-config'
import i18n from '@/i18n-config/i18next-config'
import { ACTION_KEYS } from '../../constants'
import { accountCommand } from './account'
import { executeCommand } from './command-bus'
import { bananaCommand } from './banana'
import { communityCommand } from './community'
import { docsCommand } from './docs'
import { forumCommand } from './forum'
@@ -14,17 +16,11 @@ import { slashCommandRegistry } from './registry'
import { themeCommand } from './theme'
import { zenCommand } from './zen'
export const slashAction: ActionItem = {
key: '/',
shortcut: '/',
export const slashScope: ScopeDescriptor = {
id: 'slash',
shortcut: ACTION_KEYS.SLASH,
title: i18n.t('gotoAnything.actions.slashTitle', { ns: 'app' }),
description: i18n.t('gotoAnything.actions.slashDesc', { ns: 'app' }),
action: (result) => {
if (result.type !== 'command')
return
const { command, args } = result.data
executeCommand(command, args)
},
search: async (query, _searchTerm = '') => {
// Delegate all search logic to the command registry system
return slashCommandRegistry.search(query, i18n.language)
@@ -32,7 +28,7 @@ export const slashAction: ActionItem = {
}
// Register/unregister default handlers for slash commands with external dependencies.
export const registerSlashCommands = (deps: Record<string, any>) => {
export const registerSlashCommands = (deps: SlashCommandDependencies) => {
// Register command handlers to the registry system with their respective dependencies
slashCommandRegistry.register(themeCommand, { setTheme: deps.setTheme })
slashCommandRegistry.register(languageCommand, { setLocale: deps.setLocale })
@@ -41,6 +37,7 @@ export const registerSlashCommands = (deps: Record<string, any>) => {
slashCommandRegistry.register(communityCommand, {})
slashCommandRegistry.register(accountCommand, {})
slashCommandRegistry.register(zenCommand, {})
slashCommandRegistry.register(bananaCommand, {})
}
export const unregisterSlashCommands = () => {
@@ -52,6 +49,7 @@ export const unregisterSlashCommands = () => {
slashCommandRegistry.unregister('community')
slashCommandRegistry.unregister('account')
slashCommandRegistry.unregister('zen')
slashCommandRegistry.unregister('banana')
}
export const SlashCommandProvider = () => {

View File

@@ -1,10 +1,11 @@
import type { CommandSearchResult } from '../types'
import type { Locale } from '@/i18n-config/language'
/**
* Slash command handler interface
* Each slash command should implement this interface
*/
export type SlashCommandHandler<TDeps = any> = {
export type SlashCommandHandler<TDeps = unknown> = {
/** Command name (e.g., 'theme', 'language') */
name: string
@@ -51,3 +52,31 @@ export type SlashCommandHandler<TDeps = any> = {
*/
unregister?: () => void
}
/**
* Theme command dependencies
*/
export type ThemeCommandDeps = {
setTheme?: (value: 'light' | 'dark' | 'system') => void
}
/**
* Language command dependencies
*/
export type LanguageCommandDeps = {
setLocale?: (locale: Locale, reloadPage?: boolean) => Promise<void>
}
/**
* Commands without external dependencies
*/
export type NoDepsCommandDeps = Record<string, never>
/**
* Union type of all slash command dependencies
* Used for type-safe dependency injection in registerSlashCommands
*/
export type SlashCommandDependencies = {
setTheme?: (value: 'light' | 'dark' | 'system') => void
setLocale?: (locale: Locale, reloadPage?: boolean) => Promise<void>
}

View File

@@ -0,0 +1,150 @@
import { isInWorkflowPage, VIBE_COMMAND_EVENT } from '@/app/components/workflow/constants'
import i18n from '@/i18n-config/i18next-config'
import { registerCommands, unregisterCommands } from './command-bus'
import { vibeCommand } from './vibe'
vi.mock('@/i18n-config/i18next-config', () => ({
default: {
t: vi.fn((key: string, options?: Record<string, unknown>) => {
if (!options)
return key
return `${key}:${JSON.stringify(options)}`
}),
},
}))
vi.mock('@/app/components/workflow/constants', async () => {
const actual = await vi.importActual<typeof import('@/app/components/workflow/constants')>(
'@/app/components/workflow/constants',
)
return {
...actual,
isInWorkflowPage: vi.fn(),
}
})
vi.mock('./command-bus', () => ({
registerCommands: vi.fn(),
unregisterCommands: vi.fn(),
}))
const mockedIsInWorkflowPage = vi.mocked(isInWorkflowPage)
const mockedRegisterCommands = vi.mocked(registerCommands)
const mockedUnregisterCommands = vi.mocked(unregisterCommands)
const mockedT = vi.mocked(i18n.t)
type CommandArgs = { dsl?: string }
type CommandMap = Record<string, (args?: CommandArgs) => void | Promise<void>>
beforeEach(() => {
vi.clearAllMocks()
})
// Command availability, search, and registration behavior for workflow vibe.
describe('vibeCommand', () => {
// Availability mirrors workflow page detection.
describe('availability', () => {
it('should return true when on workflow page', () => {
// Arrange
mockedIsInWorkflowPage.mockReturnValue(true)
// Act
const available = vibeCommand.isAvailable?.()
// Assert
expect(available).toBe(true)
})
it('should return false when not on workflow page', () => {
// Arrange
mockedIsInWorkflowPage.mockReturnValue(false)
// Act
const available = vibeCommand.isAvailable?.()
// Assert
expect(available).toBe(false)
})
})
// Search results depend on provided arguments.
describe('search', () => {
it('should return hint description when args are empty', async () => {
// Arrange
mockedIsInWorkflowPage.mockReturnValue(true)
// Act
const result = await vibeCommand.search(' ', 'en')
// Assert
expect(result).toHaveLength(1)
const [item] = result
expect(item.description).toContain('gotoAnything.actions.vibeHint')
expect(item.data?.args?.dsl).toBe('')
expect(mockedT).toHaveBeenCalledWith(
'gotoAnything.actions.vibeHint',
expect.objectContaining({ prompt: expect.any(String), lng: 'en', ns: 'app' }),
)
})
it('should return default description when args are provided', async () => {
// Arrange
mockedIsInWorkflowPage.mockReturnValue(true)
// Act
const result = await vibeCommand.search(' make a flow ', 'en')
// Assert
expect(result).toHaveLength(1)
const [item] = result
expect(item.description).toContain('gotoAnything.actions.vibeDesc')
expect(item.data?.args?.dsl).toBe('make a flow')
})
})
// Command registration and event dispatching.
describe('registration', () => {
it('should register the workflow vibe command', () => {
// Act
expect(vibeCommand.register).toBeDefined()
vibeCommand.register?.({})
// Assert
expect(mockedRegisterCommands).toHaveBeenCalledTimes(1)
const commands = mockedRegisterCommands.mock.calls[0]?.[0] as CommandMap
expect(commands['workflow.vibe']).toEqual(expect.any(Function))
})
it('should dispatch vibe event when command handler runs', async () => {
// Arrange
const dispatchSpy = vi.spyOn(document, 'dispatchEvent')
expect(vibeCommand.register).toBeDefined()
vibeCommand.register?.({})
expect(mockedRegisterCommands).toHaveBeenCalledTimes(1)
const commands = mockedRegisterCommands.mock.calls[0]?.[0] as CommandMap
try {
// Act
await commands['workflow.vibe']?.({ dsl: 'hello' })
// Assert
expect(dispatchSpy).toHaveBeenCalledTimes(1)
const event = dispatchSpy.mock.calls[0][0] as CustomEvent
expect(event.type).toBe(VIBE_COMMAND_EVENT)
expect(event.detail).toEqual({ dsl: 'hello' })
}
finally {
dispatchSpy.mockRestore()
}
})
it('should unregister workflow vibe command', () => {
// Act
expect(vibeCommand.unregister).toBeDefined()
vibeCommand.unregister?.()
// Assert
expect(mockedUnregisterCommands).toHaveBeenCalledWith(['workflow.vibe'])
})
})
})

View File

@@ -0,0 +1,59 @@
import type { SlashCommandHandler } from './types'
import { RiSparklingFill } from '@remixicon/react'
import * as React from 'react'
import { isInWorkflowPage, VIBE_COMMAND_EVENT } from '@/app/components/workflow/constants'
import i18n from '@/i18n-config/i18next-config'
import { registerCommands, unregisterCommands } from './command-bus'
type VibeDeps = Record<string, never>
const VIBE_PROMPT_EXAMPLE = 'Summarize a document, classify sentiment, then notify Slack'
const dispatchVibeCommand = (input?: string) => {
if (typeof document === 'undefined')
return
document.dispatchEvent(new CustomEvent(VIBE_COMMAND_EVENT, { detail: { dsl: input } }))
}
export const vibeCommand: SlashCommandHandler<VibeDeps> = {
name: 'vibe',
description: i18n.t('gotoAnything.actions.vibeDesc', { ns: 'app' }),
mode: 'submenu',
isAvailable: () => isInWorkflowPage(),
async search(args: string, locale: string = 'en') {
const trimmed = args.trim()
const hasInput = !!trimmed
return [{
id: 'vibe',
title: i18n.t('gotoAnything.actions.vibeTitle', { ns: 'app', lng: locale }) || 'Vibe',
description: hasInput
? i18n.t('gotoAnything.actions.vibeDesc', { ns: 'app', lng: locale })
: i18n.t('gotoAnything.actions.vibeHint', { ns: 'app', lng: locale, prompt: VIBE_PROMPT_EXAMPLE }),
type: 'command' as const,
icon: (
<div className="flex h-6 w-6 items-center justify-center rounded-md border-[0.5px] border-divider-regular bg-components-panel-bg">
<RiSparklingFill className="h-4 w-4 text-text-tertiary" />
</div>
),
data: {
command: 'workflow.vibe',
args: { dsl: trimmed },
},
}]
},
register(_deps: VibeDeps) {
registerCommands({
'workflow.vibe': async (args) => {
dispatchVibeCommand(args?.dsl)
},
})
},
unregister() {
unregisterCommands(['workflow.vibe'])
},
}

View File

@@ -3,228 +3,66 @@
*
* This file defines the action registry for the goto-anything search system.
* Actions handle different types of searches: apps, knowledge bases, plugins, workflow nodes, and commands.
*
* ## How to Add a New Slash Command
*
* 1. **Create Command Handler File** (in `./commands/` directory):
* ```typescript
* // commands/my-command.ts
* import type { SlashCommandHandler } from './types'
* import type { CommandSearchResult } from '../types'
* import { registerCommands, unregisterCommands } from './command-bus'
*
* interface MyCommandDeps {
* myService?: (data: any) => Promise<void>
* }
*
* export const myCommand: SlashCommandHandler<MyCommandDeps> = {
* name: 'mycommand',
* aliases: ['mc'], // Optional aliases
* description: 'My custom command description',
*
* async search(args: string, locale: string = 'en') {
* // Return search results based on args
* return [{
* id: 'my-result',
* title: 'My Command Result',
* description: 'Description of the result',
* type: 'command' as const,
* data: { command: 'my.action', args: { value: args } }
* }]
* },
*
* register(deps: MyCommandDeps) {
* registerCommands({
* 'my.action': async (args) => {
* await deps.myService?.(args?.value)
* }
* })
* },
*
* unregister() {
* unregisterCommands(['my.action'])
* }
* }
* ```
*
* **Example for Self-Contained Command (no external dependencies):**
* ```typescript
* // commands/calculator-command.ts
* export const calculatorCommand: SlashCommandHandler = {
* name: 'calc',
* aliases: ['calculator'],
* description: 'Simple calculator',
*
* async search(args: string) {
* if (!args.trim()) return []
* try {
* // Safe math evaluation (implement proper parser in real use)
* const result = Function('"use strict"; return (' + args + ')')()
* return [{
* id: 'calc-result',
* title: `${args} = ${result}`,
* description: 'Calculator result',
* type: 'command' as const,
* data: { command: 'calc.copy', args: { result: result.toString() } }
* }]
* } catch {
* return [{
* id: 'calc-error',
* title: 'Invalid expression',
* description: 'Please enter a valid math expression',
* type: 'command' as const,
* data: { command: 'calc.noop', args: {} }
* }]
* }
* },
*
* register() {
* registerCommands({
* 'calc.copy': (args) => navigator.clipboard.writeText(args.result),
* 'calc.noop': () => {} // No operation
* })
* },
*
* unregister() {
* unregisterCommands(['calc.copy', 'calc.noop'])
* }
* }
* ```
*
* 2. **Register Command** (in `./commands/slash.tsx`):
* ```typescript
* import { myCommand } from './my-command'
* import { calculatorCommand } from './calculator-command' // For self-contained commands
*
* export const registerSlashCommands = (deps: Record<string, any>) => {
* slashCommandRegistry.register(themeCommand, { setTheme: deps.setTheme })
* slashCommandRegistry.register(languageCommand, { setLocale: deps.setLocale })
* slashCommandRegistry.register(myCommand, { myService: deps.myService }) // With dependencies
* slashCommandRegistry.register(calculatorCommand) // Self-contained, no dependencies
* }
*
* export const unregisterSlashCommands = () => {
* slashCommandRegistry.unregister('theme')
* slashCommandRegistry.unregister('language')
* slashCommandRegistry.unregister('mycommand')
* slashCommandRegistry.unregister('calc') // Add this line
* }
* ```
*
*
* 3. **Update SlashCommandProvider** (in `./commands/slash.tsx`):
* ```typescript
* export const SlashCommandProvider = () => {
* const theme = useTheme()
* const myService = useMyService() // Add external dependency if needed
*
* useEffect(() => {
* registerSlashCommands({
* setTheme: theme.setTheme, // Required for theme command
* setLocale: setLocaleOnClient, // Required for language command
* myService: myService, // Required for your custom command
* // Note: calculatorCommand doesn't need dependencies, so not listed here
* })
* return () => unregisterSlashCommands()
* }, [theme.setTheme, myService]) // Update dependency array for all dynamic deps
*
* return null
* }
* ```
*
* **Note:** Self-contained commands (like calculator) don't require dependencies but are
* still registered through the same system for consistent lifecycle management.
*
* 4. **Usage**: Users can now type `/mycommand` or `/mc` to use your command
*
* ## Command System Architecture
* - Commands are registered via `SlashCommandRegistry`
* - Each command is self-contained with its own dependencies
* - Commands support aliases for easier access
* - Command execution is handled by the command bus system
* - All commands should be registered through `SlashCommandProvider` for consistent lifecycle management
*
* ## Command Types
* **Commands with External Dependencies:**
* - Require external services, APIs, or React hooks
* - Must provide dependencies in `SlashCommandProvider`
* - Example: theme commands (needs useTheme), API commands (needs service)
*
* **Self-Contained Commands:**
* - Pure logic operations, no external dependencies
* - Still recommended to register through `SlashCommandProvider` for consistency
* - Example: calculator, text manipulation commands
*
* ## Available Actions
* - `@app` - Search applications
* - `@knowledge` / `@kb` - Search knowledge bases
* - `@plugin` - Search plugins
* - `@node` - Search workflow nodes (workflow pages only)
* - `/` - Execute slash commands (theme, language, etc.)
*/
import type { ActionItem, SearchResult } from './types'
import { appAction } from './app'
import { slashAction } from './commands'
import type { ScopeContext, ScopeDescriptor, SearchResult } from './types'
import { ACTION_KEYS } from '../constants'
import { appScope } from './app'
import { slashScope } from './commands'
import { slashCommandRegistry } from './commands/registry'
import { knowledgeAction } from './knowledge'
import { pluginAction } from './plugin'
import { ragPipelineNodesAction } from './rag-pipeline-nodes'
import { workflowNodesAction } from './workflow-nodes'
import { knowledgeScope } from './knowledge'
import { pluginScope } from './plugin'
import { registerRagPipelineNodeScope } from './rag-pipeline-nodes'
import { scopeRegistry, useScopeRegistry } from './scope-registry'
import { registerWorkflowNodeScope } from './workflow-nodes'
// Create dynamic Actions based on context
export const createActions = (isWorkflowPage: boolean, isRagPipelinePage: boolean) => {
const baseActions = {
slash: slashAction,
app: appAction,
knowledge: knowledgeAction,
plugin: pluginAction,
}
let scopesInitialized = false
// Add appropriate node search based on context
if (isRagPipelinePage) {
return {
...baseActions,
node: ragPipelineNodesAction,
}
}
else if (isWorkflowPage) {
return {
...baseActions,
node: workflowNodesAction,
}
}
export const initGotoAnythingScopes = () => {
if (scopesInitialized)
return
// Default actions without node search
return baseActions
scopesInitialized = true
scopeRegistry.register(slashScope)
scopeRegistry.register(appScope)
scopeRegistry.register(knowledgeScope)
scopeRegistry.register(pluginScope)
registerWorkflowNodeScope()
registerRagPipelineNodeScope()
}
// Legacy export for backward compatibility
export const Actions = {
slash: slashAction,
app: appAction,
knowledge: knowledgeAction,
plugin: pluginAction,
node: workflowNodesAction,
export const useGotoAnythingScopes = (context: ScopeContext) => {
initGotoAnythingScopes()
return useScopeRegistry(context)
}
const isSlashScope = (scope: ScopeDescriptor) => {
if (scope.shortcut === ACTION_KEYS.SLASH)
return true
return scope.aliases?.includes(ACTION_KEYS.SLASH) ?? false
}
const getScopeShortcuts = (scope: ScopeDescriptor) => [scope.shortcut, ...(scope.aliases ?? [])]
export const searchAnything = async (
locale: string,
query: string,
actionItem?: ActionItem,
dynamicActions?: Record<string, ActionItem>,
scope: ScopeDescriptor | undefined,
scopes: ScopeDescriptor[],
): Promise<SearchResult[]> => {
const trimmedQuery = query.trim()
if (actionItem) {
if (scope) {
const escapeRegExp = (value: string) => value.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')
const prefixPattern = new RegExp(`^(${escapeRegExp(actionItem.key)}|${escapeRegExp(actionItem.shortcut)})\\s*`)
const shortcuts = getScopeShortcuts(scope).map(escapeRegExp)
const prefixPattern = new RegExp(`^(${shortcuts.join('|')})\\s*`)
const searchTerm = trimmedQuery.replace(prefixPattern, '').trim()
try {
return await actionItem.search(query, searchTerm, locale)
return await scope.search(query, searchTerm, locale)
}
catch (error) {
console.warn(`Search failed for ${actionItem.key}:`, error)
console.warn(`Search failed for ${scope.id}:`, error)
return []
}
}
@@ -232,19 +70,19 @@ export const searchAnything = async (
if (trimmedQuery.startsWith('@') || trimmedQuery.startsWith('/'))
return []
const globalSearchActions = Object.values(dynamicActions || Actions)
// Exclude slash commands from general search results
.filter(action => action.key !== '/')
// Filter out slash commands from general search
const searchScopes = scopes.filter(scope => !isSlashScope(scope))
// Use Promise.allSettled to handle partial failures gracefully
const searchPromises = globalSearchActions.map(async (action) => {
const searchPromises = searchScopes.map(async (action) => {
const actionId = action.id
try {
const results = await action.search(query, query, locale)
return { success: true, data: results, actionType: action.key }
return { success: true, data: results, actionType: actionId }
}
catch (error) {
console.warn(`Search failed for ${action.key}:`, error)
return { success: false, data: [], actionType: action.key, error }
console.warn(`Search failed for ${actionId}:`, error)
return { success: false, data: [], actionType: actionId, error }
}
})
@@ -258,7 +96,7 @@ export const searchAnything = async (
allResults.push(...result.value.data)
}
else {
const actionKey = globalSearchActions[index]?.key || 'unknown'
const actionKey = searchScopes[index]?.id || 'unknown'
failedActions.push(actionKey)
}
})
@@ -269,31 +107,31 @@ export const searchAnything = async (
return allResults
}
export const matchAction = (query: string, actions: Record<string, ActionItem>) => {
return Object.values(actions).find((action) => {
// Special handling for slash commands
if (action.key === '/') {
// Get all registered commands from the registry
const allCommands = slashCommandRegistry.getAllCommands()
// ...
// Check if query matches any registered command
export const matchAction = (query: string, scopes: ScopeDescriptor[]) => {
const escapeRegExp = (value: string) => value.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')
return scopes.find((scope) => {
// Special handling for slash commands
if (isSlashScope(scope)) {
const allCommands = slashCommandRegistry.getAllCommands()
return allCommands.some((cmd) => {
const cmdPattern = `/${cmd.name}`
// For direct mode commands, don't match (keep in command selector)
if (cmd.mode === 'direct')
return false
// For submenu mode commands, match when complete command is entered
return query === cmdPattern || query.startsWith(`${cmdPattern} `)
})
}
const reg = new RegExp(`^(${action.key}|${action.shortcut})(?:\\s|$)`)
// Check if query matches shortcut (exact or prefix)
// Only match if it's the full shortcut followed by space
const shortcuts = getScopeShortcuts(scope).map(escapeRegExp)
const reg = new RegExp(`^(${shortcuts.join('|')})(?:\\s|$)`)
return reg.test(query)
})
}
export * from './commands'
export * from './scope-registry'
export * from './types'
export { appAction, knowledgeAction, pluginAction, workflowNodesAction }
export { appScope, knowledgeScope, pluginScope }

View File

@@ -1,8 +1,9 @@
import type { ActionItem, KnowledgeSearchResult } from './types'
import type { KnowledgeSearchResult, ScopeDescriptor } from './types'
import type { DataSet } from '@/models/datasets'
import { fetchDatasets } from '@/service/datasets'
import { cn } from '@/utils/classnames'
import { Folder } from '../../base/icons/src/vender/solid/files'
import { ACTION_KEYS } from '../constants'
const EXTERNAL_PROVIDER = 'external' as const
const isExternalProvider = (provider: string): boolean => provider === EXTERNAL_PROVIDER
@@ -30,9 +31,10 @@ const parser = (datasets: DataSet[]): KnowledgeSearchResult[] => {
})
}
export const knowledgeAction: ActionItem = {
key: '@knowledge',
shortcut: '@kb',
export const knowledgeScope: ScopeDescriptor = {
id: 'knowledge',
shortcut: ACTION_KEYS.KNOWLEDGE,
aliases: ['@kb'],
title: 'Search Knowledge Bases',
description: 'Search and navigate to your knowledge bases',
// action,

View File

@@ -1,9 +1,10 @@
import type { Plugin, PluginsFromMarketplaceResponse } from '../../plugins/types'
import type { ActionItem, PluginSearchResult } from './types'
import type { PluginSearchResult, ScopeDescriptor } from './types'
import { renderI18nObject } from '@/i18n-config'
import { postMarketplace } from '@/service/base'
import Icon from '../../plugins/card/base/card-icon'
import { getPluginIconInMarketplace } from '../../plugins/marketplace/utils'
import { ACTION_KEYS } from '../constants'
const parser = (plugins: Plugin[], locale: string): PluginSearchResult[] => {
return plugins.map((plugin) => {
@@ -18,9 +19,9 @@ const parser = (plugins: Plugin[], locale: string): PluginSearchResult[] => {
})
}
export const pluginAction: ActionItem = {
key: '@plugin',
shortcut: '@plugin',
export const pluginScope: ScopeDescriptor = {
id: 'plugin',
shortcut: ACTION_KEYS.PLUGIN,
title: 'Search Plugins',
description: 'Search and navigate to your plugins',
search: async (_, searchTerm = '', locale) => {

View File

@@ -1,24 +1,41 @@
import type { ActionItem } from './types'
import type { ScopeSearchHandler } from './scope-registry'
import type { SearchResult } from './types'
import { ACTION_KEYS } from '../constants'
import { scopeRegistry } from './scope-registry'
// Create the RAG pipeline nodes action
export const ragPipelineNodesAction: ActionItem = {
key: '@node',
shortcut: '@node',
title: 'Search RAG Pipeline Nodes',
description: 'Find and jump to nodes in the current RAG pipeline by name or type',
searchFn: undefined, // Will be set by useRagPipelineSearch hook
search: async (_, searchTerm = '', _locale) => {
const scopeId = 'rag-pipeline-node'
let scopeRegistered = false
const buildSearchHandler = (searchFn?: (searchTerm: string) => SearchResult[]): ScopeSearchHandler => {
return async (_, searchTerm = '', _locale) => {
try {
// Use the searchFn if available (set by useRagPipelineSearch hook)
if (ragPipelineNodesAction.searchFn)
return ragPipelineNodesAction.searchFn(searchTerm)
// If not in RAG pipeline context, return empty array
if (searchFn)
return searchFn(searchTerm)
return []
}
catch (error) {
console.warn('RAG pipeline nodes search failed:', error)
return []
}
},
}
}
export const registerRagPipelineNodeScope = () => {
if (scopeRegistered)
return
scopeRegistered = true
scopeRegistry.register({
id: scopeId,
shortcut: ACTION_KEYS.NODE,
title: 'Search RAG Pipeline Nodes',
description: 'Find and jump to nodes in the current RAG pipeline by name or type',
isAvailable: context => context.isRagPipelinePage,
search: buildSearchHandler(),
})
}
export const setRagPipelineNodesSearchFn = (fn: (searchTerm: string) => SearchResult[]) => {
registerRagPipelineNodeScope()
scopeRegistry.updateSearchHandler(scopeId, buildSearchHandler(fn))
}

View File

@@ -0,0 +1,123 @@
import type { SearchResult } from './types'
import { useCallback, useMemo, useSyncExternalStore } from 'react'
export type ScopeContext = {
isWorkflowPage: boolean
isRagPipelinePage: boolean
isAdmin?: boolean
}
export type ScopeSearchHandler = (
query: string,
searchTerm: string,
locale?: string,
) => Promise<SearchResult[]> | SearchResult[]
export type ScopeDescriptor = {
/**
* Unique identifier for the scope (e.g. 'app', 'plugin')
*/
id: string
/**
* Shortcut to trigger this scope (e.g. '@app')
*/
shortcut: string
/**
* Additional shortcuts that map to this scope (e.g. ['@kb'])
*/
aliases?: string[]
/**
* I18n key or string for the scope title
*/
title: string
/**
* Description for help text
*/
description: string
/**
* Search handler function
*/
search: ScopeSearchHandler
/**
* Predicate to check if this scope is available in current context
*/
isAvailable?: (context: ScopeContext) => boolean
}
type Listener = () => void
class ScopeRegistry {
private scopes: Map<string, ScopeDescriptor> = new Map()
private listeners: Set<Listener> = new Set()
private version = 0
register(scope: ScopeDescriptor) {
this.scopes.set(scope.id, scope)
this.notify()
}
unregister(id: string) {
if (this.scopes.delete(id))
this.notify()
}
getScope(id: string) {
return this.scopes.get(id)
}
getScopes(context: ScopeContext): ScopeDescriptor[] {
return Array.from(this.scopes.values())
.filter(scope => !scope.isAvailable || scope.isAvailable(context))
.sort((a, b) => a.shortcut.localeCompare(b.shortcut))
}
updateSearchHandler(id: string, search: ScopeSearchHandler) {
const scope = this.scopes.get(id)
if (!scope)
return
this.scopes.set(id, { ...scope, search })
this.notify()
}
getVersion() {
return this.version
}
subscribe(listener: Listener) {
this.listeners.add(listener)
return () => {
this.listeners.delete(listener)
}
}
private notify() {
this.version += 1
this.listeners.forEach(listener => listener())
}
}
export const scopeRegistry = new ScopeRegistry()
export const useScopeRegistry = (context: ScopeContext) => {
const subscribe = useCallback(
(listener: Listener) => scopeRegistry.subscribe(listener),
[],
)
const getSnapshot = useCallback(
() => scopeRegistry.getVersion(),
[],
)
const version = useSyncExternalStore(
subscribe,
getSnapshot,
getSnapshot,
)
return useMemo(
() => scopeRegistry.getScopes(context),
[version, context.isWorkflowPage, context.isRagPipelinePage, context.isAdmin],
)
}

View File

@@ -1,5 +1,4 @@
import type { ReactNode } from 'react'
import type { TypeWithI18N } from '../../base/form/types'
import type { Plugin } from '../../plugins/types'
import type { CommonNodeType } from '../../workflow/types'
import type { DataSet } from '@/models/datasets'
@@ -7,7 +6,7 @@ import type { App } from '@/types/app'
export type SearchResultType = 'app' | 'knowledge' | 'plugin' | 'workflow-node' | 'command'
export type BaseSearchResult<T = any> = {
export type BaseSearchResult<T = unknown> = {
id: string
title: string
description?: string
@@ -39,20 +38,8 @@ export type WorkflowNodeSearchResult = {
export type CommandSearchResult = {
type: 'command'
} & BaseSearchResult<{ command: string, args?: Record<string, any> }>
} & BaseSearchResult<{ command: string, args?: Record<string, unknown> }>
export type SearchResult = AppSearchResult | PluginSearchResult | KnowledgeSearchResult | WorkflowNodeSearchResult | CommandSearchResult
export type ActionItem = {
key: '@app' | '@knowledge' | '@plugin' | '@node' | '/'
shortcut: string
title: string | TypeWithI18N
description: string
action?: (data: SearchResult) => void
searchFn?: (searchTerm: string) => SearchResult[]
search: (
query: string,
searchTerm: string,
locale?: string,
) => (Promise<SearchResult[]> | SearchResult[])
}
export type { ScopeContext, ScopeDescriptor } from './scope-registry'

View File

@@ -1,24 +1,41 @@
import type { ActionItem } from './types'
import type { ScopeSearchHandler } from './scope-registry'
import type { SearchResult } from './types'
import { ACTION_KEYS } from '../constants'
import { scopeRegistry } from './scope-registry'
// Create the workflow nodes action
export const workflowNodesAction: ActionItem = {
key: '@node',
shortcut: '@node',
title: 'Search Workflow Nodes',
description: 'Find and jump to nodes in the current workflow by name or type',
searchFn: undefined, // Will be set by useWorkflowSearch hook
search: async (_, searchTerm = '', _locale) => {
const scopeId = 'workflow-node'
let scopeRegistered = false
const buildSearchHandler = (searchFn?: (searchTerm: string) => SearchResult[]): ScopeSearchHandler => {
return async (_, searchTerm = '', _locale) => {
try {
// Use the searchFn if available (set by useWorkflowSearch hook)
if (workflowNodesAction.searchFn)
return workflowNodesAction.searchFn(searchTerm)
// If not in workflow context, return empty array
if (searchFn)
return searchFn(searchTerm)
return []
}
catch (error) {
console.warn('Workflow nodes search failed:', error)
return []
}
},
}
}
export const registerWorkflowNodeScope = () => {
if (scopeRegistered)
return
scopeRegistered = true
scopeRegistry.register({
id: scopeId,
shortcut: ACTION_KEYS.NODE,
title: 'Search Workflow Nodes',
description: 'Find and jump to nodes in the current workflow by name or type',
isAvailable: context => context.isWorkflowPage,
search: buildSearchHandler(),
})
}
export const setWorkflowNodesSearchFn = (fn: (searchTerm: string) => SearchResult[]) => {
registerWorkflowNodeScope()
scopeRegistry.updateSearchHandler(scopeId, buildSearchHandler(fn))
}

View File

@@ -1,5 +1,5 @@
import type { ActionItem } from './actions/types'
import { render, screen } from '@testing-library/react'
import type { ScopeDescriptor } from './actions/scope-registry'
import { render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import { Command } from 'cmdk'
import * as React from 'react'
@@ -22,63 +22,229 @@ vi.mock('./actions/commands/registry', () => ({
},
}))
const createActions = (): Record<string, ActionItem> => ({
app: {
key: '@app',
type CommandSelectorProps = React.ComponentProps<typeof CommandSelector>
const mockScopes: ScopeDescriptor[] = [
{
id: 'app',
shortcut: '@app',
title: 'Apps',
title: 'Search Applications',
description: 'Search apps',
search: vi.fn(),
description: '',
} as ActionItem,
plugin: {
key: '@plugin',
},
{
id: 'knowledge',
shortcut: '@knowledge',
title: 'Search Knowledge Bases',
description: 'Search knowledge bases',
search: vi.fn(),
},
{
id: 'plugin',
shortcut: '@plugin',
title: 'Plugins',
title: 'Search Plugins',
description: 'Search plugins',
search: vi.fn(),
description: '',
} as ActionItem,
})
},
{
id: 'workflow-node',
shortcut: '@node',
title: 'Search Nodes',
description: 'Search workflow nodes',
search: vi.fn(),
},
]
const mockOnCommandSelect = vi.fn()
const mockOnCommandValueChange = vi.fn()
const buildCommandSelector = (props: Partial<CommandSelectorProps> = {}) => (
<Command>
<Command.List>
<CommandSelector
scopes={mockScopes}
onCommandSelect={mockOnCommandSelect}
{...props}
/>
</Command.List>
</Command>
)
const renderCommandSelector = (props: Partial<CommandSelectorProps> = {}) => {
return render(buildCommandSelector(props))
}
describe('CommandSelector', () => {
it('should list contextual search actions and notify selection', async () => {
const actions = createActions()
const onSelect = vi.fn()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter="app"
originalQuery="@app"
/>
</Command>,
)
const actionButton = screen.getByText('app.gotoAnything.actions.searchApplicationsDesc')
await userEvent.click(actionButton)
expect(onSelect).toHaveBeenCalledWith('@app')
beforeEach(() => {
vi.clearAllMocks()
})
it('should render slash commands when query starts with slash', async () => {
const actions = createActions()
const onSelect = vi.fn()
describe('Basic Rendering', () => {
it('should render all scopes when no filter is provided', () => {
renderCommandSelector()
render(
<Command>
<CommandSelector
actions={actions}
onCommandSelect={onSelect}
searchFilter="zen"
originalQuery="/zen"
/>
</Command>,
)
expect(screen.getByText('@app')).toBeInTheDocument()
expect(screen.getByText('@knowledge')).toBeInTheDocument()
expect(screen.getByText('@plugin')).toBeInTheDocument()
expect(screen.getByText('@node')).toBeInTheDocument()
})
const slashItem = await screen.findByText('app.gotoAnything.actions.zenDesc')
await userEvent.click(slashItem)
it('should render empty filter as showing all scopes', () => {
renderCommandSelector({ searchFilter: '' })
expect(onSelect).toHaveBeenCalledWith('/zen')
expect(screen.getByText('@app')).toBeInTheDocument()
expect(screen.getByText('@knowledge')).toBeInTheDocument()
expect(screen.getByText('@plugin')).toBeInTheDocument()
expect(screen.getByText('@node')).toBeInTheDocument()
})
})
describe('Filtering Functionality', () => {
it('should filter scopes based on searchFilter - single match', () => {
renderCommandSelector({ searchFilter: 'k' })
expect(screen.queryByText('@app')).not.toBeInTheDocument()
expect(screen.getByText('@knowledge')).toBeInTheDocument()
expect(screen.queryByText('@plugin')).not.toBeInTheDocument()
expect(screen.queryByText('@node')).not.toBeInTheDocument()
})
it('should filter scopes with multiple matches', () => {
renderCommandSelector({ searchFilter: 'p' })
expect(screen.getByText('@app')).toBeInTheDocument()
expect(screen.queryByText('@knowledge')).not.toBeInTheDocument()
expect(screen.getByText('@plugin')).toBeInTheDocument()
expect(screen.queryByText('@node')).not.toBeInTheDocument()
})
it('should be case-insensitive when filtering', () => {
renderCommandSelector({ searchFilter: 'APP' })
expect(screen.getByText('@app')).toBeInTheDocument()
expect(screen.queryByText('@knowledge')).not.toBeInTheDocument()
})
it('should match partial strings', () => {
renderCommandSelector({ searchFilter: 'od' })
expect(screen.queryByText('@app')).not.toBeInTheDocument()
expect(screen.queryByText('@knowledge')).not.toBeInTheDocument()
expect(screen.queryByText('@plugin')).not.toBeInTheDocument()
expect(screen.getByText('@node')).toBeInTheDocument()
})
})
describe('Empty State', () => {
it('should show empty state when no matches found', () => {
renderCommandSelector({ searchFilter: 'xyz' })
expect(screen.queryByText('@app')).not.toBeInTheDocument()
expect(screen.queryByText('@knowledge')).not.toBeInTheDocument()
expect(screen.queryByText('@plugin')).not.toBeInTheDocument()
expect(screen.queryByText('@node')).not.toBeInTheDocument()
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
expect(screen.getByText('app.gotoAnything.tryDifferentSearch')).toBeInTheDocument()
})
it('should not show empty state when filter is empty', () => {
renderCommandSelector({ searchFilter: '' })
expect(screen.queryByText('app.gotoAnything.noMatchingCommands')).not.toBeInTheDocument()
})
})
describe('Selection and Highlight Management', () => {
it('should call onCommandValueChange when filter changes and first item differs', async () => {
const { rerender } = renderCommandSelector({
searchFilter: '',
commandValue: '@app',
onCommandValueChange: mockOnCommandValueChange,
})
rerender(buildCommandSelector({
searchFilter: 'k',
commandValue: '@app',
onCommandValueChange: mockOnCommandValueChange,
}))
await waitFor(() => {
expect(mockOnCommandValueChange).toHaveBeenCalledWith('@knowledge')
})
})
it('should not call onCommandValueChange if current value still exists', async () => {
const { rerender } = renderCommandSelector({
searchFilter: '',
commandValue: '@app',
onCommandValueChange: mockOnCommandValueChange,
})
rerender(buildCommandSelector({
searchFilter: 'a',
commandValue: '@app',
onCommandValueChange: mockOnCommandValueChange,
}))
await waitFor(() => {
expect(mockOnCommandValueChange).not.toHaveBeenCalled()
})
})
it('should handle onCommandSelect callback correctly', async () => {
const user = userEvent.setup()
renderCommandSelector({ searchFilter: 'k' })
await user.click(screen.getByText('@knowledge'))
expect(mockOnCommandSelect).toHaveBeenCalledWith('@knowledge')
})
})
describe('Edge Cases', () => {
it('should handle empty scopes array', () => {
renderCommandSelector({ scopes: [] })
expect(screen.getByText('app.gotoAnything.noMatchingCommands')).toBeInTheDocument()
})
it('should handle special characters in filter', () => {
renderCommandSelector({ searchFilter: '@' })
expect(screen.getByText('@app')).toBeInTheDocument()
expect(screen.getByText('@knowledge')).toBeInTheDocument()
expect(screen.getByText('@plugin')).toBeInTheDocument()
expect(screen.getByText('@node')).toBeInTheDocument()
})
it('should handle undefined onCommandValueChange gracefully', () => {
const { rerender } = renderCommandSelector({ searchFilter: '' })
expect(() => {
rerender(buildCommandSelector({ searchFilter: 'k' }))
}).not.toThrow()
})
})
describe('User Interactions', () => {
it('should list contextual scopes and notify selection', async () => {
const user = userEvent.setup()
renderCommandSelector({ searchFilter: 'app', originalQuery: '@app' })
await user.click(screen.getByText('app.gotoAnything.actions.searchApplicationsDesc'))
expect(mockOnCommandSelect).toHaveBeenCalledWith('@app')
})
it('should render slash commands when query starts with slash', async () => {
const user = userEvent.setup()
renderCommandSelector({ searchFilter: 'zen', originalQuery: '/zen' })
const slashItem = await screen.findByText('app.gotoAnything.actions.zenDesc')
await user.click(slashItem)
expect(mockOnCommandSelect).toHaveBeenCalledWith('/zen')
})
})
})

View File

@@ -1,13 +1,14 @@
import type { FC } from 'react'
import type { ActionItem } from './actions/types'
import type { ScopeDescriptor } from './actions/scope-registry'
import { Command } from 'cmdk'
import { usePathname } from 'next/navigation'
import { useEffect, useMemo } from 'react'
import { useTranslation } from 'react-i18next'
import { slashCommandRegistry } from './actions/commands/registry'
import { ACTION_KEYS } from './constants'
type Props = {
actions: Record<string, ActionItem>
scopes: ScopeDescriptor[]
onCommandSelect: (commandKey: string) => void
searchFilter?: string
commandValue?: string
@@ -15,7 +16,7 @@ type Props = {
originalQuery?: string
}
const CommandSelector: FC<Props> = ({ actions, onCommandSelect, searchFilter, commandValue, onCommandValueChange, originalQuery }) => {
const CommandSelector: FC<Props> = ({ scopes, onCommandSelect, searchFilter, commandValue, onCommandValueChange, originalQuery }) => {
const { t } = useTranslation()
const pathname = usePathname()
@@ -43,22 +44,31 @@ const CommandSelector: FC<Props> = ({ actions, onCommandSelect, searchFilter, co
}))
}, [isSlashMode, searchFilter, pathname])
const filteredActions = useMemo(() => {
const filteredScopes = useMemo(() => {
if (isSlashMode)
return []
return Object.values(actions).filter((action) => {
return scopes.filter((scope) => {
// Exclude slash action when in @ mode
if (action.key === '/')
if (scope.id === 'slash' || scope.shortcut === ACTION_KEYS.SLASH)
return false
if (!searchFilter)
return true
const filterLower = searchFilter.toLowerCase()
return action.shortcut.toLowerCase().includes(filterLower)
})
}, [actions, searchFilter, isSlashMode])
const allItems = isSlashMode ? slashCommands : filteredActions
// Match against shortcut/aliases or title
const filterLower = searchFilter.toLowerCase()
const shortcuts = [scope.shortcut, ...(scope.aliases || [])]
return shortcuts.some(shortcut => shortcut.toLowerCase().includes(filterLower))
|| scope.title.toLowerCase().includes(filterLower)
}).map(scope => ({
key: scope.shortcut, // Map to shortcut for UI display consistency
shortcut: scope.shortcut,
title: scope.title,
description: scope.description,
}))
}, [scopes, searchFilter, isSlashMode])
const allItems = isSlashMode ? slashCommands : filteredScopes
useEffect(() => {
if (allItems.length > 0 && onCommandValueChange) {
@@ -116,6 +126,7 @@ const CommandSelector: FC<Props> = ({ actions, onCommandSelect, searchFilter, co
'/docs': 'gotoAnything.actions.docDesc',
'/community': 'gotoAnything.actions.communityDesc',
'/zen': 'gotoAnything.actions.zenDesc',
'/banana': 'gotoAnything.actions.vibeDesc',
} as const
return t(slashKeyMap[item.key as keyof typeof slashKeyMap] || item.description, { ns: 'app' })
})()

View File

@@ -0,0 +1,20 @@
/**
* Goto Anything Constants
* Centralized constants for action keys
*/
/**
* Action keys for scope-based searches
*/
export const ACTION_KEYS = {
APP: '@app',
KNOWLEDGE: '@knowledge',
PLUGIN: '@plugin',
NODE: '@node',
SLASH: '/',
} as const
/**
* Type-safe action key union type
*/
export type ActionKey = typeof ACTION_KEYS[keyof typeof ACTION_KEYS]

View File

@@ -0,0 +1,93 @@
import { keepPreviousData, useQuery } from '@tanstack/react-query'
import { useDebounce } from 'ahooks'
import { useMemo } from 'react'
import { useGetLanguage } from '@/context/i18n'
import { matchAction, searchAnything, useGotoAnythingScopes } from '../actions'
import { ACTION_KEYS } from '../constants'
import { useGotoAnythingContext } from '../context'
export const useSearch = (searchQuery: string) => {
const defaultLocale = useGetLanguage()
const { isWorkflowPage, isRagPipelinePage } = useGotoAnythingContext()
// Fetch scopes from registry based on context
const scopes = useGotoAnythingScopes({ isWorkflowPage, isRagPipelinePage })
const searchQueryDebouncedValue = useDebounce(searchQuery.trim(), {
wait: 300,
})
const isCommandsMode = searchQuery.trim() === '@' || searchQuery.trim() === '/'
|| (searchQuery.trim().startsWith('@') && !matchAction(searchQuery.trim(), scopes))
|| (searchQuery.trim().startsWith('/') && !matchAction(searchQuery.trim(), scopes))
const searchMode = useMemo(() => {
if (isCommandsMode) {
// Distinguish between @ (scopes) and / (commands) mode
if (searchQuery.trim().startsWith('@'))
return 'scopes'
else if (searchQuery.trim().startsWith('/'))
return 'commands'
return 'commands' // default fallback
}
const query = searchQueryDebouncedValue.toLowerCase()
const action = matchAction(query, scopes)
if (!action)
return 'general'
if (action.id === 'slash' || action.shortcut === ACTION_KEYS.SLASH)
return '@command'
return action.shortcut
}, [searchQueryDebouncedValue, scopes, isCommandsMode, searchQuery])
const { data: searchResults = [], isLoading, isError, error } = useQuery(
{
queryKey: [
'goto-anything',
'search-result',
searchQueryDebouncedValue,
searchMode,
isWorkflowPage,
isRagPipelinePage,
defaultLocale,
scopes.map(s => s.id).sort().join(','),
],
queryFn: async () => {
const query = searchQueryDebouncedValue.toLowerCase()
const scope = matchAction(query, scopes)
return await searchAnything(defaultLocale, query, scope, scopes)
},
enabled: !!searchQueryDebouncedValue && !isCommandsMode,
staleTime: 30000,
gcTime: 300000,
placeholderData: keepPreviousData,
},
)
const dedupedResults = useMemo(() => {
if (!searchQuery.trim())
return []
const seen = new Set<string>()
return searchResults.filter((result) => {
const key = `${result.type}-${result.id}`
if (seen.has(key))
return false
seen.add(key)
return true
})
}, [searchResults, searchQuery])
return {
scopes,
searchResults: dedupedResults,
isLoading,
isError,
error,
searchMode,
isCommandsMode,
}
}

View File

@@ -1,4 +1,5 @@
import type { ActionItem, SearchResult } from './actions/types'
import type { ScopeDescriptor } from './actions/scope-registry'
import type { SearchResult } from './actions/types'
import { act, render, screen } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import * as React from 'react'
@@ -35,6 +36,7 @@ const triggerKeyPress = (combo: string) => {
let mockQueryResult = { data: [] as SearchResult[], isLoading: false, isError: false, error: null as Error | null }
vi.mock('@tanstack/react-query', () => ({
useQuery: () => mockQueryResult,
keepPreviousData: (data: unknown) => data,
}))
vi.mock('@/context/i18n', () => ({
@@ -47,33 +49,30 @@ vi.mock('./context', () => ({
GotoAnythingProvider: ({ children }: { children: React.ReactNode }) => <>{children}</>,
}))
const createActionItem = (key: ActionItem['key'], shortcut: string): ActionItem => ({
key,
shortcut,
title: `${key} title`,
description: `${key} desc`,
action: vi.fn(),
search: vi.fn(),
type MatchAction = typeof import('./actions').matchAction
type SearchAnything = typeof import('./actions').searchAnything
const mockState = vi.hoisted(() => {
const state = {
scopes: [] as ScopeDescriptor[],
useGotoAnythingScopesMock: vi.fn(() => state.scopes),
matchActionMock: vi.fn<MatchAction>(() => undefined),
searchAnythingMock: vi.fn<SearchAnything>(async () => []),
}
return state
})
const actionsMock = {
slash: createActionItem('/', '/'),
app: createActionItem('@app', '@app'),
plugin: createActionItem('@plugin', '@plugin'),
}
const createActionsMock = vi.fn(() => actionsMock)
const matchActionMock = vi.fn(() => undefined)
const searchAnythingMock = vi.fn(async () => mockQueryResult.data)
vi.mock('./actions', () => ({
createActions: () => createActionsMock(),
matchAction: () => matchActionMock(),
searchAnything: () => searchAnythingMock(),
__esModule: true,
matchAction: (...args: Parameters<MatchAction>) => mockState.matchActionMock(...args),
searchAnything: (...args: Parameters<SearchAnything>) => mockState.searchAnythingMock(...args),
useGotoAnythingScopes: () => mockState.useGotoAnythingScopesMock(),
}))
vi.mock('./actions/commands', () => ({
SlashCommandProvider: () => null,
executeCommand: vi.fn(),
}))
vi.mock('./actions/commands/registry', () => ({
@@ -84,6 +83,20 @@ vi.mock('./actions/commands/registry', () => ({
},
}))
const createScope = (id: ScopeDescriptor['id'], shortcut: string): ScopeDescriptor => ({
id,
shortcut,
title: `${id} title`,
description: `${id} desc`,
search: vi.fn(),
})
const scopesMock = [
createScope('slash', '/'),
createScope('app', '@app'),
createScope('plugin', '@plugin'),
]
vi.mock('@/app/components/workflow/utils/common', () => ({
getKeyboardKeyCodeBySystem: () => 'ctrl',
isEventTargetInputArea: () => false,
@@ -108,8 +121,10 @@ describe('GotoAnything', () => {
routerPush.mockClear()
Object.keys(keyPressHandlers).forEach(key => delete keyPressHandlers[key])
mockQueryResult = { data: [], isLoading: false, isError: false, error: null }
matchActionMock.mockReset()
searchAnythingMock.mockClear()
mockState.scopes = scopesMock
mockState.matchActionMock.mockReset()
mockState.searchAnythingMock.mockClear()
mockState.searchAnythingMock.mockImplementation(async () => mockQueryResult.data)
})
it('should open modal via shortcut and navigate to selected result', async () => {

View File

@@ -4,23 +4,22 @@ import type { FC } from 'react'
import type { Plugin } from '../plugins/types'
import type { SearchResult } from './actions'
import { RiSearchLine } from '@remixicon/react'
import { useQuery } from '@tanstack/react-query'
import { useDebounce, useKeyPress } from 'ahooks'
import { useKeyPress } from 'ahooks'
import { Command } from 'cmdk'
import { useRouter } from 'next/navigation'
import { useCallback, useEffect, useMemo, useRef, useState } from 'react'
import { useTranslation } from 'react-i18next'
import Input from '@/app/components/base/input'
import Modal from '@/app/components/base/modal'
import { VIBE_COMMAND_EVENT } from '@/app/components/workflow/constants'
import { getKeyboardKeyCodeBySystem, isEventTargetInputArea, isMac } from '@/app/components/workflow/utils/common'
import { selectWorkflowNode } from '@/app/components/workflow/utils/node-navigation'
import { useGetLanguage } from '@/context/i18n'
import InstallFromMarketplace from '../plugins/install-plugin/install-from-marketplace'
import { createActions, matchAction, searchAnything } from './actions'
import { SlashCommandProvider } from './actions/commands'
import { executeCommand, SlashCommandProvider } from './actions/commands'
import { slashCommandRegistry } from './actions/commands/registry'
import CommandSelector from './command-selector'
import { GotoAnythingProvider, useGotoAnythingContext } from './context'
import { GotoAnythingProvider } from './context'
import { useSearch } from './hooks/use-search'
type Props = {
onHide?: () => void
@@ -29,19 +28,21 @@ const GotoAnything: FC<Props> = ({
onHide,
}) => {
const router = useRouter()
const defaultLocale = useGetLanguage()
const { isWorkflowPage, isRagPipelinePage } = useGotoAnythingContext()
const { t } = useTranslation()
const [show, setShow] = useState<boolean>(false)
const [searchQuery, setSearchQuery] = useState<string>('')
const [cmdVal, setCmdVal] = useState<string>('_')
const inputRef = useRef<HTMLInputElement>(null)
// Filter actions based on context
const Actions = useMemo(() => {
// Create actions based on current page context
return createActions(isWorkflowPage, isRagPipelinePage)
}, [isWorkflowPage, isRagPipelinePage])
const {
scopes,
searchResults: dedupedResults,
isLoading,
isError,
error,
searchMode,
isCommandsMode,
} = useSearch(searchQuery)
const [activePlugin, setActivePlugin] = useState<Plugin>()
@@ -73,56 +74,6 @@ const GotoAnything: FC<Props> = ({
}
})
const searchQueryDebouncedValue = useDebounce(searchQuery.trim(), {
wait: 300,
})
const isCommandsMode = searchQuery.trim() === '@' || searchQuery.trim() === '/'
|| (searchQuery.trim().startsWith('@') && !matchAction(searchQuery.trim(), Actions))
|| (searchQuery.trim().startsWith('/') && !matchAction(searchQuery.trim(), Actions))
const searchMode = useMemo(() => {
if (isCommandsMode) {
// Distinguish between @ (scopes) and / (commands) mode
if (searchQuery.trim().startsWith('@'))
return 'scopes'
else if (searchQuery.trim().startsWith('/'))
return 'commands'
return 'commands' // default fallback
}
const query = searchQueryDebouncedValue.toLowerCase()
const action = matchAction(query, Actions)
if (!action)
return 'general'
return action.key === '/' ? '@command' : action.key
}, [searchQueryDebouncedValue, Actions, isCommandsMode, searchQuery])
const { data: searchResults = [], isLoading, isError, error } = useQuery(
{
queryKey: [
'goto-anything',
'search-result',
searchQueryDebouncedValue,
searchMode,
isWorkflowPage,
isRagPipelinePage,
defaultLocale,
Object.keys(Actions).sort().join(','),
],
queryFn: async () => {
const query = searchQueryDebouncedValue.toLowerCase()
const action = matchAction(query, Actions)
return await searchAnything(defaultLocale, query, action, Actions)
},
enabled: !!searchQueryDebouncedValue && !isCommandsMode,
staleTime: 30000,
gcTime: 300000,
},
)
// Prevent automatic selection of the first option when cmdVal is not set
const clearSelection = () => {
setCmdVal('_')
@@ -158,9 +109,25 @@ const GotoAnything: FC<Props> = ({
switch (result.type) {
case 'command': {
// Execute slash commands
const action = Actions.slash
action?.action?.(result)
if (result.data.command === 'workflow.vibe') {
if (typeof document !== 'undefined') {
document.dispatchEvent(new CustomEvent(VIBE_COMMAND_EVENT, { detail: { dsl: result.data.args?.dsl } }))
}
break
}
// Execute slash commands using the command bus
// This handles both direct execution and submenu commands with args
const { command, args } = result.data
// Try executing via command bus first (preferred for submenu commands with args)
// We can't easily check if it exists in bus without potentially running it if we were to try/catch
// but typically search results point to valid bus commands.
executeCommand(command, args)
// Note: We previously checked slashCommandRegistry handlers here, but search results
// should return executable command strings (like 'theme.set') that are registered in the bus.
// The registry is mainly for the top-level command matching (e.g. /theme).
break
}
case 'plugin':
@@ -178,17 +145,6 @@ const GotoAnything: FC<Props> = ({
}
}, [router])
const dedupedResults = useMemo(() => {
const seen = new Set<string>()
return searchResults.filter((result) => {
const key = `${result.type}-${result.id}`
if (seen.has(key))
return false
seen.add(key)
return true
})
}, [searchResults])
// Group results by type
const groupedResults = useMemo(() => dedupedResults.reduce((acc, result) => {
if (!acc[result.type])
@@ -250,12 +206,12 @@ const GotoAnything: FC<Props> = ({
<div className="mt-1 text-xs text-text-quaternary">
{isCommandSearch
? t('gotoAnything.emptyState.tryDifferentTerm', { ns: 'app' })
: t('gotoAnything.emptyState.trySpecificSearch', { ns: 'app', shortcuts: Object.values(Actions).map(action => action.shortcut).join(', ') })}
: t('gotoAnything.emptyState.trySpecificSearch', { ns: 'app', shortcuts: scopes.map(s => s.shortcut).join(', ') })}
</div>
</div>
</div>
)
}, [dedupedResults, searchQuery, Actions, searchMode, isLoading, isError, isCommandsMode])
}, [dedupedResults, searchQuery, scopes, searchMode, isLoading, isError, isCommandsMode])
const defaultUI = useMemo(() => {
if (searchQuery.trim())
@@ -273,7 +229,7 @@ const GotoAnything: FC<Props> = ({
</div>
</div>
)
}, [searchQuery, Actions])
}, [searchQuery, scopes])
useEffect(() => {
if (show) {
@@ -380,7 +336,7 @@ const GotoAnything: FC<Props> = ({
<div>
<div className="text-sm font-medium text-red-500">{t('gotoAnything.searchFailed', { ns: 'app' })}</div>
<div className="mt-1 text-xs text-text-quaternary">
{error.message}
{error?.message}
</div>
</div>
</div>
@@ -390,7 +346,7 @@ const GotoAnything: FC<Props> = ({
{isCommandsMode
? (
<CommandSelector
actions={Actions}
scopes={scopes}
onCommandSelect={handleCommandSelect}
searchFilter={searchQuery.trim().substring(1)}
commandValue={cmdVal}

View File

@@ -5,7 +5,7 @@ import type { LLMNodeType } from '@/app/components/workflow/nodes/llm/types'
import type { ToolNodeType } from '@/app/components/workflow/nodes/tool/types'
import type { CommonNodeType } from '@/app/components/workflow/types'
import { useCallback, useEffect, useMemo } from 'react'
import { ragPipelineNodesAction } from '@/app/components/goto-anything/actions/rag-pipeline-nodes'
import { setRagPipelineNodesSearchFn } from '@/app/components/goto-anything/actions/rag-pipeline-nodes'
import BlockIcon from '@/app/components/workflow/block-icon'
import { useNodesInteractions } from '@/app/components/workflow/hooks/use-nodes-interactions'
import { useGetToolIcon } from '@/app/components/workflow/hooks/use-tool-icon'
@@ -153,16 +153,15 @@ export const useRagPipelineSearch = () => {
return results
}, [searchableNodes, calculateScore])
// Directly set the search function on the action object
// Directly set the search function using the setter
useEffect(() => {
if (searchableNodes.length > 0) {
// Set the search function directly on the action
ragPipelineNodesAction.searchFn = searchRagPipelineNodes
setRagPipelineNodesSearchFn(searchRagPipelineNodes)
}
return () => {
// Clean up when component unmounts
ragPipelineNodesAction.searchFn = undefined
setRagPipelineNodesSearchFn(() => [])
}
}, [searchableNodes, searchRagPipelineNodes])

View File

@@ -9,6 +9,8 @@ export const NODE_WIDTH = 240
export const X_OFFSET = 60
export const NODE_WIDTH_X_OFFSET = NODE_WIDTH + X_OFFSET
export const Y_OFFSET = 39
export const VIBE_COMMAND_EVENT = 'workflow-vibe-command'
export const VIBE_APPLY_EVENT = 'workflow-vibe-apply'
export const START_INITIAL_POSITION = { x: 80, y: 282 }
export const AUTO_LAYOUT_OFFSET = {
x: -42,

View File

@@ -0,0 +1,80 @@
import { describe, expect, it } from 'vitest'
import { replaceVariableReferences } from '../use-workflow-vibe'
// Mock types needed for the test
type NodeData = {
title: string
[key: string]: any
}
describe('use-workflow-vibe', () => {
describe('replaceVariableReferences', () => {
it('should replace variable references in strings', () => {
const data = {
title: 'Test Node',
prompt: 'Hello {{#old_id.query#}}',
}
const nodeIdMap = new Map<string, any>()
nodeIdMap.set('old_id', { id: 'new_uuid', data: { type: 'llm' } })
const result = replaceVariableReferences(data, nodeIdMap) as NodeData
expect(result.prompt).toBe('Hello {{#new_uuid.query#}}')
})
it('should handle multiple references in one string', () => {
const data = {
title: 'Test Node',
text: '{{#node1.out#}} and {{#node2.out#}}',
}
const nodeIdMap = new Map<string, any>()
nodeIdMap.set('node1', { id: 'uuid1', data: { type: 'llm' } })
nodeIdMap.set('node2', { id: 'uuid2', data: { type: 'llm' } })
const result = replaceVariableReferences(data, nodeIdMap) as NodeData
expect(result.text).toBe('{{#uuid1.out#}} and {{#uuid2.out#}}')
})
it('should replace variable references in value_selector arrays', () => {
const data = {
title: 'End Node',
outputs: [
{
variable: 'result',
value_selector: ['old_id', 'text'],
},
],
}
const nodeIdMap = new Map<string, any>()
nodeIdMap.set('old_id', { id: 'new_uuid', data: { type: 'llm' } })
const result = replaceVariableReferences(data, nodeIdMap) as NodeData
expect(result.outputs[0].value_selector).toEqual(['new_uuid', 'text'])
})
it('should handle nested objects recursively', () => {
const data = {
config: {
model: {
prompt: '{{#old_id.text#}}',
},
},
}
const nodeIdMap = new Map<string, any>()
nodeIdMap.set('old_id', { id: 'new_uuid', data: { type: 'llm' } })
const result = replaceVariableReferences(data, nodeIdMap) as any
expect(result.config.model.prompt).toBe('{{#new_uuid.text#}}')
})
it('should ignoring missing node mappings', () => {
const data = {
text: '{{#missing_id.text#}}',
}
const nodeIdMap = new Map<string, any>()
// missing_id is not in map
const result = replaceVariableReferences(data, nodeIdMap) as NodeData
expect(result.text).toBe('{{#missing_id.text#}}')
})
})
})

View File

@@ -24,3 +24,5 @@ export * from './use-workflow-run'
export * from './use-workflow-search'
export * from './use-workflow-start-run'
export * from './use-workflow-variables'
export * from './use-workflow-vibe'
export * from './use-workflow-vibe-config'

View File

@@ -160,7 +160,7 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
}
}
else {
usedVars = getNodeUsedVars(node).filter(v => v.length > 0)
usedVars = getNodeUsedVars(node).filter(v => v && v.length > 0)
}
if (node.type === CUSTOM_NODE) {
@@ -359,7 +359,7 @@ export const useChecklistBeforePublish = () => {
}
}
else {
usedVars = getNodeUsedVars(node).filter(v => v.length > 0)
usedVars = getNodeUsedVars(node).filter(v => v && v.length > 0)
}
const checkData = getCheckData(node.data, datasets)
const { errorMessage } = nodesExtraData![node.data.type as BlockEnum].checkValid(checkData, t, moreDataForCheckValid)

View File

@@ -5,7 +5,7 @@ import type { CommonNodeType } from '../types'
import type { Emoji } from '@/app/components/tools/types'
import { useCallback, useEffect, useMemo } from 'react'
import { useNodes } from 'reactflow'
import { workflowNodesAction } from '@/app/components/goto-anything/actions/workflow-nodes'
import { setWorkflowNodesSearchFn } from '@/app/components/goto-anything/actions/workflow-nodes'
import { CollectionType } from '@/app/components/tools/types'
import BlockIcon from '@/app/components/workflow/block-icon'
import {
@@ -183,16 +183,15 @@ export const useWorkflowSearch = () => {
return results
}, [searchableNodes, calculateScore])
// Directly set the search function on the action object
// Directly set the search function using the setter
useEffect(() => {
if (searchableNodes.length > 0) {
// Set the search function directly on the action
workflowNodesAction.searchFn = searchWorkflowNodes
setWorkflowNodesSearchFn(searchWorkflowNodes)
}
return () => {
// Clean up when component unmounts
workflowNodesAction.searchFn = undefined
setWorkflowNodesSearchFn(() => [])
}
}, [searchableNodes, searchWorkflowNodes])

View File

@@ -0,0 +1,99 @@
/**
* Vibe Workflow Generator Configuration
*
* This module centralizes configuration for the Vibe workflow generation feature,
* including node type aliases and field name corrections.
*
* Note: These definitions are mirrored in the backend at:
* api/core/workflow/generator/config/node_schemas.py
* When updating these values, also update the backend file.
*/
/**
* Node type aliases for inference from natural language.
* Maps common terms to canonical node type names.
*/
export const NODE_TYPE_ALIASES: Record<string, string> = {
// Start node aliases
'start': 'start',
'begin': 'start',
'input': 'start',
// End node aliases
'end': 'end',
'finish': 'end',
'output': 'end',
// LLM node aliases
'llm': 'llm',
'ai': 'llm',
'gpt': 'llm',
'model': 'llm',
'chat': 'llm',
// Code node aliases
'code': 'code',
'script': 'code',
'python': 'code',
'javascript': 'code',
// HTTP request node aliases
'http-request': 'http-request',
'http': 'http-request',
'request': 'http-request',
'api': 'http-request',
'fetch': 'http-request',
'webhook': 'http-request',
// Conditional node aliases
'if-else': 'if-else',
'condition': 'if-else',
'branch': 'if-else',
'switch': 'if-else',
// Loop node aliases
'iteration': 'iteration',
'loop': 'loop',
'foreach': 'iteration',
// Tool node alias
'tool': 'tool',
}
/**
* Field name corrections for LLM-generated node configs.
* Maps incorrect field names to correct ones for specific node types.
*/
export const FIELD_NAME_CORRECTIONS: Record<string, Record<string, string>> = {
'http-request': {
text: 'body', // LLM might use "text" instead of "body"
content: 'body',
response: 'body',
},
'code': {
text: 'result', // LLM might use "text" instead of "result"
output: 'result',
},
'llm': {
response: 'text',
answer: 'text',
},
}
/**
* Correct field names based on node type.
* LLM sometimes generates wrong field names (e.g., "text" instead of "body" for HTTP nodes).
*
* @param field - The field name to correct
* @param nodeType - The type of the node
* @returns The corrected field name, or the original if no correction needed
*/
export const correctFieldName = (field: string, nodeType: string): string => {
const corrections = FIELD_NAME_CORRECTIONS[nodeType]
if (corrections && corrections[field])
return corrections[field]
return field
}
/**
* Get the canonical node type from an alias.
*
* @param alias - The alias to look up
* @returns The canonical node type, or undefined if not found
*/
export const getCanonicalNodeType = (alias: string): string | undefined => {
return NODE_TYPE_ALIASES[alias.toLowerCase()]
}

File diff suppressed because it is too large Load Diff

View File

@@ -471,12 +471,14 @@ export const useNodesReadOnly = () => {
const workflowRunningData = useStore(s => s.workflowRunningData)
const historyWorkflowData = useStore(s => s.historyWorkflowData)
const isRestoring = useStore(s => s.isRestoring)
// const showVibePanel = useStore(s => s.showVibePanel)
const getNodesReadOnly = useCallback((): boolean => {
const {
workflowRunningData,
historyWorkflowData,
isRestoring,
// showVibePanel,
} = workflowStore.getState()
return !!(workflowRunningData?.result.status === WorkflowRunningStatus.Running || historyWorkflowData || isRestoring)

View File

@@ -68,6 +68,7 @@ import {
useWorkflow,
useWorkflowReadOnly,
useWorkflowRefreshDraft,
useWorkflowVibe,
} from './hooks'
import { HooksStoreContextProvider, useHooksStore } from './hooks-store'
import { useWorkflowSearch } from './hooks/use-workflow-search'
@@ -319,6 +320,7 @@ export const Workflow: FC<WorkflowProps> = memo(({
useShortcuts()
// Initialize workflow node search functionality
useWorkflowSearch()
useWorkflowVibe()
// Set up scroll to node event listener using the utility function
useEffect(() => {

View File

@@ -33,9 +33,9 @@ const FileUploadSetting: FC<Props> = ({
const { t } = useTranslation()
const {
allowed_file_upload_methods,
allowed_file_upload_methods = [],
max_length,
allowed_file_types,
allowed_file_types = [],
allowed_file_extensions,
} = payload
const { data: fileUploadConfigResponse } = useFileUploadConfig()

View File

@@ -1391,9 +1391,9 @@ export const getNodeUsedVars = (node: Node): ValueSelector[] => {
payload.url,
payload.headers,
payload.params,
typeof payload.body.data === 'string'
typeof payload.body?.data === 'string'
? payload.body.data
: payload.body.data.map(d => d.value).join(''),
: (payload.body?.data?.map(d => d.value).join('') ?? ''),
])
break
}

View File

@@ -5,6 +5,9 @@ import { useCallback, useEffect, useState } from 'react'
const UNIQUE_ID_PREFIX = 'key-value-'
const strToKeyValueList = (value: string) => {
if (typeof value !== 'string' || !value)
return []
return value.split('\n').map((item) => {
const [key, ...others] = item.split(':')
return {
@@ -16,7 +19,7 @@ const strToKeyValueList = (value: string) => {
}
const useKeyValueList = (value: string, onChange: (value: string) => void, noFilter?: boolean) => {
const [list, doSetList] = useState<KeyValue[]>(() => value ? strToKeyValueList(value) : [])
const [list, doSetList] = useState<KeyValue[]>(() => typeof value === 'string' && value ? strToKeyValueList(value) : [])
const setList = (l: KeyValue[]) => {
doSetList(l.map((item) => {
return {

View File

@@ -49,7 +49,7 @@ const ConditionValue = ({
if (value === true || value === false)
return value ? 'True' : 'False'
return value.replace(/\{\{#([^#]*)#\}\}/g, (a, b) => {
return String(value).replace(/\{\{#([^#]*)#\}\}/g, (a, b) => {
const arr: string[] = b.split('.')
if (isSystemVar(arr))
return `{{${b}}}`

View File

@@ -127,23 +127,30 @@ const NodeGroupItem = ({
!!item.variables.length && (
<div className="space-y-0.5">
{
item.variables.map((variable = [], index) => {
const isSystem = isSystemVar(variable)
item.variables
.map((variable = [], index) => {
// Ensure variable is an array
const safeVariable = Array.isArray(variable) ? variable : []
if (!safeVariable.length)
return null
const node = isSystem ? nodes.find(node => node.data.type === BlockEnum.Start) : nodes.find(node => node.id === variable[0])
const varName = isSystem ? `sys.${variable[variable.length - 1]}` : variable.slice(1).join('.')
const isException = isExceptionVariable(varName, node?.data.type)
const isSystem = isSystemVar(safeVariable)
return (
<VariableLabelInNode
key={index}
variables={variable}
nodeType={node?.data.type}
nodeTitle={node?.data.title}
isExceptionVariable={isException}
/>
)
})
const node = isSystem ? nodes.find(node => node.data.type === BlockEnum.Start) : nodes.find(node => node.id === safeVariable[0])
const varName = isSystem ? `sys.${safeVariable[safeVariable.length - 1]}` : safeVariable.slice(1).join('.')
const isException = isExceptionVariable(varName, node?.data.type)
return (
<VariableLabelInNode
key={index}
variables={safeVariable}
nodeType={node?.data.type}
nodeTitle={node?.data.title}
isExceptionVariable={isException}
/>
)
})
.filter(Boolean)
}
</div>
)

View File

@@ -8,6 +8,7 @@ import { cn } from '@/utils/classnames'
import { Panel as NodePanel } from '../nodes'
import { useStore } from '../store'
import EnvPanel from './env-panel'
import VibePanel from './vibe-panel'
const VersionHistoryPanel = dynamic(() => import('@/app/components/workflow/panel/version-history-panel'), {
ssr: false,
@@ -85,6 +86,7 @@ const Panel: FC<PanelProps> = ({
const showEnvPanel = useStore(s => s.showEnvPanel)
const isRestoring = useStore(s => s.isRestoring)
const showWorkflowVersionHistoryPanel = useStore(s => s.showWorkflowVersionHistoryPanel)
const showVibePanel = useStore(s => s.showVibePanel)
// widths used for adaptive layout
const workflowCanvasWidth = useStore(s => s.workflowCanvasWidth)
@@ -124,33 +126,36 @@ const Panel: FC<PanelProps> = ({
)
return (
<div
ref={rightPanelRef}
tabIndex={-1}
className={cn('absolute bottom-1 right-0 top-14 z-10 flex outline-none')}
key={`${isRestoring}`}
>
{components?.left}
{!!selectedNode && <NodePanel {...selectedNode} />}
<>
<div
className="relative"
ref={otherPanelRef}
ref={rightPanelRef}
tabIndex={-1}
className={cn('absolute bottom-1 right-0 top-14 z-10 flex outline-none')}
key={`${isRestoring}`}
>
{
components?.right
}
{
showWorkflowVersionHistoryPanel && (
<VersionHistoryPanel {...versionHistoryPanelProps} />
)
}
{
showEnvPanel && (
<EnvPanel />
)
}
{components?.left}
{!!selectedNode && <NodePanel {...selectedNode} />}
<div
className="relative"
ref={otherPanelRef}
>
{
components?.right
}
{
showWorkflowVersionHistoryPanel && (
<VersionHistoryPanel {...versionHistoryPanelProps} />
)
}
{
showEnvPanel && (
<EnvPanel />
)
}
</div>
</div>
</div>
{showVibePanel && <VibePanel />}
</>
)
}

View File

@@ -0,0 +1,333 @@
/**
* VibePanel Component Tests
*
* Covers rendering states, user interactions, and edge cases for the vibe panel.
*/
import type { Shape as WorkflowState } from '@/app/components/workflow/store/workflow'
import type { Edge, Node } from '@/app/components/workflow/types'
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
import userEvent from '@testing-library/user-event'
import Toast from '@/app/components/base/toast'
import { WorkflowContext } from '@/app/components/workflow/context'
import { HooksStoreContext } from '@/app/components/workflow/hooks-store/provider'
import { createHooksStore } from '@/app/components/workflow/hooks-store/store'
import { createWorkflowStore } from '@/app/components/workflow/store/workflow'
import { BlockEnum } from '@/app/components/workflow/types'
import { VIBE_APPLY_EVENT, VIBE_COMMAND_EVENT } from '../../constants'
import VibePanel from './index'
// ============================================================================
// Mocks
// ============================================================================
const mockCopy = vi.hoisted(() => vi.fn())
const mockUseVibeFlowData = vi.hoisted(() => vi.fn())
vi.mock('copy-to-clipboard', () => ({
default: mockCopy,
}))
vi.mock('@/app/components/header/account-setting/model-provider-page/hooks', () => ({
useModelListAndDefaultModelAndCurrentProviderAndModel: () => ({ defaultModel: null }),
}))
vi.mock('@/app/components/header/account-setting/model-provider-page/model-parameter-modal', () => ({
__esModule: true,
default: ({ modelId, provider }: { modelId: string, provider: string }) => (
<div data-testid="model-parameter-modal" data-model-id={modelId} data-provider={provider} />
),
}))
vi.mock('@/app/components/workflow/hooks/use-workflow-vibe', () => ({
useVibeFlowData: () => mockUseVibeFlowData(),
}))
vi.mock('@/app/components/workflow/workflow-preview', () => ({
__esModule: true,
default: ({ nodes, edges }: { nodes: Node[], edges: Edge[] }) => (
<div data-testid="workflow-preview" data-nodes-count={nodes.length} data-edges-count={edges.length} />
),
}))
// ============================================================================
// Test Utilities
// ============================================================================
type FlowGraph = {
nodes: Node[]
edges: Edge[]
}
type VibeFlowData = {
versions: FlowGraph[]
currentVersionIndex: number
setCurrentVersionIndex: (index: number) => void
current?: FlowGraph
}
const createMockNode = (overrides: Partial<Node> = {}): Node => ({
id: 'node-1',
type: 'custom',
position: { x: 0, y: 0 },
data: {
title: 'Start',
desc: '',
type: BlockEnum.Start,
},
...overrides,
})
const createMockEdge = (overrides: Partial<Edge> = {}): Edge => ({
id: 'edge-1',
source: 'node-1',
target: 'node-2',
data: {
sourceType: BlockEnum.Start,
targetType: BlockEnum.End,
},
...overrides,
})
const createFlowGraph = (overrides: Partial<FlowGraph> = {}): FlowGraph => ({
nodes: [],
edges: [],
...overrides,
})
const createVibeFlowData = (overrides: Partial<VibeFlowData> = {}): VibeFlowData => ({
versions: [],
currentVersionIndex: 0,
setCurrentVersionIndex: vi.fn(),
current: undefined,
...overrides,
})
const renderVibePanel = ({
workflowState,
vibeFlowData,
}: {
workflowState?: Partial<WorkflowState>
vibeFlowData?: VibeFlowData
} = {}) => {
if (vibeFlowData)
mockUseVibeFlowData.mockReturnValue(vibeFlowData)
const workflowStore = createWorkflowStore({})
const vibeFlowState = vibeFlowData
? {
vibeFlowVersions: vibeFlowData.versions,
vibeFlowCurrentIndex: vibeFlowData.currentVersionIndex,
currentVibeFlow: vibeFlowData.current,
}
: {}
workflowStore.setState({
showVibePanel: true,
isVibeGenerating: false,
vibePanelInstruction: '',
vibePanelMermaidCode: '',
...vibeFlowState,
...workflowState,
})
const hooksStore = createHooksStore({})
return {
workflowStore,
...render(
<WorkflowContext.Provider value={workflowStore}>
<HooksStoreContext.Provider value={hooksStore}>
<VibePanel />
</HooksStoreContext.Provider>
</WorkflowContext.Provider>,
),
}
}
const getCopyButton = () => {
const buttons = screen.getAllByRole('button')
const copyButton = buttons.find(button => button.textContent?.trim() === '' && button.querySelector('svg'))
if (!copyButton)
throw new Error('Copy button not found')
return copyButton
}
// ============================================================================
// Tests
// ============================================================================
describe('VibePanel', () => {
let toastNotifySpy: ReturnType<typeof vi.spyOn>
beforeEach(() => {
vi.clearAllMocks()
mockUseVibeFlowData.mockReturnValue(createVibeFlowData())
toastNotifySpy = vi.spyOn(Toast, 'notify').mockImplementation(() => ({ clear: vi.fn() }))
})
afterEach(() => {
toastNotifySpy.mockRestore()
})
// --------------------------------------------------------------------------
// Rendering: default visibility and primary view states.
// --------------------------------------------------------------------------
describe('Rendering', () => {
it('should render nothing when panel is hidden', () => {
renderVibePanel({ workflowState: { showVibePanel: false } })
expect(screen.queryByText(/app\.gotoAnything\.actions\.vibeTitle/i)).not.toBeInTheDocument()
})
it('should render placeholder when no preview data and not generating', () => {
renderVibePanel({
workflowState: { showVibePanel: true, isVibeGenerating: false },
vibeFlowData: createVibeFlowData({ current: undefined }),
})
expect(screen.getByText(/appDebug\.generate\.newNoDataLine1/i)).toBeInTheDocument()
})
it('should render loading state when generating', () => {
renderVibePanel({
workflowState: { showVibePanel: true, isVibeGenerating: true },
})
expect(screen.getByText(/workflow\.vibe\.generatingFlowchart/i)).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'appDebug.generate.generate' })).toBeDisabled()
})
it('should render preview panel when nodes exist', () => {
const flowGraph = createFlowGraph({
nodes: [createMockNode()],
edges: [createMockEdge()],
})
renderVibePanel({
vibeFlowData: createVibeFlowData({
current: flowGraph,
versions: [flowGraph],
}),
})
expect(screen.getByTestId('workflow-preview')).toBeInTheDocument()
expect(screen.getByRole('button', { name: 'workflow.vibe.apply' })).toBeInTheDocument()
expect(screen.getByText(/appDebug\.generate\.version/i)).toBeInTheDocument()
})
})
// --------------------------------------------------------------------------
// Props: store-driven inputs that toggle behavior.
// --------------------------------------------------------------------------
describe('Props', () => {
it('should render modal content when showVibePanel is true', () => {
renderVibePanel({ workflowState: { showVibePanel: true } })
expect(screen.getByText(/app\.gotoAnything\.actions\.vibeTitle/i)).toBeInTheDocument()
})
})
// --------------------------------------------------------------------------
// User Interactions: input edits and action triggers.
// --------------------------------------------------------------------------
describe('User Interactions', () => {
it('should update instruction in store when typing', async () => {
const { workflowStore } = renderVibePanel()
const textarea = screen.getByPlaceholderText('workflow.vibe.missingInstruction')
fireEvent.change(textarea, { target: { value: 'Build a vibe flow' } })
await waitFor(() => {
expect(workflowStore.getState().vibePanelInstruction).toBe('Build a vibe flow')
})
})
it('should dispatch command event with instruction when generate clicked', async () => {
const user = userEvent.setup()
const { workflowStore } = renderVibePanel({
workflowState: { vibePanelInstruction: 'Generate a workflow' },
})
const handler = vi.fn()
document.addEventListener(VIBE_COMMAND_EVENT, handler)
await user.click(screen.getByRole('button', { name: 'appDebug.generate.generate' }))
expect(handler).toHaveBeenCalledTimes(1)
const event = handler.mock.calls[0][0] as CustomEvent<{ dsl?: string }>
expect(event.detail).toEqual({ dsl: workflowStore.getState().vibePanelInstruction })
document.removeEventListener(VIBE_COMMAND_EVENT, handler)
})
it('should close panel when dismiss clicked', async () => {
const user = userEvent.setup()
const { workflowStore } = renderVibePanel({
workflowState: {
vibePanelMermaidCode: 'graph TD',
isVibeGenerating: true,
},
})
await user.click(screen.getByRole('button', { name: 'appDebug.generate.dismiss' }))
const state = workflowStore.getState()
expect(state.showVibePanel).toBe(false)
expect(state.vibePanelMermaidCode).toBe('')
expect(state.isVibeGenerating).toBe(false)
})
it('should dispatch apply event and close panel when apply clicked', async () => {
const user = userEvent.setup()
const flowGraph = createFlowGraph({
nodes: [createMockNode()],
edges: [createMockEdge()],
})
const { workflowStore } = renderVibePanel({
workflowState: { vibePanelMermaidCode: 'graph TD' },
vibeFlowData: createVibeFlowData({
current: flowGraph,
versions: [flowGraph],
}),
})
const handler = vi.fn()
document.addEventListener(VIBE_APPLY_EVENT, handler)
await user.click(screen.getByRole('button', { name: 'workflow.vibe.apply' }))
expect(handler).toHaveBeenCalledTimes(1)
const state = workflowStore.getState()
expect(state.showVibePanel).toBe(false)
expect(state.vibePanelMermaidCode).toBe('')
expect(state.isVibeGenerating).toBe(false)
document.removeEventListener(VIBE_APPLY_EVENT, handler)
})
it('should copy mermaid and notify when copy clicked', async () => {
const user = userEvent.setup()
const flowGraph = createFlowGraph({
nodes: [createMockNode()],
edges: [createMockEdge()],
})
renderVibePanel({
workflowState: { vibePanelMermaidCode: 'graph TD' },
vibeFlowData: createVibeFlowData({
current: flowGraph,
versions: [flowGraph],
}),
})
await user.click(getCopyButton())
expect(mockCopy).toHaveBeenCalledWith('graph TD')
expect(toastNotifySpy).toHaveBeenCalledWith(expect.objectContaining({
type: 'success',
message: 'common.actionMsg.copySuccessfully',
}))
})
})
})

View File

@@ -0,0 +1,298 @@
'use client'
import type { FC } from 'react'
import type { FormValue } from '@/app/components/header/account-setting/model-provider-page/declarations'
import type { CompletionParams, Model } from '@/types/app'
import { RiClipboardLine, RiInformation2Line } from '@remixicon/react'
import copy from 'copy-to-clipboard'
import { useCallback, useEffect, useMemo, useState } from 'react'
import { useTranslation } from 'react-i18next'
import ResPlaceholder from '@/app/components/app/configuration/config/automatic/res-placeholder'
import VersionSelector from '@/app/components/app/configuration/config/automatic/version-selector'
import Button from '@/app/components/base/button'
import { Generator } from '@/app/components/base/icons/src/vender/other'
import Loading from '@/app/components/base/loading'
import Modal from '@/app/components/base/modal'
import Textarea from '@/app/components/base/textarea'
import Toast from '@/app/components/base/toast'
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useModelListAndDefaultModelAndCurrentProviderAndModel } from '@/app/components/header/account-setting/model-provider-page/hooks'
import ModelParameterModal from '@/app/components/header/account-setting/model-provider-page/model-parameter-modal'
import { ModelModeType } from '@/types/app'
import { VIBE_APPLY_EVENT, VIBE_COMMAND_EVENT } from '../../constants'
import { useStore, useWorkflowStore } from '../../store'
import WorkflowPreview from '../../workflow-preview'
const VibePanel: FC = () => {
const { t } = useTranslation()
const workflowStore = useWorkflowStore()
const showVibePanel = useStore(s => s.showVibePanel)
const setShowVibePanel = useStore(s => s.setShowVibePanel)
const isVibeGenerating = useStore(s => s.isVibeGenerating)
const setIsVibeGenerating = useStore(s => s.setIsVibeGenerating)
const vibePanelInstruction = useStore(s => s.vibePanelInstruction)
const vibePanelMermaidCode = useStore(s => s.vibePanelMermaidCode)
const setVibePanelMermaidCode = useStore(s => s.setVibePanelMermaidCode)
const currentFlowGraph = useStore(s => s.currentVibeFlow)
const versions = useStore(s => s.vibeFlowVersions)
const currentVersionIndex = useStore(s => s.vibeFlowCurrentIndex)
const vibePanelPreviewNodes = currentFlowGraph?.nodes || []
const vibePanelPreviewEdges = currentFlowGraph?.edges || []
const setVibePanelInstruction = useStore(s => s.setVibePanelInstruction)
const vibePanelIntent = useStore(s => s.vibePanelIntent)
const setVibePanelIntent = useStore(s => s.setVibePanelIntent)
const vibePanelMessage = useStore(s => s.vibePanelMessage)
const setVibePanelMessage = useStore(s => s.setVibePanelMessage)
const vibePanelSuggestions = useStore(s => s.vibePanelSuggestions)
const setVibePanelSuggestions = useStore(s => s.setVibePanelSuggestions)
const localModel = localStorage.getItem('auto-gen-model')
? JSON.parse(localStorage.getItem('auto-gen-model') as string) as Model
: null
const [model, setModel] = useState<Model>(localModel || {
name: '',
provider: '',
mode: ModelModeType.chat,
completion_params: {} as CompletionParams,
})
const { defaultModel } = useModelListAndDefaultModelAndCurrentProviderAndModel(ModelTypeEnum.textGeneration)
useEffect(() => {
if (defaultModel) {
const localModel = localStorage.getItem('auto-gen-model')
? JSON.parse(localStorage.getItem('auto-gen-model') || '')
: null
if (localModel) {
setModel(localModel)
}
else {
setModel(prev => ({
...prev,
name: defaultModel.model,
provider: defaultModel.provider.provider,
}))
}
}
}, [defaultModel])
const handleModelChange = useCallback((newValue: { modelId: string, provider: string, mode?: string, features?: string[] }) => {
const newModel = {
...model,
provider: newValue.provider,
name: newValue.modelId,
mode: newValue.mode as ModelModeType,
}
setModel(newModel)
localStorage.setItem('auto-gen-model', JSON.stringify(newModel))
}, [model])
const handleCompletionParamsChange = useCallback((newParams: FormValue) => {
const newModel = {
...model,
completion_params: newParams as CompletionParams,
}
setModel(newModel)
localStorage.setItem('auto-gen-model', JSON.stringify(newModel))
}, [model])
const handleInstructionChange = useCallback((e: React.ChangeEvent<HTMLTextAreaElement>) => {
workflowStore.setState(state => ({
...state,
vibePanelInstruction: e.target.value,
}))
}, [workflowStore])
const handleClose = useCallback(() => {
setShowVibePanel(false)
setVibePanelMermaidCode('')
setIsVibeGenerating(false)
setVibePanelIntent('')
setVibePanelMessage('')
setVibePanelSuggestions([])
}, [setShowVibePanel, setVibePanelMermaidCode, setIsVibeGenerating, setVibePanelIntent, setVibePanelMessage, setVibePanelSuggestions])
const handleGenerate = useCallback(() => {
const event = new CustomEvent(VIBE_COMMAND_EVENT, {
detail: { dsl: vibePanelInstruction },
})
document.dispatchEvent(event)
}, [vibePanelInstruction])
const handleAccept = useCallback(() => {
const event = new CustomEvent(VIBE_APPLY_EVENT)
document.dispatchEvent(event)
handleClose()
}, [handleClose])
const handleCopyMermaid = useCallback(() => {
copy(vibePanelMermaidCode)
Toast.notify({ type: 'success', message: t('actionMsg.copySuccessfully', { ns: 'common' }) })
}, [vibePanelMermaidCode, t])
const handleSuggestionClick = useCallback((suggestion: string) => {
setVibePanelInstruction(suggestion)
// Trigger generation with the suggestion
const event = new CustomEvent(VIBE_COMMAND_EVENT, {
detail: { dsl: suggestion },
})
document.dispatchEvent(event)
}, [setVibePanelInstruction])
const handleVersionChange = useCallback((index: number) => {
const { setVibeFlowCurrentIndex } = workflowStore.getState()
setVibeFlowCurrentIndex(index)
}, [workflowStore])
// Button label - always use "Generate" (refinement mode removed)
const generateButtonLabel = useMemo(() => {
return t('generate.generate', { ns: 'appDebug' })
}, [t])
if (!showVibePanel)
return null
const renderLoading = (
<div className="flex h-full w-0 grow flex-col items-center justify-center space-y-3">
<Loading />
<div className="text-[13px] text-text-tertiary">{t('vibe.generatingFlowchart', { ns: 'workflow' })}</div>
</div>
)
const renderOffTopic = (
<div className="flex h-full w-0 grow flex-col items-center justify-center bg-background-default-subtle p-6">
<div className="flex max-w-[400px] flex-col items-center text-center">
<div className="mb-4 flex h-12 w-12 items-center justify-center rounded-full bg-state-warning-hover">
<RiInformation2Line className="h-6 w-6 text-text-warning" />
</div>
<div className="mb-2 text-base font-semibold text-text-primary">
{t('vibe.offTopicTitle', { ns: 'workflow' })}
</div>
<div className="mb-6 text-sm text-text-secondary">
{vibePanelMessage || t('vibe.offTopicDefault', { ns: 'workflow' })}
</div>
{vibePanelSuggestions.length > 0 && (
<div className="w-full">
<div className="mb-3 text-xs font-medium text-text-tertiary">
{t('vibe.trySuggestion', { ns: 'workflow' })}
</div>
<div className="flex flex-col gap-2">
{vibePanelSuggestions.map((suggestion, index) => (
<button
key={index}
onClick={() => handleSuggestionClick(suggestion)}
className="w-full rounded-lg border border-divider-regular bg-components-panel-bg px-4 py-2.5 text-left text-sm text-text-secondary transition-colors hover:border-components-button-primary-border hover:bg-state-accent-hover"
>
{suggestion}
</button>
))}
</div>
</div>
)}
</div>
</div>
)
return (
<Modal
isShow={showVibePanel}
onClose={handleClose}
className="min-w-[1140px] !p-0"
clickOutsideNotClose
>
<div className="flex h-[680px] flex-wrap">
<div className="h-full w-[300px] shrink-0 overflow-y-auto border-r border-divider-regular p-6">
<div className="mb-5">
<div className="text-lg font-bold leading-[28px] text-text-primary">{t('gotoAnything.actions.vibeTitle', { ns: 'app' })}</div>
<div className="mt-1 text-[13px] font-normal text-text-tertiary">{t('gotoAnything.actions.vibeDesc', { ns: 'app' })}</div>
</div>
<div>
<ModelParameterModal
popupClassName="!w-[520px]"
portalToFollowElemContentClassName="z-[1000]"
isAdvancedMode={true}
provider={model.provider}
completionParams={model.completion_params}
modelId={model.name}
setModel={handleModelChange}
onCompletionParamsChange={handleCompletionParamsChange}
hideDebugWithMultipleModel
/>
</div>
<div className="mt-4">
<div className="system-sm-semibold-uppercase mb-1.5 text-text-secondary">{t('generate.instruction', { ns: 'appDebug' })}</div>
<Textarea
className="min-h-[240px] resize-none rounded-[10px] px-4 pt-3"
placeholder={t('vibe.missingInstruction', { ns: 'workflow' })}
value={vibePanelInstruction}
onChange={handleInstructionChange}
/>
</div>
<div className="mt-7 flex justify-end space-x-2">
<Button onClick={handleClose}>{t('generate.dismiss', { ns: 'appDebug' })}</Button>
<Button
className="flex space-x-1"
variant="primary"
onClick={handleGenerate}
disabled={isVibeGenerating}
>
<Generator className="h-4 w-4" />
<span className="system-xs-semibold">{generateButtonLabel}</span>
</Button>
</div>
</div>
{!isVibeGenerating && vibePanelIntent === 'off_topic' && renderOffTopic}
{!isVibeGenerating && vibePanelIntent !== 'off_topic' && (vibePanelPreviewNodes.length > 0 || vibePanelMermaidCode) && (
<div className="h-full w-0 grow bg-background-default-subtle p-6 pb-0">
<div className="flex h-full flex-col">
<div className="mb-3 flex shrink-0 items-center justify-between">
<div className="flex shrink-0 flex-col">
<div className="system-xl-semibold text-text-secondary">{t('vibe.panelTitle', { ns: 'workflow' })}</div>
<VersionSelector
versionLen={versions.length}
value={currentVersionIndex}
onChange={handleVersionChange}
contentClassName="z-[1200]"
/>
</div>
<div className="flex items-center space-x-2">
<Button
variant="secondary"
size="medium"
onClick={handleCopyMermaid}
className="px-2"
>
<RiClipboardLine className="h-4 w-4" />
</Button>
<Button
variant="primary"
size="medium"
onClick={handleAccept}
>
{t('vibe.apply', { ns: 'workflow' })}
</Button>
</div>
</div>
<div className="flex grow flex-col overflow-hidden pb-6">
<WorkflowPreview
key={currentVersionIndex}
nodes={vibePanelPreviewNodes}
edges={vibePanelPreviewEdges}
viewport={{ x: 0, y: 0, zoom: 1 }}
className="rounded-lg border border-divider-subtle"
/>
</div>
</div>
</div>
)}
{isVibeGenerating && renderLoading}
{!isVibeGenerating && vibePanelIntent !== 'off_topic' && vibePanelPreviewNodes.length === 0 && !vibePanelMermaidCode && <ResPlaceholder />}
</div>
</Modal>
)
}
export default VibePanel

View File

@@ -12,6 +12,7 @@ import type { NodeSliceShape } from './node-slice'
import type { PanelSliceShape } from './panel-slice'
import type { ToolSliceShape } from './tool-slice'
import type { VersionSliceShape } from './version-slice'
import type { VibeWorkflowSliceShape } from './vibe-workflow-slice'
import type { WorkflowDraftSliceShape } from './workflow-draft-slice'
import type { WorkflowSliceShape } from './workflow-slice'
import type { RagPipelineSliceShape } from '@/app/components/rag-pipeline/store'
@@ -34,6 +35,7 @@ import { createNodeSlice } from './node-slice'
import { createPanelSlice } from './panel-slice'
import { createToolSlice } from './tool-slice'
import { createVersionSlice } from './version-slice'
import { createVibeWorkflowSlice } from './vibe-workflow-slice'
import { createWorkflowDraftSlice } from './workflow-draft-slice'
import { createWorkflowSlice } from './workflow-slice'
@@ -56,6 +58,7 @@ export type Shape
& InspectVarsSliceShape
& LayoutSliceShape
& SliceFromInjection
& VibeWorkflowSliceShape
export type InjectWorkflowStoreSliceFn = StateCreator<SliceFromInjection>
@@ -80,6 +83,7 @@ export const createWorkflowStore = (params: CreateWorkflowStoreParams) => {
...createWorkflowSlice(...args),
...createInspectVarsSlice(...args),
...createLayoutSlice(...args),
...createVibeWorkflowSlice(...args),
...(injectWorkflowStoreSliceFn?.(...args) || {} as SliceFromInjection),
}))
}

View File

@@ -1,4 +1,7 @@
import type { StateCreator } from 'zustand'
import type { BackendEdgeSpec, BackendNodeSpec } from '@/service/debug'
export type VibeIntent = 'generate' | 'off_topic' | 'error' | ''
export type PanelSliceShape = {
panelWidth: number
@@ -24,6 +27,26 @@ export type PanelSliceShape = {
setShowVariableInspectPanel: (showVariableInspectPanel: boolean) => void
initShowLastRunTab: boolean
setInitShowLastRunTab: (initShowLastRunTab: boolean) => void
showVibePanel: boolean
setShowVibePanel: (showVibePanel: boolean) => void
vibePanelMermaidCode: string
setVibePanelMermaidCode: (vibePanelMermaidCode: string) => void
vibePanelBackendNodes?: BackendNodeSpec[]
setVibePanelBackendNodes: (nodes?: BackendNodeSpec[]) => void
vibePanelBackendEdges?: BackendEdgeSpec[]
setVibePanelBackendEdges: (edges?: BackendEdgeSpec[]) => void
isVibeGenerating: boolean
setIsVibeGenerating: (isVibeGenerating: boolean) => void
vibePanelInstruction: string
setVibePanelInstruction: (vibePanelInstruction: string) => void
vibePanelIntent: VibeIntent
setVibePanelIntent: (vibePanelIntent: VibeIntent) => void
vibePanelMessage: string
setVibePanelMessage: (vibePanelMessage: string) => void
vibePanelSuggestions: string[]
setVibePanelSuggestions: (vibePanelSuggestions: string[]) => void
vibePanelLastWarnings: string[]
setVibePanelLastWarnings: (vibePanelLastWarnings: string[]) => void
}
export const createPanelSlice: StateCreator<PanelSliceShape> = set => ({
@@ -44,4 +67,24 @@ export const createPanelSlice: StateCreator<PanelSliceShape> = set => ({
setShowVariableInspectPanel: showVariableInspectPanel => set(() => ({ showVariableInspectPanel })),
initShowLastRunTab: false,
setInitShowLastRunTab: initShowLastRunTab => set(() => ({ initShowLastRunTab })),
showVibePanel: false,
setShowVibePanel: showVibePanel => set(() => ({ showVibePanel })),
vibePanelMermaidCode: '',
setVibePanelMermaidCode: vibePanelMermaidCode => set(() => ({ vibePanelMermaidCode })),
vibePanelBackendNodes: undefined,
setVibePanelBackendNodes: vibePanelBackendNodes => set(() => ({ vibePanelBackendNodes })),
vibePanelBackendEdges: undefined,
setVibePanelBackendEdges: vibePanelBackendEdges => set(() => ({ vibePanelBackendEdges })),
isVibeGenerating: false,
setIsVibeGenerating: isVibeGenerating => set(() => ({ isVibeGenerating })),
vibePanelInstruction: '',
setVibePanelInstruction: vibePanelInstruction => set(() => ({ vibePanelInstruction })),
vibePanelIntent: '',
setVibePanelIntent: vibePanelIntent => set(() => ({ vibePanelIntent })),
vibePanelMessage: '',
setVibePanelMessage: vibePanelMessage => set(() => ({ vibePanelMessage })),
vibePanelSuggestions: [],
setVibePanelSuggestions: vibePanelSuggestions => set(() => ({ vibePanelSuggestions })),
vibePanelLastWarnings: [],
setVibePanelLastWarnings: vibePanelLastWarnings => set(() => ({ vibePanelLastWarnings })),
})

View File

@@ -0,0 +1,78 @@
import type { StateCreator } from 'zustand'
import type { Edge, Node } from '../../types'
export type FlowGraph = {
nodes: Node[]
edges: Edge[]
}
export type VibeWorkflowSliceShape = {
vibePanelMermaidCode: string
setVibePanelMermaidCode: (vibePanelMermaidCode: string) => void
isVibeGenerating: boolean
setIsVibeGenerating: (isVibeGenerating: boolean) => void
vibePanelInstruction: string
setVibePanelInstruction: (vibePanelInstruction: string) => void
vibeFlowVersions: FlowGraph[]
setVibeFlowVersions: (versions: FlowGraph[]) => void
vibeFlowCurrentIndex: number
setVibeFlowCurrentIndex: (index: number) => void
addVibeFlowVersion: (version: FlowGraph) => void
currentVibeFlow: FlowGraph | undefined
}
const getCurrentVibeFlow = (versions: FlowGraph[], currentIndex: number): FlowGraph | undefined => {
if (!versions || versions.length === 0)
return undefined
const index = currentIndex ?? 0
if (index < 0)
return undefined
return versions[index] || versions[versions.length - 1]
}
export const createVibeWorkflowSlice: StateCreator<VibeWorkflowSliceShape> = (set, get) => ({
vibePanelMermaidCode: '',
setVibePanelMermaidCode: vibePanelMermaidCode => set(() => ({ vibePanelMermaidCode })),
isVibeGenerating: false,
setIsVibeGenerating: isVibeGenerating => set(() => ({ isVibeGenerating })),
vibePanelInstruction: '',
setVibePanelInstruction: vibePanelInstruction => set(() => ({ vibePanelInstruction })),
vibeFlowVersions: [],
setVibeFlowVersions: versions => set((state) => {
const currentVibeFlow = getCurrentVibeFlow(versions, state.vibeFlowCurrentIndex)
return { vibeFlowVersions: versions, currentVibeFlow }
}),
vibeFlowCurrentIndex: 0,
setVibeFlowCurrentIndex: (index) => {
const state = get()
const versions = state.vibeFlowVersions || []
if (!versions || versions.length === 0) {
set({ vibeFlowCurrentIndex: 0, currentVibeFlow: undefined })
return
}
const normalizedIndex = Math.min(Math.max(index, 0), versions.length - 1)
const currentVibeFlow = getCurrentVibeFlow(versions, normalizedIndex)
set({ vibeFlowCurrentIndex: normalizedIndex, currentVibeFlow })
},
addVibeFlowVersion: (version) => {
// Prevent adding empty graphs
if (!version || !version.nodes || version.nodes.length === 0) {
set({ vibeFlowCurrentIndex: -1, currentVibeFlow: undefined })
return
}
set((state) => {
const newVersions = [...(state.vibeFlowVersions || []), version]
const newIndex = newVersions.length - 1
const currentVibeFlow = getCurrentVibeFlow(newVersions, newIndex)
return {
vibeFlowVersions: newVersions,
vibeFlowCurrentIndex: newIndex,
currentVibeFlow,
}
})
},
currentVibeFlow: undefined,
})

View File

@@ -111,8 +111,8 @@ export const preprocessNodesAndEdges = (nodes: Node[], edges: Edge[]) => {
const currentNode = nodes[i] as Node<IterationNodeType | LoopNodeType>
if (currentNode.data.type === BlockEnum.Iteration) {
if (currentNode.data.start_node_id) {
if (nodesMap[currentNode.data.start_node_id]?.type !== CUSTOM_ITERATION_START_NODE)
if (currentNode.data.start_node_id && nodesMap[currentNode.data.start_node_id]) {
if (nodesMap[currentNode.data.start_node_id].type !== CUSTOM_ITERATION_START_NODE)
iterationNodesWithStartNode.push(currentNode)
}
else {
@@ -121,8 +121,8 @@ export const preprocessNodesAndEdges = (nodes: Node[], edges: Edge[]) => {
}
if (currentNode.data.type === BlockEnum.Loop) {
if (currentNode.data.start_node_id) {
if (nodesMap[currentNode.data.start_node_id]?.type !== CUSTOM_LOOP_START_NODE)
if (currentNode.data.start_node_id && nodesMap[currentNode.data.start_node_id]) {
if (nodesMap[currentNode.data.start_node_id].type !== CUSTOM_LOOP_START_NODE)
loopNodesWithStartNode.push(currentNode)
}
else {

View File

@@ -65,7 +65,7 @@ const IfElseNode: FC<NodeProps<IfElseNodeType>> = (props) => {
</div>
<div className="space-y-0.5">
{caseItem.conditions.map((condition, i) => (
<div key={condition.id} className="relative">
<div key={condition.id || i} className="relative">
{
checkIsConditionSet(condition)
? (

View File

@@ -75,6 +75,9 @@
"gotoAnything.actions.themeLightDesc": "Use light appearance",
"gotoAnything.actions.themeSystem": "System Theme",
"gotoAnything.actions.themeSystemDesc": "Follow your OS appearance",
"gotoAnything.actions.vibeDesc": "Generate workflow from natural language",
"gotoAnything.actions.vibeHint": "Try: {{prompt}}",
"gotoAnything.actions.vibeTitle": "Vibe",
"gotoAnything.actions.zenDesc": "Toggle canvas focus mode",
"gotoAnything.actions.zenTitle": "Zen Mode",
"gotoAnything.clearToSearchAll": "Clear @ to search all",

View File

@@ -1047,5 +1047,26 @@
"versionHistory.nameThisVersion": "Name this version",
"versionHistory.releaseNotesPlaceholder": "Describe what changed",
"versionHistory.restorationTip": "After version restoration, the current draft will be overwritten.",
"versionHistory.title": "Versions"
"versionHistory.title": "Versions",
"vibe.apply": "Apply",
"vibe.generateError": "Failed to generate workflow. Please try again.",
"vibe.generatingFlowchart": "Generating flowchart preview...",
"vibe.invalidFlowchart": "The generated flowchart could not be parsed.",
"vibe.missingFlowchart": "No flowchart was generated.",
"vibe.missingInstruction": "Describe the workflow you want to build.",
"vibe.modelUnavailable": "No model available for flowchart generation.",
"vibe.noFlowchart": "No flowchart provided",
"vibe.noFlowchartYet": "No flowchart preview available",
"vibe.nodeTypeUnavailable": "Node type \"{{type}}\" is not available in this workflow.",
"vibe.nodesUnavailable": "Workflow nodes are not available yet.",
"vibe.offTopicDefault": "I'm the Dify workflow design assistant. I can help you create AI automation workflows, but I can't answer general questions. Would you like to create a workflow instead?",
"vibe.offTopicTitle": "Off-Topic Request",
"vibe.panelTitle": "Workflow Preview",
"vibe.readOnly": "This workflow is read-only.",
"vibe.regenerate": "Regenerate",
"vibe.regenerateReminder": "Please verify your input and re-generate.",
"vibe.toolUnavailable": "Tool \"{{tool}}\" is not available in this workspace.",
"vibe.trySuggestion": "Try one of these suggestions:",
"vibe.unknownNodeId": "Node \"{{id}}\" is used before it is defined.",
"vibe.unsupportedEdgeLabel": "Unsupported edge label \"{{label}}\". Only true/false are allowed for if/else."
}

View File

@@ -1047,5 +1047,7 @@
"versionHistory.nameThisVersion": "命名",
"versionHistory.releaseNotesPlaceholder": "请描述变更",
"versionHistory.restorationTip": "版本回滚后,当前草稿将被覆盖。",
"versionHistory.title": "版本"
"versionHistory.title": "版本",
"vibe.refine": "调整",
"vibe.regenerateReminder": "请检查输入并重新生成。"
}

View File

@@ -19,6 +19,48 @@ export type GenRes = {
error?: string
}
export type ToolRecommendation = {
requested_capability: string
unconfigured_tools: Array<{
provider_id: string
tool_name: string
description: string
}>
configured_alternatives: Array<{
provider_id: string
tool_name: string
description: string
}>
recommendation: string
}
export type BackendNodeSpec = {
id: string
type: string
title?: string
config?: Record<string, any>
position?: { x: number, y: number }
}
export type BackendEdgeSpec = {
source: string
target: string
sourceHandle?: string
targetHandle?: string
}
export type FlowchartGenRes = {
intent?: 'generate' | 'off_topic' | 'error'
flowchart: string
nodes?: BackendNodeSpec[]
edges?: BackendEdgeSpec[]
message?: string
warnings?: string[]
suggestions?: string[]
tool_recommendations?: ToolRecommendation[]
error?: string
}
export type CodeGenRes = {
code: string
language: string[]
@@ -93,6 +135,12 @@ export const generateRule = (body: Record<string, any>) => {
})
}
export const generateFlowchart = (body: Record<string, any>) => {
return post<FlowchartGenRes>('/flowchart-generate', {
body,
})
}
export const fetchModelParams = (providerName: string, modelId: string) => {
return get(`workspaces/current/model-providers/${providerName}/models/parameter-rules`, {
params: {