refactor: tool engine

This commit is contained in:
Yeuoly
2024-03-28 18:36:58 +08:00
parent 82a82fff35
commit 51404f9035
8 changed files with 318 additions and 283 deletions

View File

@@ -123,7 +123,6 @@ class AppApi(Resource):
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,

View File

@@ -58,7 +58,6 @@ class ModelConfigResource(Resource):
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
manager = ToolParameterConfigurationManager(
tenant_id=current_user.current_tenant_id,
@@ -96,7 +95,6 @@ class ModelConfigResource(Resource):
tool_runtime = ToolManager.get_agent_tool_runtime(
tenant_id=current_user.current_tenant_id,
agent_tool=agent_tool_entity,
agent_callback=None
)
except Exception as e:
continue

View File

@@ -10,12 +10,10 @@ from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.base_app_runner import AppRunner
from core.app.entities.app_invoke_entities import (
AgentChatAppGenerateEntity,
InvokeFrom,
ModelConfigWithCredentialsEntity,
)
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
from core.file.message_file_parser import FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMUsage
@@ -32,7 +30,6 @@ from core.model_runtime.model_providers.__base.large_language_model import Large
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.entities.tool_entities import (
ToolInvokeMessage,
ToolInvokeMessageBinary,
ToolParameter,
ToolRuntimeVariablePool,
)
@@ -40,7 +37,7 @@ from core.tools.tool.dataset_retriever_tool import DatasetRetrieverTool
from core.tools.tool.tool import Tool
from core.tools.tool_manager import ToolManager
from extensions.ext_database import db
from models.model import Message, MessageAgentThought, MessageFile
from models.model import Message, MessageAgentThought
from models.tools import ToolConversationVariables
logger = logging.getLogger(__name__)
@@ -156,7 +153,6 @@ class BaseAgentRunner(AppRunner):
tool_entity = ToolManager.get_agent_tool_runtime(
tenant_id=self.tenant_id,
agent_tool=tool,
agent_callback=self.agent_callback
)
tool_entity.load_variables(self.variables_pool)
@@ -270,87 +266,6 @@ class BaseAgentRunner(AppRunner):
prompt_tool.parameters['required'].append(parameter.name)
return prompt_tool
def extract_tool_response_binary(self, tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
"""
Extract tool response binary
"""
result = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
result.append(ToolInvokeMessageBinary(
mimetype=response.meta.get('mime_type', 'octet/stream'),
url=response.message,
save_as=response.save_as,
))
elif response.type == ToolInvokeMessage.MessageType.BLOB:
result.append(ToolInvokeMessageBinary(
mimetype=response.meta.get('mime_type', 'octet/stream'),
url=response.message,
save_as=response.save_as,
))
elif response.type == ToolInvokeMessage.MessageType.LINK:
# check if there is a mime type in meta
if response.meta and 'mime_type' in response.meta:
result.append(ToolInvokeMessageBinary(
mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream',
url=response.message,
save_as=response.save_as,
))
return result
def create_message_files(self, messages: list[ToolInvokeMessageBinary]) -> list[tuple[MessageFile, bool]]:
"""
Create message file
:param messages: messages
:return: message files, should save as variable
"""
result = []
for message in messages:
file_type = 'bin'
if 'image' in message.mimetype:
file_type = 'image'
elif 'video' in message.mimetype:
file_type = 'video'
elif 'audio' in message.mimetype:
file_type = 'audio'
elif 'text' in message.mimetype:
file_type = 'text'
elif 'pdf' in message.mimetype:
file_type = 'pdf'
elif 'zip' in message.mimetype:
file_type = 'archive'
# ...
invoke_from = self.application_generate_entity.invoke_from
message_file = MessageFile(
message_id=self.message.id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE.value,
belongs_to='assistant',
url=message.url,
upload_file_id=None,
created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
created_by=self.user_id,
)
db.session.add(message_file)
db.session.commit()
db.session.refresh(message_file)
result.append((
message_file,
message.save_as
))
db.session.close()
return result
def create_agent_thought(self, message_id: str, message: str,
tool_name: str, tool_input: str, messages_ids: list[str]
@@ -500,8 +415,12 @@ class BaseAgentRunner(AppRunner):
try:
tool_inputs = json.loads(agent_thought.tool_input)
except Exception as e:
logging.warning("tool execution error: {}, tool_input: {}.".format(str(e), agent_thought.tool_input))
tool_inputs = { agent_thought.tool: agent_thought.tool_input }
tool_inputs = { tool: {} for tool in tools }
try:
tool_responses = json.loads(agent_thought.observation)
except Exception as e:
tool_responses = { tool: agent_thought.observation for tool in tools }
for tool in tools:
# generate a uuid for tool call
tool_call_id = str(uuid.uuid4())
@@ -514,7 +433,7 @@ class BaseAgentRunner(AppRunner):
)
))
tool_call_response.append(ToolPromptMessage(
content=agent_thought.observation,
content=tool_responses.get(tool, agent_thought.observation),
name=tool,
tool_call_id=tool_call_id,
))

