fix: tool entities

This commit is contained in:
Yeuoly
2024-04-01 16:43:10 +08:00
parent df9e2e478f
commit 5b81234db8
2 changed files with 45 additions and 43 deletions

View File

@@ -1,4 +1,4 @@
from typing import Literal, Optional, Union
from typing import Literal, Union
from pydantic import BaseModel, validator
@@ -16,24 +16,22 @@ class ToolEntity(BaseModel):
class ToolNodeData(BaseNodeData, ToolEntity):
class ToolInput(BaseModel):
value_type: Literal['variable', 'static']
static_value: Optional[Union[int, float, str]]
variable_value: Optional[Union[str, list[str]]]
parameter_name: str
type: Literal['mixed', 'variable', 'constant']
value: Union[ToolParameterValue, list[str]]
@validator('value_type', pre=True, always=True)
def check_value_type(cls, value, values):
if value == 'variable':
# check if template_value is None
if values.get('variable_value') is not None:
raise ValueError('template_value must be None for value_type variable')
elif value == 'static':
# check if static_value is None
if values.get('static_value') is None:
raise ValueError('static_value must be provided for value_type static')
@validator('type', pre=True, always=True)
def check_type(cls, value, values):
typ = value
value = values.get('value')
if typ == 'mixed' and not isinstance(value, str):
raise ValueError('value must be a string')
elif typ == 'variable' and not isinstance(value, list):
raise ValueError('value must be a list')
elif typ == 'constant' and not isinstance(value, ToolParameterValue):
raise ValueError('value must be a string, int, float, or bool')
return value
"""
Tool Node Schema
"""
tool_parameters: list[ToolInput]
tool_parameters: dict[str, ToolInput]

View File

@@ -88,24 +88,27 @@ class ToolNode(BaseNode):
Generate parameters
"""
result = {}
for parameter in node_data.tool_parameters:
if parameter.value_type == 'static':
result[parameter.parameter_name] = parameter.static_value
else:
if isinstance(parameter.variable_value, str):
parser = VariableTemplateParser(parameter.variable_value)
variable_selectors = parser.extract_variable_selectors()
values = {
selector.variable: variable_pool.get_variable_value(selector)
for selector in variable_selectors
}
# if multiple values, use the parser to format the values into a string
result[parameter.parameter_name] = parser.format(values)
elif isinstance(parameter.variable_value, list):
result[parameter.parameter_name] = variable_pool.get_variable_value(parameter.variable_value)
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
if input.type == 'mixed':
result[parameter_name] = self._format_variable_template(input.value, variable_pool)
elif input.type == 'variable':
result[parameter_name] = variable_pool.get_variable_value(input.value)
elif input.type == 'constant':
result[parameter_name] = input.value
return result
def _format_variable_template(self, template: str, variable_pool: VariablePool) -> str:
"""
Format variable template
"""
inputs = {}
template_parser = VariableTemplateParser(template)
for selector in template_parser.extract_variable_selectors():
inputs[selector.variable] = variable_pool.get_variable_value(selector.value_selector)
return template_parser.format(inputs)
def _convert_tool_messages(self, messages: list[ToolInvokeMessage]) -> tuple[str, list[FileVar]]:
"""
@@ -184,14 +187,15 @@ class ToolNode(BaseNode):
:return:
"""
result = {}
for parameter in node_data.tool_parameters:
if parameter.value_type == 'variable':
if isinstance(parameter.variable_value, str):
parser = VariableTemplateParser(parameter.variable_value)
variable_selectors = parser.extract_variable_selectors()
for selector in variable_selectors:
result[selector.variable] = selector.value_selector
elif isinstance(parameter.variable_value, list):
result[parameter.parameter_name] = parameter.variable_value
for parameter_name in node_data.tool_parameters:
input = node_data.tool_parameters[parameter_name]
if input.type == 'mixed':
selectors = VariableTemplateParser(input.value).extract_variable_selectors()
for selector in selectors:
result[selector.variable] = selector.value_selector
elif input.type == 'variable':
result[parameter_name] = input.value
elif input.type == 'constant':
pass
return result
return result