mirror of
https://github.com/langgenius/dify.git
synced 2026-01-06 06:26:00 +00:00
refactor: tool engine
This commit is contained in:
@@ -123,7 +123,6 @@ class AppApi(Resource):
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
agent_callback=None
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
|
||||
@@ -58,7 +58,6 @@ class ModelConfigResource(Resource):
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
agent_callback=None
|
||||
)
|
||||
manager = ToolParameterConfigurationManager(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
@@ -96,7 +95,6 @@ class ModelConfigResource(Resource):
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
agent_tool=agent_tool_entity,
|
||||
agent_callback=None
|
||||
)
|
||||
except Exception as e:
|
||||
continue
|
||||
|
||||
@@ -10,12 +10,10 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AgentChatAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
ModelConfigWithCredentialsEntity,
|
||||
)
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.file.message_file_parser import FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
@@ -32,7 +30,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolInvokeMessage,
|
||||
ToolInvokeMessageBinary,
|
||||
ToolParameter,
|
||||
ToolRuntimeVariablePool,
|
||||
)
|
||||
@@ -40,7 +37,7 @@ from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message, MessageAgentThought, MessageFile
|
||||
from models.model import Message, MessageAgentThought
|
||||
from models.tools import ToolConversationVariables
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -156,7 +153,6 @@ class BaseAgentRunner(AppRunner):
|
||||
tool_entity = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=self.tenant_id,
|
||||
agent_tool=tool,
|
||||
agent_callback=self.agent_callback
|
||||
)
|
||||
tool_entity.load_variables(self.variables_pool)
|
||||
|
||||
@@ -270,87 +266,6 @@ class BaseAgentRunner(AppRunner):
|
||||
prompt_tool.parameters['required'].append(parameter.name)
|
||||
|
||||
return prompt_tool
|
||||
|
||||
def extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
result = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream'),
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream'),
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# check if there is a mime type in meta
|
||||
if response.meta and 'mime_type' in response.meta:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream',
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[tuple[MessageFile, bool]]:
|
||||
"""
|
||||
Create message file
|
||||
|
||||
:param messages: messages
|
||||
:return: message files, should save as variable
|
||||
"""
|
||||
result = []
|
||||
|
||||
for message in messages:
|
||||
file_type = 'bin'
|
||||
if 'image' in message.mimetype:
|
||||
file_type = 'image'
|
||||
elif 'video' in message.mimetype:
|
||||
file_type = 'video'
|
||||
elif 'audio' in message.mimetype:
|
||||
file_type = 'audio'
|
||||
elif 'text' in message.mimetype:
|
||||
file_type = 'text'
|
||||
elif 'pdf' in message.mimetype:
|
||||
file_type = 'pdf'
|
||||
elif 'zip' in message.mimetype:
|
||||
file_type = 'archive'
|
||||
# ...
|
||||
|
||||
invoke_from = self.application_generate_entity.invoke_from
|
||||
|
||||
message_file = MessageFile(
|
||||
message_id=self.message.id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE.value,
|
||||
belongs_to='assistant',
|
||||
url=message.url,
|
||||
upload_file_id=None,
|
||||
created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
|
||||
created_by=self.user_id,
|
||||
)
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(message_file)
|
||||
|
||||
result.append((
|
||||
message_file,
|
||||
message.save_as
|
||||
))
|
||||
|
||||
db.session.close()
|
||||
|
||||
return result
|
||||
|
||||
def create_agent_thought(self, message_id: str, message: str,
|
||||
tool_name: str, tool_input: str, messages_ids: list[str]
|
||||
@@ -500,8 +415,12 @@ class BaseAgentRunner(AppRunner):
|
||||
try:
|
||||
tool_inputs = json.loads(agent_thought.tool_input)
|
||||
except Exception as e:
|
||||
logging.warning("tool execution error: {}, tool_input: {}.".format(str(e), agent_thought.tool_input))
|
||||
tool_inputs = { agent_thought.tool: agent_thought.tool_input }
|
||||
tool_inputs = { tool: {} for tool in tools }
|
||||
try:
|
||||
tool_responses = json.loads(agent_thought.observation)
|
||||
except Exception as e:
|
||||
tool_responses = { tool: agent_thought.observation for tool in tools }
|
||||
|
||||
for tool in tools:
|
||||
# generate a uuid for tool call
|
||||
tool_call_id = str(uuid.uuid4())
|
||||
@@ -514,7 +433,7 @@ class BaseAgentRunner(AppRunner):
|
||||
)
|
||||
))
|
||||
tool_call_response.append(ToolPromptMessage(
|
||||
content=agent_thought.observation,
|
||||
content=tool_responses.get(tool, agent_thought.observation),
|
||||
name=tool,
|
||||
tool_call_id=tool_call_id,
|
||||
))
|
||||
|
||||
@@ -17,15 +17,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.tools.errors import (
|
||||
ToolInvokeError,
|
||||
ToolNotFoundError,
|
||||
ToolNotSupportedError,
|
||||
ToolParameterValidationError,
|
||||
ToolProviderCredentialValidationError,
|
||||
ToolProviderNotFoundError,
|
||||
)
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Conversation, Message
|
||||
|
||||
|
||||
@@ -267,60 +259,47 @@ class CotAgentRunner(BaseAgentRunner):
|
||||
agent_thought_id=agent_thought.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
else:
|
||||
if isinstance(tool_call_args, str):
|
||||
try:
|
||||
tool_call_args = json.loads(tool_call_args)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
|
||||
# invoke tool
|
||||
error_response = None
|
||||
try:
|
||||
if isinstance(tool_call_args, str):
|
||||
try:
|
||||
tool_call_args = json.loads(tool_call_args)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
tool_invoke_response, message_files = ToolEngine.agent_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_call_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=self.message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback
|
||||
)
|
||||
# publish files
|
||||
for message_file, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
|
||||
|
||||
tool_response = tool_instance.invoke(
|
||||
user_id=self.user_id,
|
||||
tool_parameters=tool_call_args
|
||||
)
|
||||
# transform tool response to llm friendly response
|
||||
tool_response = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=tool_response,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=self.message.conversation_id
|
||||
)
|
||||
# extract binary data from tool invoke message
|
||||
binary_files = self.extract_tool_response_binary(tool_response)
|
||||
# create message file
|
||||
message_files = self.create_message_files(binary_files)
|
||||
# publish files
|
||||
for message_file, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name,
|
||||
value=message_file.id,
|
||||
name=save_as)
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
# publish message file
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file.id)
|
||||
|
||||
message_file_ids = [message_file.id for message_file, _ in message_files]
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = "Please check your tool provider credentials"
|
||||
except (
|
||||
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
) as e:
|
||||
error_response = f"there is not a tool named {tool_call_name}"
|
||||
except (
|
||||
ToolParameterValidationError
|
||||
) as e:
|
||||
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
|
||||
except ToolInvokeError as e:
|
||||
error_response = f"tool invoke error: {e}"
|
||||
except Exception as e:
|
||||
error_response = f"unknown error: {e}"
|
||||
# publish files
|
||||
for message_file, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name,
|
||||
value=message_file.id,
|
||||
name=save_as)
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
if error_response:
|
||||
observation = error_response
|
||||
else:
|
||||
observation = self._convert_tool_response_to_str(tool_response)
|
||||
message_file_ids = [message_file.id for message_file, _ in message_files]
|
||||
|
||||
observation = tool_invoke_response
|
||||
|
||||
# save scratchpad
|
||||
scratchpad.observation = observation
|
||||
|
||||
@@ -15,15 +15,7 @@ from core.model_runtime.entities.message_entities import (
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.tools.errors import (
|
||||
ToolInvokeError,
|
||||
ToolNotFoundError,
|
||||
ToolNotSupportedError,
|
||||
ToolParameterValidationError,
|
||||
ToolProviderCredentialValidationError,
|
||||
ToolProviderNotFoundError,
|
||||
)
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from models.model import Conversation, Message, MessageAgentThought
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -261,70 +253,37 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": f"there is not a tool named {tool_call_name}"
|
||||
}
|
||||
tool_responses.append(tool_response)
|
||||
else:
|
||||
# invoke tool
|
||||
error_response = None
|
||||
try:
|
||||
tool_invoke_message = tool_instance.invoke(
|
||||
user_id=self.user_id,
|
||||
tool_parameters=tool_call_args,
|
||||
)
|
||||
# transform tool invoke message to get LLM friendly message
|
||||
tool_invoke_message = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=tool_invoke_message,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
conversation_id=self.message.conversation_id
|
||||
)
|
||||
# extract binary data from tool invoke message
|
||||
binary_files = self.extract_tool_response_binary(tool_invoke_message)
|
||||
# create message file
|
||||
message_files = self.create_message_files(binary_files)
|
||||
# publish files
|
||||
for message_file, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file.id)
|
||||
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = "Please check your tool provider credentials"
|
||||
except (
|
||||
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
) as e:
|
||||
error_response = f"there is not a tool named {tool_call_name}"
|
||||
except (
|
||||
ToolParameterValidationError
|
||||
) as e:
|
||||
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
|
||||
except ToolInvokeError as e:
|
||||
error_response = f"tool invoke error: {e}"
|
||||
except Exception as e:
|
||||
error_response = f"unknown error: {e}"
|
||||
|
||||
if error_response:
|
||||
observation = error_response
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": error_response
|
||||
}
|
||||
tool_responses.append(tool_response)
|
||||
else:
|
||||
observation = self._convert_tool_response_to_str(tool_invoke_message)
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": observation
|
||||
}
|
||||
tool_responses.append(tool_response)
|
||||
tool_invoke_response, message_files = ToolEngine.agent_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_call_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=self.message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
)
|
||||
# publish files
|
||||
for message_file, save_as in message_files:
|
||||
if save_as:
|
||||
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
|
||||
|
||||
# publish message file
|
||||
self.queue_manager.publish(QueueMessageFileEvent(
|
||||
message_file_id=message_file.id
|
||||
), PublishFrom.APPLICATION_MANAGER)
|
||||
# add message file ids
|
||||
message_file_ids.append(message_file.id)
|
||||
|
||||
observation = tool_invoke_response
|
||||
tool_response = {
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_call_name": tool_call_name,
|
||||
"tool_response": observation
|
||||
}
|
||||
|
||||
tool_responses.append(tool_response)
|
||||
prompt_messages = self.organize_prompt_messages(
|
||||
prompt_template=prompt_template,
|
||||
query=None,
|
||||
@@ -341,7 +300,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
observation=tool_response['tool_response'],
|
||||
observation=json.dumps({
|
||||
tool_response['tool_call_name']: tool_response['tool_response']
|
||||
for tool_response in tool_responses
|
||||
}),
|
||||
answer=None,
|
||||
messages_ids=message_file_ids
|
||||
)
|
||||
|
||||
@@ -4,7 +4,6 @@ from typing import Any, Optional, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolDescription,
|
||||
ToolIdentity,
|
||||
@@ -22,8 +21,6 @@ class Tool(BaseModel, ABC):
|
||||
parameters: Optional[list[ToolParameter]] = None
|
||||
description: ToolDescription = None
|
||||
is_team_authorization: bool = False
|
||||
agent_callback: Optional[DifyAgentCallbackHandler] = None
|
||||
use_callback: bool = False
|
||||
|
||||
class Runtime(BaseModel):
|
||||
"""
|
||||
@@ -45,15 +42,10 @@ class Tool(BaseModel, ABC):
|
||||
def __init__(self, **data: Any):
|
||||
super().__init__(**data)
|
||||
|
||||
if not self.agent_callback:
|
||||
self.use_callback = False
|
||||
else:
|
||||
self.use_callback = True
|
||||
|
||||
class VARIABLE_KEY(Enum):
|
||||
IMAGE = 'image'
|
||||
|
||||
def fork_tool_runtime(self, meta: dict[str, Any], agent_callback: DifyAgentCallbackHandler = None) -> 'Tool':
|
||||
def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
|
||||
"""
|
||||
fork a new tool with meta data
|
||||
|
||||
@@ -65,7 +57,6 @@ class Tool(BaseModel, ABC):
|
||||
parameters=self.parameters.copy() if self.parameters else None,
|
||||
description=self.description.copy() if self.description else None,
|
||||
runtime=Tool.Runtime(**meta),
|
||||
agent_callback=agent_callback
|
||||
)
|
||||
|
||||
def load_variables(self, variables: ToolRuntimeVariablePool):
|
||||
@@ -174,50 +165,19 @@ class Tool(BaseModel, ABC):
|
||||
|
||||
return result
|
||||
|
||||
def invoke(self, user_id: str, tool_parameters: Union[dict[str, Any], str]) -> list[ToolInvokeMessage]:
|
||||
# check if tool_parameters is a string
|
||||
if isinstance(tool_parameters, str):
|
||||
# check if this tool has only one parameter
|
||||
parameters = [parameter for parameter in self.parameters if parameter.form == ToolParameter.ToolParameterForm.LLM]
|
||||
if parameters and len(parameters) == 1:
|
||||
tool_parameters = {
|
||||
parameters[0].name: tool_parameters
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
|
||||
|
||||
def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
|
||||
# update tool_parameters
|
||||
if self.runtime.runtime_parameters:
|
||||
tool_parameters.update(self.runtime.runtime_parameters)
|
||||
|
||||
# hit callback
|
||||
if self.use_callback:
|
||||
self.agent_callback.on_tool_start(
|
||||
tool_name=self.identity.name,
|
||||
tool_inputs=tool_parameters
|
||||
)
|
||||
|
||||
try:
|
||||
result = self._invoke(
|
||||
user_id=user_id,
|
||||
tool_parameters=tool_parameters,
|
||||
)
|
||||
except Exception as e:
|
||||
if self.use_callback:
|
||||
self.agent_callback.on_tool_error(e)
|
||||
raise e
|
||||
result = self._invoke(
|
||||
user_id=user_id,
|
||||
tool_parameters=tool_parameters,
|
||||
)
|
||||
|
||||
if not isinstance(result, list):
|
||||
result = [result]
|
||||
|
||||
# hit callback
|
||||
if self.use_callback:
|
||||
self.agent_callback.on_tool_end(
|
||||
tool_name=self.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=self._convert_tool_response_to_str(result)
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:
|
||||
|
||||
218
api/core/tools/tool_engine.py
Normal file
218
api/core/tools/tool_engine.py
Normal file
@@ -0,0 +1,218 @@
|
||||
from typing import Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.file.file_obj import FileTransferMethod
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolParameter
|
||||
from core.tools.errors import (
|
||||
ToolInvokeError,
|
||||
ToolNotFoundError,
|
||||
ToolNotSupportedError,
|
||||
ToolParameterValidationError,
|
||||
ToolProviderCredentialValidationError,
|
||||
ToolProviderNotFoundError,
|
||||
)
|
||||
from core.tools.tool.tool import Tool
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message, MessageFile
|
||||
|
||||
|
||||
class ToolEngine:
|
||||
"""
|
||||
Tool runtime engine take care of the tool executions.
|
||||
"""
|
||||
@staticmethod
|
||||
def agent_invoke(tool: Tool, tool_parameters: Union[str, dict],
|
||||
user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom,
|
||||
agent_tool_callback: DifyAgentCallbackHandler) \
|
||||
-> tuple[str, list[tuple[MessageFile, bool]]]:
|
||||
"""
|
||||
Agent invokes the tool with the given arguments.
|
||||
"""
|
||||
# check if arguments is a string
|
||||
if isinstance(tool_parameters, str):
|
||||
# check if this tool has only one parameter
|
||||
parameters = [
|
||||
parameter for parameter in tool.parameters
|
||||
if parameter.form == ToolParameter.ToolParameterForm.LLM
|
||||
]
|
||||
if parameters and len(parameters) == 1:
|
||||
tool_parameters = {
|
||||
parameters[0].name: tool_parameters
|
||||
}
|
||||
else:
|
||||
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
|
||||
|
||||
# invoke the tool
|
||||
try:
|
||||
# hit the callback handler
|
||||
agent_tool_callback.on_tool_start(
|
||||
tool_name=tool.identity.name,
|
||||
tool_inputs=tool_parameters
|
||||
)
|
||||
|
||||
response = tool.invoke(user_id, tool_parameters)
|
||||
|
||||
response = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=response,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=message.conversation_id
|
||||
)
|
||||
|
||||
# extract binary data from tool invoke message
|
||||
binary_files = ToolEngine._extract_tool_response_binary(response)
|
||||
# create message file
|
||||
message_files = ToolEngine._create_message_files(
|
||||
tool_messages=binary_files,
|
||||
agent_message=message,
|
||||
invoke_from=invoke_from,
|
||||
user_id=user_id
|
||||
)
|
||||
|
||||
plain_text = ToolEngine._convert_tool_response_to_str(response)
|
||||
|
||||
# hit the callback handler
|
||||
agent_tool_callback.on_tool_end(
|
||||
tool_name=tool.identity.name,
|
||||
tool_inputs=tool_parameters,
|
||||
tool_outputs=plain_text
|
||||
)
|
||||
|
||||
# transform tool invoke message to get LLM friendly message
|
||||
return plain_text, message_files
|
||||
except ToolProviderCredentialValidationError as e:
|
||||
error_response = "Please check your tool provider credentials"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except (
|
||||
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
|
||||
) as e:
|
||||
error_response = f"there is not a tool named {tool.identity.name}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except (
|
||||
ToolParameterValidationError
|
||||
) as e:
|
||||
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except ToolInvokeError as e:
|
||||
error_response = f"tool invoke error: {e}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
except Exception as e:
|
||||
error_response = f"unknown error: {e}"
|
||||
agent_tool_callback.on_tool_error(e)
|
||||
|
||||
return error_response, []
|
||||
|
||||
@staticmethod
|
||||
def workflow_invoke(tool: Tool, tool_parameters: dict,
|
||||
user_id: str, workflow_id: str) -> dict:
|
||||
"""
|
||||
Workflow invokes the tool with the given arguments.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
|
||||
"""
|
||||
Handle tool response
|
||||
"""
|
||||
result = ''
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
result += response.message
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
result += f"result link: {response.message}. please tell user to check it."
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
|
||||
else:
|
||||
result += f"tool response: {response.message}."
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
|
||||
"""
|
||||
Extract tool response binary
|
||||
"""
|
||||
result = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
|
||||
response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream'),
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream'),
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# check if there is a mime type in meta
|
||||
if response.meta and 'mime_type' in response.meta:
|
||||
result.append(ToolInvokeMessageBinary(
|
||||
mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream',
|
||||
url=response.message,
|
||||
save_as=response.save_as,
|
||||
))
|
||||
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _create_message_files(
|
||||
tool_messages: list[ToolInvokeMessageBinary],
|
||||
agent_message: Message,
|
||||
invoke_from: InvokeFrom,
|
||||
user_id: str
|
||||
) -> list[tuple[MessageFile, bool]]:
|
||||
"""
|
||||
Create message file
|
||||
|
||||
:param messages: messages
|
||||
:return: message files, should save as variable
|
||||
"""
|
||||
result = []
|
||||
|
||||
for message in tool_messages:
|
||||
file_type = 'bin'
|
||||
if 'image' in message.mimetype:
|
||||
file_type = 'image'
|
||||
elif 'video' in message.mimetype:
|
||||
file_type = 'video'
|
||||
elif 'audio' in message.mimetype:
|
||||
file_type = 'audio'
|
||||
elif 'text' in message.mimetype:
|
||||
file_type = 'text'
|
||||
elif 'pdf' in message.mimetype:
|
||||
file_type = 'pdf'
|
||||
elif 'zip' in message.mimetype:
|
||||
file_type = 'archive'
|
||||
# ...
|
||||
|
||||
message_file = MessageFile(
|
||||
message_id=agent_message.id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE.value,
|
||||
belongs_to='assistant',
|
||||
url=message.url,
|
||||
upload_file_id=None,
|
||||
created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
|
||||
created_by=user_id,
|
||||
)
|
||||
|
||||
db.session.add(message_file)
|
||||
db.session.commit()
|
||||
db.session.refresh(message_file)
|
||||
|
||||
result.append((
|
||||
message_file,
|
||||
message.save_as
|
||||
))
|
||||
|
||||
db.session.close()
|
||||
|
||||
return result
|
||||
@@ -5,7 +5,6 @@ from os import listdir, path
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.common_entities import I18nObject
|
||||
@@ -139,8 +138,7 @@ class ToolManager:
|
||||
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
|
||||
|
||||
@staticmethod
|
||||
def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str,
|
||||
agent_callback: DifyAgentCallbackHandler = None) \
|
||||
def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str) \
|
||||
-> Union[BuiltinTool, ApiTool]:
|
||||
"""
|
||||
get the tool runtime
|
||||
@@ -160,7 +158,7 @@ class ToolManager:
|
||||
return builtin_tool.fork_tool_runtime(meta={
|
||||
'tenant_id': tenant_id,
|
||||
'credentials': {},
|
||||
}, agent_callback=agent_callback)
|
||||
})
|
||||
|
||||
# get credentials
|
||||
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
|
||||
@@ -182,7 +180,7 @@ class ToolManager:
|
||||
'tenant_id': tenant_id,
|
||||
'credentials': decrypted_credentials,
|
||||
'runtime_parameters': {}
|
||||
}, agent_callback=agent_callback)
|
||||
})
|
||||
|
||||
elif provider_type == 'api':
|
||||
if tenant_id is None:
|
||||
@@ -259,14 +257,13 @@ class ToolManager:
|
||||
return parameter_value
|
||||
|
||||
@staticmethod
|
||||
def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool:
|
||||
def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity) -> Tool:
|
||||
"""
|
||||
get the agent tool runtime
|
||||
"""
|
||||
tool_entity = ToolManager.get_tool_runtime(
|
||||
provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name,
|
||||
tenant_id=tenant_id,
|
||||
agent_callback=agent_callback
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
@@ -289,7 +286,7 @@ class ToolManager:
|
||||
return tool_entity
|
||||
|
||||
@staticmethod
|
||||
def get_workflow_tool_runtime(tenant_id: str, workflow_tool: ToolEntity, agent_callback: DifyAgentCallbackHandler):
|
||||
def get_workflow_tool_runtime(tenant_id: str, workflow_tool: ToolEntity):
|
||||
"""
|
||||
get the workflow tool runtime
|
||||
"""
|
||||
@@ -298,7 +295,6 @@ class ToolManager:
|
||||
provider_name=workflow_tool.provider_id,
|
||||
tool_name=workflow_tool.tool_name,
|
||||
tenant_id=tenant_id,
|
||||
agent_callback=agent_callback
|
||||
)
|
||||
runtime_parameters = {}
|
||||
parameters = tool_entity.get_all_runtime_parameters()
|
||||
@@ -364,12 +360,16 @@ class ToolManager:
|
||||
continue
|
||||
|
||||
# init provider
|
||||
provider_class = load_single_subclass_from_source(
|
||||
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||
'provider', 'builtin', provider, f'{provider}.py'),
|
||||
parent_type=BuiltinToolProviderController)
|
||||
builtin_providers.append(provider_class())
|
||||
try:
|
||||
provider_class = load_single_subclass_from_source(
|
||||
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
|
||||
script_path=path.join(path.dirname(path.realpath(__file__)),
|
||||
'provider', 'builtin', provider, f'{provider}.py'),
|
||||
parent_type=BuiltinToolProviderController)
|
||||
builtin_providers.append(provider_class())
|
||||
except Exception as e:
|
||||
logger.error(f'load builtin provider {provider} error: {e}')
|
||||
continue
|
||||
|
||||
# cache the builtin providers
|
||||
for provider in builtin_providers:
|
||||
|
||||
Reference in New Issue
Block a user