View File

@@ -17,15 +17,7 @@ from core.model_runtime.entities.message_entities import (
UserPromptMessage,
)
from core.model_runtime.utils.encoders import jsonable_encoder
from core.tools.errors import (
ToolInvokeError,
ToolNotFoundError,
ToolNotSupportedError,
ToolParameterValidationError,
ToolProviderCredentialValidationError,
ToolProviderNotFoundError,
)
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.tools.tool_engine import ToolEngine
from models.model import Conversation, Message
@@ -267,60 +259,47 @@ class CotAgentRunner(BaseAgentRunner):
agent_thought_id=agent_thought.id
), PublishFrom.APPLICATION_MANAGER)
else:
if isinstance(tool_call_args, str):
try:
tool_call_args = json.loads(tool_call_args)
except json.JSONDecodeError:
pass
# invoke tool
error_response = None
try:
if isinstance(tool_call_args, str):
try:
tool_call_args = json.loads(tool_call_args)
except json.JSONDecodeError:
pass
tool_invoke_response, message_files = ToolEngine.agent_invoke(
tool=tool_instance,
tool_parameters=tool_call_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=self.message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback
)
# publish files
for message_file, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
tool_response = tool_instance.invoke(
user_id=self.user_id,
tool_parameters=tool_call_args
)
# transform tool response to llm friendly response
tool_response = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=tool_response,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=self.message.conversation_id
)
# extract binary data from tool invoke message
binary_files = self.extract_tool_response_binary(tool_response)
# create message file
message_files = self.create_message_files(binary_files)
# publish files
for message_file, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name,
value=message_file.id,
name=save_as)
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file.id
), PublishFrom.APPLICATION_MANAGER)
# publish message file
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file.id
), PublishFrom.APPLICATION_MANAGER)
# add message file ids
message_file_ids.append(message_file.id)
message_file_ids = [message_file.id for message_file, _ in message_files]
except ToolProviderCredentialValidationError as e:
error_response = "Please check your tool provider credentials"
except (
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
) as e:
error_response = f"there is not a tool named {tool_call_name}"
except (
ToolParameterValidationError
) as e:
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
except ToolInvokeError as e:
error_response = f"tool invoke error: {e}"
except Exception as e:
error_response = f"unknown error: {e}"
# publish files
for message_file, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name,
value=message_file.id,
name=save_as)
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file.id
), PublishFrom.APPLICATION_MANAGER)
if error_response:
observation = error_response
else:
observation = self._convert_tool_response_to_str(tool_response)
message_file_ids = [message_file.id for message_file, _ in message_files]
observation = tool_invoke_response
# save scratchpad
scratchpad.observation = observation

View File

