mirror of
https://github.com/langgenius/dify.git
synced 2025-12-21 15:02:26 +00:00
Compare commits
1 Commits
test/log-r
...
fix/user-c
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
62369a9ee8 |
@@ -5,6 +5,7 @@ from typing import Any
|
|||||||
|
|
||||||
from sqlalchemy import select
|
from sqlalchemy import select
|
||||||
|
|
||||||
|
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||||
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
from core.file import FILE_MODEL_IDENTITY, File, FileTransferMethod
|
||||||
from core.tools.__base.tool import Tool
|
from core.tools.__base.tool import Tool
|
||||||
from core.tools.__base.tool_runtime import ToolRuntime
|
from core.tools.__base.tool_runtime import ToolRuntime
|
||||||
@@ -18,7 +19,8 @@ from core.tools.errors import ToolInvokeError
|
|||||||
from extensions.ext_database import db
|
from extensions.ext_database import db
|
||||||
from factories.file_factory import build_from_mapping
|
from factories.file_factory import build_from_mapping
|
||||||
from libs.login import current_user
|
from libs.login import current_user
|
||||||
from models.model import App
|
from models.account import Account
|
||||||
|
from models.model import App, EndUser
|
||||||
from models.workflow import Workflow
|
from models.workflow import Workflow
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
@@ -79,11 +81,13 @@ class WorkflowTool(Tool):
|
|||||||
generator = WorkflowAppGenerator()
|
generator = WorkflowAppGenerator()
|
||||||
assert self.runtime is not None
|
assert self.runtime is not None
|
||||||
assert self.runtime.invoke_from is not None
|
assert self.runtime.invoke_from is not None
|
||||||
assert current_user is not None
|
user = self._resolve_user(user_id)
|
||||||
|
if user is None:
|
||||||
|
raise ToolInvokeError("workflow tool invoke missing user context")
|
||||||
result = generator.generate(
|
result = generator.generate(
|
||||||
app_model=app,
|
app_model=app,
|
||||||
workflow=workflow,
|
workflow=workflow,
|
||||||
user=current_user,
|
user=user,
|
||||||
args={"inputs": tool_parameters, "files": files},
|
args={"inputs": tool_parameters, "files": files},
|
||||||
invoke_from=self.runtime.invoke_from,
|
invoke_from=self.runtime.invoke_from,
|
||||||
streaming=False,
|
streaming=False,
|
||||||
@@ -227,3 +231,26 @@ class WorkflowTool(Tool):
|
|||||||
elif transfer_method == FileTransferMethod.LOCAL_FILE:
|
elif transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||||
file_dict["upload_file_id"] = file_dict.get("related_id")
|
file_dict["upload_file_id"] = file_dict.get("related_id")
|
||||||
return file_dict
|
return file_dict
|
||||||
|
|
||||||
|
def _resolve_user(self, user_id: str) -> Account | EndUser | None:
|
||||||
|
runtime = self.runtime
|
||||||
|
try:
|
||||||
|
user_candidate = current_user
|
||||||
|
except RuntimeError:
|
||||||
|
user_candidate = None
|
||||||
|
|
||||||
|
if user_candidate is not None and getattr(user_candidate, "is_authenticated", False):
|
||||||
|
return user_candidate
|
||||||
|
|
||||||
|
if not user_id or runtime is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
invoke_from = runtime.invoke_from
|
||||||
|
if invoke_from in {InvokeFrom.SERVICE_API, InvokeFrom.WEB_APP, InvokeFrom.PUBLISHED}:
|
||||||
|
end_user = (
|
||||||
|
db.session.query(EndUser).where(EndUser.id == user_id, EndUser.tenant_id == runtime.tenant_id).first()
|
||||||
|
)
|
||||||
|
if end_user:
|
||||||
|
return end_user
|
||||||
|
|
||||||
|
return db.session.query(Account).where(Account.id == user_id).first()
|
||||||
|
|||||||
@@ -40,9 +40,64 @@ def test_workflow_tool_should_raise_tool_invoke_error_when_result_has_error_fiel
|
|||||||
lambda *args, **kwargs: {"data": {"error": "oops"}},
|
lambda *args, **kwargs: {"data": {"error": "oops"}},
|
||||||
)
|
)
|
||||||
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
|
monkeypatch.setattr("libs.login.current_user", lambda *args, **kwargs: None)
|
||||||
|
monkeypatch.setattr(
|
||||||
|
WorkflowTool,
|
||||||
|
"_resolve_user",
|
||||||
|
lambda self, _user_id: type("DummyUser", (), {"id": _user_id, "is_authenticated": True})(),
|
||||||
|
raising=False,
|
||||||
|
)
|
||||||
|
|
||||||
with pytest.raises(ToolInvokeError) as exc_info:
|
with pytest.raises(ToolInvokeError) as exc_info:
|
||||||
# WorkflowTool always returns a generator, so we need to iterate to
|
# WorkflowTool always returns a generator, so we need to iterate to
|
||||||
# actually `run` the tool.
|
# actually `run` the tool.
|
||||||
list(tool.invoke("test_user", {}))
|
list(tool.invoke("test_user", {}))
|
||||||
assert exc_info.value.args == ("oops",)
|
assert exc_info.value.args == ("oops",)
|
||||||
|
|
||||||
|
|
||||||
|
def test_workflow_tool_falls_back_to_user_resolver_when_no_current_user(monkeypatch: pytest.MonkeyPatch):
|
||||||
|
entity = ToolEntity(
|
||||||
|
identity=ToolIdentity(author="tester", name="work", label=I18nObject(en_US="work"), provider="prv"),
|
||||||
|
parameters=[],
|
||||||
|
description=None,
|
||||||
|
has_runtime_parameters=False,
|
||||||
|
)
|
||||||
|
runtime = ToolRuntime(tenant_id="tenant-id", invoke_from=InvokeFrom.SERVICE_API)
|
||||||
|
tool = WorkflowTool(
|
||||||
|
workflow_app_id="app-id",
|
||||||
|
workflow_as_tool_id="tool-id",
|
||||||
|
version="1",
|
||||||
|
workflow_entities={},
|
||||||
|
workflow_call_depth=0,
|
||||||
|
entity=entity,
|
||||||
|
runtime=runtime,
|
||||||
|
)
|
||||||
|
|
||||||
|
# keep tool internals simple for the test
|
||||||
|
monkeypatch.setattr(tool, "_get_app", lambda *_args, **_kwargs: object())
|
||||||
|
monkeypatch.setattr(tool, "_get_workflow", lambda *_args, **_kwargs: object())
|
||||||
|
monkeypatch.setattr(tool, "_transform_args", lambda tool_parameters, **_: (tool_parameters, []))
|
||||||
|
|
||||||
|
captured: dict[str, str] = {}
|
||||||
|
|
||||||
|
class DummyUser:
|
||||||
|
id = "dummy-user"
|
||||||
|
is_authenticated = True
|
||||||
|
|
||||||
|
dummy_user = DummyUser()
|
||||||
|
|
||||||
|
def fake_resolver(self, user_id: str):
|
||||||
|
captured["user_id"] = user_id
|
||||||
|
return dummy_user
|
||||||
|
|
||||||
|
def fake_generate(self, *, user, **_kwargs):
|
||||||
|
assert user is dummy_user
|
||||||
|
return {"data": {"outputs": {}}}
|
||||||
|
|
||||||
|
monkeypatch.setattr("core.app.apps.workflow.app_generator.WorkflowAppGenerator.generate", fake_generate)
|
||||||
|
monkeypatch.setattr("core.tools.workflow_as_tool.tool.current_user", None)
|
||||||
|
monkeypatch.setattr(WorkflowTool, "_resolve_user", fake_resolver, raising=False)
|
||||||
|
|
||||||
|
result = list(tool.invoke("user-123", {}))
|
||||||
|
|
||||||
|
assert captured["user_id"] == "user-123"
|
||||||
|
assert len(result) == 2 # text + json outputs
|
||||||
|
|||||||
Reference in New Issue
Block a user