mirror of
https://github.com/langgenius/dify.git
synced 2026-03-23 16:57:10 +00:00
Compare commits
3 Commits
feat/evalu
...
scdeng/mai
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9f8e1f9cc0 | ||
|
|
dd03e2fe2a | ||
|
|
8b082c13d3 |
@@ -1,6 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
@@ -36,6 +39,11 @@ from .exc import (
|
||||
)
|
||||
from .protocols import TemplateRenderer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
VARIABLE_PATTERN = re.compile(r"\{\{#[^#]+#\}\}")
|
||||
MAX_RESOLVED_VALUE_LENGTH = 1024
|
||||
|
||||
|
||||
def fetch_model_schema(*, model_instance: ModelInstance) -> AIModelEntity:
|
||||
model_schema = cast(LargeLanguageModel, model_instance.model_type_instance).get_model_schema(
|
||||
@@ -475,3 +483,61 @@ def _append_file_prompts(
|
||||
prompt_messages[-1] = UserPromptMessage(content=file_prompts + existing_contents)
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
|
||||
def _coerce_resolved_value(raw: str) -> int | float | bool | str:
|
||||
"""Try to restore the original type from a resolved template string.
|
||||
|
||||
Variable references are always resolved to text, but completion params may
|
||||
expect numeric or boolean values (e.g. a variable that holds "0.7" mapped to
|
||||
the ``temperature`` parameter). This helper attempts a JSON parse so that
|
||||
``"0.7"`` → ``0.7``, ``"true"`` → ``True``, etc. Plain strings that are not
|
||||
valid JSON literals are returned as-is.
|
||||
"""
|
||||
stripped = raw.strip()
|
||||
if not stripped:
|
||||
return raw
|
||||
|
||||
try:
|
||||
parsed: object = json.loads(stripped)
|
||||
except (json.JSONDecodeError, ValueError):
|
||||
return raw
|
||||
|
||||
if isinstance(parsed, (int, float, bool)):
|
||||
return parsed
|
||||
return raw
|
||||
|
||||
|
||||
def resolve_completion_params_variables(
|
||||
completion_params: Mapping[str, Any],
|
||||
variable_pool: VariablePool,
|
||||
) -> dict[str, Any]:
|
||||
"""Resolve variable references (``{{#node_id.var#}}``) in string-typed completion params.
|
||||
|
||||
Security notes:
|
||||
- Resolved values are length-capped to ``MAX_RESOLVED_VALUE_LENGTH`` to
|
||||
prevent denial-of-service through excessively large variable payloads.
|
||||
- This follows the same ``VariablePool.convert_template`` pattern used across
|
||||
Dify (Answer Node, HTTP Request Node, Agent Node, etc.). The downstream
|
||||
model plugin receives these values as structured JSON key-value pairs — they
|
||||
are never concatenated into raw HTTP headers or SQL queries.
|
||||
- Numeric/boolean coercion is applied so that variables holding ``"0.7"`` are
|
||||
restored to their native type rather than sent as a bare string.
|
||||
"""
|
||||
resolved: dict[str, Any] = {}
|
||||
for key, value in completion_params.items():
|
||||
if isinstance(value, str) and VARIABLE_PATTERN.search(value):
|
||||
segment_group = variable_pool.convert_template(value)
|
||||
text = segment_group.text
|
||||
if len(text) > MAX_RESOLVED_VALUE_LENGTH:
|
||||
logger.warning(
|
||||
"Resolved value for param '%s' truncated from %d to %d chars",
|
||||
key,
|
||||
len(text),
|
||||
MAX_RESOLVED_VALUE_LENGTH,
|
||||
)
|
||||
text = text[:MAX_RESOLVED_VALUE_LENGTH]
|
||||
resolved[key] = _coerce_resolved_value(text)
|
||||
else:
|
||||
resolved[key] = value
|
||||
return resolved
|
||||
|
||||
@@ -202,6 +202,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
# fetch model config
|
||||
model_instance = self._model_instance
|
||||
# Resolve variable references in string-typed completion params
|
||||
model_instance.parameters = llm_utils.resolve_completion_params_variables(
|
||||
model_instance.parameters, variable_pool
|
||||
)
|
||||
model_name = model_instance.model_name
|
||||
model_provider = model_instance.provider
|
||||
model_stop = model_instance.stop
|
||||
|
||||
@@ -164,6 +164,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
)
|
||||
|
||||
model_instance = self._model_instance
|
||||
# Resolve variable references in string-typed completion params
|
||||
model_instance.parameters = llm_utils.resolve_completion_params_variables(
|
||||
model_instance.parameters, variable_pool
|
||||
)
|
||||
if not isinstance(model_instance.model_type_instance, LargeLanguageModel):
|
||||
raise InvalidModelTypeError("Model is not a Large Language Model")
|
||||
|
||||
|
||||
@@ -114,6 +114,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
variables = {"query": query}
|
||||
# fetch model instance
|
||||
model_instance = self._model_instance
|
||||
# Resolve variable references in string-typed completion params
|
||||
model_instance.parameters = llm_utils.resolve_completion_params_variables(
|
||||
model_instance.parameters, variable_pool
|
||||
)
|
||||
memory = self._memory
|
||||
# fetch instruction
|
||||
node_data.instruction = node_data.instruction or ""
|
||||
|
||||
@@ -3,7 +3,11 @@ from unittest import mock
|
||||
import pytest
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from dify_graph.model_runtime.entities import ImagePromptMessageContent, PromptMessageRole, TextPromptMessageContent
|
||||
from dify_graph.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from dify_graph.model_runtime.entities.message_entities import SystemPromptMessage
|
||||
from dify_graph.nodes.llm import llm_utils
|
||||
from dify_graph.nodes.llm.entities import LLMNodeChatModelMessage
|
||||
@@ -11,6 +15,15 @@ from dify_graph.nodes.llm.exc import NoPromptFoundError
|
||||
from dify_graph.runtime import VariablePool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def variable_pool() -> VariablePool:
|
||||
pool = VariablePool.empty()
|
||||
pool.add(["node1", "output"], "resolved_value")
|
||||
pool.add(["node2", "text"], "hello world")
|
||||
pool.add(["start", "user_input"], "dynamic_param")
|
||||
return pool
|
||||
|
||||
|
||||
def _fetch_prompt_messages_with_mocked_content(content):
|
||||
variable_pool = VariablePool.empty()
|
||||
model_instance = mock.MagicMock(spec=ModelInstance)
|
||||
@@ -53,6 +66,159 @@ def _fetch_prompt_messages_with_mocked_content(content):
|
||||
)
|
||||
|
||||
|
||||
class TestTypeCoercionViaResolve:
|
||||
"""Type coercion is tested through the public resolve_completion_params_variables API."""
|
||||
|
||||
def test_numeric_string_coerced_to_float(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.add(["n", "v"], "0.7")
|
||||
result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool)
|
||||
assert result["p"] == 0.7
|
||||
|
||||
def test_integer_string_coerced_to_int(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.add(["n", "v"], "1024")
|
||||
result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool)
|
||||
assert result["p"] == 1024
|
||||
|
||||
def test_boolean_string_coerced_to_bool(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.add(["n", "v"], "true")
|
||||
result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool)
|
||||
assert result["p"] is True
|
||||
|
||||
def test_plain_string_stays_string(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.add(["n", "v"], "json_object")
|
||||
result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool)
|
||||
assert result["p"] == "json_object"
|
||||
|
||||
def test_json_object_string_stays_string(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.add(["n", "v"], '{"key": "val"}')
|
||||
result = llm_utils.resolve_completion_params_variables({"p": "{{#n.v#}}"}, pool)
|
||||
assert result["p"] == '{"key": "val"}'
|
||||
|
||||
def test_mixed_text_and_variable_stays_string(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.add(["n", "v"], "0.7")
|
||||
result = llm_utils.resolve_completion_params_variables({"p": "val={{#n.v#}}"}, pool)
|
||||
assert result["p"] == "val=0.7"
|
||||
|
||||
|
||||
class TestResolveCompletionParamsVariables:
|
||||
def test_plain_string_values_unchanged(self, variable_pool: VariablePool):
|
||||
params = {"response_format": "json", "custom_param": "static_value"}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
|
||||
|
||||
assert result == {"response_format": "json", "custom_param": "static_value"}
|
||||
|
||||
def test_numeric_values_unchanged(self, variable_pool: VariablePool):
|
||||
params = {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
|
||||
|
||||
assert result == {"temperature": 0.7, "top_p": 0.9, "max_tokens": 1024}
|
||||
|
||||
def test_boolean_values_unchanged(self, variable_pool: VariablePool):
|
||||
params = {"stream": True, "echo": False}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
|
||||
|
||||
assert result == {"stream": True, "echo": False}
|
||||
|
||||
def test_list_values_unchanged(self, variable_pool: VariablePool):
|
||||
params = {"stop": ["Human:", "Assistant:"]}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
|
||||
|
||||
assert result == {"stop": ["Human:", "Assistant:"]}
|
||||
|
||||
def test_single_variable_reference_resolved(self, variable_pool: VariablePool):
|
||||
params = {"response_format": "{{#node1.output#}}"}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
|
||||
|
||||
assert result == {"response_format": "resolved_value"}
|
||||
|
||||
def test_multiple_variable_references_resolved(self, variable_pool: VariablePool):
|
||||
params = {
|
||||
"param_a": "{{#node1.output#}}",
|
||||
"param_b": "{{#node2.text#}}",
|
||||
}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
|
||||
|
||||
assert result == {"param_a": "resolved_value", "param_b": "hello world"}
|
||||
|
||||
def test_mixed_text_and_variable_resolved(self, variable_pool: VariablePool):
|
||||
params = {"prompt_prefix": "prefix_{{#node1.output#}}_suffix"}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
|
||||
|
||||
assert result == {"prompt_prefix": "prefix_resolved_value_suffix"}
|
||||
|
||||
def test_mixed_params_types(self, variable_pool: VariablePool):
|
||||
"""Non-string params pass through; string params with variables get resolved."""
|
||||
params = {
|
||||
"temperature": 0.7,
|
||||
"response_format": "{{#node1.output#}}",
|
||||
"custom_string": "no_vars_here",
|
||||
"max_tokens": 512,
|
||||
"stop": ["\n"],
|
||||
}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
|
||||
|
||||
assert result == {
|
||||
"temperature": 0.7,
|
||||
"response_format": "resolved_value",
|
||||
"custom_string": "no_vars_here",
|
||||
"max_tokens": 512,
|
||||
"stop": ["\n"],
|
||||
}
|
||||
|
||||
def test_empty_params(self, variable_pool: VariablePool):
|
||||
result = llm_utils.resolve_completion_params_variables({}, variable_pool)
|
||||
|
||||
assert result == {}
|
||||
|
||||
def test_unresolvable_variable_keeps_selector_text(self):
|
||||
"""When a referenced variable doesn't exist in the pool, convert_template
|
||||
falls back to the raw selector path (e.g. 'nonexistent.var')."""
|
||||
pool = VariablePool.empty()
|
||||
params = {"format": "{{#nonexistent.var#}}"}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, pool)
|
||||
|
||||
assert result["format"] == "nonexistent.var"
|
||||
|
||||
def test_multiple_variables_in_single_value(self, variable_pool: VariablePool):
|
||||
params = {"combined": "{{#node1.output#}} and {{#node2.text#}}"}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, variable_pool)
|
||||
|
||||
assert result == {"combined": "resolved_value and hello world"}
|
||||
|
||||
def test_original_params_not_mutated(self, variable_pool: VariablePool):
|
||||
original = {"response_format": "{{#node1.output#}}", "temperature": 0.5}
|
||||
original_copy = dict(original)
|
||||
|
||||
_ = llm_utils.resolve_completion_params_variables(original, variable_pool)
|
||||
|
||||
assert original == original_copy
|
||||
|
||||
def test_long_value_truncated(self):
|
||||
pool = VariablePool.empty()
|
||||
pool.add(["node1", "big"], "x" * 2000)
|
||||
params = {"param": "{{#node1.big#}}"}
|
||||
|
||||
result = llm_utils.resolve_completion_params_variables(params, pool)
|
||||
|
||||
assert len(result["param"]) == llm_utils.MAX_RESOLVED_VALUE_LENGTH
|
||||
|
||||
|
||||
def test_fetch_prompt_messages_skips_messages_when_all_contents_are_filtered_out():
|
||||
with pytest.raises(NoPromptFoundError):
|
||||
_fetch_prompt_messages_with_mocked_content(
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
import { fireEvent, render, screen, waitFor } from '@testing-library/react'
|
||||
import ModelParameterModal from '../index'
|
||||
|
||||
let isAPIKeySet = true
|
||||
let parameterRules: Array<Record<string, unknown>> | undefined = [
|
||||
{
|
||||
name: 'temperature',
|
||||
@@ -40,7 +39,7 @@ let activeTextGenerationModelList: Array<Record<string, unknown>> = [
|
||||
|
||||
vi.mock('@/context/provider-context', () => ({
|
||||
useProviderContext: () => ({
|
||||
isAPIKeySet,
|
||||
isAPIKeySet: true,
|
||||
}),
|
||||
}))
|
||||
|
||||
@@ -50,6 +49,7 @@ vi.mock('@/service/use-common', () => ({
|
||||
data: parameterRules,
|
||||
},
|
||||
isLoading: isRulesLoading,
|
||||
isPending: isRulesLoading,
|
||||
}),
|
||||
}))
|
||||
|
||||
@@ -62,12 +62,18 @@ vi.mock('../../hooks', () => ({
|
||||
}))
|
||||
|
||||
vi.mock('../parameter-item', () => ({
|
||||
default: ({ parameterRule, onChange, onSwitch }: {
|
||||
default: ({ parameterRule, onChange, onSwitch, nodesOutputVars, availableNodes }: {
|
||||
parameterRule: { name: string, label: { en_US: string } }
|
||||
onChange: (v: number) => void
|
||||
onSwitch: (checked: boolean, val: unknown) => void
|
||||
nodesOutputVars?: unknown[]
|
||||
availableNodes?: unknown[]
|
||||
}) => (
|
||||
<div data-testid={`param-${parameterRule.name}`}>
|
||||
<div
|
||||
data-testid={`param-${parameterRule.name}`}
|
||||
data-has-nodes-output-vars={!!nodesOutputVars}
|
||||
data-has-available-nodes={!!availableNodes}
|
||||
>
|
||||
{parameterRule.label.en_US}
|
||||
<button onClick={() => onChange(0.9)}>Change</button>
|
||||
<button onClick={() => onSwitch(false, undefined)}>Remove</button>
|
||||
@@ -119,7 +125,6 @@ describe('ModelParameterModal', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
isAPIKeySet = true
|
||||
isRulesLoading = false
|
||||
parameterRules = [
|
||||
{
|
||||
@@ -233,6 +238,26 @@ describe('ModelParameterModal', () => {
|
||||
expect(screen.getByTestId('model-selector')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should pass nodesOutputVars and availableNodes to ParameterItem', () => {
|
||||
const mockNodesOutputVars = [{ nodeId: 'n1', title: 'Node', vars: [] }]
|
||||
const mockAvailableNodes = [{ id: 'n1', data: { title: 'Node', type: 'llm' } }]
|
||||
|
||||
render(
|
||||
<ModelParameterModal
|
||||
{...defaultProps}
|
||||
isInWorkflow
|
||||
nodesOutputVars={mockNodesOutputVars as never}
|
||||
availableNodes={mockAvailableNodes as never}
|
||||
/>,
|
||||
)
|
||||
|
||||
fireEvent.click(screen.getByText('Open Settings'))
|
||||
|
||||
const paramEl = screen.getByTestId('param-temperature')
|
||||
expect(paramEl).toHaveAttribute('data-has-nodes-output-vars', 'true')
|
||||
expect(paramEl).toHaveAttribute('data-has-available-nodes', 'true')
|
||||
})
|
||||
|
||||
it('should support custom triggers, workflow mode, and missing default model values', async () => {
|
||||
render(
|
||||
<ModelParameterModal
|
||||
|
||||
@@ -1,5 +1,10 @@
|
||||
import type { ModelParameterRule } from '../../declarations'
|
||||
import type {
|
||||
Node,
|
||||
NodeOutPutVar,
|
||||
} from '@/app/components/workflow/types'
|
||||
import { fireEvent, render, screen } from '@testing-library/react'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import ParameterItem from '../parameter-item'
|
||||
|
||||
vi.mock('../../hooks', () => ({
|
||||
@@ -18,6 +23,29 @@ vi.mock('@/app/components/base/tag-input', () => ({
|
||||
),
|
||||
}))
|
||||
|
||||
let promptEditorOnChange: ((text: string) => void) | undefined
|
||||
let capturedWorkflowNodesMap: Record<string, { title: string, type: string }> | undefined
|
||||
|
||||
vi.mock('@/app/components/base/prompt-editor', () => ({
|
||||
default: ({ value, onChange, workflowVariableBlock }: {
|
||||
value: string
|
||||
onChange: (text: string) => void
|
||||
workflowVariableBlock?: {
|
||||
show: boolean
|
||||
variables: NodeOutPutVar[]
|
||||
workflowNodesMap?: Record<string, { title: string, type: string }>
|
||||
}
|
||||
}) => {
|
||||
promptEditorOnChange = onChange
|
||||
capturedWorkflowNodesMap = workflowVariableBlock?.workflowNodesMap
|
||||
return (
|
||||
<div data-testid="prompt-editor" data-value={value} data-has-workflow-vars={!!workflowVariableBlock?.variables}>
|
||||
{value}
|
||||
</div>
|
||||
)
|
||||
},
|
||||
}))
|
||||
|
||||
describe('ParameterItem', () => {
|
||||
const createRule = (overrides: Partial<ModelParameterRule> = {}): ModelParameterRule => ({
|
||||
name: 'temp',
|
||||
@@ -30,9 +58,10 @@ describe('ParameterItem', () => {
|
||||
|
||||
beforeEach(() => {
|
||||
vi.clearAllMocks()
|
||||
promptEditorOnChange = undefined
|
||||
capturedWorkflowNodesMap = undefined
|
||||
})
|
||||
|
||||
// Float tests
|
||||
it('should render float controls and clamp numeric input to max', () => {
|
||||
const onChange = vi.fn()
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'float', min: 0, max: 1 })} value={0.7} onChange={onChange} />)
|
||||
@@ -50,7 +79,6 @@ describe('ParameterItem', () => {
|
||||
expect(onChange).toHaveBeenCalledWith(0.1)
|
||||
})
|
||||
|
||||
// Int tests
|
||||
it('should render int controls and clamp numeric input', () => {
|
||||
const onChange = vi.fn()
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'int', min: 0, max: 10 })} value={5} onChange={onChange} />)
|
||||
@@ -75,22 +103,17 @@ describe('ParameterItem', () => {
|
||||
it('should render int input without slider if min or max is missing', () => {
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'int', min: 0 })} value={5} />)
|
||||
expect(screen.queryByRole('slider')).not.toBeInTheDocument()
|
||||
// No max -> precision step
|
||||
expect(screen.getByRole('spinbutton')).toHaveAttribute('step', '0')
|
||||
})
|
||||
|
||||
// Slider events (uses generic value mock for slider)
|
||||
it('should handle slide change and clamp values', () => {
|
||||
const onChange = vi.fn()
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'float', min: 0, max: 10 })} value={0.7} onChange={onChange} />)
|
||||
|
||||
// Test that the actual slider triggers the onChange logic correctly
|
||||
// The implementation of Slider uses onChange(val) directly via the mock
|
||||
fireEvent.click(screen.getByTestId('slider-btn'))
|
||||
expect(onChange).toHaveBeenCalledWith(2)
|
||||
})
|
||||
|
||||
// Text & String tests
|
||||
it('should render exact string input and propagate text changes', () => {
|
||||
const onChange = vi.fn()
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'string', name: 'prompt' })} value="initial" onChange={onChange} />)
|
||||
@@ -109,21 +132,17 @@ describe('ParameterItem', () => {
|
||||
|
||||
it('should render select for string with options', () => {
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'string', options: ['a', 'b'] })} value="a" />)
|
||||
// Select renders the selected value in the trigger
|
||||
expect(screen.getByText('a')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Tag Tests
|
||||
it('should render tag input for tag type', () => {
|
||||
const onChange = vi.fn()
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'tag', tagPlaceholder: { en_US: 'placeholder', zh_Hans: 'placeholder' } })} value={['a']} onChange={onChange} />)
|
||||
expect(screen.getByText('placeholder')).toBeInTheDocument()
|
||||
// Trigger mock tag input
|
||||
fireEvent.click(screen.getByTestId('tag-input'))
|
||||
expect(onChange).toHaveBeenCalledWith(['tag1', 'tag2'])
|
||||
})
|
||||
|
||||
// Boolean tests
|
||||
it('should render boolean radios and update value on click', () => {
|
||||
const onChange = vi.fn()
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'boolean', default: false })} value={true} onChange={onChange} />)
|
||||
@@ -131,7 +150,6 @@ describe('ParameterItem', () => {
|
||||
expect(onChange).toHaveBeenCalledWith(false)
|
||||
})
|
||||
|
||||
// Switch tests
|
||||
it('should call onSwitch with current value when optional switch is toggled off', () => {
|
||||
const onSwitch = vi.fn()
|
||||
render(<ParameterItem parameterRule={createRule()} value={0.7} onSwitch={onSwitch} />)
|
||||
@@ -146,7 +164,6 @@ describe('ParameterItem', () => {
|
||||
expect(screen.queryByRole('switch')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
// Default Value Fallbacks (rendering without value)
|
||||
it('should use default values if value is undefined', () => {
|
||||
const { rerender } = render(<ParameterItem parameterRule={createRule({ type: 'float', default: 0.5 })} />)
|
||||
expect(screen.getByRole('spinbutton')).toHaveValue(0.5)
|
||||
@@ -158,26 +175,102 @@ describe('ParameterItem', () => {
|
||||
expect(screen.getByText('True')).toBeInTheDocument()
|
||||
expect(screen.getByText('False')).toBeInTheDocument()
|
||||
|
||||
// Without default
|
||||
rerender(<ParameterItem parameterRule={createRule({ type: 'float' })} />) // min is 0 by default in createRule
|
||||
rerender(<ParameterItem parameterRule={createRule({ type: 'float' })} />)
|
||||
expect(screen.getByRole('spinbutton')).toHaveValue(0)
|
||||
})
|
||||
|
||||
// Input Blur
|
||||
it('should reset input to actual bound value on blur', () => {
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'float', min: 0, max: 1 })} />)
|
||||
const input = screen.getByRole('spinbutton')
|
||||
// change local state (which triggers clamp internally to let's say 1.4 -> 1 but leaves input text, though handleInputChange updates local state)
|
||||
// Actually our test fires a change so localValue = 1, then blur sets it
|
||||
fireEvent.change(input, { target: { value: '5' } })
|
||||
fireEvent.blur(input)
|
||||
expect(input).toHaveValue(1)
|
||||
})
|
||||
|
||||
// Unsupported
|
||||
it('should render no input for unsupported parameter type', () => {
|
||||
render(<ParameterItem parameterRule={createRule({ type: 'unsupported' as unknown as string })} value={0.7} />)
|
||||
expect(screen.queryByRole('textbox')).not.toBeInTheDocument()
|
||||
expect(screen.queryByRole('spinbutton')).not.toBeInTheDocument()
|
||||
})
|
||||
|
||||
describe('workflow variable reference', () => {
|
||||
const mockNodesOutputVars: NodeOutPutVar[] = [
|
||||
{ nodeId: 'node1', title: 'LLM Node', vars: [] },
|
||||
]
|
||||
const mockAvailableNodes: Node[] = [
|
||||
{ id: 'node1', type: 'custom', position: { x: 0, y: 0 }, data: { title: 'LLM Node', type: BlockEnum.LLM } } as Node,
|
||||
{ id: 'start', type: 'custom', position: { x: 0, y: 0 }, data: { title: 'Start', type: BlockEnum.Start } } as Node,
|
||||
]
|
||||
|
||||
it('should build workflowNodesMap and render PromptEditor for string type', () => {
|
||||
const onChange = vi.fn()
|
||||
render(
|
||||
<ParameterItem
|
||||
parameterRule={createRule({ type: 'string', name: 'system_prompt' })}
|
||||
value="hello {{#node1.output#}}"
|
||||
onChange={onChange}
|
||||
isInWorkflow
|
||||
nodesOutputVars={mockNodesOutputVars}
|
||||
availableNodes={mockAvailableNodes}
|
||||
/>,
|
||||
)
|
||||
|
||||
const editor = screen.getByTestId('prompt-editor')
|
||||
expect(editor).toBeInTheDocument()
|
||||
expect(editor).toHaveAttribute('data-has-workflow-vars', 'true')
|
||||
expect(capturedWorkflowNodesMap).toBeDefined()
|
||||
expect(capturedWorkflowNodesMap!.node1.title).toBe('LLM Node')
|
||||
expect(capturedWorkflowNodesMap!.sys.title).toBe('workflow.blocks.start')
|
||||
expect(capturedWorkflowNodesMap!.sys.type).toBe(BlockEnum.Start)
|
||||
|
||||
promptEditorOnChange?.('updated text')
|
||||
expect(onChange).toHaveBeenCalledWith('updated text')
|
||||
})
|
||||
|
||||
it('should build workflowNodesMap and render PromptEditor for text type', () => {
|
||||
const onChange = vi.fn()
|
||||
render(
|
||||
<ParameterItem
|
||||
parameterRule={createRule({ type: 'text', name: 'user_prompt' })}
|
||||
value="some long text"
|
||||
onChange={onChange}
|
||||
isInWorkflow
|
||||
nodesOutputVars={mockNodesOutputVars}
|
||||
availableNodes={mockAvailableNodes}
|
||||
/>,
|
||||
)
|
||||
|
||||
const editor = screen.getByTestId('prompt-editor')
|
||||
expect(editor).toBeInTheDocument()
|
||||
expect(editor).toHaveAttribute('data-has-workflow-vars', 'true')
|
||||
expect(capturedWorkflowNodesMap).toBeDefined()
|
||||
|
||||
promptEditorOnChange?.('new long text')
|
||||
expect(onChange).toHaveBeenCalledWith('new long text')
|
||||
})
|
||||
|
||||
it('should fall back to plain input when not in workflow mode for string type', () => {
|
||||
render(
|
||||
<ParameterItem
|
||||
parameterRule={createRule({ type: 'string', name: 'system_prompt' })}
|
||||
value="plain"
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(screen.queryByTestId('prompt-editor')).not.toBeInTheDocument()
|
||||
expect(screen.getByRole('textbox')).toBeInTheDocument()
|
||||
})
|
||||
|
||||
it('should return undefined workflowNodesMap when not in workflow mode', () => {
|
||||
render(
|
||||
<ParameterItem
|
||||
parameterRule={createRule({ type: 'string', name: 'system_prompt' })}
|
||||
value="plain"
|
||||
availableNodes={mockAvailableNodes}
|
||||
/>,
|
||||
)
|
||||
|
||||
expect(capturedWorkflowNodesMap).toBeUndefined()
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -9,6 +9,10 @@ import type {
|
||||
} from '../declarations'
|
||||
import type { ParameterValue } from './parameter-item'
|
||||
import type { TriggerProps } from './trigger'
|
||||
import type {
|
||||
Node,
|
||||
NodeOutPutVar,
|
||||
} from '@/app/components/workflow/types'
|
||||
import { useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { ArrowNarrowLeft } from '@/app/components/base/icons/src/vender/line/arrows'
|
||||
@@ -45,6 +49,8 @@ export type ModelParameterModalProps = {
|
||||
readonly?: boolean
|
||||
isInWorkflow?: boolean
|
||||
scope?: string
|
||||
nodesOutputVars?: NodeOutPutVar[]
|
||||
availableNodes?: Node[]
|
||||
}
|
||||
|
||||
const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
||||
@@ -61,11 +67,18 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
||||
renderTrigger,
|
||||
readonly,
|
||||
isInWorkflow,
|
||||
nodesOutputVars,
|
||||
availableNodes,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [open, setOpen] = useState(false)
|
||||
const settingsIconRef = useRef<HTMLDivElement>(null)
|
||||
const { data: parameterRulesData, isLoading } = useModelParameterRules(provider, modelId)
|
||||
const {
|
||||
data: parameterRulesData,
|
||||
isPending,
|
||||
isLoading,
|
||||
} = useModelParameterRules(provider, modelId)
|
||||
const isRulesLoading = isPending || isLoading
|
||||
const {
|
||||
currentProvider,
|
||||
currentModel,
|
||||
@@ -191,7 +204,7 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
||||
}
|
||||
</div>
|
||||
{
|
||||
isLoading
|
||||
isRulesLoading
|
||||
? <div className="py-5"><Loading /></div>
|
||||
: (
|
||||
[
|
||||
@@ -205,6 +218,8 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
||||
onChange={v => handleParamChange(parameter.name, v)}
|
||||
onSwitch={(checked, assignValue) => handleSwitch(parameter.name, checked, assignValue)}
|
||||
isInWorkflow={isInWorkflow}
|
||||
nodesOutputVars={nodesOutputVars}
|
||||
availableNodes={availableNodes}
|
||||
/>
|
||||
))
|
||||
)
|
||||
@@ -213,7 +228,7 @@ const ModelParameterModal: FC<ModelParameterModalProps> = ({
|
||||
)
|
||||
}
|
||||
{
|
||||
!parameterRules.length && isLoading && (
|
||||
!parameterRules.length && isRulesLoading && (
|
||||
<div className="px-4 py-5"><Loading /></div>
|
||||
)
|
||||
}
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
import type { ModelParameterRule } from '../declarations'
|
||||
import { useEffect, useRef, useState } from 'react'
|
||||
import type {
|
||||
Node,
|
||||
NodeOutPutVar,
|
||||
} from '@/app/components/workflow/types'
|
||||
import { useEffect, useMemo, useRef, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import PromptEditor from '@/app/components/base/prompt-editor'
|
||||
import Radio from '@/app/components/base/radio'
|
||||
import Slider from '@/app/components/base/slider'
|
||||
import Switch from '@/app/components/base/switch'
|
||||
import TagInput from '@/app/components/base/tag-input'
|
||||
import { Select, SelectContent, SelectItem, SelectTrigger, SelectValue } from '@/app/components/base/ui/select'
|
||||
import { Tooltip, TooltipContent, TooltipTrigger } from '@/app/components/base/ui/tooltip'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import { useLanguage } from '../hooks'
|
||||
import { isNullOrUndefined } from '../utils'
|
||||
@@ -18,18 +25,43 @@ type ParameterItemProps = {
|
||||
onChange?: (value: ParameterValue) => void
|
||||
onSwitch?: (checked: boolean, assignValue: ParameterValue) => void
|
||||
isInWorkflow?: boolean
|
||||
nodesOutputVars?: NodeOutPutVar[]
|
||||
availableNodes?: Node[]
|
||||
}
|
||||
|
||||
function ParameterItem({
|
||||
parameterRule,
|
||||
value,
|
||||
onChange,
|
||||
onSwitch,
|
||||
isInWorkflow,
|
||||
nodesOutputVars,
|
||||
availableNodes = [],
|
||||
}: ParameterItemProps) {
|
||||
const { t } = useTranslation()
|
||||
const language = useLanguage()
|
||||
const [localValue, setLocalValue] = useState(value)
|
||||
const numberInputRef = useRef<HTMLInputElement>(null)
|
||||
|
||||
const workflowNodesMap = useMemo(() => {
|
||||
if (!isInWorkflow || !availableNodes.length)
|
||||
return undefined
|
||||
|
||||
return availableNodes.reduce<Record<string, Pick<Node['data'], 'title' | 'type'>>>((acc, node) => {
|
||||
acc[node.id] = {
|
||||
title: node.data.title,
|
||||
type: node.data.type,
|
||||
}
|
||||
if (node.data.type === BlockEnum.Start) {
|
||||
acc.sys = {
|
||||
title: t('blocks.start', { ns: 'workflow' }),
|
||||
type: BlockEnum.Start,
|
||||
}
|
||||
}
|
||||
return acc
|
||||
}, {})
|
||||
}, [availableNodes, isInWorkflow, t])
|
||||
|
||||
const getDefaultValue = () => {
|
||||
let defaultValue: ParameterValue
|
||||
|
||||
@@ -196,6 +228,25 @@ function ParameterItem({
|
||||
}
|
||||
|
||||
if (parameterRule.type === 'string' && !parameterRule.options?.length) {
|
||||
if (isInWorkflow && nodesOutputVars) {
|
||||
return (
|
||||
<div className="ml-4 w-[200px] rounded-lg bg-components-input-bg-normal px-2 py-1">
|
||||
<PromptEditor
|
||||
compact
|
||||
className="min-h-[22px] text-[13px]"
|
||||
value={renderValue as string}
|
||||
onChange={(text) => { handleInputChange(text) }}
|
||||
workflowVariableBlock={{
|
||||
show: true,
|
||||
variables: nodesOutputVars,
|
||||
workflowNodesMap,
|
||||
}}
|
||||
editable
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<input
|
||||
className={cn(isInWorkflow ? 'w-[150px]' : 'w-full', 'ml-4 flex h-8 appearance-none items-center rounded-lg bg-components-input-bg-normal px-3 text-components-input-text-filled outline-none system-sm-regular')}
|
||||
@@ -206,6 +257,25 @@ function ParameterItem({
|
||||
}
|
||||
|
||||
if (parameterRule.type === 'text') {
|
||||
if (isInWorkflow && nodesOutputVars) {
|
||||
return (
|
||||
<div className="ml-4 w-full rounded-lg bg-components-input-bg-normal px-2 py-1">
|
||||
<PromptEditor
|
||||
compact
|
||||
className="min-h-[56px] text-[13px]"
|
||||
value={renderValue as string}
|
||||
onChange={(text) => { handleInputChange(text) }}
|
||||
workflowVariableBlock={{
|
||||
show: true,
|
||||
variables: nodesOutputVars,
|
||||
workflowNodesMap,
|
||||
}}
|
||||
editable
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
return (
|
||||
<textarea
|
||||
className="ml-4 h-20 w-full rounded-lg bg-components-input-bg-normal px-1 text-components-input-text-filled system-sm-regular"
|
||||
@@ -215,7 +285,7 @@ function ParameterItem({
|
||||
)
|
||||
}
|
||||
|
||||
if (parameterRule.type === 'string' && !!parameterRule?.options?.length) {
|
||||
if (parameterRule.type === 'string' && !!parameterRule.options?.length) {
|
||||
return (
|
||||
<Select
|
||||
value={renderValue as string}
|
||||
|
||||
@@ -131,6 +131,8 @@ const Panel: FC<NodePanelProps<LLMNodeType>> = ({
|
||||
hideDebugWithMultipleModel
|
||||
debugWithMultipleModel={false}
|
||||
readonly={readOnly}
|
||||
nodesOutputVars={availableVars}
|
||||
availableNodes={availableNodesWithParent}
|
||||
/>
|
||||
</Field>
|
||||
|
||||
|
||||
@@ -75,6 +75,8 @@ const Panel: FC<NodePanelProps<ParameterExtractorNodeType>> = ({
|
||||
hideDebugWithMultipleModel
|
||||
debugWithMultipleModel={false}
|
||||
readonly={readOnly}
|
||||
nodesOutputVars={availableVars}
|
||||
availableNodes={availableNodesWithParent}
|
||||
/>
|
||||
</Field>
|
||||
<Field
|
||||
|
||||
@@ -64,6 +64,8 @@ const Panel: FC<NodePanelProps<QuestionClassifierNodeType>> = ({
|
||||
hideDebugWithMultipleModel
|
||||
debugWithMultipleModel={false}
|
||||
readonly={readOnly}
|
||||
nodesOutputVars={availableVars}
|
||||
availableNodes={availableNodesWithParent}
|
||||
/>
|
||||
</Field>
|
||||
<Field
|
||||
|
||||
Reference in New Issue
Block a user