@@ -15,15 +15,7 @@ from core.model_runtime.entities.message_entities import (
ToolPromptMessage,
UserPromptMessage,
)
from core.tools.errors import (
ToolInvokeError,
ToolNotFoundError,
ToolNotSupportedError,
ToolParameterValidationError,
ToolProviderCredentialValidationError,
ToolProviderNotFoundError,
)
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from core.tools.tool_engine import ToolEngine
from models.model import Conversation, Message, MessageAgentThought
logger = logging.getLogger(__name__)
@@ -261,70 +253,37 @@ class FunctionCallAgentRunner(BaseAgentRunner):
"tool_call_name": tool_call_name,
"tool_response": f"there is not a tool named {tool_call_name}"
}
tool_responses.append(tool_response)
else:
# invoke tool
error_response = None
try:
tool_invoke_message = tool_instance.invoke(
user_id=self.user_id,
tool_parameters=tool_call_args,
)
# transform tool invoke message to get LLM friendly message
tool_invoke_message = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=tool_invoke_message,
user_id=self.user_id,
tenant_id=self.tenant_id,
conversation_id=self.message.conversation_id
)
# extract binary data from tool invoke message
binary_files = self.extract_tool_response_binary(tool_invoke_message)
# create message file
message_files = self.create_message_files(binary_files)
# publish files
for message_file, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
# publish message file
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file.id
), PublishFrom.APPLICATION_MANAGER)
# add message file ids
message_file_ids.append(message_file.id)
except ToolProviderCredentialValidationError as e:
error_response = "Please check your tool provider credentials"
except (
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
) as e:
error_response = f"there is not a tool named {tool_call_name}"
except (
ToolParameterValidationError
) as e:
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
except ToolInvokeError as e:
error_response = f"tool invoke error: {e}"
except Exception as e:
error_response = f"unknown error: {e}"
if error_response:
observation = error_response
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": error_response
}
tool_responses.append(tool_response)
else:
observation = self._convert_tool_response_to_str(tool_invoke_message)
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": observation
}
tool_responses.append(tool_response)
tool_invoke_response, message_files = ToolEngine.agent_invoke(
tool=tool_instance,
tool_parameters=tool_call_args,
user_id=self.user_id,
tenant_id=self.tenant_id,
message=self.message,
invoke_from=self.application_generate_entity.invoke_from,
agent_tool_callback=self.agent_callback,
)
# publish files
for message_file, save_as in message_files:
if save_as:
self.variables_pool.set_file(tool_name=tool_call_name, value=message_file.id, name=save_as)
# publish message file
self.queue_manager.publish(QueueMessageFileEvent(
message_file_id=message_file.id
), PublishFrom.APPLICATION_MANAGER)
# add message file ids
message_file_ids.append(message_file.id)
observation = tool_invoke_response
tool_response = {
"tool_call_id": tool_call_id,
"tool_call_name": tool_call_name,
"tool_response": observation
}
tool_responses.append(tool_response)
prompt_messages = self.organize_prompt_messages(
prompt_template=prompt_template,
query=None,
@@ -341,7 +300,10 @@ class FunctionCallAgentRunner(BaseAgentRunner):
tool_name=None,
tool_input=None,
thought=None,
observation=tool_response['tool_response'],
observation=json.dumps({
tool_response['tool_call_name']: tool_response['tool_response']
for tool_response in tool_responses
}),
answer=None,
messages_ids=message_file_ids
)

View File

