mirror of
https://github.com/langgenius/dify.git
synced 2026-01-08 07:14:14 +00:00
fix: tool entities
This commit is contained in:
@@ -1,4 +1,4 @@
|
||||
from typing import Literal, Optional, Union
|
||||
from typing import Literal, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
@@ -16,24 +16,22 @@ class ToolEntity(BaseModel):
|
||||
|
||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
class ToolInput(BaseModel):
|
||||
value_type: Literal['variable', 'static']
|
||||
static_value: Optional[Union[int, float, str]]
|
||||
variable_value: Optional[Union[str, list[str]]]
|
||||
parameter_name: str
|
||||
type: Literal['mixed', 'variable', 'constant']
|
||||
value: Union[ToolParameterValue, list[str]]
|
||||
|
||||
@validator('value_type', pre=True, always=True)
|
||||
def check_value_type(cls, value, values):
|
||||
if value == 'variable':
|
||||
# check if template_value is None
|
||||
if values.get('variable_value') is not None:
|
||||
raise ValueError('template_value must be None for value_type variable')
|
||||
elif value == 'static':
|
||||
# check if static_value is None
|
||||
if values.get('static_value') is None:
|
||||
raise ValueError('static_value must be provided for value_type static')
|
||||
@validator('type', pre=True, always=True)
|
||||
def check_type(cls, value, values):
|
||||
typ = value
|
||||
value = values.get('value')
|
||||
if typ == 'mixed' and not isinstance(value, str):
|
||||
raise ValueError('value must be a string')
|
||||
elif typ == 'variable' and not isinstance(value, list):
|
||||
raise ValueError('value must be a list')
|
||||
elif typ == 'constant' and not isinstance(value, ToolParameterValue):
|
||||
raise ValueError('value must be a string, int, float, or bool')
|
||||
return value
|
||||
|
||||
|
||||
"""
|
||||
Tool Node Schema
|
||||
"""
|
||||
tool_parameters: list[ToolInput]
|
||||
tool_parameters: dict[str, ToolInput]
|
||||
|
||||
@@ -88,24 +88,27 @@ class ToolNode(BaseNode):
|
||||
Generate parameters
|
||||
"""
|
||||
result = {}
|
||||
for parameter in node_data.tool_parameters:
|
||||
if parameter.value_type == 'static':
|
||||
result[parameter.parameter_name] = parameter.static_value
|
||||
else:
|
||||
if isinstance(parameter.variable_value, str):
|
||||
parser = VariableTemplateParser(parameter.variable_value)
|
||||
variable_selectors = parser.extract_variable_selectors()
|
||||
values = {
|
||||
selector.variable: variable_pool.get_variable_value(selector)
|
||||
for selector in variable_selectors
|
||||
}
|
||||
|
||||
# if multiple values, use the parser to format the values into a string
|
||||
result[parameter.parameter_name] = parser.format(values)
|
||||
elif isinstance(parameter.variable_value, list):
|
||||
result[parameter.parameter_name] = variable_pool.get_variable_value(parameter.variable_value)
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
input = node_data.tool_parameters[parameter_name]
|
||||
if input.type == 'mixed':
|
||||
result[parameter_name] = self._format_variable_template(input.value, variable_pool)
|
||||
elif input.type == 'variable':
|
||||
result[parameter_name] = variable_pool.get_variable_value(input.value)
|
||||
elif input.type == 'constant':
|
||||
result[parameter_name] = input.value
|
||||
|
||||
return result
|
||||
|
||||
def _format_variable_template(self, template: str, variable_pool: VariablePool) -> str:
|
||||
"""
|
||||
Format variable template
|
||||
"""
|
||||
inputs = {}
|
||||
template_parser = VariableTemplateParser(template)
|
||||
for selector in template_parser.extract_variable_selectors():
|
||||
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
|
||||
|
||||
return template_parser.format(inputs)
|
||||
|
||||
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
|
||||
"""
|
||||
@@ -184,14 +187,15 @@ class ToolNode(BaseNode):
|
||||
:return:
|
||||
"""
|
||||
result = {}
|
||||
for parameter in node_data.tool_parameters:
|
||||
if parameter.value_type == 'variable':
|
||||
if isinstance(parameter.variable_value, str):
|
||||
parser = VariableTemplateParser(parameter.variable_value)
|
||||
variable_selectors = parser.extract_variable_selectors()
|
||||
for selector in variable_selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif isinstance(parameter.variable_value, list):
|
||||
result[parameter.parameter_name] = parameter.variable_value
|
||||
for parameter_name in node_data.tool_parameters:
|
||||
input = node_data.tool_parameters[parameter_name]
|
||||
if input.type == 'mixed':
|
||||
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == 'variable':
|
||||
result[parameter_name] = input.value
|
||||
elif input.type == 'constant':
|
||||
pass
|
||||
|
||||
return result
|
||||
return result
|
||||
|
||||
Reference in New Issue
Block a user