diff --git a/api/.importlinter b/api/.importlinter index c30007aafb..37dbfb15ec 100644 --- a/api/.importlinter +++ b/api/.importlinter @@ -106,7 +106,6 @@ ignore_imports = core.workflow.nodes.agent.agent_node -> core.provider_manager core.workflow.nodes.agent.agent_node -> core.tools.tool_manager core.workflow.nodes.document_extractor.node -> core.helper.ssrf_proxy - core.workflow.nodes.http_request.node -> core.tools.tool_file_manager core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory core.workflow.nodes.llm.llm_utils -> configs @@ -147,8 +146,6 @@ ignore_imports = core.workflow.nodes.code.code_node -> core.helper.code_executor.javascript.javascript_code_provider core.workflow.nodes.code.code_node -> core.helper.code_executor.python3.python3_code_provider core.workflow.nodes.code.entities -> core.helper.code_executor.code_executor - core.workflow.nodes.http_request.executor -> core.helper.ssrf_proxy - core.workflow.nodes.http_request.node -> core.helper.ssrf_proxy core.workflow.nodes.llm.file_saver -> core.helper.ssrf_proxy core.workflow.nodes.llm.node -> core.helper.code_executor core.workflow.nodes.template_transform.template_renderer -> core.helper.code_executor.code_executor diff --git a/api/core/workflow/nodes/http_request/executor.py b/api/core/workflow/nodes/http_request/executor.py index cd6007e720..de14c8c517 100644 --- a/api/core/workflow/nodes/http_request/executor.py +++ b/api/core/workflow/nodes/http_request/executor.py @@ -10,9 +10,7 @@ from urllib.parse import urlencode, urlparse import httpx from json_repair import repair_json -from core.helper.ssrf_proxy import ssrf_proxy from core.workflow.file.enums import FileTransferMethod -from core.workflow.file.file_manager import file_manager as default_file_manager from core.workflow.runtime import VariablePool from core.workflow.variables.segments import ArrayFileSegment, FileSegment @@ -81,8 +79,8 @@ class Executor: http_request_config: HttpRequestNodeConfig, max_retries: int | None = None, ssl_verify: bool | None = None, - http_client: HttpClientProtocol | None = None, - file_manager: FileManagerProtocol | None = None, + http_client: HttpClientProtocol, + file_manager: FileManagerProtocol, ): self._http_request_config = http_request_config # If authorization API key is present, convert the API key using the variable pool @@ -116,8 +114,8 @@ class Executor: self.max_retries = ( max_retries if max_retries is not None else self._http_request_config.ssrf_default_max_retries ) - self._http_client = http_client or ssrf_proxy - self._file_manager = file_manager or default_file_manager + self._http_client = http_client + self._file_manager = file_manager # init template self.variable_pool = variable_pool diff --git a/api/core/workflow/nodes/http_request/node.py b/api/core/workflow/nodes/http_request/node.py index 89eebb181c..11458db758 100644 --- a/api/core/workflow/nodes/http_request/node.py +++ b/api/core/workflow/nodes/http_request/node.py @@ -3,17 +3,14 @@ import mimetypes from collections.abc import Callable, Mapping, Sequence from typing import TYPE_CHECKING, Any -from core.helper.ssrf_proxy import ssrf_proxy -from core.tools.tool_file_manager import ToolFileManager from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus from core.workflow.file import File, FileTransferMethod -from core.workflow.file.file_manager import file_manager as default_file_manager from core.workflow.node_events import NodeRunResult from core.workflow.nodes.base import variable_template_parser from core.workflow.nodes.base.entities import VariableSelector from core.workflow.nodes.base.node import Node from core.workflow.nodes.http_request.executor import Executor -from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol +from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol, ToolFileManagerProtocol from core.workflow.variables.segments import ArrayFileSegment from factories import file_factory @@ -45,9 +42,9 @@ class HttpRequestNode(Node[HttpRequestNodeData]): graph_runtime_state: "GraphRuntimeState", *, http_request_config: HttpRequestNodeConfig, - http_client: HttpClientProtocol | None = None, - tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager, - file_manager: FileManagerProtocol | None = None, + http_client: HttpClientProtocol, + tool_file_manager_factory: Callable[[], ToolFileManagerProtocol], + file_manager: FileManagerProtocol, ) -> None: super().__init__( id=id, @@ -55,10 +52,11 @@ class HttpRequestNode(Node[HttpRequestNodeData]): graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, ) + self._http_request_config = http_request_config - self._http_client = http_client or ssrf_proxy + self._http_client = http_client self._tool_file_manager_factory = tool_file_manager_factory - self._file_manager = file_manager or default_file_manager + self._file_manager = file_manager @classmethod def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]: diff --git a/api/core/workflow/nodes/protocols.py b/api/core/workflow/nodes/protocols.py index a1f3e20835..fda524d701 100644 --- a/api/core/workflow/nodes/protocols.py +++ b/api/core/workflow/nodes/protocols.py @@ -27,3 +27,16 @@ class HttpClientProtocol(Protocol): class FileManagerProtocol(Protocol): def download(self, f: File, /) -> bytes: ... + + +class ToolFileManagerProtocol(Protocol): + def create_file_by_raw( + self, + *, + user_id: str, + tenant_id: str, + conversation_id: str | None, + file_binary: bytes, + mimetype: str, + filename: str | None = None, + ) -> Any: ... diff --git a/api/tests/integration_tests/workflow/nodes/test_http.py b/api/tests/integration_tests/workflow/nodes/test_http.py index 0473d9832a..e0f2363799 100644 --- a/api/tests/integration_tests/workflow/nodes/test_http.py +++ b/api/tests/integration_tests/workflow/nodes/test_http.py @@ -7,8 +7,11 @@ import pytest from configs import dify_config from core.app.entities.app_invoke_entities import InvokeFrom from core.app.workflow.node_factory import DifyNodeFactory +from core.helper.ssrf_proxy import ssrf_proxy +from core.tools.tool_file_manager import ToolFileManager from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus +from core.workflow.file.file_manager import file_manager from core.workflow.graph import Graph from core.workflow.nodes.http_request import HttpRequestNode, HttpRequestNodeConfig from core.workflow.runtime import GraphRuntimeState, VariablePool @@ -76,6 +79,9 @@ def init_http_node(config: dict): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, + http_client=ssrf_proxy, + tool_file_manager_factory=ToolFileManager, + file_manager=file_manager, ) return node @@ -229,6 +235,8 @@ def test_custom_auth_with_empty_api_key_raises_error(setup_http_mock): timeout=HttpRequestNodeTimeout(connect=10, read=30, write=10), http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) @@ -716,6 +724,9 @@ def test_nested_object_variable_selector(setup_http_mock): graph_init_params=init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, + http_client=ssrf_proxy, + tool_file_manager_factory=ToolFileManager, + file_manager=file_manager, ) result = node._run() diff --git a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py index 186f8a8425..c4fc5ccc1f 100644 --- a/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py +++ b/api/tests/unit_tests/core/workflow/graph_engine/test_mock_factory.py @@ -123,6 +123,9 @@ class MockNodeFactory(DifyNodeFactory): graph_runtime_state=self.graph_runtime_state, mock_config=self.mock_config, http_request_config=self._http_request_config, + http_client=self._http_request_http_client, + tool_file_manager_factory=self._http_request_tool_file_manager_factory, + file_manager=self._http_request_file_manager, ) elif node_type in {NodeType.LLM, NodeType.QUESTION_CLASSIFIER, NodeType.PARAMETER_EXTRACTOR}: mock_instance = mock_class( diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py index 65f4de8c1d..67da890eb2 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_executor.py @@ -1,6 +1,8 @@ import pytest from configs import dify_config +from core.helper.ssrf_proxy import ssrf_proxy +from core.workflow.file.file_manager import file_manager from core.workflow.nodes.http_request import ( BodyData, HttpRequestNodeAuthorization, @@ -59,6 +61,8 @@ def test_executor_with_json_body_and_number_variable(): timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) # Check the executor's data @@ -113,6 +117,8 @@ def test_executor_with_json_body_and_object_variable(): timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) # Check the executor's data @@ -169,6 +175,8 @@ def test_executor_with_json_body_and_nested_object_variable(): timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) # Check the executor's data @@ -213,6 +221,8 @@ def test_extract_selectors_from_template_with_newline(): timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) assert executor.params == [("test", "line1\nline2")] @@ -258,6 +268,8 @@ def test_executor_with_form_data(): timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) # Check the executor's data @@ -309,6 +321,8 @@ def test_init_headers(): timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, variable_pool=VariablePool(system_variables=SystemVariable.default()), + http_client=ssrf_proxy, + file_manager=file_manager, ) executor = create_executor("aa\n cc:") @@ -344,6 +358,8 @@ def test_init_params(): timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, variable_pool=VariablePool(system_variables=SystemVariable.default()), + http_client=ssrf_proxy, + file_manager=file_manager, ) # Test basic key-value pairs @@ -394,6 +410,8 @@ def test_empty_api_key_raises_error_bearer(): timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) @@ -419,6 +437,8 @@ def test_empty_api_key_raises_error_basic(): timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) @@ -444,6 +464,8 @@ def test_empty_api_key_raises_error_custom(): timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) @@ -469,6 +491,8 @@ def test_whitespace_only_api_key_raises_error(): timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) @@ -493,6 +517,8 @@ def test_valid_api_key_works(): timeout=timeout, http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) # Should not raise an error @@ -541,6 +567,8 @@ def test_executor_with_json_body_and_unquoted_uuid_variable(): timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) # The UUID should be preserved in full, not truncated @@ -586,6 +614,8 @@ def test_executor_with_json_body_and_unquoted_uuid_with_newlines(): timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) # The UUID should be preserved in full @@ -625,6 +655,8 @@ def test_executor_with_json_body_preserves_numbers_and_strings(): timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30), http_request_config=HTTP_REQUEST_CONFIG, variable_pool=variable_pool, + http_client=ssrf_proxy, + file_manager=file_manager, ) assert executor.json["count"] == 42 diff --git a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py index 472718188f..cad0466809 100644 --- a/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py +++ b/api/tests/unit_tests/core/workflow/nodes/http_request/test_http_request_node.py @@ -5,8 +5,11 @@ import httpx import pytest from core.app.entities.app_invoke_entities import InvokeFrom +from core.helper.ssrf_proxy import ssrf_proxy +from core.tools.tool_file_manager import ToolFileManager from core.workflow.entities import GraphInitParams from core.workflow.enums import WorkflowNodeExecutionStatus +from core.workflow.file.file_manager import file_manager from core.workflow.nodes.http_request import HTTP_REQUEST_CONFIG_FILTER_KEY, HttpRequestNode, HttpRequestNodeConfig from core.workflow.nodes.http_request.entities import HttpRequestNodeTimeout, Response from core.workflow.runtime import GraphRuntimeState, VariablePool @@ -116,6 +119,9 @@ def _build_http_node( graph_init_params=graph_init_params, graph_runtime_state=graph_runtime_state, http_request_config=HTTP_REQUEST_CONFIG, + http_client=ssrf_proxy, + tool_file_manager_factory=ToolFileManager, + file_manager=file_manager, )