@@ -4,7 +4,6 @@ from typing import Any, Optional, Union
from pydantic import BaseModel
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.tools.entities.tool_entities import (
ToolDescription,
ToolIdentity,
@@ -22,8 +21,6 @@ class Tool(BaseModel, ABC):
parameters: Optional[list[ToolParameter]] = None
description: ToolDescription = None
is_team_authorization: bool = False
agent_callback: Optional[DifyAgentCallbackHandler] = None
use_callback: bool = False
class Runtime(BaseModel):
"""
@@ -45,15 +42,10 @@ class Tool(BaseModel, ABC):
def __init__(self, **data: Any):
super().__init__(**data)
if not self.agent_callback:
self.use_callback = False
else:
self.use_callback = True
class VARIABLE_KEY(Enum):
IMAGE = 'image'
def fork_tool_runtime(self, meta: dict[str, Any], agent_callback: DifyAgentCallbackHandler = None) -> 'Tool':
def fork_tool_runtime(self, meta: dict[str, Any]) -> 'Tool':
"""
fork a new tool with meta data
@@ -65,7 +57,6 @@ class Tool(BaseModel, ABC):
parameters=self.parameters.copy() if self.parameters else None,
description=self.description.copy() if self.description else None,
runtime=Tool.Runtime(**meta),
agent_callback=agent_callback
)
def load_variables(self, variables: ToolRuntimeVariablePool):
@@ -174,50 +165,19 @@ class Tool(BaseModel, ABC):
return result
def invoke(self, user_id: str, tool_parameters: Union[dict[str, Any], str]) -> list[ToolInvokeMessage]:
# check if tool_parameters is a string
if isinstance(tool_parameters, str):
# check if this tool has only one parameter
parameters = [parameter for parameter in self.parameters if parameter.form == ToolParameter.ToolParameterForm.LLM]
if parameters and len(parameters) == 1:
tool_parameters = {
parameters[0].name: tool_parameters
}
else:
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
def invoke(self, user_id: str, tool_parameters: dict[str, Any]) -> list[ToolInvokeMessage]:
# update tool_parameters
if self.runtime.runtime_parameters:
tool_parameters.update(self.runtime.runtime_parameters)
# hit callback
if self.use_callback:
self.agent_callback.on_tool_start(
tool_name=self.identity.name,
tool_inputs=tool_parameters
)
try:
result = self._invoke(
user_id=user_id,
tool_parameters=tool_parameters,
)
except Exception as e:
if self.use_callback:
self.agent_callback.on_tool_error(e)
raise e
result = self._invoke(
user_id=user_id,
tool_parameters=tool_parameters,
)
if not isinstance(result, list):
result = [result]
# hit callback
if self.use_callback:
self.agent_callback.on_tool_end(
tool_name=self.identity.name,
tool_inputs=tool_parameters,
tool_outputs=self._convert_tool_response_to_str(result)
)
return result
def _convert_tool_response_to_str(self, tool_response: list[ToolInvokeMessage]) -> str:

View File

@@ -0,0 +1,218 @@
from typing import Union
from core.app.entities.app_invoke_entities import InvokeFrom
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.file.file_obj import FileTransferMethod
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMessageBinary, ToolParameter
from core.tools.errors import (
ToolInvokeError,
ToolNotFoundError,
ToolNotSupportedError,
ToolParameterValidationError,
ToolProviderCredentialValidationError,
ToolProviderNotFoundError,
)
from core.tools.tool.tool import Tool
from core.tools.utils.message_transformer import ToolFileMessageTransformer
from extensions.ext_database import db
from models.model import Message, MessageFile
class ToolEngine:
"""
Tool runtime engine take care of the tool executions.
"""
@staticmethod
def agent_invoke(tool: Tool, tool_parameters: Union[str, dict],
user_id: str, tenant_id: str, message: Message, invoke_from: InvokeFrom,
agent_tool_callback: DifyAgentCallbackHandler) \
-> tuple[str, list[tuple[MessageFile, bool]]]:
"""
Agent invokes the tool with the given arguments.
"""
# check if arguments is a string
if isinstance(tool_parameters, str):
# check if this tool has only one parameter
parameters = [
parameter for parameter in tool.parameters
if parameter.form == ToolParameter.ToolParameterForm.LLM
]
if parameters and len(parameters) == 1:
tool_parameters = {
parameters[0].name: tool_parameters
}
else:
raise ValueError(f"tool_parameters should be a dict, but got a string: {tool_parameters}")
# invoke the tool
try:
# hit the callback handler
agent_tool_callback.on_tool_start(
tool_name=tool.identity.name,
tool_inputs=tool_parameters
)
response = tool.invoke(user_id, tool_parameters)
response = ToolFileMessageTransformer.transform_tool_invoke_messages(
messages=response,
user_id=user_id,
tenant_id=tenant_id,
conversation_id=message.conversation_id
)
# extract binary data from tool invoke message
binary_files = ToolEngine._extract_tool_response_binary(response)
# create message file
message_files = ToolEngine._create_message_files(
tool_messages=binary_files,
agent_message=message,
invoke_from=invoke_from,
user_id=user_id
)
plain_text = ToolEngine._convert_tool_response_to_str(response)
# hit the callback handler
agent_tool_callback.on_tool_end(
tool_name=tool.identity.name,
tool_inputs=tool_parameters,
tool_outputs=plain_text
)
# transform tool invoke message to get LLM friendly message
return plain_text, message_files
except ToolProviderCredentialValidationError as e:
error_response = "Please check your tool provider credentials"
agent_tool_callback.on_tool_error(e)
except (
ToolNotFoundError, ToolNotSupportedError, ToolProviderNotFoundError
) as e:
error_response = f"there is not a tool named {tool.identity.name}"
agent_tool_callback.on_tool_error(e)
except (
ToolParameterValidationError
) as e:
error_response = f"tool parameters validation error: {e}, please check your tool parameters"
agent_tool_callback.on_tool_error(e)
except ToolInvokeError as e:
error_response = f"tool invoke error: {e}"
agent_tool_callback.on_tool_error(e)
except Exception as e:
error_response = f"unknown error: {e}"
agent_tool_callback.on_tool_error(e)
return error_response, []
@staticmethod
def workflow_invoke(tool: Tool, tool_parameters: dict,
user_id: str, workflow_id: str) -> dict:
"""
Workflow invokes the tool with the given arguments.
"""
@staticmethod
def _convert_tool_response_to_str(tool_response: list[ToolInvokeMessage]) -> str:
"""
Handle tool response
"""
result = ''
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.TEXT:
result += response.message
elif response.type == ToolInvokeMessage.MessageType.LINK:
result += f"result link: {response.message}. please tell user to check it."
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
result += "image has been created and sent to user already, you do not need to create it, just tell the user to check it now."
else:
result += f"tool response: {response.message}."
return result
@staticmethod
def _extract_tool_response_binary(tool_response: list[ToolInvokeMessage]) -> list[ToolInvokeMessageBinary]:
"""
Extract tool response binary
"""
result = []
for response in tool_response:
if response.type == ToolInvokeMessage.MessageType.IMAGE_LINK or \
response.type == ToolInvokeMessage.MessageType.IMAGE:
result.append(ToolInvokeMessageBinary(
mimetype=response.meta.get('mime_type', 'octet/stream'),
url=response.message,
save_as=response.save_as,
))
elif response.type == ToolInvokeMessage.MessageType.BLOB:
result.append(ToolInvokeMessageBinary(
mimetype=response.meta.get('mime_type', 'octet/stream'),
url=response.message,
save_as=response.save_as,
))
elif response.type == ToolInvokeMessage.MessageType.LINK:
# check if there is a mime type in meta
if response.meta and 'mime_type' in response.meta:
result.append(ToolInvokeMessageBinary(
mimetype=response.meta.get('mime_type', 'octet/stream') if response.meta else 'octet/stream',
url=response.message,
save_as=response.save_as,
))
return result
@staticmethod
def _create_message_files(
tool_messages: list[ToolInvokeMessageBinary],
agent_message: Message,
invoke_from: InvokeFrom,
user_id: str
) -> list[tuple[MessageFile, bool]]:
"""
Create message file
:param messages: messages
:return: message files, should save as variable
"""
result = []
for message in tool_messages:
file_type = 'bin'
if 'image' in message.mimetype:
file_type = 'image'
elif 'video' in message.mimetype:
file_type = 'video'
elif 'audio' in message.mimetype:
file_type = 'audio'
elif 'text' in message.mimetype:
file_type = 'text'
elif 'pdf' in message.mimetype:
file_type = 'pdf'
elif 'zip' in message.mimetype:
file_type = 'archive'
# ...
message_file = MessageFile(
message_id=agent_message.id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE.value,
belongs_to='assistant',
url=message.url,
upload_file_id=None,
created_by_role=('account'if invoke_from in [InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER] else 'end_user'),
created_by=user_id,
)
db.session.add(message_file)
db.session.commit()
db.session.refresh(message_file)
result.append((
message_file,
message.save_as
))
db.session.close()
return result

View File

@@ -5,7 +5,6 @@ from os import listdir, path
from typing import Any, Union
from core.agent.entities import AgentToolEntity
from core.callback_handler.agent_tool_callback_handler import DifyAgentCallbackHandler
from core.model_runtime.entities.message_entities import PromptMessage
from core.provider_manager import ProviderManager
from core.tools.entities.common_entities import I18nObject
@@ -139,8 +138,7 @@ class ToolManager:
raise ToolProviderNotFoundError(f'provider type {provider_type} not found')
@staticmethod
def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str,
agent_callback: DifyAgentCallbackHandler = None) \
def get_tool_runtime(provider_type: str, provider_name: str, tool_name: str, tenant_id: str) \
-> Union[BuiltinTool, ApiTool]:
"""
get the tool runtime
@@ -160,7 +158,7 @@ class ToolManager:
return builtin_tool.fork_tool_runtime(meta={
'tenant_id': tenant_id,
'credentials': {},
}, agent_callback=agent_callback)
})
# get credentials
builtin_provider: BuiltinToolProvider = db.session.query(BuiltinToolProvider).filter(
@@ -182,7 +180,7 @@ class ToolManager:
'tenant_id': tenant_id,
'credentials': decrypted_credentials,
'runtime_parameters': {}
}, agent_callback=agent_callback)
})
elif provider_type == 'api':
if tenant_id is None:
@@ -259,14 +257,13 @@ class ToolManager:
return parameter_value
@staticmethod
def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity, agent_callback: DifyAgentCallbackHandler) -> Tool:
def get_agent_tool_runtime(tenant_id: str, agent_tool: AgentToolEntity) -> Tool:
"""
get the agent tool runtime
"""
tool_entity = ToolManager.get_tool_runtime(
provider_type=agent_tool.provider_type, provider_name=agent_tool.provider_id, tool_name=agent_tool.tool_name,
tenant_id=tenant_id,
agent_callback=agent_callback
)
runtime_parameters = {}
parameters = tool_entity.get_all_runtime_parameters()
@@ -289,7 +286,7 @@ class ToolManager:
return tool_entity
@staticmethod
def get_workflow_tool_runtime(tenant_id: str, workflow_tool: ToolEntity, agent_callback: DifyAgentCallbackHandler):
def get_workflow_tool_runtime(tenant_id: str, workflow_tool: ToolEntity):
"""
get the workflow tool runtime
"""
@@ -298,7 +295,6 @@ class ToolManager:
provider_name=workflow_tool.provider_id,
tool_name=workflow_tool.tool_name,
tenant_id=tenant_id,
agent_callback=agent_callback
)
runtime_parameters = {}
parameters = tool_entity.get_all_runtime_parameters()
@@ -364,12 +360,16 @@ class ToolManager:
continue
# init provider
provider_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'provider', 'builtin', provider, f'{provider}.py'),
parent_type=BuiltinToolProviderController)
builtin_providers.append(provider_class())
try:
provider_class = load_single_subclass_from_source(
module_name=f'core.tools.provider.builtin.{provider}.{provider}',
script_path=path.join(path.dirname(path.realpath(__file__)),
'provider', 'builtin', provider, f'{provider}.py'),
parent_type=BuiltinToolProviderController)
builtin_providers.append(provider_class())
except Exception as e:
logger.error(f'load builtin provider {provider} error: {e}')
continue
# cache the builtin providers
for provider in builtin_providers: