mirror of
https://github.com/langgenius/dify.git
synced 2026-03-17 21:37:03 +00:00
Compare commits
12 Commits
3-18-dev-w
...
yanli/phas
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
2f81d5dfdf | ||
|
|
7639d8e43f | ||
|
|
1dce81c604 | ||
|
|
f874ca183e | ||
|
|
0d805e624e | ||
|
|
61196180b8 | ||
|
|
79433b0091 | ||
|
|
c4aeaa35d4 | ||
|
|
9f0d79b8b0 | ||
|
|
a717519822 | ||
|
|
a592c53573 | ||
|
|
239e09473e |
@@ -78,7 +78,7 @@ class UserProfile(TypedDict):
|
||||
nickname: NotRequired[str]
|
||||
```
|
||||
|
||||
- For classes, declare member variables at the top of the class body (before `__init__`) so the class shape is obvious at a glance:
|
||||
- For classes, declare all member variables explicitly with types at the top of the class body (before `__init__`), even when the class is not a dataclass or Pydantic model, so the class shape is obvious at a glance:
|
||||
|
||||
```python
|
||||
from datetime import datetime
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Literal, Protocol
|
||||
from typing import Literal, Protocol, cast
|
||||
from urllib.parse import quote_plus, urlunparse
|
||||
|
||||
from pydantic import AliasChoices, Field
|
||||
@@ -12,16 +12,13 @@ class RedisConfigDefaults(Protocol):
|
||||
REDIS_PASSWORD: str | None
|
||||
REDIS_DB: int
|
||||
REDIS_USE_SSL: bool
|
||||
REDIS_USE_SENTINEL: bool | None
|
||||
REDIS_USE_CLUSTERS: bool
|
||||
|
||||
|
||||
class RedisConfigDefaultsMixin:
|
||||
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
|
||||
return self
|
||||
def _redis_defaults(config: object) -> RedisConfigDefaults:
|
||||
return cast(RedisConfigDefaults, config)
|
||||
|
||||
|
||||
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||
class RedisPubSubConfig(BaseSettings):
|
||||
"""
|
||||
Configuration settings for event transport between API and workers.
|
||||
|
||||
@@ -74,7 +71,7 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||
)
|
||||
|
||||
def _build_default_pubsub_url(self) -> str:
|
||||
defaults = self._redis_defaults()
|
||||
defaults = _redis_defaults(self)
|
||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
|
||||
|
||||
@@ -91,11 +88,9 @@ class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||
if userinfo:
|
||||
userinfo = f"{userinfo}@"
|
||||
|
||||
host = defaults.REDIS_HOST
|
||||
port = defaults.REDIS_PORT
|
||||
db = defaults.REDIS_DB
|
||||
|
||||
netloc = f"{userinfo}{host}:{port}"
|
||||
netloc = f"{userinfo}{defaults.REDIS_HOST}:{defaults.REDIS_PORT}"
|
||||
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
|
||||
|
||||
@property
|
||||
|
||||
@@ -5,7 +5,7 @@ import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -47,7 +47,6 @@ from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.base import Base
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.conversation_service import ConversationService
|
||||
from services.workflow_draft_variable_service import (
|
||||
@@ -524,6 +523,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
workflow = _refresh_model(session, workflow)
|
||||
message = _refresh_model(session, message)
|
||||
assert message is not None
|
||||
# workflow_ = session.get(Workflow, workflow.id)
|
||||
# assert workflow_ is not None
|
||||
# workflow = workflow_
|
||||
@@ -690,11 +690,20 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
raise e
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=Base)
|
||||
@overload
|
||||
def _refresh_model(session: Session, model: Workflow) -> Workflow: ...
|
||||
|
||||
|
||||
def _refresh_model(session, model: _T) -> _T:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
detach_model = session.get(type(model), model.id)
|
||||
assert detach_model is not None
|
||||
return detach_model
|
||||
@overload
|
||||
def _refresh_model(session: Session, model: Message) -> Message: ...
|
||||
|
||||
|
||||
def _refresh_model(session: Session, model: Workflow | Message) -> Workflow | Message:
|
||||
if isinstance(model, Workflow):
|
||||
detached_workflow = session.get(Workflow, model.id)
|
||||
assert detached_workflow is not None
|
||||
return detached_workflow
|
||||
|
||||
detached_message = session.get(Message, model.id)
|
||||
assert detached_message is not None
|
||||
return detached_message
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import Any, cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
@@ -56,7 +56,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -87,7 +87,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, Any, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
@@ -55,7 +55,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -86,7 +86,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union
|
||||
from collections.abc import Generator, Iterator, Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
|
||||
@@ -16,24 +16,26 @@ class AppGenerateResponseConverter(ABC):
|
||||
|
||||
@classmethod
|
||||
def convert(
|
||||
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
cls, response: AppBlockingResponse | Iterator[AppStreamResponse], invoke_from: InvokeFrom
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
|
||||
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_full_response(response)
|
||||
else:
|
||||
stream_response = response
|
||||
|
||||
def _generate_full_response() -> Generator[dict | str, Any, None]:
|
||||
yield from cls.convert_stream_full_response(response)
|
||||
def _generate_full_response() -> Generator[dict[str, Any] | str, None, None]:
|
||||
yield from cls.convert_stream_full_response(stream_response)
|
||||
|
||||
return _generate_full_response()
|
||||
else:
|
||||
if isinstance(response, AppBlockingResponse):
|
||||
return cls.convert_blocking_simple_response(response)
|
||||
else:
|
||||
stream_response = response
|
||||
|
||||
def _generate_simple_response() -> Generator[dict | str, Any, None]:
|
||||
yield from cls.convert_stream_simple_response(response)
|
||||
def _generate_simple_response() -> Generator[dict[str, Any] | str, None, None]:
|
||||
yield from cls.convert_stream_simple_response(stream_response)
|
||||
|
||||
return _generate_simple_response()
|
||||
|
||||
@@ -50,14 +52,14 @@ class AppGenerateResponseConverter(ABC):
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@@ -224,6 +224,7 @@ class BaseAppGenerator:
|
||||
def _get_draft_var_saver_factory(invoke_from: InvokeFrom, account: Account | EndUser) -> DraftVariableSaverFactory:
|
||||
if invoke_from == InvokeFrom.DEBUGGER:
|
||||
assert isinstance(account, Account)
|
||||
debug_account = account
|
||||
|
||||
def draft_var_saver_factory(
|
||||
session: Session,
|
||||
@@ -240,7 +241,7 @@ class BaseAppGenerator:
|
||||
node_type=node_type,
|
||||
node_execution_id=node_execution_id,
|
||||
enclosing_node_id=enclosing_node_id,
|
||||
user=account,
|
||||
user=debug_account,
|
||||
)
|
||||
else:
|
||||
|
||||
|
||||
@@ -166,15 +166,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
assert conversation is not None
|
||||
assert message is not None
|
||||
generated_conversation_id = str(conversation.id)
|
||||
generated_message_id = str(message.id)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
conversation_id=generated_conversation_id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
message_id=generated_message_id,
|
||||
)
|
||||
|
||||
# new thread with request context
|
||||
@@ -184,8 +188,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
conversation_id=conversation.id,
|
||||
message_id=message.id,
|
||||
conversation_id=generated_conversation_id,
|
||||
message_id=generated_message_id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
@@ -55,7 +55,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -86,7 +86,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -149,6 +149,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
assert conversation is not None
|
||||
assert message is not None
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
@@ -312,15 +314,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity)
|
||||
assert conversation is not None
|
||||
assert message is not None
|
||||
conversation_id = str(conversation.id)
|
||||
message_id = str(message.id)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = MessageBasedAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
user_id=application_generate_entity.user_id,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
conversation_id=conversation.id,
|
||||
conversation_id=conversation_id,
|
||||
app_mode=conversation.mode,
|
||||
message_id=message.id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
# new thread with request context
|
||||
@@ -330,7 +336,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
|
||||
flask_app=current_app._get_current_object(), # type: ignore
|
||||
application_generate_entity=application_generate_entity,
|
||||
queue_manager=queue_manager,
|
||||
message_id=message.id,
|
||||
message_id=message_id,
|
||||
)
|
||||
|
||||
worker_thread = threading.Thread(target=worker_with_context)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
@@ -54,7 +54,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -84,7 +84,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
@@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Generator, Iterator
|
||||
from typing import cast
|
||||
|
||||
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
|
||||
@@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_full_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream full response.
|
||||
@@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
|
||||
@classmethod
|
||||
def convert_stream_simple_response(
|
||||
cls, stream_response: Generator[AppStreamResponse, None, None]
|
||||
cls, stream_response: Iterator[AppStreamResponse]
|
||||
) -> Generator[dict | str, None, None]:
|
||||
"""
|
||||
Convert stream simple response.
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Protocol, TypeAlias
|
||||
|
||||
from pydantic import ValidationError
|
||||
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.entities.agent_strategy import AgentStrategyInfo
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
InvokeFrom,
|
||||
UserFrom,
|
||||
build_dify_run_context,
|
||||
)
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentLogEvent,
|
||||
@@ -36,7 +40,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.entities import GraphInitParams
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
|
||||
from dify_graph.entities.pause_reason import HumanInputRequired
|
||||
from dify_graph.graph import Graph
|
||||
from dify_graph.graph_engine.layers.base import GraphEngineLayer
|
||||
@@ -75,6 +79,14 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GraphConfigObject: TypeAlias = dict[str, object]
|
||||
GraphConfigMapping: TypeAlias = Mapping[str, object]
|
||||
|
||||
|
||||
class SingleNodeRunEntity(Protocol):
|
||||
node_id: str
|
||||
inputs: Mapping[str, object]
|
||||
|
||||
|
||||
class WorkflowBasedAppRunner:
|
||||
def __init__(
|
||||
@@ -98,7 +110,7 @@ class WorkflowBasedAppRunner:
|
||||
|
||||
def _init_graph(
|
||||
self,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph_config: GraphConfigMapping,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
@@ -154,8 +166,8 @@ class WorkflowBasedAppRunner:
|
||||
def _prepare_single_node_execution(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
single_iteration_run: Any | None = None,
|
||||
single_loop_run: Any | None = None,
|
||||
single_iteration_run: SingleNodeRunEntity | None = None,
|
||||
single_loop_run: SingleNodeRunEntity | None = None,
|
||||
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
|
||||
"""
|
||||
Prepare graph, variable pool, and runtime state for single node execution
|
||||
@@ -208,11 +220,88 @@ class WorkflowBasedAppRunner:
|
||||
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
|
||||
return graph, variable_pool, graph_runtime_state
|
||||
|
||||
@staticmethod
|
||||
def _get_graph_items(graph_config: GraphConfigMapping) -> tuple[list[GraphConfigMapping], list[GraphConfigMapping]]:
|
||||
nodes = graph_config.get("nodes")
|
||||
edges = graph_config.get("edges")
|
||||
if not isinstance(nodes, list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
if not isinstance(edges, list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
validated_nodes: list[GraphConfigMapping] = []
|
||||
for node in nodes:
|
||||
if not isinstance(node, Mapping):
|
||||
raise ValueError("nodes in workflow graph must be mappings")
|
||||
validated_nodes.append(node)
|
||||
|
||||
validated_edges: list[GraphConfigMapping] = []
|
||||
for edge in edges:
|
||||
if not isinstance(edge, Mapping):
|
||||
raise ValueError("edges in workflow graph must be mappings")
|
||||
validated_edges.append(edge)
|
||||
|
||||
return validated_nodes, validated_edges
|
||||
|
||||
@staticmethod
|
||||
def _extract_start_node_id(node_config: GraphConfigMapping | None) -> str | None:
|
||||
if node_config is None:
|
||||
return None
|
||||
node_data = node_config.get("data")
|
||||
if not isinstance(node_data, Mapping):
|
||||
return None
|
||||
start_node_id = node_data.get("start_node_id")
|
||||
return start_node_id if isinstance(start_node_id, str) else None
|
||||
|
||||
@classmethod
|
||||
def _build_single_node_graph_config(
|
||||
cls,
|
||||
*,
|
||||
graph_config: GraphConfigMapping,
|
||||
node_id: str,
|
||||
node_type_filter_key: str,
|
||||
) -> tuple[GraphConfigObject, NodeConfigDict]:
|
||||
node_configs, edge_configs = cls._get_graph_items(graph_config)
|
||||
main_node_config = next((node for node in node_configs if node.get("id") == node_id), None)
|
||||
start_node_id = cls._extract_start_node_id(main_node_config)
|
||||
|
||||
filtered_node_configs = [
|
||||
dict(node)
|
||||
for node in node_configs
|
||||
if node.get("id") == node_id
|
||||
or (isinstance(node_data := node.get("data"), Mapping) and node_data.get(node_type_filter_key) == node_id)
|
||||
or (start_node_id and node.get("id") == start_node_id)
|
||||
]
|
||||
if not filtered_node_configs:
|
||||
raise ValueError(f"node id {node_id} not found in workflow graph")
|
||||
|
||||
filtered_node_ids = {
|
||||
str(node_id_value) for node in filtered_node_configs if isinstance((node_id_value := node.get("id")), str)
|
||||
}
|
||||
filtered_edge_configs = [
|
||||
dict(edge)
|
||||
for edge in edge_configs
|
||||
if (edge.get("source") is None or edge.get("source") in filtered_node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in filtered_node_ids)
|
||||
]
|
||||
|
||||
target_node_config = next((node for node in filtered_node_configs if node.get("id") == node_id), None)
|
||||
if target_node_config is None:
|
||||
raise ValueError(f"node id {node_id} not found in workflow graph")
|
||||
|
||||
return (
|
||||
{
|
||||
"nodes": filtered_node_configs,
|
||||
"edges": filtered_edge_configs,
|
||||
},
|
||||
NodeConfigDictAdapter.validate_python(target_node_config),
|
||||
)
|
||||
|
||||
def _get_graph_and_variable_pool_for_single_node_run(
|
||||
self,
|
||||
workflow: Workflow,
|
||||
node_id: str,
|
||||
user_inputs: dict[str, Any],
|
||||
user_inputs: Mapping[str, object],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
|
||||
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
|
||||
@@ -236,41 +325,14 @@ class WorkflowBasedAppRunner:
|
||||
if not graph_config:
|
||||
raise ValueError("workflow graph not found")
|
||||
|
||||
graph_config = cast(dict[str, Any], graph_config)
|
||||
|
||||
if "nodes" not in graph_config or "edges" not in graph_config:
|
||||
raise ValueError("nodes or edges not found in workflow graph")
|
||||
|
||||
if not isinstance(graph_config.get("nodes"), list):
|
||||
raise ValueError("nodes in workflow graph must be a list")
|
||||
|
||||
if not isinstance(graph_config.get("edges"), list):
|
||||
raise ValueError("edges in workflow graph must be a list")
|
||||
|
||||
# filter nodes only in the specified node type (iteration or loop)
|
||||
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
|
||||
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
|
||||
node_configs = [
|
||||
node
|
||||
for node in graph_config.get("nodes", [])
|
||||
if node.get("id") == node_id
|
||||
or node.get("data", {}).get(node_type_filter_key, "") == node_id
|
||||
or (start_node_id and node.get("id") == start_node_id)
|
||||
]
|
||||
|
||||
graph_config["nodes"] = node_configs
|
||||
|
||||
node_ids = [node.get("id") for node in node_configs]
|
||||
|
||||
# filter edges only in the specified node type
|
||||
edge_configs = [
|
||||
edge
|
||||
for edge in graph_config.get("edges", [])
|
||||
if (edge.get("source") is None or edge.get("source") in node_ids)
|
||||
and (edge.get("target") is None or edge.get("target") in node_ids)
|
||||
]
|
||||
|
||||
graph_config["edges"] = edge_configs
|
||||
graph_config, target_node_config = self._build_single_node_graph_config(
|
||||
graph_config=graph_config,
|
||||
node_id=node_id,
|
||||
node_type_filter_key=node_type_filter_key,
|
||||
)
|
||||
|
||||
# Create required parameters for Graph.init
|
||||
graph_init_params = GraphInitParams(
|
||||
@@ -299,18 +361,6 @@ class WorkflowBasedAppRunner:
|
||||
if not graph:
|
||||
raise ValueError("graph not found in workflow")
|
||||
|
||||
# fetch node config from node id
|
||||
target_node_config = None
|
||||
for node in node_configs:
|
||||
if node.get("id") == node_id:
|
||||
target_node_config = node
|
||||
break
|
||||
|
||||
if not target_node_config:
|
||||
raise ValueError(f"{node_type_label} node id not found in workflow graph")
|
||||
|
||||
target_node_config = NodeConfigDictAdapter.validate_python(target_node_config)
|
||||
|
||||
# Get node class
|
||||
node_type = target_node_config["data"].type
|
||||
node_version = str(target_node_config["data"].version)
|
||||
|
||||
@@ -213,7 +213,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: Mapping
|
||||
inputs: Mapping[str, object]
|
||||
|
||||
single_iteration_run: SingleIterationRunEntity | None = None
|
||||
|
||||
@@ -223,7 +223,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: Mapping
|
||||
inputs: Mapping[str, object]
|
||||
|
||||
single_loop_run: SingleLoopRunEntity | None = None
|
||||
|
||||
@@ -243,7 +243,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
inputs: Mapping[str, object]
|
||||
|
||||
single_iteration_run: SingleIterationRunEntity | None = None
|
||||
|
||||
@@ -253,7 +253,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
inputs: dict
|
||||
inputs: Mapping[str, object]
|
||||
|
||||
single_loop_run: SingleLoopRunEntity | None = None
|
||||
|
||||
|
||||
@@ -473,9 +473,21 @@ class ProviderConfiguration(BaseModel):
|
||||
|
||||
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
|
||||
else:
|
||||
# some historical data may have a provider record but not be set as valid
|
||||
provider_record.is_valid = True
|
||||
|
||||
if provider_record.credential_id is None:
|
||||
provider_record.credential_id = new_record.id
|
||||
provider_record.updated_at = naive_utc_now()
|
||||
|
||||
provider_model_credentials_cache = ProviderCredentialsCache(
|
||||
tenant_id=self.tenant_id,
|
||||
identity_id=provider_record.id,
|
||||
cache_type=ProviderCredentialsCacheType.PROVIDER,
|
||||
)
|
||||
provider_model_credentials_cache.delete()
|
||||
|
||||
self.switch_preferred_provider_type(provider_type=ProviderType.CUSTOM, session=session)
|
||||
|
||||
session.commit()
|
||||
except Exception:
|
||||
session.rollback()
|
||||
|
||||
@@ -196,6 +196,8 @@ class ProviderManager:
|
||||
|
||||
if preferred_provider_type_record:
|
||||
preferred_provider_type = ProviderType.value_of(preferred_provider_type_record.preferred_provider_type)
|
||||
elif dify_config.EDITION == "CLOUD" and system_configuration.enabled:
|
||||
preferred_provider_type = ProviderType.SYSTEM
|
||||
elif custom_configuration.provider or custom_configuration.models:
|
||||
preferred_provider_type = ProviderType.CUSTOM
|
||||
elif system_configuration.enabled:
|
||||
|
||||
@@ -1040,9 +1040,10 @@ class ToolManager:
|
||||
continue
|
||||
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
variable_selector = tool_input.require_variable_selector()
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
raise ToolParameterError(f"Variable {variable_selector} does not exist")
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type == "constant":
|
||||
parameter_value = tool_input.value
|
||||
|
||||
@@ -1,13 +1,24 @@
|
||||
from enum import IntEnum, StrEnum, auto
|
||||
from typing import Any, Literal, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel
|
||||
from enum import IntEnum, StrEnum, auto
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolSelector
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
|
||||
AgentInputConstantValue: TypeAlias = (
|
||||
list[ToolSelector] | str | int | float | bool | dict[str, object] | list[object] | None
|
||||
)
|
||||
VariableSelector: TypeAlias = list[str]
|
||||
|
||||
_AGENT_INPUT_VALUE_ADAPTER: TypeAdapter[AgentInputConstantValue] = TypeAdapter(AgentInputConstantValue)
|
||||
_AGENT_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
|
||||
|
||||
|
||||
class AgentNodeData(BaseNodeData):
|
||||
type: NodeType = BuiltinNodeTypes.AGENT
|
||||
@@ -21,8 +32,20 @@ class AgentNodeData(BaseNodeData):
|
||||
tool_node_version: str | None = None
|
||||
|
||||
class AgentInput(BaseModel):
|
||||
value: Union[list[str], list[ToolSelector], Any]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
value: AgentInputConstantValue | VariableSelector
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def validate_value(
|
||||
cls, value: object, validation_info: ValidationInfo
|
||||
) -> AgentInputConstantValue | VariableSelector:
|
||||
input_type = validation_info.data.get("type")
|
||||
if input_type == "variable":
|
||||
return _AGENT_VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
||||
if input_type in {"mixed", "constant"}:
|
||||
return _AGENT_INPUT_VALUE_ADAPTER.validate_python(value)
|
||||
raise ValueError(f"Unknown agent input type: {input_type}")
|
||||
|
||||
agent_parameters: dict[str, AgentInput]
|
||||
|
||||
|
||||
@@ -1,16 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TypeAlias
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
@@ -28,6 +29,14 @@ from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGen
|
||||
from .exceptions import AgentInputTypeError, AgentVariableNotFoundError
|
||||
from .strategy_protocols import ResolvedAgentStrategy
|
||||
|
||||
JsonObject: TypeAlias = dict[str, object]
|
||||
JsonObjectList: TypeAlias = list[JsonObject]
|
||||
VariableSelector: TypeAlias = list[str]
|
||||
|
||||
_JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
|
||||
_JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
|
||||
_VARIABLE_SELECTOR_ADAPTER = TypeAdapter(VariableSelector)
|
||||
|
||||
|
||||
class AgentRuntimeSupport:
|
||||
def build_parameters(
|
||||
@@ -39,12 +48,12 @@ class AgentRuntimeSupport:
|
||||
strategy: ResolvedAgentStrategy,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
invoke_from: Any,
|
||||
invoke_from: InvokeFrom,
|
||||
for_log: bool = False,
|
||||
) -> dict[str, Any]:
|
||||
) -> dict[str, object]:
|
||||
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
result: dict[str, object] = {}
|
||||
for parameter_name in node_data.agent_parameters:
|
||||
parameter = agent_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
@@ -54,9 +63,10 @@ class AgentRuntimeSupport:
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
match agent_input.type:
|
||||
case "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore[arg-type]
|
||||
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(agent_input.value)
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
parameter_value = variable.value
|
||||
case "mixed" | "constant":
|
||||
try:
|
||||
@@ -79,60 +89,38 @@ class AgentRuntimeSupport:
|
||||
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = [tool for tool in value if tool.get("enabled", False)]
|
||||
value = self._filter_mcp_type_tool(strategy, value)
|
||||
for tool in value:
|
||||
if "schemas" in tool:
|
||||
tool.pop("schemas")
|
||||
parameters = tool.get("parameters", {})
|
||||
if all(isinstance(v, dict) for _, v in parameters.items()):
|
||||
params = {}
|
||||
for key, param in parameters.items():
|
||||
if param.get("auto", ParamsAutoGenerated.OPEN) in (
|
||||
ParamsAutoGenerated.CLOSE,
|
||||
0,
|
||||
):
|
||||
value_param = param.get("value", {})
|
||||
if value_param and value_param.get("type", "") == "variable":
|
||||
variable_selector = value_param.get("value")
|
||||
if not variable_selector:
|
||||
raise ValueError("Variable selector is missing for a variable-type parameter.")
|
||||
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
|
||||
params[key] = variable.value
|
||||
else:
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
params[key] = None
|
||||
parameters = params
|
||||
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
|
||||
tool["parameters"] = parameters
|
||||
tool_payloads = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
|
||||
value = self._normalize_tool_payloads(
|
||||
strategy=strategy,
|
||||
tools=tool_payloads,
|
||||
variable_pool=variable_pool,
|
||||
)
|
||||
|
||||
if not for_log:
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
provider_type = self._coerce_tool_provider_type(tool.get("type"))
|
||||
setting_params = self._coerce_json_object(tool.get("settings")) or {}
|
||||
parameters = self._coerce_json_object(tool.get("parameters")) or {}
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
|
||||
parameters = {**parameters, **setting_params}
|
||||
provider_id = self._coerce_optional_string(tool.get("provider_name")) or ""
|
||||
tool_name = self._coerce_optional_string(tool.get("tool_name")) or ""
|
||||
plugin_unique_identifier = self._coerce_optional_string(tool.get("plugin_unique_identifier"))
|
||||
credential_id = self._coerce_optional_string(tool.get("credential_id"))
|
||||
entity = AgentToolEntity(
|
||||
provider_id=tool.get("provider_name", ""),
|
||||
provider_id=provider_id,
|
||||
provider_type=provider_type,
|
||||
tool_name=tool.get("tool_name", ""),
|
||||
tool_name=tool_name,
|
||||
tool_parameters=parameters,
|
||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||
credential_id=tool.get("credential_id", None),
|
||||
plugin_unique_identifier=plugin_unique_identifier,
|
||||
credential_id=credential_id,
|
||||
)
|
||||
|
||||
extra = tool.get("extra", {})
|
||||
extra = self._coerce_json_object(tool.get("extra")) or {}
|
||||
|
||||
runtime_variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||
@@ -145,8 +133,9 @@ class AgentRuntimeSupport:
|
||||
runtime_variable_pool,
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
description_override = self._coerce_optional_string(extra.get("description"))
|
||||
tool_runtime.entity.description.llm = (
|
||||
extra.get("description", "") or tool_runtime.entity.description.llm
|
||||
description_override or tool_runtime.entity.description.llm
|
||||
)
|
||||
for tool_runtime_params in tool_runtime.entity.parameters:
|
||||
tool_runtime_params.form = (
|
||||
@@ -167,13 +156,13 @@ class AgentRuntimeSupport:
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": runtime_parameters,
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
"credential_id": credential_id,
|
||||
"provider_type": provider_type.value,
|
||||
}
|
||||
)
|
||||
value = tool_value
|
||||
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
|
||||
value = cast(dict[str, Any], value)
|
||||
value = _JSON_OBJECT_ADAPTER.validate_python(value)
|
||||
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
|
||||
history_prompt_messages = []
|
||||
if node_data.memory:
|
||||
@@ -199,17 +188,27 @@ class AgentRuntimeSupport:
|
||||
|
||||
return result
|
||||
|
||||
def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials:
|
||||
def build_credentials(self, *, parameters: Mapping[str, object]) -> InvokeCredentials:
|
||||
credentials = InvokeCredentials()
|
||||
credentials.tool_credentials = {}
|
||||
for tool in parameters.get("tools", []):
|
||||
tools = parameters.get("tools")
|
||||
if not isinstance(tools, list):
|
||||
return credentials
|
||||
|
||||
for raw_tool in tools:
|
||||
tool = self._coerce_json_object(raw_tool)
|
||||
if tool is None:
|
||||
continue
|
||||
if not tool.get("credential_id"):
|
||||
continue
|
||||
try:
|
||||
identity = ToolIdentity.model_validate(tool.get("identity", {}))
|
||||
except ValidationError:
|
||||
continue
|
||||
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
|
||||
credential_id = self._coerce_optional_string(tool.get("credential_id"))
|
||||
if credential_id is None:
|
||||
continue
|
||||
credentials.tool_credentials[identity.provider] = credential_id
|
||||
return credentials
|
||||
|
||||
def fetch_memory(
|
||||
@@ -232,14 +231,14 @@ class AgentRuntimeSupport:
|
||||
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
def fetch_model(self, *, tenant_id: str, value: Mapping[str, object]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
provider_manager = ProviderManager()
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=tenant_id,
|
||||
provider=value.get("provider", ""),
|
||||
provider=str(value.get("provider", "")),
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
model_name = value.get("model", "")
|
||||
model_name = str(value.get("model", ""))
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM,
|
||||
model=model_name,
|
||||
@@ -249,7 +248,7 @@ class AgentRuntimeSupport:
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
model_type=ModelType(str(value.get("model_type", ""))),
|
||||
model=model_name,
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
@@ -268,9 +267,88 @@ class AgentRuntimeSupport:
|
||||
@staticmethod
|
||||
def _filter_mcp_type_tool(
|
||||
strategy: ResolvedAgentStrategy,
|
||||
tools: list[dict[str, Any]],
|
||||
) -> list[dict[str, Any]]:
|
||||
tools: JsonObjectList,
|
||||
) -> JsonObjectList:
|
||||
meta_version = strategy.meta_version
|
||||
if meta_version and Version(meta_version) > Version("0.0.1"):
|
||||
return tools
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||
|
||||
def _normalize_tool_payloads(
|
||||
self,
|
||||
*,
|
||||
strategy: ResolvedAgentStrategy,
|
||||
tools: JsonObjectList,
|
||||
variable_pool: VariablePool,
|
||||
) -> JsonObjectList:
|
||||
enabled_tools = [dict(tool) for tool in tools if bool(tool.get("enabled", False))]
|
||||
normalized_tools = self._filter_mcp_type_tool(strategy, enabled_tools)
|
||||
for tool in normalized_tools:
|
||||
tool.pop("schemas", None)
|
||||
tool["parameters"] = self._resolve_tool_parameters(tool=tool, variable_pool=variable_pool)
|
||||
tool["settings"] = self._resolve_tool_settings(tool)
|
||||
return normalized_tools
|
||||
|
||||
def _resolve_tool_parameters(self, *, tool: Mapping[str, object], variable_pool: VariablePool) -> JsonObject:
|
||||
parameter_configs = self._coerce_named_json_objects(tool.get("parameters"))
|
||||
if parameter_configs is None:
|
||||
raw_parameters = self._coerce_json_object(tool.get("parameters"))
|
||||
return raw_parameters or {}
|
||||
|
||||
resolved_parameters: JsonObject = {}
|
||||
for key, parameter_config in parameter_configs.items():
|
||||
if parameter_config.get("auto", ParamsAutoGenerated.OPEN) in (ParamsAutoGenerated.CLOSE, 0):
|
||||
value_param = self._coerce_json_object(parameter_config.get("value"))
|
||||
if value_param and value_param.get("type") == "variable":
|
||||
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(value_param.get("value"))
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
resolved_parameters[key] = variable.value
|
||||
else:
|
||||
resolved_parameters[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
resolved_parameters[key] = None
|
||||
|
||||
return resolved_parameters
|
||||
|
||||
@staticmethod
|
||||
def _resolve_tool_settings(tool: Mapping[str, object]) -> JsonObject:
|
||||
settings = AgentRuntimeSupport._coerce_named_json_objects(tool.get("settings"))
|
||||
if settings is None:
|
||||
return {}
|
||||
return {key: setting.get("value") for key, setting in settings.items()}
|
||||
|
||||
@staticmethod
|
||||
def _coerce_json_object(value: object) -> JsonObject | None:
|
||||
try:
|
||||
return _JSON_OBJECT_ADAPTER.validate_python(value)
|
||||
except ValidationError:
|
||||
return None
|
||||
|
||||
@staticmethod
|
||||
def _coerce_optional_string(value: object) -> str | None:
|
||||
return value if isinstance(value, str) else None
|
||||
|
||||
@staticmethod
|
||||
def _coerce_tool_provider_type(value: object) -> ToolProviderType:
|
||||
if isinstance(value, ToolProviderType):
|
||||
return value
|
||||
if isinstance(value, str):
|
||||
return ToolProviderType(value)
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
@classmethod
|
||||
def _coerce_named_json_objects(cls, value: object) -> dict[str, JsonObject] | None:
|
||||
if not isinstance(value, dict):
|
||||
return None
|
||||
|
||||
coerced: dict[str, JsonObject] = {}
|
||||
for key, item in value.items():
|
||||
if not isinstance(key, str):
|
||||
return None
|
||||
json_object = cls._coerce_json_object(item)
|
||||
if json_object is None:
|
||||
return None
|
||||
coerced[key] = json_object
|
||||
return coerced
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
from typing import Any, TypeAlias, cast
|
||||
|
||||
from configs import dify_config
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
@@ -32,6 +32,13 @@ from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SpecialValueScalar: TypeAlias = str | int | float | bool | None
|
||||
SpecialValue: TypeAlias = SpecialValueScalar | File | Mapping[str, "SpecialValue"] | list["SpecialValue"]
|
||||
SerializedSpecialValue: TypeAlias = (
|
||||
SpecialValueScalar | dict[str, "SerializedSpecialValue"] | list["SerializedSpecialValue"]
|
||||
)
|
||||
SingleNodeGraphConfig: TypeAlias = dict[str, list[dict[str, object]]]
|
||||
|
||||
|
||||
class _WorkflowChildEngineBuilder:
|
||||
@staticmethod
|
||||
@@ -276,10 +283,10 @@ class WorkflowEntry:
|
||||
@staticmethod
|
||||
def _create_single_node_graph(
|
||||
node_id: str,
|
||||
node_data: dict[str, Any],
|
||||
node_data: Mapping[str, object],
|
||||
node_width: int = 114,
|
||||
node_height: int = 514,
|
||||
) -> dict[str, Any]:
|
||||
) -> SingleNodeGraphConfig:
|
||||
"""
|
||||
Create a minimal graph structure for testing a single node in isolation.
|
||||
|
||||
@@ -289,14 +296,14 @@ class WorkflowEntry:
|
||||
:param node_height: height for UI layout (default: 100)
|
||||
:return: graph dictionary with start node and target node
|
||||
"""
|
||||
node_config = {
|
||||
node_config: dict[str, object] = {
|
||||
"id": node_id,
|
||||
"width": node_width,
|
||||
"height": node_height,
|
||||
"type": "custom",
|
||||
"data": node_data,
|
||||
"data": dict(node_data),
|
||||
}
|
||||
start_node_config = {
|
||||
start_node_config: dict[str, object] = {
|
||||
"id": "start",
|
||||
"width": node_width,
|
||||
"height": node_height,
|
||||
@@ -321,7 +328,12 @@ class WorkflowEntry:
|
||||
|
||||
@classmethod
|
||||
def run_free_node(
|
||||
cls, node_data: dict[str, Any], node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
|
||||
cls,
|
||||
node_data: Mapping[str, object],
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
user_inputs: Mapping[str, object],
|
||||
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
|
||||
"""
|
||||
Run free node
|
||||
@@ -339,6 +351,8 @@ class WorkflowEntry:
|
||||
graph_dict = cls._create_single_node_graph(node_id, node_data)
|
||||
|
||||
node_type = node_data.get("type", "")
|
||||
if not isinstance(node_type, str):
|
||||
raise ValueError("Node type must be a string")
|
||||
if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
|
||||
raise ValueError(f"Node type {node_type} not supported")
|
||||
|
||||
@@ -369,7 +383,7 @@ class WorkflowEntry:
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init workflow run state
|
||||
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
|
||||
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": dict(node_data)})
|
||||
node_factory = DifyNodeFactory(
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
@@ -405,30 +419,34 @@ class WorkflowEntry:
|
||||
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
|
||||
|
||||
@staticmethod
|
||||
def handle_special_values(value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
|
||||
def handle_special_values(value: Mapping[str, SpecialValue] | None) -> dict[str, SerializedSpecialValue] | None:
|
||||
# NOTE(QuantumGhost): Avoid using this function in new code.
|
||||
# Keep values structured as long as possible and only convert to dict
|
||||
# immediately before serialization (e.g., JSON serialization) to maintain
|
||||
# data integrity and type information.
|
||||
result = WorkflowEntry._handle_special_values(value)
|
||||
return result if isinstance(result, Mapping) or result is None else dict(result)
|
||||
if result is None:
|
||||
return None
|
||||
if isinstance(result, dict):
|
||||
return result
|
||||
raise TypeError("handle_special_values expects a mapping input")
|
||||
|
||||
@staticmethod
|
||||
def _handle_special_values(value: Any):
|
||||
def _handle_special_values(value: SpecialValue) -> SerializedSpecialValue:
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, dict):
|
||||
res = {}
|
||||
if isinstance(value, Mapping):
|
||||
res: dict[str, SerializedSpecialValue] = {}
|
||||
for k, v in value.items():
|
||||
res[k] = WorkflowEntry._handle_special_values(v)
|
||||
return res
|
||||
if isinstance(value, list):
|
||||
res_list = []
|
||||
res_list: list[SerializedSpecialValue] = []
|
||||
for item in value:
|
||||
res_list.append(WorkflowEntry._handle_special_values(item))
|
||||
return res_list
|
||||
if isinstance(value, File):
|
||||
return value.to_dict()
|
||||
return dict(value.to_dict())
|
||||
return value
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -112,6 +112,8 @@ def _get_encoded_string(f: File, /) -> str:
|
||||
data = _download_file_content(f.storage_key)
|
||||
case FileTransferMethod.DATASOURCE_FILE:
|
||||
data = _download_file_content(f.storage_key)
|
||||
case _:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
return base64.b64encode(data).decode("utf-8")
|
||||
|
||||
|
||||
@@ -133,6 +133,8 @@ class ExecutionLimitsLayer(GraphEngineLayer):
|
||||
elif limit_type == LimitType.TIME_LIMIT:
|
||||
elapsed_time = time.time() - self.start_time if self.start_time else 0
|
||||
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
|
||||
else:
|
||||
return
|
||||
|
||||
self.logger.warning("Execution limit exceeded: %s", reason)
|
||||
|
||||
|
||||
@@ -336,12 +336,7 @@ class Node(Generic[NodeDataT]):
|
||||
|
||||
def _restore_execution_id_from_runtime_state(self) -> str | None:
|
||||
graph_execution = self.graph_runtime_state.graph_execution
|
||||
try:
|
||||
node_executions = graph_execution.node_executions
|
||||
except AttributeError:
|
||||
return None
|
||||
if not isinstance(node_executions, dict):
|
||||
return None
|
||||
node_executions = graph_execution.node_executions
|
||||
node_execution = node_executions.get(self._node_id)
|
||||
if node_execution is None:
|
||||
return None
|
||||
@@ -395,8 +390,7 @@ class Node(Generic[NodeDataT]):
|
||||
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
yield self._dispatch(event)
|
||||
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
|
||||
event.id = self.execution_id
|
||||
yield event
|
||||
yield event.model_copy(update={"id": self.execution_id})
|
||||
else:
|
||||
yield event
|
||||
except Exception as e:
|
||||
|
||||
@@ -443,7 +443,10 @@ def _extract_text_from_docx(file_content: bytes) -> str:
|
||||
# Keep track of paragraph and table positions
|
||||
content_items: list[tuple[int, str, Table | Paragraph]] = []
|
||||
|
||||
it = iter(doc.element.body)
|
||||
doc_body = getattr(doc.element, "body", None)
|
||||
if doc_body is None:
|
||||
raise TextExtractionError("DOCX body not found")
|
||||
it = iter(doc_body)
|
||||
part = next(it, None)
|
||||
i = 0
|
||||
while part is not None:
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Literal, NotRequired
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
@@ -10,11 +11,17 @@ from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from dify_graph.nodes.base.entities import VariableSelector
|
||||
|
||||
|
||||
class StructuredOutputConfig(TypedDict):
|
||||
schema: Mapping[str, object]
|
||||
name: NotRequired[str]
|
||||
description: NotRequired[str]
|
||||
|
||||
|
||||
class ModelConfig(BaseModel):
|
||||
provider: str
|
||||
name: str
|
||||
mode: LLMMode
|
||||
completion_params: dict[str, Any] = Field(default_factory=dict)
|
||||
completion_params: dict[str, object] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class ContextConfig(BaseModel):
|
||||
@@ -33,7 +40,7 @@ class VisionConfig(BaseModel):
|
||||
|
||||
@field_validator("configs", mode="before")
|
||||
@classmethod
|
||||
def convert_none_configs(cls, v: Any):
|
||||
def convert_none_configs(cls, v: object):
|
||||
if v is None:
|
||||
return VisionConfigOptions()
|
||||
return v
|
||||
@@ -44,7 +51,7 @@ class PromptConfig(BaseModel):
|
||||
|
||||
@field_validator("jinja2_variables", mode="before")
|
||||
@classmethod
|
||||
def convert_none_jinja2_variables(cls, v: Any):
|
||||
def convert_none_jinja2_variables(cls, v: object):
|
||||
if v is None:
|
||||
return []
|
||||
return v
|
||||
@@ -67,7 +74,7 @@ class LLMNodeData(BaseNodeData):
|
||||
memory: MemoryConfig | None = None
|
||||
context: ContextConfig
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
structured_output: Mapping[str, Any] | None = None
|
||||
structured_output: StructuredOutputConfig | None = None
|
||||
# We used 'structured_output_enabled' in the past, but it's not a good name.
|
||||
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
|
||||
reasoning_format: Literal["separated", "tagged"] = Field(
|
||||
@@ -90,7 +97,7 @@ class LLMNodeData(BaseNodeData):
|
||||
|
||||
@field_validator("prompt_config", mode="before")
|
||||
@classmethod
|
||||
def convert_none_prompt_config(cls, v: Any):
|
||||
def convert_none_prompt_config(cls, v: object):
|
||||
if v is None:
|
||||
return PromptConfig()
|
||||
return v
|
||||
|
||||
@@ -9,6 +9,7 @@ import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
@@ -74,6 +75,7 @@ from .entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
StructuredOutputConfig,
|
||||
)
|
||||
from .exc import (
|
||||
InvalidContextStructureError,
|
||||
@@ -88,6 +90,7 @@ if TYPE_CHECKING:
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_JSON_OBJECT_ADAPTER = TypeAdapter(dict[str, object])
|
||||
|
||||
|
||||
class LLMNode(Node[LLMNodeData]):
|
||||
@@ -354,7 +357,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
stop: Sequence[str] | None = None,
|
||||
user_id: str,
|
||||
structured_output_enabled: bool,
|
||||
structured_output: Mapping[str, Any] | None = None,
|
||||
structured_output: StructuredOutputConfig | None = None,
|
||||
file_saver: LLMFileSaver,
|
||||
file_outputs: list[File],
|
||||
node_id: str,
|
||||
@@ -367,8 +370,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
|
||||
|
||||
if structured_output_enabled:
|
||||
if structured_output is None:
|
||||
raise LLMNodeError("Please provide a valid structured output schema")
|
||||
output_schema = LLMNode.fetch_structured_output_schema(
|
||||
structured_output=structured_output or {},
|
||||
structured_output=structured_output,
|
||||
)
|
||||
request_start_time = time.perf_counter()
|
||||
|
||||
@@ -920,6 +925,12 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# Extract clean text and reasoning from <think> tags
|
||||
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
|
||||
|
||||
structured_output = (
|
||||
dict(invoke_result.structured_output)
|
||||
if isinstance(invoke_result, LLMResultWithStructuredOutput) and invoke_result.structured_output is not None
|
||||
else None
|
||||
)
|
||||
|
||||
event = ModelInvokeCompletedEvent(
|
||||
# Use clean_text for separated mode, full_text for tagged mode
|
||||
text=clean_text if reasoning_format == "separated" else full_text,
|
||||
@@ -928,7 +939,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# Reasoning content for workflow variables and downstream nodes
|
||||
reasoning_content=reasoning_content,
|
||||
# Pass structured output if enabled
|
||||
structured_output=getattr(invoke_result, "structured_output", None),
|
||||
structured_output=structured_output,
|
||||
)
|
||||
if request_latency is not None:
|
||||
event.usage.latency = round(request_latency, 3)
|
||||
@@ -962,27 +973,18 @@ class LLMNode(Node[LLMNodeData]):
|
||||
@staticmethod
|
||||
def fetch_structured_output_schema(
|
||||
*,
|
||||
structured_output: Mapping[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
structured_output: StructuredOutputConfig,
|
||||
) -> dict[str, object]:
|
||||
"""
|
||||
Fetch the structured output schema from the node data.
|
||||
|
||||
Returns:
|
||||
dict[str, Any]: The structured output schema
|
||||
dict[str, object]: The structured output schema
|
||||
"""
|
||||
if not structured_output:
|
||||
schema = structured_output.get("schema")
|
||||
if not schema:
|
||||
raise LLMNodeError("Please provide a valid structured output schema")
|
||||
structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
|
||||
if not structured_output_schema:
|
||||
raise LLMNodeError("Please provide a valid structured output schema")
|
||||
|
||||
try:
|
||||
schema = json.loads(structured_output_schema)
|
||||
if not isinstance(schema, dict):
|
||||
raise LLMNodeError("structured_output_schema must be a JSON object")
|
||||
return schema
|
||||
except json.JSONDecodeError:
|
||||
raise LLMNodeError("structured_output_schema is not valid JSON format")
|
||||
return _JSON_OBJECT_ADAPTER.validate_python(schema)
|
||||
|
||||
@staticmethod
|
||||
def _save_multimodal_output_and_convert_result_to_markdown(
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Any, Literal
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, field_validator
|
||||
from enum import StrEnum
|
||||
from typing import Annotated, Literal, TypeAlias
|
||||
|
||||
from pydantic import AfterValidator, BaseModel, Field, TypeAdapter, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
@@ -9,6 +12,14 @@ from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
|
||||
from dify_graph.utils.condition.entities import Condition
|
||||
from dify_graph.variables.types import SegmentType
|
||||
|
||||
LoopValue: TypeAlias = str | int | float | bool | None | dict[str, "LoopValue"] | list["LoopValue"]
|
||||
LoopValueMapping: TypeAlias = dict[str, LoopValue]
|
||||
VariableSelector: TypeAlias = list[str]
|
||||
|
||||
_LOOP_VALUE_ADAPTER: TypeAdapter[LoopValue] = TypeAdapter(LoopValue)
|
||||
_LOOP_VALUE_MAPPING_ADAPTER: TypeAdapter[LoopValueMapping] = TypeAdapter(LoopValueMapping)
|
||||
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
|
||||
|
||||
_VALID_VAR_TYPE = frozenset(
|
||||
[
|
||||
SegmentType.STRING,
|
||||
@@ -37,7 +48,29 @@ class LoopVariableData(BaseModel):
|
||||
label: str
|
||||
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
|
||||
value_type: Literal["variable", "constant"]
|
||||
value: Any | list[str] | None = None
|
||||
value: LoopValue | VariableSelector | None = None
|
||||
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def validate_value(cls, value: object, validation_info: ValidationInfo) -> LoopValue | VariableSelector | None:
|
||||
value_type = validation_info.data.get("value_type")
|
||||
if value_type == "variable":
|
||||
if value is None:
|
||||
return None
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
||||
if value_type == "constant":
|
||||
return _LOOP_VALUE_ADAPTER.validate_python(value)
|
||||
raise ValueError(f"Unknown loop variable value type: {value_type}")
|
||||
|
||||
def require_variable_selector(self) -> VariableSelector:
|
||||
if self.value_type != "variable":
|
||||
raise ValueError(f"Expected variable loop input, got {self.value_type}")
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
|
||||
|
||||
def require_constant_value(self) -> LoopValue:
|
||||
if self.value_type != "constant":
|
||||
raise ValueError(f"Expected constant loop input, got {self.value_type}")
|
||||
return _LOOP_VALUE_ADAPTER.validate_python(self.value)
|
||||
|
||||
|
||||
class LoopNodeData(BaseLoopNodeData):
|
||||
@@ -46,14 +79,14 @@ class LoopNodeData(BaseLoopNodeData):
|
||||
break_conditions: list[Condition] # Conditions to break the loop
|
||||
logical_operator: Literal["and", "or"]
|
||||
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
|
||||
outputs: dict[str, Any] = Field(default_factory=dict)
|
||||
outputs: LoopValueMapping = Field(default_factory=dict)
|
||||
|
||||
@field_validator("outputs", mode="before")
|
||||
@classmethod
|
||||
def validate_outputs(cls, v):
|
||||
if v is None:
|
||||
def validate_outputs(cls, value: object) -> LoopValueMapping:
|
||||
if value is None:
|
||||
return {}
|
||||
return v
|
||||
return _LOOP_VALUE_MAPPING_ADAPTER.validate_python(value)
|
||||
|
||||
|
||||
class LoopStartNodeData(BaseNodeData):
|
||||
@@ -77,8 +110,8 @@ class LoopState(BaseLoopState):
|
||||
Loop State.
|
||||
"""
|
||||
|
||||
outputs: list[Any] = Field(default_factory=list)
|
||||
current_output: Any = None
|
||||
outputs: list[LoopValue] = Field(default_factory=list)
|
||||
current_output: LoopValue | None = None
|
||||
|
||||
class MetaData(BaseLoopState.MetaData):
|
||||
"""
|
||||
@@ -87,7 +120,7 @@ class LoopState(BaseLoopState):
|
||||
|
||||
loop_length: int
|
||||
|
||||
def get_last_output(self) -> Any:
|
||||
def get_last_output(self) -> LoopValue | None:
|
||||
"""
|
||||
Get last output.
|
||||
"""
|
||||
@@ -95,7 +128,7 @@ class LoopState(BaseLoopState):
|
||||
return self.outputs[-1]
|
||||
return None
|
||||
|
||||
def get_current_output(self) -> Any:
|
||||
def get_current_output(self) -> LoopValue | None:
|
||||
"""
|
||||
Get current output.
|
||||
"""
|
||||
|
||||
@@ -3,7 +3,7 @@ import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import TYPE_CHECKING, Literal, cast
|
||||
|
||||
from dify_graph.entities.graph_config import NodeConfigDictAdapter
|
||||
from dify_graph.enums import (
|
||||
@@ -29,7 +29,7 @@ from dify_graph.node_events import (
|
||||
)
|
||||
from dify_graph.nodes.base import LLMUsageTrackingMixin
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
|
||||
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopValue, LoopVariableData
|
||||
from dify_graph.utils.condition.processor import ConditionProcessor
|
||||
from dify_graph.variables import Segment, SegmentType
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
|
||||
@@ -60,7 +60,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
break_conditions = self.node_data.break_conditions
|
||||
logical_operator = self.node_data.logical_operator
|
||||
|
||||
inputs = {"loop_count": loop_count}
|
||||
inputs: dict[str, object] = {"loop_count": loop_count}
|
||||
|
||||
if not self.node_data.start_node_id:
|
||||
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
|
||||
@@ -68,12 +68,14 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
root_node_id = self.node_data.start_node_id
|
||||
|
||||
# Initialize loop variables in the original variable pool
|
||||
loop_variable_selectors = {}
|
||||
loop_variable_selectors: dict[str, list[str]] = {}
|
||||
if self.node_data.loop_variables:
|
||||
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
|
||||
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
|
||||
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.require_constant_value()),
|
||||
"variable": lambda var: (
|
||||
self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None
|
||||
self.graph_runtime_state.variable_pool.get(var.require_variable_selector())
|
||||
if var.value is not None
|
||||
else None
|
||||
),
|
||||
}
|
||||
for loop_variable in self.node_data.loop_variables:
|
||||
@@ -95,7 +97,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
condition_processor = ConditionProcessor()
|
||||
|
||||
loop_duration_map: dict[str, float] = {}
|
||||
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
|
||||
single_loop_variable_map: dict[str, dict[str, LoopValue]] = {} # single loop variable output
|
||||
loop_usage = LLMUsage.empty_usage()
|
||||
loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
|
||||
|
||||
@@ -146,7 +148,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
|
||||
|
||||
# Collect loop variable values after iteration
|
||||
single_loop_variable = {}
|
||||
single_loop_variable: dict[str, LoopValue] = {}
|
||||
for key, selector in loop_variable_selectors.items():
|
||||
segment = self.graph_runtime_state.variable_pool.get(selector)
|
||||
single_loop_variable[key] = segment.value if segment else None
|
||||
@@ -297,20 +299,29 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
graph_config: Mapping[str, object],
|
||||
node_id: str,
|
||||
node_data: LoopNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
variable_mapping = {}
|
||||
variable_mapping: dict[str, Sequence[str]] = {}
|
||||
|
||||
# Extract loop node IDs statically from graph_config
|
||||
|
||||
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
|
||||
|
||||
# Get node configs from graph_config
|
||||
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
|
||||
raw_nodes = graph_config.get("nodes")
|
||||
node_configs: dict[str, Mapping[str, object]] = {}
|
||||
if isinstance(raw_nodes, list):
|
||||
for raw_node in raw_nodes:
|
||||
if not isinstance(raw_node, dict):
|
||||
continue
|
||||
raw_node_id = raw_node.get("id")
|
||||
if isinstance(raw_node_id, str):
|
||||
node_configs[raw_node_id] = raw_node
|
||||
for sub_node_id, sub_node_config in node_configs.items():
|
||||
if sub_node_config.get("data", {}).get("loop_id") != node_id:
|
||||
sub_node_data = sub_node_config.get("data")
|
||||
if not isinstance(sub_node_data, dict) or sub_node_data.get("loop_id") != node_id:
|
||||
continue
|
||||
|
||||
# variable selector to variable mapping
|
||||
@@ -341,9 +352,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
|
||||
for loop_variable in node_data.loop_variables or []:
|
||||
if loop_variable.value_type == "variable":
|
||||
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
|
||||
# add loop variable to variable mapping
|
||||
selector = loop_variable.value
|
||||
selector = loop_variable.require_variable_selector()
|
||||
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
|
||||
|
||||
# remove variable out from loop
|
||||
@@ -352,7 +362,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
return variable_mapping
|
||||
|
||||
@classmethod
|
||||
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
|
||||
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, object], loop_node_id: str) -> set[str]:
|
||||
"""
|
||||
Extract node IDs that belong to a specific loop from graph configuration.
|
||||
|
||||
@@ -363,12 +373,19 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
:param loop_node_id: the ID of the loop node
|
||||
:return: set of node IDs that belong to the loop
|
||||
"""
|
||||
loop_node_ids = set()
|
||||
loop_node_ids: set[str] = set()
|
||||
|
||||
# Find all nodes that belong to this loop
|
||||
nodes = graph_config.get("nodes", [])
|
||||
for node in nodes:
|
||||
node_data = node.get("data", {})
|
||||
raw_nodes = graph_config.get("nodes")
|
||||
if not isinstance(raw_nodes, list):
|
||||
return loop_node_ids
|
||||
|
||||
for node in raw_nodes:
|
||||
if not isinstance(node, dict):
|
||||
continue
|
||||
node_data = node.get("data")
|
||||
if not isinstance(node_data, dict):
|
||||
continue
|
||||
if node_data.get("loop_id") == loop_node_id:
|
||||
node_id = node.get("id")
|
||||
if node_id:
|
||||
@@ -377,7 +394,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
|
||||
return loop_node_ids
|
||||
|
||||
@staticmethod
|
||||
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
|
||||
def _get_segment_for_constant(var_type: SegmentType, original_value: LoopValue | None) -> Segment:
|
||||
"""Get the appropriate segment type for a constant value."""
|
||||
# TODO: Refactor for maintainability:
|
||||
# 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Annotated, Any, Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -6,6 +6,7 @@ from pydantic import (
|
||||
Field,
|
||||
field_validator,
|
||||
)
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
@@ -55,7 +56,7 @@ class ParameterConfig(BaseModel):
|
||||
|
||||
@field_validator("name", mode="before")
|
||||
@classmethod
|
||||
def validate_name(cls, value) -> str:
|
||||
def validate_name(cls, value: object) -> str:
|
||||
if not value:
|
||||
raise ValueError("Parameter name is required")
|
||||
if value in {"__reason", "__is_success"}:
|
||||
@@ -79,6 +80,23 @@ class ParameterConfig(BaseModel):
|
||||
return element_type
|
||||
|
||||
|
||||
class JsonSchemaArrayItems(TypedDict):
|
||||
type: str
|
||||
|
||||
|
||||
class ParameterJsonSchemaProperty(TypedDict, total=False):
|
||||
description: str
|
||||
type: str
|
||||
items: JsonSchemaArrayItems
|
||||
enum: list[str]
|
||||
|
||||
|
||||
class ParameterJsonSchema(TypedDict):
|
||||
type: Literal["object"]
|
||||
properties: dict[str, ParameterJsonSchemaProperty]
|
||||
required: list[str]
|
||||
|
||||
|
||||
class ParameterExtractorNodeData(BaseNodeData):
|
||||
"""
|
||||
Parameter Extractor Node Data.
|
||||
@@ -95,19 +113,19 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
|
||||
@field_validator("reasoning_mode", mode="before")
|
||||
@classmethod
|
||||
def set_reasoning_mode(cls, v) -> str:
|
||||
return v or "function_call"
|
||||
def set_reasoning_mode(cls, v: object) -> str:
|
||||
return str(v) if v else "function_call"
|
||||
|
||||
def get_parameter_json_schema(self):
|
||||
def get_parameter_json_schema(self) -> ParameterJsonSchema:
|
||||
"""
|
||||
Get parameter json schema.
|
||||
|
||||
:return: parameter json schema
|
||||
"""
|
||||
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
|
||||
parameters: ParameterJsonSchema = {"type": "object", "properties": {}, "required": []}
|
||||
|
||||
for parameter in self.parameters:
|
||||
parameter_schema: dict[str, Any] = {"description": parameter.description}
|
||||
parameter_schema: ParameterJsonSchemaProperty = {"description": parameter.description}
|
||||
|
||||
if parameter.type == SegmentType.STRING:
|
||||
parameter_schema["type"] = "string"
|
||||
@@ -118,7 +136,7 @@ class ParameterExtractorNodeData(BaseNodeData):
|
||||
raise AssertionError("element type should not be None.")
|
||||
parameter_schema["items"] = {"type": element_type.value}
|
||||
else:
|
||||
parameter_schema["type"] = parameter.type
|
||||
parameter_schema["type"] = parameter.type.value
|
||||
|
||||
if parameter.options:
|
||||
parameter_schema["enum"] = parameter.options
|
||||
|
||||
@@ -5,6 +5,8 @@ import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
@@ -63,6 +65,7 @@ from .prompts import (
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
_JSON_OBJECT_ADAPTER = TypeAdapter(dict[str, object])
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from dify_graph.entities import GraphInitParams
|
||||
@@ -70,7 +73,7 @@ if TYPE_CHECKING:
|
||||
from dify_graph.runtime import GraphRuntimeState
|
||||
|
||||
|
||||
def extract_json(text):
|
||||
def extract_json(text: str) -> str | None:
|
||||
"""
|
||||
From a given JSON started from '{' or '[' extract the complete JSON object.
|
||||
"""
|
||||
@@ -392,10 +395,15 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
)
|
||||
|
||||
# generate tool
|
||||
parameter_schema = node_data.get_parameter_json_schema()
|
||||
tool = PromptMessageTool(
|
||||
name=FUNCTION_CALLING_EXTRACTOR_NAME,
|
||||
description="Extract parameters from the natural language text",
|
||||
parameters=node_data.get_parameter_json_schema(),
|
||||
parameters={
|
||||
"type": parameter_schema["type"],
|
||||
"properties": dict(parameter_schema["properties"]),
|
||||
"required": list(parameter_schema["required"]),
|
||||
},
|
||||
)
|
||||
|
||||
return prompt_messages, [tool]
|
||||
@@ -602,19 +610,21 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
else:
|
||||
return None
|
||||
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
|
||||
def _transform_result(self, data: ParameterExtractorNodeData, result: Mapping[str, object]) -> dict[str, object]:
|
||||
"""
|
||||
Transform result into standard format.
|
||||
"""
|
||||
transformed_result: dict[str, Any] = {}
|
||||
transformed_result: dict[str, object] = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.name in result:
|
||||
param_value = result[parameter.name]
|
||||
# transform value
|
||||
if parameter.type == SegmentType.NUMBER:
|
||||
transformed = self._transform_number(param_value)
|
||||
if transformed is not None:
|
||||
transformed_result[parameter.name] = transformed
|
||||
if isinstance(param_value, (bool, int, float, str)):
|
||||
numeric_value: bool | int | float | str = param_value
|
||||
transformed = self._transform_number(numeric_value)
|
||||
if transformed is not None:
|
||||
transformed_result[parameter.name] = transformed
|
||||
elif parameter.type == SegmentType.BOOLEAN:
|
||||
if isinstance(result[parameter.name], (bool, int)):
|
||||
transformed_result[parameter.name] = bool(result[parameter.name])
|
||||
@@ -661,7 +671,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
|
||||
return transformed_result
|
||||
|
||||
def _extract_complete_json_response(self, result: str) -> dict | None:
|
||||
def _extract_complete_json_response(self, result: str) -> dict[str, object] | None:
|
||||
"""
|
||||
Extract complete json response.
|
||||
"""
|
||||
@@ -672,11 +682,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
json_str = extract_json(result[idx:])
|
||||
if json_str:
|
||||
with contextlib.suppress(Exception):
|
||||
return cast(dict, json.loads(json_str))
|
||||
return _JSON_OBJECT_ADAPTER.validate_python(json.loads(json_str))
|
||||
logger.info("extra error: %s", result)
|
||||
return None
|
||||
|
||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
|
||||
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict[str, object] | None:
|
||||
"""
|
||||
Extract json from tool call.
|
||||
"""
|
||||
@@ -690,16 +700,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
json_str = extract_json(result[idx:])
|
||||
if json_str:
|
||||
with contextlib.suppress(Exception):
|
||||
return cast(dict, json.loads(json_str))
|
||||
return _JSON_OBJECT_ADAPTER.validate_python(json.loads(json_str))
|
||||
|
||||
logger.info("extra error: %s", result)
|
||||
return None
|
||||
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData):
|
||||
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict[str, object]:
|
||||
"""
|
||||
Generate default result.
|
||||
"""
|
||||
result: dict[str, Any] = {}
|
||||
result: dict[str, object] = {}
|
||||
for parameter in data.parameters:
|
||||
if parameter.type == "number":
|
||||
result[parameter.name] = 0
|
||||
|
||||
@@ -1,12 +1,24 @@
|
||||
from typing import Any, Literal, Union
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from typing import Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, TypeAdapter, field_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from dify_graph.entities.base_node_data import BaseNodeData
|
||||
from dify_graph.enums import BuiltinNodeTypes, NodeType
|
||||
|
||||
ToolConfigurationValue: TypeAlias = str | int | float | bool
|
||||
ToolConfigurations: TypeAlias = dict[str, ToolConfigurationValue]
|
||||
ToolInputConstantValue: TypeAlias = str | int | float | bool | dict[str, object] | list[object] | None
|
||||
VariableSelector: TypeAlias = list[str]
|
||||
|
||||
_TOOL_CONFIGURATIONS_ADAPTER: TypeAdapter[ToolConfigurations] = TypeAdapter(ToolConfigurations)
|
||||
_TOOL_INPUT_MIXED_ADAPTER: TypeAdapter[str] = TypeAdapter(str)
|
||||
_TOOL_INPUT_CONSTANT_ADAPTER: TypeAdapter[ToolInputConstantValue] = TypeAdapter(ToolInputConstantValue)
|
||||
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
provider_id: str
|
||||
@@ -14,52 +26,41 @@ class ToolEntity(BaseModel):
|
||||
provider_name: str # redundancy
|
||||
tool_name: str
|
||||
tool_label: str # redundancy
|
||||
tool_configurations: dict[str, Any]
|
||||
tool_configurations: ToolConfigurations
|
||||
credential_id: str | None = None
|
||||
plugin_unique_identifier: str | None = None # redundancy
|
||||
|
||||
@field_validator("tool_configurations", mode="before")
|
||||
@classmethod
|
||||
def validate_tool_configurations(cls, value, values: ValidationInfo):
|
||||
if not isinstance(value, dict):
|
||||
raise ValueError("tool_configurations must be a dictionary")
|
||||
|
||||
for key in values.data.get("tool_configurations", {}):
|
||||
value = values.data.get("tool_configurations", {}).get(key)
|
||||
if not isinstance(value, str | int | float | bool):
|
||||
raise ValueError(f"{key} must be a string")
|
||||
|
||||
return value
|
||||
def validate_tool_configurations(cls, value: object, _validation_info: ValidationInfo) -> ToolConfigurations:
|
||||
return _TOOL_CONFIGURATIONS_ADAPTER.validate_python(value)
|
||||
|
||||
|
||||
class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
type: NodeType = BuiltinNodeTypes.TOOL
|
||||
|
||||
class ToolInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
value: ToolInputConstantValue | VariableSelector
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@field_validator("value", mode="before")
|
||||
@classmethod
|
||||
def check_type(cls, value, validation_info: ValidationInfo):
|
||||
typ = value
|
||||
value = validation_info.data.get("value")
|
||||
def validate_value(
|
||||
cls, value: object, validation_info: ValidationInfo
|
||||
) -> ToolInputConstantValue | VariableSelector:
|
||||
input_type = validation_info.data.get("type")
|
||||
if input_type == "mixed":
|
||||
return _TOOL_INPUT_MIXED_ADAPTER.validate_python(value)
|
||||
if input_type == "variable":
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
|
||||
if input_type == "constant":
|
||||
return _TOOL_INPUT_CONSTANT_ADAPTER.validate_python(value)
|
||||
raise ValueError(f"Unknown tool input type: {input_type}")
|
||||
|
||||
if value is None:
|
||||
return typ
|
||||
|
||||
if typ == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
elif typ == "variable":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("value must be a list")
|
||||
for val in value:
|
||||
if not isinstance(val, str):
|
||||
raise ValueError("value must be a list of strings")
|
||||
elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))):
|
||||
raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}")
|
||||
return typ
|
||||
def require_variable_selector(self) -> VariableSelector:
|
||||
if self.type != "variable":
|
||||
raise ValueError(f"Expected variable tool input, got {self.type}")
|
||||
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
|
||||
|
||||
tool_parameters: dict[str, ToolInput]
|
||||
# The version of the tool parameter.
|
||||
@@ -69,7 +70,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
|
||||
@field_validator("tool_parameters", mode="before")
|
||||
@classmethod
|
||||
def filter_none_tool_inputs(cls, value):
|
||||
def filter_none_tool_inputs(cls, value: object) -> object:
|
||||
if not isinstance(value, dict):
|
||||
return value
|
||||
|
||||
@@ -80,8 +81,10 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _has_valid_value(tool_input):
|
||||
def _has_valid_value(tool_input: object) -> bool:
|
||||
"""Check if the value is valid"""
|
||||
if isinstance(tool_input, dict):
|
||||
return tool_input.get("value") is not None
|
||||
return getattr(tool_input, "value", None) is not None
|
||||
if isinstance(tool_input, ToolNodeData.ToolInput):
|
||||
return tool_input.value is not None
|
||||
return False
|
||||
|
||||
@@ -225,10 +225,11 @@ class ToolNode(Node[ToolNodeData]):
|
||||
continue
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
variable_selector = tool_input.require_variable_selector()
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
if parameter.required:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
raise ToolParameterError(f"Variable {variable_selector} does not exist")
|
||||
continue
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type in {"mixed", "constant"}:
|
||||
@@ -510,8 +511,9 @@ class ToolNode(Node[ToolNodeData]):
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
variable_selector = input.require_variable_selector()
|
||||
selector_key = ".".join(variable_selector)
|
||||
result[f"#{selector_key}#"] = variable_selector
|
||||
case "constant":
|
||||
pass
|
||||
|
||||
|
||||
@@ -9,7 +9,7 @@ from dify_graph.node_events import NodeRunResult
|
||||
from dify_graph.nodes.base.node import Node
|
||||
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
|
||||
from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError
|
||||
from dify_graph.variables import SegmentType, VariableBase
|
||||
from dify_graph.variables import Segment, SegmentType, VariableBase
|
||||
|
||||
from .node_data import VariableAssignerData, WriteMode
|
||||
|
||||
@@ -74,23 +74,29 @@ class VariableAssignerNode(Node[VariableAssignerData]):
|
||||
if not isinstance(original_variable, VariableBase):
|
||||
raise VariableOperatorNodeError("assigned variable not found")
|
||||
|
||||
income_value: Segment
|
||||
updated_variable: VariableBase
|
||||
match self.node_data.write_mode:
|
||||
case WriteMode.OVER_WRITE:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if input_value is None:
|
||||
raise VariableOperatorNodeError("input value not found")
|
||||
income_value = input_value
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.value})
|
||||
|
||||
case WriteMode.APPEND:
|
||||
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if not income_value:
|
||||
input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
|
||||
if input_value is None:
|
||||
raise VariableOperatorNodeError("input value not found")
|
||||
income_value = input_value
|
||||
updated_value = original_variable.value + [income_value.value]
|
||||
updated_variable = original_variable.model_copy(update={"value": updated_value})
|
||||
|
||||
case WriteMode.CLEAR:
|
||||
income_value = SegmentType.get_zero_value(original_variable.value_type)
|
||||
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
|
||||
case _:
|
||||
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
|
||||
|
||||
# Over write the variable.
|
||||
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)
|
||||
|
||||
@@ -66,6 +66,11 @@ class GraphExecutionProtocol(Protocol):
|
||||
exceptions_count: int
|
||||
pause_reasons: list[PauseReason]
|
||||
|
||||
@property
|
||||
def node_executions(self) -> Mapping[str, NodeExecutionProtocol]:
|
||||
"""Return node execution state keyed by node id for resume support."""
|
||||
...
|
||||
|
||||
def start(self) -> None:
|
||||
"""Transition execution into the running state."""
|
||||
...
|
||||
@@ -91,6 +96,12 @@ class GraphExecutionProtocol(Protocol):
|
||||
...
|
||||
|
||||
|
||||
class NodeExecutionProtocol(Protocol):
|
||||
"""Structural interface for per-node execution state used during resume."""
|
||||
|
||||
execution_id: str | None
|
||||
|
||||
|
||||
class ResponseStreamCoordinatorProtocol(Protocol):
|
||||
"""Structural interface for response stream coordinator."""
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from dify_graph.file.models import File
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
from dify_graph.variables.segments import Segment
|
||||
|
||||
|
||||
class ArrayValidation(StrEnum):
|
||||
@@ -219,7 +219,7 @@ class SegmentType(StrEnum):
|
||||
return _ARRAY_ELEMENT_TYPES_MAPPING.get(self)
|
||||
|
||||
@staticmethod
|
||||
def get_zero_value(t: SegmentType):
|
||||
def get_zero_value(t: SegmentType) -> Segment:
|
||||
# Lazy import to avoid circular dependency
|
||||
from factories import variable_factory
|
||||
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Protocol, cast
|
||||
|
||||
from fastopenapi.routers import FlaskRouter
|
||||
from flask_cors import CORS
|
||||
|
||||
@@ -9,6 +11,10 @@ from extensions.ext_blueprints import AUTHENTICATED_HEADERS, EXPOSED_HEADERS
|
||||
DOCS_PREFIX = "/fastopenapi"
|
||||
|
||||
|
||||
class SupportsIncludeRouter(Protocol):
|
||||
def include_router(self, router: object, *, prefix: str = "") -> None: ...
|
||||
|
||||
|
||||
def init_app(app: DifyApp) -> None:
|
||||
docs_enabled = dify_config.SWAGGER_UI_ENABLED
|
||||
docs_url = f"{DOCS_PREFIX}/docs" if docs_enabled else None
|
||||
@@ -36,7 +42,7 @@ def init_app(app: DifyApp) -> None:
|
||||
_ = remote_files
|
||||
_ = setup
|
||||
|
||||
router.include_router(console_router, prefix="/console/api")
|
||||
cast(SupportsIncludeRouter, router).include_router(console_router, prefix="/console/api")
|
||||
CORS(
|
||||
app,
|
||||
resources={r"/console/api/.*": {"origins": dify_config.CONSOLE_CORS_ALLOW_ORIGINS}},
|
||||
|
||||
@@ -55,7 +55,7 @@ class TypeMismatchError(Exception):
|
||||
|
||||
|
||||
# Define the constant
|
||||
SEGMENT_TO_VARIABLE_MAP = {
|
||||
SEGMENT_TO_VARIABLE_MAP: Mapping[type[Segment], type[VariableBase]] = {
|
||||
ArrayAnySegment: ArrayAnyVariable,
|
||||
ArrayBooleanSegment: ArrayBooleanVariable,
|
||||
ArrayFileSegment: ArrayFileVariable,
|
||||
@@ -296,13 +296,11 @@ def segment_to_variable(
|
||||
raise UnsupportedSegmentTypeError(f"not supported segment type {segment_type}")
|
||||
|
||||
variable_class = SEGMENT_TO_VARIABLE_MAP[segment_type]
|
||||
return cast(
|
||||
VariableBase,
|
||||
variable_class(
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
value=segment.value,
|
||||
selector=list(selector),
|
||||
),
|
||||
return variable_class(
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
value_type=segment.value_type,
|
||||
value=segment.value,
|
||||
selector=list(selector),
|
||||
)
|
||||
|
||||
@@ -32,6 +32,11 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _stream_with_request_context(response: object) -> Any:
|
||||
"""Bridge Flask's loosely-typed streaming helper without leaking casts into callers."""
|
||||
return cast(Any, stream_with_context)(response)
|
||||
|
||||
|
||||
def escape_like_pattern(pattern: str) -> str:
|
||||
"""
|
||||
Escape special characters in a string for safe use in SQL LIKE patterns.
|
||||
@@ -286,22 +291,32 @@ def generate_text_hash(text: str) -> str:
|
||||
return sha256(hash_text.encode()).hexdigest()
|
||||
|
||||
|
||||
def compact_generate_response(response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
||||
if isinstance(response, dict):
|
||||
def compact_generate_response(
|
||||
response: Mapping[str, Any] | Generator[str, None, None] | RateLimitGenerator,
|
||||
) -> Response:
|
||||
if isinstance(response, Mapping):
|
||||
return Response(
|
||||
response=json.dumps(jsonable_encoder(response)),
|
||||
status=200,
|
||||
content_type="application/json; charset=utf-8",
|
||||
)
|
||||
else:
|
||||
stream_response = response
|
||||
|
||||
def generate() -> Generator:
|
||||
yield from response
|
||||
def generate() -> Generator[str, None, None]:
|
||||
yield from stream_response
|
||||
|
||||
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
|
||||
return Response(
|
||||
_stream_with_request_context(generate()),
|
||||
status=200,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
def length_prefixed_response(magic_number: int, response: Union[Mapping, Generator, RateLimitGenerator]) -> Response:
|
||||
def length_prefixed_response(
|
||||
magic_number: int,
|
||||
response: Mapping[str, Any] | BaseModel | Generator[str | bytes, None, None] | RateLimitGenerator,
|
||||
) -> Response:
|
||||
"""
|
||||
This function is used to return a response with a length prefix.
|
||||
Magic number is a one byte number that indicates the type of the response.
|
||||
@@ -332,7 +347,7 @@ def length_prefixed_response(magic_number: int, response: Union[Mapping, Generat
|
||||
# | Magic Number 1byte | Reserved 1byte | Header Length 2bytes | Data Length 4bytes | Reserved 6bytes | Data
|
||||
return struct.pack("<BBHI", magic_number, 0, header_length, data_length) + b"\x00" * 6 + response
|
||||
|
||||
if isinstance(response, dict):
|
||||
if isinstance(response, Mapping):
|
||||
return Response(
|
||||
response=pack_response_with_length_prefix(json.dumps(jsonable_encoder(response)).encode("utf-8")),
|
||||
status=200,
|
||||
@@ -345,14 +360,20 @@ def length_prefixed_response(magic_number: int, response: Union[Mapping, Generat
|
||||
mimetype="application/json",
|
||||
)
|
||||
|
||||
def generate() -> Generator:
|
||||
for chunk in response:
|
||||
stream_response = response
|
||||
|
||||
def generate() -> Generator[bytes, None, None]:
|
||||
for chunk in stream_response:
|
||||
if isinstance(chunk, str):
|
||||
yield pack_response_with_length_prefix(chunk.encode("utf-8"))
|
||||
else:
|
||||
yield pack_response_with_length_prefix(chunk)
|
||||
|
||||
return Response(stream_with_context(generate()), status=200, mimetype="text/event-stream")
|
||||
return Response(
|
||||
_stream_with_request_context(generate()),
|
||||
status=200,
|
||||
mimetype="text/event-stream",
|
||||
)
|
||||
|
||||
|
||||
class TokenManager:
|
||||
|
||||
@@ -77,12 +77,14 @@ def login_required(func: Callable[P, R]) -> Callable[P, R | ResponseReturnValue]
|
||||
@wraps(func)
|
||||
def decorated_view(*args: P.args, **kwargs: P.kwargs) -> R | ResponseReturnValue:
|
||||
if request.method in EXEMPT_METHODS or dify_config.LOGIN_DISABLED:
|
||||
pass
|
||||
elif current_user is not None and not current_user.is_authenticated:
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
|
||||
user = _get_user()
|
||||
if user is None or not user.is_authenticated:
|
||||
return current_app.login_manager.unauthorized() # type: ignore
|
||||
# we put csrf validation here for less conflicts
|
||||
# TODO: maybe find a better place for it.
|
||||
check_csrf_token(request, current_user.id)
|
||||
check_csrf_token(request, user.id)
|
||||
return current_app.ensure_sync(func)(*args, **kwargs)
|
||||
|
||||
return decorated_view
|
||||
|
||||
@@ -7,9 +7,10 @@ https://github.com/django/django/blob/main/django/utils/module_loading.py
|
||||
|
||||
import sys
|
||||
from importlib import import_module
|
||||
from typing import Any
|
||||
|
||||
|
||||
def cached_import(module_path: str, class_name: str):
|
||||
def cached_import(module_path: str, class_name: str) -> Any:
|
||||
"""
|
||||
Import a module and return the named attribute/class from it, with caching.
|
||||
|
||||
@@ -20,16 +21,14 @@ def cached_import(module_path: str, class_name: str):
|
||||
Returns:
|
||||
The imported attribute/class
|
||||
"""
|
||||
if not (
|
||||
(module := sys.modules.get(module_path))
|
||||
and (spec := getattr(module, "__spec__", None))
|
||||
and getattr(spec, "_initializing", False) is False
|
||||
):
|
||||
module = sys.modules.get(module_path)
|
||||
spec = getattr(module, "__spec__", None) if module is not None else None
|
||||
if module is None or getattr(spec, "_initializing", False):
|
||||
module = import_module(module_path)
|
||||
return getattr(module, class_name)
|
||||
|
||||
|
||||
def import_string(dotted_path: str):
|
||||
def import_string(dotted_path: str) -> Any:
|
||||
"""
|
||||
Import a dotted module path and return the attribute/class designated by
|
||||
the last name in the path. Raise ImportError if the import failed.
|
||||
|
||||
@@ -1,7 +1,48 @@
|
||||
import sys
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from typing import NotRequired
|
||||
|
||||
import httpx
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
JsonObject = dict[str, object]
|
||||
JsonObjectList = list[JsonObject]
|
||||
|
||||
JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
|
||||
JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
|
||||
|
||||
|
||||
class AccessTokenResponse(TypedDict, total=False):
|
||||
access_token: str
|
||||
|
||||
|
||||
class GitHubEmailRecord(TypedDict, total=False):
|
||||
email: str
|
||||
primary: bool
|
||||
|
||||
|
||||
class GitHubRawUserInfo(TypedDict):
|
||||
id: int | str
|
||||
login: str
|
||||
name: NotRequired[str]
|
||||
email: NotRequired[str]
|
||||
|
||||
|
||||
class GoogleRawUserInfo(TypedDict):
|
||||
sub: str
|
||||
email: str
|
||||
|
||||
|
||||
ACCESS_TOKEN_RESPONSE_ADAPTER = TypeAdapter(AccessTokenResponse)
|
||||
GITHUB_RAW_USER_INFO_ADAPTER = TypeAdapter(GitHubRawUserInfo)
|
||||
GITHUB_EMAIL_RECORDS_ADAPTER = TypeAdapter(list[GitHubEmailRecord])
|
||||
GOOGLE_RAW_USER_INFO_ADAPTER = TypeAdapter(GoogleRawUserInfo)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -11,26 +52,38 @@ class OAuthUserInfo:
|
||||
email: str
|
||||
|
||||
|
||||
def _json_object(response: httpx.Response) -> JsonObject:
|
||||
return JSON_OBJECT_ADAPTER.validate_python(response.json())
|
||||
|
||||
|
||||
def _json_list(response: httpx.Response) -> JsonObjectList:
|
||||
return JSON_OBJECT_LIST_ADAPTER.validate_python(response.json())
|
||||
|
||||
|
||||
class OAuth:
|
||||
client_id: str
|
||||
client_secret: str
|
||||
redirect_uri: str
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self):
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_user_info(self, token: str) -> OAuthUserInfo:
|
||||
raw_info = self.get_raw_user_info(token)
|
||||
return self._transform_user_info(raw_info)
|
||||
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -40,7 +93,7 @@ class GitHubOAuth(OAuth):
|
||||
_USER_INFO_URL = "https://api.github.com/user"
|
||||
_EMAIL_INFO_URL = "https://api.github.com/user/emails"
|
||||
|
||||
def get_authorization_url(self, invite_token: str | None = None):
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"redirect_uri": self.redirect_uri,
|
||||
@@ -50,7 +103,7 @@ class GitHubOAuth(OAuth):
|
||||
params["state"] = invite_token
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> str:
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
@@ -60,7 +113,7 @@ class GitHubOAuth(OAuth):
|
||||
headers = {"Accept": "application/json"}
|
||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
|
||||
access_token = response_json.get("access_token")
|
||||
|
||||
if not access_token:
|
||||
@@ -68,23 +121,24 @@ class GitHubOAuth(OAuth):
|
||||
|
||||
return access_token
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||
headers = {"Authorization": f"token {token}"}
|
||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
user_info = response.json()
|
||||
user_info = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(_json_object(response))
|
||||
|
||||
email_response = httpx.get(self._EMAIL_INFO_URL, headers=headers)
|
||||
email_info = email_response.json()
|
||||
primary_email: dict = next((email for email in email_info if email["primary"] == True), {})
|
||||
email_info = GITHUB_EMAIL_RECORDS_ADAPTER.validate_python(_json_list(email_response))
|
||||
primary_email = next((email for email in email_info if email.get("primary") is True), None)
|
||||
|
||||
return {**user_info, "email": primary_email.get("email", "")}
|
||||
return {**user_info, "email": primary_email.get("email", "") if primary_email else ""}
|
||||
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
email = raw_info.get("email")
|
||||
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||
payload = GITHUB_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
|
||||
email = payload.get("email")
|
||||
if not email:
|
||||
email = f"{raw_info['id']}+{raw_info['login']}@users.noreply.github.com"
|
||||
return OAuthUserInfo(id=str(raw_info["id"]), name=raw_info["name"], email=email)
|
||||
email = f"{payload['id']}+{payload['login']}@users.noreply.github.com"
|
||||
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email)
|
||||
|
||||
|
||||
class GoogleOAuth(OAuth):
|
||||
@@ -92,7 +146,7 @@ class GoogleOAuth(OAuth):
|
||||
_TOKEN_URL = "https://oauth2.googleapis.com/token"
|
||||
_USER_INFO_URL = "https://www.googleapis.com/oauth2/v3/userinfo"
|
||||
|
||||
def get_authorization_url(self, invite_token: str | None = None):
|
||||
def get_authorization_url(self, invite_token: str | None = None) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"response_type": "code",
|
||||
@@ -103,7 +157,7 @@ class GoogleOAuth(OAuth):
|
||||
params["state"] = invite_token
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> str:
|
||||
data = {
|
||||
"client_id": self.client_id,
|
||||
"client_secret": self.client_secret,
|
||||
@@ -114,7 +168,7 @@ class GoogleOAuth(OAuth):
|
||||
headers = {"Accept": "application/json"}
|
||||
response = httpx.post(self._TOKEN_URL, data=data, headers=headers)
|
||||
|
||||
response_json = response.json()
|
||||
response_json = ACCESS_TOKEN_RESPONSE_ADAPTER.validate_python(_json_object(response))
|
||||
access_token = response_json.get("access_token")
|
||||
|
||||
if not access_token:
|
||||
@@ -122,11 +176,12 @@ class GoogleOAuth(OAuth):
|
||||
|
||||
return access_token
|
||||
|
||||
def get_raw_user_info(self, token: str):
|
||||
def get_raw_user_info(self, token: str) -> JsonObject:
|
||||
headers = {"Authorization": f"Bearer {token}"}
|
||||
response = httpx.get(self._USER_INFO_URL, headers=headers)
|
||||
response.raise_for_status()
|
||||
return response.json()
|
||||
return _json_object(response)
|
||||
|
||||
def _transform_user_info(self, raw_info: dict) -> OAuthUserInfo:
|
||||
return OAuthUserInfo(id=str(raw_info["sub"]), name="", email=raw_info["email"])
|
||||
def _transform_user_info(self, raw_info: JsonObject) -> OAuthUserInfo:
|
||||
payload = GOOGLE_RAW_USER_INFO_ADAPTER.validate_python(raw_info)
|
||||
return OAuthUserInfo(id=str(payload["sub"]), name="", email=payload["email"])
|
||||
|
||||
@@ -1,25 +1,57 @@
|
||||
import sys
|
||||
import urllib.parse
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
import httpx
|
||||
from flask_login import current_user
|
||||
from pydantic import TypeAdapter
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.source import DataSourceOauthBinding
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import TypedDict
|
||||
else:
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
|
||||
class NotionPageSummary(TypedDict):
|
||||
page_id: str
|
||||
page_name: str
|
||||
page_icon: dict[str, str] | None
|
||||
parent_id: str
|
||||
type: Literal["page", "database"]
|
||||
|
||||
|
||||
class NotionSourceInfo(TypedDict):
|
||||
workspace_name: str | None
|
||||
workspace_icon: str | None
|
||||
workspace_id: str | None
|
||||
pages: list[NotionPageSummary]
|
||||
total: int
|
||||
|
||||
|
||||
SOURCE_INFO_STORAGE_ADAPTER = TypeAdapter(dict[str, object])
|
||||
NOTION_SOURCE_INFO_ADAPTER = TypeAdapter(NotionSourceInfo)
|
||||
NOTION_PAGE_SUMMARY_ADAPTER = TypeAdapter(NotionPageSummary)
|
||||
|
||||
|
||||
class OAuthDataSource:
|
||||
client_id: str
|
||||
client_secret: str
|
||||
redirect_uri: str
|
||||
|
||||
def __init__(self, client_id: str, client_secret: str, redirect_uri: str):
|
||||
self.client_id = client_id
|
||||
self.client_secret = client_secret
|
||||
self.redirect_uri = redirect_uri
|
||||
|
||||
def get_authorization_url(self):
|
||||
def get_authorization_url(self) -> str:
|
||||
raise NotImplementedError()
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> None:
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
@@ -30,7 +62,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
_NOTION_BLOCK_SEARCH = "https://api.notion.com/v1/blocks"
|
||||
_NOTION_BOT_USER = "https://api.notion.com/v1/users/me"
|
||||
|
||||
def get_authorization_url(self):
|
||||
def get_authorization_url(self) -> str:
|
||||
params = {
|
||||
"client_id": self.client_id,
|
||||
"response_type": "code",
|
||||
@@ -39,7 +71,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
}
|
||||
return f"{self._AUTH_URL}?{urllib.parse.urlencode(params)}"
|
||||
|
||||
def get_access_token(self, code: str):
|
||||
def get_access_token(self, code: str) -> None:
|
||||
data = {"code": code, "grant_type": "authorization_code", "redirect_uri": self.redirect_uri}
|
||||
headers = {"Accept": "application/json"}
|
||||
auth = (self.client_id, self.client_secret)
|
||||
@@ -54,13 +86,12 @@ class NotionOAuth(OAuthDataSource):
|
||||
workspace_id = response_json.get("workspace_id")
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(access_token)
|
||||
source_info = {
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
"workspace_id": workspace_id,
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
source_info = self._build_source_info(
|
||||
workspace_name=workspace_name,
|
||||
workspace_icon=workspace_icon,
|
||||
workspace_id=workspace_id,
|
||||
pages=pages,
|
||||
)
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
@@ -70,7 +101,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
)
|
||||
)
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = source_info
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
@@ -78,25 +109,24 @@ class NotionOAuth(OAuthDataSource):
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=source_info,
|
||||
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||
provider="notion",
|
||||
)
|
||||
db.session.add(new_data_source_binding)
|
||||
db.session.commit()
|
||||
|
||||
def save_internal_access_token(self, access_token: str):
|
||||
def save_internal_access_token(self, access_token: str) -> None:
|
||||
workspace_name = self.notion_workspace_name(access_token)
|
||||
workspace_icon = None
|
||||
workspace_id = current_user.current_tenant_id
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(access_token)
|
||||
source_info = {
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
"workspace_id": workspace_id,
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
source_info = self._build_source_info(
|
||||
workspace_name=workspace_name,
|
||||
workspace_icon=workspace_icon,
|
||||
workspace_id=workspace_id,
|
||||
pages=pages,
|
||||
)
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
@@ -106,7 +136,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
)
|
||||
)
|
||||
if data_source_binding:
|
||||
data_source_binding.source_info = source_info
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
@@ -114,13 +144,13 @@ class NotionOAuth(OAuthDataSource):
|
||||
new_data_source_binding = DataSourceOauthBinding(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
access_token=access_token,
|
||||
source_info=source_info,
|
||||
source_info=SOURCE_INFO_STORAGE_ADAPTER.validate_python(source_info),
|
||||
provider="notion",
|
||||
)
|
||||
db.session.add(new_data_source_binding)
|
||||
db.session.commit()
|
||||
|
||||
def sync_data_source(self, binding_id: str):
|
||||
def sync_data_source(self, binding_id: str) -> None:
|
||||
# save data source binding
|
||||
data_source_binding = db.session.scalar(
|
||||
select(DataSourceOauthBinding).where(
|
||||
@@ -134,23 +164,22 @@ class NotionOAuth(OAuthDataSource):
|
||||
if data_source_binding:
|
||||
# get all authorized pages
|
||||
pages = self.get_authorized_pages(data_source_binding.access_token)
|
||||
source_info = data_source_binding.source_info
|
||||
new_source_info = {
|
||||
"workspace_name": source_info["workspace_name"],
|
||||
"workspace_icon": source_info["workspace_icon"],
|
||||
"workspace_id": source_info["workspace_id"],
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
data_source_binding.source_info = new_source_info
|
||||
source_info = NOTION_SOURCE_INFO_ADAPTER.validate_python(data_source_binding.source_info)
|
||||
new_source_info = self._build_source_info(
|
||||
workspace_name=source_info["workspace_name"],
|
||||
workspace_icon=source_info["workspace_icon"],
|
||||
workspace_id=source_info["workspace_id"],
|
||||
pages=pages,
|
||||
)
|
||||
data_source_binding.source_info = SOURCE_INFO_STORAGE_ADAPTER.validate_python(new_source_info)
|
||||
data_source_binding.disabled = False
|
||||
data_source_binding.updated_at = naive_utc_now()
|
||||
db.session.commit()
|
||||
else:
|
||||
raise ValueError("Data source binding not found")
|
||||
|
||||
def get_authorized_pages(self, access_token: str):
|
||||
pages = []
|
||||
def get_authorized_pages(self, access_token: str) -> list[NotionPageSummary]:
|
||||
pages: list[NotionPageSummary] = []
|
||||
page_results = self.notion_page_search(access_token)
|
||||
database_results = self.notion_database_search(access_token)
|
||||
# get page detail
|
||||
@@ -187,7 +216,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
"parent_id": parent_id,
|
||||
"type": "page",
|
||||
}
|
||||
pages.append(page)
|
||||
pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page))
|
||||
# get database detail
|
||||
for database_result in database_results:
|
||||
page_id = database_result["id"]
|
||||
@@ -220,11 +249,11 @@ class NotionOAuth(OAuthDataSource):
|
||||
"parent_id": parent_id,
|
||||
"type": "database",
|
||||
}
|
||||
pages.append(page)
|
||||
pages.append(NOTION_PAGE_SUMMARY_ADAPTER.validate_python(page))
|
||||
return pages
|
||||
|
||||
def notion_page_search(self, access_token: str):
|
||||
results = []
|
||||
def notion_page_search(self, access_token: str) -> list[dict[str, Any]]:
|
||||
results: list[dict[str, Any]] = []
|
||||
next_cursor = None
|
||||
has_more = True
|
||||
|
||||
@@ -249,7 +278,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
|
||||
return results
|
||||
|
||||
def notion_block_parent_page_id(self, access_token: str, block_id: str):
|
||||
def notion_block_parent_page_id(self, access_token: str, block_id: str) -> str:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
@@ -265,7 +294,7 @@ class NotionOAuth(OAuthDataSource):
|
||||
return self.notion_block_parent_page_id(access_token, parent[parent_type])
|
||||
return parent[parent_type]
|
||||
|
||||
def notion_workspace_name(self, access_token: str):
|
||||
def notion_workspace_name(self, access_token: str) -> str:
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Notion-Version": "2022-06-28",
|
||||
@@ -279,8 +308,8 @@ class NotionOAuth(OAuthDataSource):
|
||||
return user_info["workspace_name"]
|
||||
return "workspace"
|
||||
|
||||
def notion_database_search(self, access_token: str):
|
||||
results = []
|
||||
def notion_database_search(self, access_token: str) -> list[dict[str, Any]]:
|
||||
results: list[dict[str, Any]] = []
|
||||
next_cursor = None
|
||||
has_more = True
|
||||
|
||||
@@ -303,3 +332,19 @@ class NotionOAuth(OAuthDataSource):
|
||||
next_cursor = response_json.get("next_cursor", None)
|
||||
|
||||
return results
|
||||
|
||||
@staticmethod
|
||||
def _build_source_info(
|
||||
*,
|
||||
workspace_name: str | None,
|
||||
workspace_icon: str | None,
|
||||
workspace_id: str | None,
|
||||
pages: list[NotionPageSummary],
|
||||
) -> NotionSourceInfo:
|
||||
return {
|
||||
"workspace_name": workspace_name,
|
||||
"workspace_icon": workspace_icon,
|
||||
"workspace_id": workspace_id,
|
||||
"pages": pages,
|
||||
"total": len(pages),
|
||||
}
|
||||
|
||||
@@ -23,6 +23,9 @@ from .enums import AppTriggerStatus, AppTriggerType, CreatorUserRole, WorkflowTr
|
||||
from .model import Account
|
||||
from .types import EnumText, LongText, StringUUID
|
||||
|
||||
TriggerJsonObject = dict[str, object]
|
||||
TriggerCredentials = dict[str, str]
|
||||
|
||||
|
||||
class WorkflowTriggerLogDict(TypedDict):
|
||||
id: str
|
||||
@@ -89,10 +92,14 @@ class TriggerSubscription(TypeBase):
|
||||
String(255), nullable=False, comment="Provider identifier (e.g., plugin_id/provider_name)"
|
||||
)
|
||||
endpoint_id: Mapped[str] = mapped_column(String(255), nullable=False, comment="Subscription endpoint")
|
||||
parameters: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription parameters JSON")
|
||||
properties: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False, comment="Subscription properties JSON")
|
||||
parameters: Mapped[TriggerJsonObject] = mapped_column(
|
||||
sa.JSON, nullable=False, comment="Subscription parameters JSON"
|
||||
)
|
||||
properties: Mapped[TriggerJsonObject] = mapped_column(
|
||||
sa.JSON, nullable=False, comment="Subscription properties JSON"
|
||||
)
|
||||
|
||||
credentials: Mapped[dict[str, Any]] = mapped_column(
|
||||
credentials: Mapped[TriggerCredentials] = mapped_column(
|
||||
sa.JSON, nullable=False, comment="Subscription credentials JSON"
|
||||
)
|
||||
credential_type: Mapped[str] = mapped_column(String(50), nullable=False, comment="oauth or api_key")
|
||||
@@ -200,8 +207,8 @@ class TriggerOAuthTenantClient(TypeBase):
|
||||
)
|
||||
|
||||
@property
|
||||
def oauth_params(self) -> Mapping[str, Any]:
|
||||
return cast(Mapping[str, Any], json.loads(self.encrypted_oauth_params or "{}"))
|
||||
def oauth_params(self) -> Mapping[str, object]:
|
||||
return cast(TriggerJsonObject, json.loads(self.encrypted_oauth_params or "{}"))
|
||||
|
||||
|
||||
class WorkflowTriggerLog(TypeBase):
|
||||
|
||||
@@ -19,7 +19,7 @@ from sqlalchemy import (
|
||||
orm,
|
||||
select,
|
||||
)
|
||||
from sqlalchemy.orm import Mapped, declared_attr, mapped_column
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from core.trigger.constants import TRIGGER_INFO_METADATA_KEY, TRIGGER_PLUGIN_NODE_TYPE
|
||||
@@ -33,7 +33,7 @@ from dify_graph.enums import BuiltinNodeTypes, NodeType, WorkflowExecutionStatus
|
||||
from dify_graph.file.constants import maybe_file_object
|
||||
from dify_graph.file.models import File
|
||||
from dify_graph.variables import utils as variable_utils
|
||||
from dify_graph.variables.variables import FloatVariable, IntegerVariable, StringVariable
|
||||
from dify_graph.variables.variables import FloatVariable, IntegerVariable, RAGPipelineVariable, StringVariable
|
||||
from extensions.ext_storage import Storage
|
||||
from factories.variable_factory import TypeMismatchError, build_segment_with_type
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
@@ -59,6 +59,9 @@ from .types import EnumText, LongText, StringUUID
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
SerializedWorkflowValue = dict[str, Any]
|
||||
SerializedWorkflowVariables = dict[str, SerializedWorkflowValue]
|
||||
|
||||
|
||||
class WorkflowContentDict(TypedDict):
|
||||
graph: Mapping[str, Any]
|
||||
@@ -405,7 +408,7 @@ class Workflow(Base): # bug
|
||||
|
||||
def rag_pipeline_user_input_form(self) -> list:
|
||||
# get user_input_form from start node
|
||||
variables: list[Any] = self.rag_pipeline_variables
|
||||
variables: list[SerializedWorkflowValue] = self.rag_pipeline_variables
|
||||
|
||||
return variables
|
||||
|
||||
@@ -448,17 +451,13 @@ class Workflow(Base): # bug
|
||||
def environment_variables(
|
||||
self,
|
||||
) -> Sequence[StringVariable | IntegerVariable | FloatVariable | SecretVariable]:
|
||||
# TODO: find some way to init `self._environment_variables` when instance created.
|
||||
if self._environment_variables is None:
|
||||
self._environment_variables = "{}"
|
||||
|
||||
# Use workflow.tenant_id to avoid relying on request user in background threads
|
||||
tenant_id = self.tenant_id
|
||||
|
||||
if not tenant_id:
|
||||
return []
|
||||
|
||||
environment_variables_dict: dict[str, Any] = json.loads(self._environment_variables or "{}")
|
||||
environment_variables_dict = cast(SerializedWorkflowVariables, json.loads(self._environment_variables or "{}"))
|
||||
results = [
|
||||
variable_factory.build_environment_variable_from_mapping(v) for v in environment_variables_dict.values()
|
||||
]
|
||||
@@ -536,11 +535,7 @@ class Workflow(Base): # bug
|
||||
|
||||
@property
|
||||
def conversation_variables(self) -> Sequence[VariableBase]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._conversation_variables is None:
|
||||
self._conversation_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._conversation_variables)
|
||||
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._conversation_variables or "{}"))
|
||||
results = [variable_factory.build_conversation_variable_from_mapping(v) for v in variables_dict.values()]
|
||||
return results
|
||||
|
||||
@@ -552,19 +547,20 @@ class Workflow(Base): # bug
|
||||
)
|
||||
|
||||
@property
|
||||
def rag_pipeline_variables(self) -> list[dict]:
|
||||
# TODO: find some way to init `self._conversation_variables` when instance created.
|
||||
if self._rag_pipeline_variables is None:
|
||||
self._rag_pipeline_variables = "{}"
|
||||
|
||||
variables_dict: dict[str, Any] = json.loads(self._rag_pipeline_variables)
|
||||
results = list(variables_dict.values())
|
||||
return results
|
||||
def rag_pipeline_variables(self) -> list[SerializedWorkflowValue]:
|
||||
variables_dict = cast(SerializedWorkflowVariables, json.loads(self._rag_pipeline_variables or "{}"))
|
||||
return [RAGPipelineVariable.model_validate(item).model_dump(mode="json") for item in variables_dict.values()]
|
||||
|
||||
@rag_pipeline_variables.setter
|
||||
def rag_pipeline_variables(self, values: list[dict]) -> None:
|
||||
def rag_pipeline_variables(self, values: Sequence[Mapping[str, Any] | RAGPipelineVariable]) -> None:
|
||||
self._rag_pipeline_variables = json.dumps(
|
||||
{item["variable"]: item for item in values},
|
||||
{
|
||||
rag_pipeline_variable.variable: rag_pipeline_variable.model_dump(mode="json")
|
||||
for rag_pipeline_variable in (
|
||||
item if isinstance(item, RAGPipelineVariable) else RAGPipelineVariable.model_validate(item)
|
||||
for item in values
|
||||
)
|
||||
},
|
||||
ensure_ascii=False,
|
||||
)
|
||||
|
||||
@@ -802,44 +798,36 @@ class WorkflowNodeExecutionModel(Base): # This model is expected to have `offlo
|
||||
|
||||
__tablename__ = "workflow_node_executions"
|
||||
|
||||
@declared_attr.directive
|
||||
@classmethod
|
||||
def __table_args__(cls) -> Any:
|
||||
return (
|
||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
Index(
|
||||
"workflow_node_execution_workflow_run_id_idx",
|
||||
"workflow_run_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_node_run_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_id_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_execution_id",
|
||||
),
|
||||
Index(
|
||||
# The first argument is the index name,
|
||||
# which we leave as `None`` to allow auto-generation by the ORM.
|
||||
None,
|
||||
cls.tenant_id,
|
||||
cls.workflow_id,
|
||||
cls.node_id,
|
||||
# MyPy may flag the following line because it doesn't recognize that
|
||||
# the `declared_attr` decorator passes the receiving class as the first
|
||||
# argument to this method, allowing us to reference class attributes.
|
||||
cls.created_at.desc(),
|
||||
),
|
||||
)
|
||||
__table_args__ = (
|
||||
PrimaryKeyConstraint("id", name="workflow_node_execution_pkey"),
|
||||
Index(
|
||||
"workflow_node_execution_workflow_run_id_idx",
|
||||
"workflow_run_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_node_run_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_id",
|
||||
),
|
||||
Index(
|
||||
"workflow_node_execution_id_idx",
|
||||
"tenant_id",
|
||||
"app_id",
|
||||
"workflow_id",
|
||||
"triggered_from",
|
||||
"node_execution_id",
|
||||
),
|
||||
Index(
|
||||
None,
|
||||
"tenant_id",
|
||||
"workflow_id",
|
||||
"node_id",
|
||||
sa.desc("created_at"),
|
||||
),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
configs/middleware/cache/redis_pubsub_config.py
|
||||
controllers/console/app/annotation.py
|
||||
controllers/console/app/app.py
|
||||
controllers/console/app/app_import.py
|
||||
@@ -14,21 +13,6 @@ controllers/console/workspace/trigger_providers.py
|
||||
controllers/service_api/app/annotation.py
|
||||
controllers/web/workflow_events.py
|
||||
core/agent/fc_agent_runner.py
|
||||
core/app/apps/advanced_chat/app_generator.py
|
||||
core/app/apps/advanced_chat/app_runner.py
|
||||
core/app/apps/advanced_chat/generate_task_pipeline.py
|
||||
core/app/apps/agent_chat/app_generator.py
|
||||
core/app/apps/base_app_generate_response_converter.py
|
||||
core/app/apps/base_app_generator.py
|
||||
core/app/apps/chat/app_generator.py
|
||||
core/app/apps/common/workflow_response_converter.py
|
||||
core/app/apps/completion/app_generator.py
|
||||
core/app/apps/pipeline/pipeline_generator.py
|
||||
core/app/apps/pipeline/pipeline_runner.py
|
||||
core/app/apps/workflow/app_generator.py
|
||||
core/app/apps/workflow/app_runner.py
|
||||
core/app/apps/workflow/generate_task_pipeline.py
|
||||
core/app/apps/workflow_app_runner.py
|
||||
core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
|
||||
core/datasource/datasource_manager.py
|
||||
core/external_data_tool/api/api.py
|
||||
@@ -109,37 +93,6 @@ core/tools/workflow_as_tool/provider.py
|
||||
core/trigger/debug/event_selectors.py
|
||||
core/trigger/entities/entities.py
|
||||
core/trigger/provider.py
|
||||
core/workflow/workflow_entry.py
|
||||
dify_graph/entities/workflow_execution.py
|
||||
dify_graph/file/file_manager.py
|
||||
dify_graph/graph_engine/error_handler.py
|
||||
dify_graph/graph_engine/layers/execution_limits.py
|
||||
dify_graph/nodes/agent/agent_node.py
|
||||
dify_graph/nodes/base/node.py
|
||||
dify_graph/nodes/code/code_node.py
|
||||
dify_graph/nodes/datasource/datasource_node.py
|
||||
dify_graph/nodes/document_extractor/node.py
|
||||
dify_graph/nodes/human_input/human_input_node.py
|
||||
dify_graph/nodes/if_else/if_else_node.py
|
||||
dify_graph/nodes/iteration/iteration_node.py
|
||||
dify_graph/nodes/knowledge_index/knowledge_index_node.py
|
||||
core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
|
||||
dify_graph/nodes/list_operator/node.py
|
||||
dify_graph/nodes/llm/node.py
|
||||
dify_graph/nodes/loop/loop_node.py
|
||||
dify_graph/nodes/parameter_extractor/parameter_extractor_node.py
|
||||
dify_graph/nodes/question_classifier/question_classifier_node.py
|
||||
dify_graph/nodes/start/start_node.py
|
||||
dify_graph/nodes/template_transform/template_transform_node.py
|
||||
dify_graph/nodes/tool/tool_node.py
|
||||
dify_graph/nodes/trigger_plugin/trigger_event_node.py
|
||||
dify_graph/nodes/trigger_schedule/trigger_schedule_node.py
|
||||
dify_graph/nodes/trigger_webhook/node.py
|
||||
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
|
||||
dify_graph/nodes/variable_assigner/v1/node.py
|
||||
dify_graph/nodes/variable_assigner/v2/node.py
|
||||
dify_graph/variables/types.py
|
||||
extensions/ext_fastopenapi.py
|
||||
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
|
||||
extensions/otel/instrumentation.py
|
||||
extensions/otel/runtime.py
|
||||
@@ -156,19 +109,7 @@ extensions/storage/oracle_oci_storage.py
|
||||
extensions/storage/supabase_storage.py
|
||||
extensions/storage/tencent_cos_storage.py
|
||||
extensions/storage/volcengine_tos_storage.py
|
||||
factories/variable_factory.py
|
||||
libs/external_api.py
|
||||
libs/gmpy2_pkcs10aep_cipher.py
|
||||
libs/helper.py
|
||||
libs/login.py
|
||||
libs/module_loading.py
|
||||
libs/oauth.py
|
||||
libs/oauth_data_source.py
|
||||
models/trigger.py
|
||||
models/workflow.py
|
||||
repositories/sqlalchemy_api_workflow_node_execution_repository.py
|
||||
repositories/sqlalchemy_api_workflow_run_repository.py
|
||||
repositories/sqlalchemy_execution_extra_content_repository.py
|
||||
schedule/queue_monitor_task.py
|
||||
services/account_service.py
|
||||
services/audio_service.py
|
||||
|
||||
@@ -8,7 +8,7 @@ using SQLAlchemy 2.0 style queries for WorkflowNodeExecutionModel operations.
|
||||
import json
|
||||
from collections.abc import Sequence
|
||||
from datetime import datetime
|
||||
from typing import cast
|
||||
from typing import Protocol, cast
|
||||
|
||||
from sqlalchemy import asc, delete, desc, func, select
|
||||
from sqlalchemy.engine import CursorResult
|
||||
@@ -22,6 +22,20 @@ from repositories.api_workflow_node_execution_repository import (
|
||||
)
|
||||
|
||||
|
||||
class _WorkflowNodeExecutionSnapshotRow(Protocol):
|
||||
id: str
|
||||
node_execution_id: str | None
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
index: int
|
||||
status: WorkflowNodeExecutionStatus
|
||||
elapsed_time: float | None
|
||||
created_at: datetime
|
||||
finished_at: datetime | None
|
||||
execution_metadata: str | None
|
||||
|
||||
|
||||
class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRepository):
|
||||
"""
|
||||
SQLAlchemy implementation of DifyAPIWorkflowNodeExecutionRepository.
|
||||
@@ -40,6 +54,8 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
- Thread-safe database operations using session-per-request pattern
|
||||
"""
|
||||
|
||||
_session_maker: sessionmaker[Session]
|
||||
|
||||
def __init__(self, session_maker: sessionmaker[Session]):
|
||||
"""
|
||||
Initialize the repository with a sessionmaker.
|
||||
@@ -156,12 +172,12 @@ class DifyAPISQLAlchemyWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecut
|
||||
)
|
||||
|
||||
with self._session_maker() as session:
|
||||
rows = session.execute(stmt).all()
|
||||
rows = cast(Sequence[_WorkflowNodeExecutionSnapshotRow], session.execute(stmt).all())
|
||||
|
||||
return [self._row_to_snapshot(row) for row in rows]
|
||||
|
||||
@staticmethod
|
||||
def _row_to_snapshot(row: object) -> WorkflowNodeExecutionSnapshot:
|
||||
def _row_to_snapshot(row: _WorkflowNodeExecutionSnapshotRow) -> WorkflowNodeExecutionSnapshot:
|
||||
metadata: dict[str, object] = {}
|
||||
execution_metadata = getattr(row, "execution_metadata", None)
|
||||
if execution_metadata:
|
||||
|
||||
@@ -30,7 +30,7 @@ from core.plugin.impl.debugging import PluginDebuggingClient
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.provider import Provider, ProviderCredential
|
||||
from models.provider import Provider, ProviderCredential, TenantPreferredModelProvider
|
||||
from models.provider_ids import GenericProviderID
|
||||
from services.enterprise.plugin_manager_service import (
|
||||
PluginManagerService,
|
||||
@@ -534,6 +534,13 @@ class PluginService:
|
||||
plugin_id = plugin.plugin_id
|
||||
logger.info("Deleting credentials for plugin: %s", plugin_id)
|
||||
|
||||
session.execute(
|
||||
delete(TenantPreferredModelProvider).where(
|
||||
TenantPreferredModelProvider.tenant_id == tenant_id,
|
||||
TenantPreferredModelProvider.provider_name.like(f"{plugin_id}/%"),
|
||||
)
|
||||
)
|
||||
|
||||
# Delete provider credentials that match this plugin
|
||||
credential_ids = session.scalars(
|
||||
select(ProviderCredential.id).where(
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
from copy import deepcopy
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -33,8 +33,8 @@ def _make_graph_state():
|
||||
],
|
||||
)
|
||||
def test_run_uses_single_node_execution_branch(
|
||||
single_iteration_run: Any,
|
||||
single_loop_run: Any,
|
||||
single_iteration_run: WorkflowAppGenerateEntity.SingleIterationRunEntity | None,
|
||||
single_loop_run: WorkflowAppGenerateEntity.SingleLoopRunEntity | None,
|
||||
) -> None:
|
||||
app_config = MagicMock()
|
||||
app_config.app_id = "app"
|
||||
@@ -130,10 +130,23 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
|
||||
"break_conditions": [],
|
||||
"logical_operator": "and",
|
||||
},
|
||||
},
|
||||
{
|
||||
"id": "other-node",
|
||||
"data": {
|
||||
"type": "answer",
|
||||
"title": "Answer",
|
||||
},
|
||||
},
|
||||
],
|
||||
"edges": [
|
||||
{
|
||||
"source": "other-node",
|
||||
"target": "loop-node",
|
||||
}
|
||||
],
|
||||
"edges": [],
|
||||
}
|
||||
original_graph_dict = deepcopy(workflow.graph_dict)
|
||||
|
||||
_, _, graph_runtime_state = _make_graph_state()
|
||||
seen_configs: list[object] = []
|
||||
@@ -143,13 +156,19 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
|
||||
seen_configs.append(value)
|
||||
return original_validate_python(value)
|
||||
|
||||
class FakeNodeClass:
|
||||
@staticmethod
|
||||
def extract_variable_selector_to_variable_mapping(**_kwargs):
|
||||
return {}
|
||||
|
||||
monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python)
|
||||
|
||||
with (
|
||||
patch("core.app.apps.workflow_app_runner.DifyNodeFactory"),
|
||||
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()),
|
||||
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()) as graph_init,
|
||||
patch("core.app.apps.workflow_app_runner.load_into_variable_pool"),
|
||||
patch("core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool"),
|
||||
patch("core.app.apps.workflow_app_runner.resolve_workflow_node_class", return_value=FakeNodeClass),
|
||||
):
|
||||
runner._get_graph_and_variable_pool_for_single_node_run(
|
||||
workflow=workflow,
|
||||
@@ -161,3 +180,8 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
|
||||
)
|
||||
|
||||
assert seen_configs == [workflow.graph_dict["nodes"][0]]
|
||||
assert workflow.graph_dict == original_graph_dict
|
||||
graph_config = graph_init.call_args.kwargs["graph_config"]
|
||||
assert graph_config is not workflow.graph_dict
|
||||
assert graph_config["nodes"] == [workflow.graph_dict["nodes"][0]]
|
||||
assert graph_config["edges"] == []
|
||||
|
||||
@@ -734,7 +734,7 @@ def test_create_provider_credential_creates_provider_record_when_missing() -> No
|
||||
def test_create_provider_credential_marks_existing_provider_as_valid() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
session = Mock()
|
||||
provider_record = SimpleNamespace(is_valid=False)
|
||||
provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id="existing-cred")
|
||||
|
||||
with _patched_session(session):
|
||||
with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False):
|
||||
@@ -743,6 +743,25 @@ def test_create_provider_credential_marks_existing_provider_as_valid() -> None:
|
||||
configuration.create_provider_credential({"api_key": "raw"}, "Main")
|
||||
|
||||
assert provider_record.is_valid is True
|
||||
assert provider_record.credential_id == "existing-cred"
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
def test_create_provider_credential_auto_activates_when_no_active_credential() -> None:
|
||||
configuration = _build_provider_configuration()
|
||||
session = Mock()
|
||||
provider_record = SimpleNamespace(id="provider-1", is_valid=False, credential_id=None, updated_at=None)
|
||||
|
||||
with _patched_session(session):
|
||||
with patch.object(ProviderConfiguration, "_check_provider_credential_name_exists", return_value=False):
|
||||
with patch.object(ProviderConfiguration, "validate_provider_credentials", return_value={"api_key": "enc"}):
|
||||
with patch.object(ProviderConfiguration, "_get_provider_record", return_value=provider_record):
|
||||
with patch("core.entities.provider_configuration.ProviderCredentialsCache"):
|
||||
with patch.object(ProviderConfiguration, "switch_preferred_provider_type"):
|
||||
configuration.create_provider_credential({"api_key": "raw"}, "Main")
|
||||
|
||||
assert provider_record.is_valid is True
|
||||
assert provider_record.credential_id is not None
|
||||
session.commit.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@@ -12,7 +12,7 @@ NEXT_PUBLIC_API_PREFIX=http://localhost:5001/console/api
|
||||
# console or api domain.
|
||||
# example: http://udify.app/api
|
||||
NEXT_PUBLIC_PUBLIC_API_PREFIX=http://localhost:5001/api
|
||||
# Dev-only Hono proxy targets. Set the api prefixes above to https://localhost:5001/... to start the proxy with HTTPS.
|
||||
# Dev-only Hono proxy targets. The frontend keeps requesting http://localhost:5001 directly.
|
||||
HONO_PROXY_HOST=127.0.0.1
|
||||
HONO_PROXY_PORT=5001
|
||||
HONO_CONSOLE_API_PROXY_TARGET=
|
||||
|
||||
@@ -328,7 +328,7 @@ describe('createWorkflowStreamHandlers', () => {
|
||||
vi.clearAllMocks()
|
||||
})
|
||||
|
||||
const setupHandlers = (overrides: { isTimedOut?: () => boolean } = {}) => {
|
||||
const setupHandlers = (overrides: { isPublicAPI?: boolean, isTimedOut?: () => boolean } = {}) => {
|
||||
let completionRes = ''
|
||||
let currentTaskId: string | null = null
|
||||
let isStopping = false
|
||||
@@ -359,6 +359,7 @@ describe('createWorkflowStreamHandlers', () => {
|
||||
const handlers = createWorkflowStreamHandlers({
|
||||
getCompletionRes: () => completionRes,
|
||||
getWorkflowProcessData: () => workflowProcessData,
|
||||
isPublicAPI: overrides.isPublicAPI ?? false,
|
||||
isTimedOut: overrides.isTimedOut ?? (() => false),
|
||||
markEnded,
|
||||
notify,
|
||||
@@ -391,7 +392,7 @@ describe('createWorkflowStreamHandlers', () => {
|
||||
}
|
||||
|
||||
it('should process workflow success and paused events', () => {
|
||||
const setup = setupHandlers()
|
||||
const setup = setupHandlers({ isPublicAPI: true })
|
||||
const handlers = setup.handlers as Required<Pick<IOtherOptions, 'onWorkflowStarted' | 'onTextChunk' | 'onHumanInputRequired' | 'onHumanInputFormFilled' | 'onHumanInputFormTimeout' | 'onWorkflowPaused' | 'onWorkflowFinished' | 'onNodeStarted' | 'onNodeFinished' | 'onIterationStart' | 'onIterationNext' | 'onIterationFinish' | 'onLoopStart' | 'onLoopNext' | 'onLoopFinish'>>
|
||||
|
||||
act(() => {
|
||||
@@ -546,7 +547,11 @@ describe('createWorkflowStreamHandlers', () => {
|
||||
resultText: 'Hello',
|
||||
status: WorkflowRunningStatus.Succeeded,
|
||||
}))
|
||||
expect(sseGetMock).toHaveBeenCalledWith('/workflow/run-1/events', {}, expect.any(Object))
|
||||
expect(sseGetMock).toHaveBeenCalledWith(
|
||||
'/workflow/run-1/events',
|
||||
{},
|
||||
expect.objectContaining({ isPublicAPI: true }),
|
||||
)
|
||||
expect(setup.messageId()).toBe('run-1')
|
||||
expect(setup.onCompleted).toHaveBeenCalledWith('{"answer":"Hello"}', 3, true)
|
||||
expect(setup.setRespondingFalse).toHaveBeenCalled()
|
||||
@@ -647,6 +652,7 @@ describe('createWorkflowStreamHandlers', () => {
|
||||
const handlers = createWorkflowStreamHandlers({
|
||||
getCompletionRes: () => '',
|
||||
getWorkflowProcessData: () => existingProcess,
|
||||
isPublicAPI: false,
|
||||
isTimedOut: () => false,
|
||||
markEnded: vi.fn(),
|
||||
notify: setup.notify,
|
||||
|
||||
@@ -351,6 +351,7 @@ describe('useResultSender', () => {
|
||||
await waitFor(() => {
|
||||
expect(createWorkflowStreamHandlersMock).toHaveBeenCalledWith(expect.objectContaining({
|
||||
getCompletionRes: harness.runState.getCompletionRes,
|
||||
isPublicAPI: true,
|
||||
resetRunState: harness.runState.resetRunState,
|
||||
setWorkflowProcessData: harness.runState.setWorkflowProcessData,
|
||||
}))
|
||||
@@ -373,6 +374,30 @@ describe('useResultSender', () => {
|
||||
expect(harness.runState.clearMoreLikeThis).not.toHaveBeenCalled()
|
||||
})
|
||||
|
||||
it('should configure workflow handlers for installed apps as non-public', async () => {
|
||||
const harness = createRunStateHarness()
|
||||
|
||||
const { result } = renderSender({
|
||||
appSourceType: AppSourceTypeEnum.installedApp,
|
||||
isWorkflow: true,
|
||||
runState: harness.runState,
|
||||
})
|
||||
|
||||
await act(async () => {
|
||||
expect(await result.current.handleSend()).toBe(true)
|
||||
})
|
||||
|
||||
expect(createWorkflowStreamHandlersMock).toHaveBeenCalledWith(expect.objectContaining({
|
||||
isPublicAPI: false,
|
||||
}))
|
||||
expect(sendWorkflowMessageMock).toHaveBeenCalledWith(
|
||||
{ inputs: { name: 'Alice' } },
|
||||
expect.any(Object),
|
||||
AppSourceTypeEnum.installedApp,
|
||||
'app-1',
|
||||
)
|
||||
})
|
||||
|
||||
it('should stringify non-Error workflow failures', async () => {
|
||||
const harness = createRunStateHarness()
|
||||
sendWorkflowMessageMock.mockRejectedValue('workflow failed')
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
import type { ResultInputValue } from '../result-request'
|
||||
import type { ResultRunStateController } from './use-result-run-state'
|
||||
import type { PromptConfig } from '@/models/debug'
|
||||
import type { AppSourceType } from '@/service/share'
|
||||
import type { VisionFile, VisionSettings } from '@/types/app'
|
||||
import { useCallback, useEffect, useRef } from 'react'
|
||||
import { TEXT_GENERATION_TIMEOUT_MS } from '@/config'
|
||||
import {
|
||||
AppSourceType,
|
||||
sendCompletionMessage,
|
||||
sendWorkflowMessage,
|
||||
} from '@/service/share'
|
||||
@@ -117,6 +117,7 @@ export const useResultSender = ({
|
||||
const otherOptions = createWorkflowStreamHandlers({
|
||||
getCompletionRes: runState.getCompletionRes,
|
||||
getWorkflowProcessData: runState.getWorkflowProcessData,
|
||||
isPublicAPI: appSourceType === AppSourceType.webApp,
|
||||
isTimedOut: () => isTimeout,
|
||||
markEnded: () => {
|
||||
isEnd = true
|
||||
|
||||
@@ -13,6 +13,7 @@ type Translate = (key: string, options?: Record<string, unknown>) => string
|
||||
type CreateWorkflowStreamHandlersParams = {
|
||||
getCompletionRes: () => string
|
||||
getWorkflowProcessData: () => WorkflowProcess | undefined
|
||||
isPublicAPI: boolean
|
||||
isTimedOut: () => boolean
|
||||
markEnded: () => void
|
||||
notify: Notify
|
||||
@@ -255,6 +256,7 @@ const serializeWorkflowOutputs = (outputs: WorkflowFinishedResponse['data']['out
|
||||
export const createWorkflowStreamHandlers = ({
|
||||
getCompletionRes,
|
||||
getWorkflowProcessData,
|
||||
isPublicAPI,
|
||||
isTimedOut,
|
||||
markEnded,
|
||||
notify,
|
||||
@@ -287,6 +289,7 @@ export const createWorkflowStreamHandlers = ({
|
||||
}
|
||||
|
||||
const otherOptions: IOtherOptions = {
|
||||
isPublicAPI,
|
||||
onWorkflowStarted: ({ workflow_run_id, task_id }) => {
|
||||
const workflowProcessData = getWorkflowProcessData()
|
||||
if (workflowProcessData?.tracing.length) {
|
||||
@@ -378,6 +381,7 @@ export const createWorkflowStreamHandlers = ({
|
||||
},
|
||||
onWorkflowPaused: ({ data }) => {
|
||||
tempMessageId = data.workflow_run_id
|
||||
// WebApp workflows must keep using the public API namespace after pause/resume.
|
||||
void sseGet(`/workflow/${data.workflow_run_id}/events`, {}, otherOptions)
|
||||
setWorkflowProcessData(applyWorkflowPaused(getWorkflowProcessData()))
|
||||
},
|
||||
|
||||
@@ -210,7 +210,6 @@
|
||||
"@types/sortablejs": "1.15.9",
|
||||
"@typescript-eslint/parser": "8.57.0",
|
||||
"@typescript/native-preview": "7.0.0-dev.20260312.1",
|
||||
"@vitejs/plugin-basic-ssl": "2.2.0",
|
||||
"@vitejs/plugin-react": "6.0.0",
|
||||
"@vitejs/plugin-rsc": "0.5.21",
|
||||
"@vitest/coverage-v8": "4.1.0",
|
||||
|
||||
@@ -34,16 +34,7 @@ const toUpstreamCookieName = (cookieName: string) => {
|
||||
return `__Host-${cookieName}`
|
||||
}
|
||||
|
||||
const toLocalCookieName = (cookieName: string, options: LocalCookieRewriteOptions) => {
|
||||
if (options.localSecure)
|
||||
return cookieName
|
||||
|
||||
return cookieName.replace(SECURE_COOKIE_PREFIX_PATTERN, '')
|
||||
}
|
||||
|
||||
type LocalCookieRewriteOptions = {
|
||||
localSecure: boolean
|
||||
}
|
||||
const toLocalCookieName = (cookieName: string) => cookieName.replace(SECURE_COOKIE_PREFIX_PATTERN, '')
|
||||
|
||||
export const rewriteCookieHeaderForUpstream = (cookieHeader?: string) => {
|
||||
if (!cookieHeader)
|
||||
@@ -64,10 +55,7 @@ export const rewriteCookieHeaderForUpstream = (cookieHeader?: string) => {
|
||||
.join('; ')
|
||||
}
|
||||
|
||||
const rewriteSetCookieValueForLocal = (
|
||||
setCookieValue: string,
|
||||
options: LocalCookieRewriteOptions,
|
||||
) => {
|
||||
const rewriteSetCookieValueForLocal = (setCookieValue: string) => {
|
||||
const [rawCookiePair, ...rawAttributes] = setCookieValue.split(';')
|
||||
const separatorIndex = rawCookiePair.indexOf('=')
|
||||
|
||||
@@ -80,11 +68,11 @@ const rewriteSetCookieValueForLocal = (
|
||||
.map(attribute => attribute.trim())
|
||||
.filter(attribute =>
|
||||
!COOKIE_DOMAIN_PATTERN.test(attribute)
|
||||
&& (options.localSecure || !COOKIE_SECURE_PATTERN.test(attribute))
|
||||
&& (options.localSecure || !COOKIE_PARTITIONED_PATTERN.test(attribute)),
|
||||
&& !COOKIE_SECURE_PATTERN.test(attribute)
|
||||
&& !COOKIE_PARTITIONED_PATTERN.test(attribute),
|
||||
)
|
||||
.map((attribute) => {
|
||||
if (!options.localSecure && SAME_SITE_NONE_PATTERN.test(attribute))
|
||||
if (SAME_SITE_NONE_PATTERN.test(attribute))
|
||||
return 'SameSite=Lax'
|
||||
|
||||
if (COOKIE_PATH_PATTERN.test(attribute))
|
||||
@@ -93,13 +81,10 @@ const rewriteSetCookieValueForLocal = (
|
||||
return attribute
|
||||
})
|
||||
|
||||
return [`${toLocalCookieName(cookieName, options)}=${cookieValue}`, ...rewrittenAttributes].join('; ')
|
||||
return [`${toLocalCookieName(cookieName)}=${cookieValue}`, ...rewrittenAttributes].join('; ')
|
||||
}
|
||||
|
||||
export const rewriteSetCookieHeadersForLocal = (
|
||||
setCookieHeaders: string | string[] | undefined,
|
||||
options: LocalCookieRewriteOptions,
|
||||
): string[] | undefined => {
|
||||
export const rewriteSetCookieHeadersForLocal = (setCookieHeaders?: string | string[]): string[] | undefined => {
|
||||
if (!setCookieHeaders)
|
||||
return undefined
|
||||
|
||||
@@ -107,7 +92,7 @@ export const rewriteSetCookieHeadersForLocal = (
|
||||
? setCookieHeaders
|
||||
: [setCookieHeaders]
|
||||
|
||||
return normalizedHeaders.map(setCookieValue => rewriteSetCookieValueForLocal(setCookieValue, options))
|
||||
return normalizedHeaders.map(rewriteSetCookieValueForLocal)
|
||||
}
|
||||
|
||||
export { DEFAULT_PROXY_TARGET }
|
||||
|
||||
@@ -1,21 +0,0 @@
|
||||
export type DevProxyProtocolEnv = Partial<Record<
|
||||
| 'NEXT_PUBLIC_API_PREFIX'
|
||||
| 'NEXT_PUBLIC_PUBLIC_API_PREFIX',
|
||||
string
|
||||
>>
|
||||
|
||||
const isHttpsUrl = (value?: string) => {
|
||||
if (!value)
|
||||
return false
|
||||
|
||||
try {
|
||||
return new URL(value).protocol === 'https:'
|
||||
}
|
||||
catch {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
export const shouldUseHttpsForDevProxy = (env: DevProxyProtocolEnv = {}) => {
|
||||
return isHttpsUrl(env.NEXT_PUBLIC_API_PREFIX) || isHttpsUrl(env.NEXT_PUBLIC_PUBLIC_API_PREFIX)
|
||||
}
|
||||
@@ -1,5 +1,5 @@
|
||||
import { beforeEach, describe, expect, it, vi } from 'vitest'
|
||||
import { buildUpstreamUrl, createDevProxyApp, isAllowedDevOrigin, resolveDevProxyTargets, shouldUseHttpsForDevProxy } from './server'
|
||||
import { buildUpstreamUrl, createDevProxyApp, isAllowedDevOrigin, resolveDevProxyTargets } from './server'
|
||||
|
||||
describe('dev proxy server', () => {
|
||||
beforeEach(() => {
|
||||
@@ -19,21 +19,6 @@ describe('dev proxy server', () => {
|
||||
expect(targets.publicApiTarget).toBe('https://public.example.com')
|
||||
})
|
||||
|
||||
// Scenario: the local dev proxy should switch to https when api prefixes are configured with https.
|
||||
it('should enable https for the local dev proxy when api prefixes use https', () => {
|
||||
// Assert
|
||||
expect(shouldUseHttpsForDevProxy({
|
||||
NEXT_PUBLIC_API_PREFIX: 'https://localhost:5001/console/api',
|
||||
})).toBe(true)
|
||||
expect(shouldUseHttpsForDevProxy({
|
||||
NEXT_PUBLIC_PUBLIC_API_PREFIX: 'https://localhost:5001/api',
|
||||
})).toBe(true)
|
||||
expect(shouldUseHttpsForDevProxy({
|
||||
NEXT_PUBLIC_API_PREFIX: 'http://localhost:5001/console/api',
|
||||
NEXT_PUBLIC_PUBLIC_API_PREFIX: 'http://localhost:5001/api',
|
||||
})).toBe(false)
|
||||
})
|
||||
|
||||
// Scenario: target paths should not be duplicated when the incoming route already includes them.
|
||||
it('should preserve prefixed targets when building upstream URLs', () => {
|
||||
// Act
|
||||
@@ -47,7 +32,6 @@ describe('dev proxy server', () => {
|
||||
it('should only allow local development origins', () => {
|
||||
// Assert
|
||||
expect(isAllowedDevOrigin('http://localhost:3000')).toBe(true)
|
||||
expect(isAllowedDevOrigin('https://localhost:3000')).toBe(true)
|
||||
expect(isAllowedDevOrigin('http://127.0.0.1:3000')).toBe(true)
|
||||
expect(isAllowedDevOrigin('https://example.com')).toBe(false)
|
||||
})
|
||||
@@ -102,39 +86,6 @@ describe('dev proxy server', () => {
|
||||
])
|
||||
})
|
||||
|
||||
// Scenario: secure local proxy responses should keep secure cross-site cookie attributes intact.
|
||||
it('should preserve secure cookie attributes when the local proxy is https', async () => {
|
||||
// Arrange
|
||||
const fetchImpl = vi.fn<typeof fetch>().mockResolvedValue(new Response('ok', {
|
||||
status: 200,
|
||||
headers: [
|
||||
['set-cookie', '__Host-access_token=abc; Path=/console/api; Domain=cloud.dify.ai; Secure; SameSite=None; Partitioned'],
|
||||
['set-cookie', '__Host-csrf_token=csrf; Path=/console/api; Domain=cloud.dify.ai; Secure; SameSite=None'],
|
||||
],
|
||||
}))
|
||||
const app = createDevProxyApp({
|
||||
consoleApiTarget: 'https://cloud.dify.ai',
|
||||
publicApiTarget: 'https://public.dify.ai',
|
||||
fetchImpl,
|
||||
})
|
||||
|
||||
// Act
|
||||
const response = await app.request('https://127.0.0.1:5001/console/api/apps?page=1', {
|
||||
headers: {
|
||||
Origin: 'https://localhost:3000',
|
||||
Cookie: 'access_token=abc',
|
||||
},
|
||||
})
|
||||
|
||||
// Assert
|
||||
expect(response.headers.getSetCookie()).toEqual([
|
||||
'__Host-access_token=abc; Path=/; Secure; SameSite=None; Partitioned',
|
||||
'__Host-csrf_token=csrf; Path=/; Secure; SameSite=None',
|
||||
])
|
||||
expect(response.headers.get('access-control-allow-origin')).toBe('https://localhost:3000')
|
||||
expect(response.headers.get('access-control-allow-credentials')).toBe('true')
|
||||
})
|
||||
|
||||
// Scenario: preflight requests should advertise allowed headers for credentialed cross-origin calls.
|
||||
it('should answer CORS preflight requests', async () => {
|
||||
// Arrange
|
||||
|
||||
@@ -2,16 +2,10 @@ import type { Context, Hono } from 'hono'
|
||||
import { Hono as HonoApp } from 'hono'
|
||||
import { DEFAULT_PROXY_TARGET, rewriteCookieHeaderForUpstream, rewriteSetCookieHeadersForLocal } from './cookies'
|
||||
|
||||
export { shouldUseHttpsForDevProxy } from './protocol'
|
||||
|
||||
type DevProxyEnv = Partial<Record<
|
||||
| 'HONO_CONSOLE_API_PROXY_TARGET'
|
||||
| 'HONO_PUBLIC_API_PROXY_TARGET',
|
||||
string
|
||||
> & Record<
|
||||
| 'NEXT_PUBLIC_API_PREFIX'
|
||||
| 'NEXT_PUBLIC_PUBLIC_API_PREFIX',
|
||||
string | undefined
|
||||
>>
|
||||
|
||||
export type DevProxyTargets = {
|
||||
@@ -99,15 +93,11 @@ const createProxyRequestHeaders = (request: Request, targetUrl: URL) => {
|
||||
return headers
|
||||
}
|
||||
|
||||
const createUpstreamResponseHeaders = (
|
||||
response: Response,
|
||||
requestOrigin: string | null | undefined,
|
||||
localSecure: boolean,
|
||||
) => {
|
||||
const createUpstreamResponseHeaders = (response: Response, requestOrigin?: string | null) => {
|
||||
const headers = new Headers(response.headers)
|
||||
RESPONSE_HEADERS_TO_DROP.forEach(header => headers.delete(header))
|
||||
|
||||
const rewrittenSetCookies = rewriteSetCookieHeadersForLocal(response.headers.getSetCookie(), { localSecure })
|
||||
const rewrittenSetCookies = rewriteSetCookieHeadersForLocal(response.headers.getSetCookie())
|
||||
rewrittenSetCookies?.forEach((cookie) => {
|
||||
headers.append('set-cookie', cookie)
|
||||
})
|
||||
@@ -136,11 +126,7 @@ const proxyRequest = async (
|
||||
}
|
||||
|
||||
const upstreamResponse = await fetchImpl(targetUrl, requestInit)
|
||||
const responseHeaders = createUpstreamResponseHeaders(
|
||||
upstreamResponse,
|
||||
context.req.header('origin'),
|
||||
requestUrl.protocol === 'https:',
|
||||
)
|
||||
const responseHeaders = createUpstreamResponseHeaders(upstreamResponse, context.req.header('origin'))
|
||||
|
||||
return new Response(upstreamResponse.body, {
|
||||
status: upstreamResponse.status,
|
||||
|
||||
13
web/pnpm-lock.yaml
generated
13
web/pnpm-lock.yaml
generated
@@ -512,9 +512,6 @@ importers:
|
||||
'@typescript/native-preview':
|
||||
specifier: 7.0.0-dev.20260312.1
|
||||
version: 7.0.0-dev.20260312.1
|
||||
'@vitejs/plugin-basic-ssl':
|
||||
specifier: 2.2.0
|
||||
version: 2.2.0(@voidzero-dev/vite-plus-core@0.1.11(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.0)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))
|
||||
'@vitejs/plugin-react':
|
||||
specifier: 6.0.0
|
||||
version: 6.0.0(@voidzero-dev/vite-plus-core@0.1.11(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.0)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))
|
||||
@@ -3606,12 +3603,6 @@ packages:
|
||||
resolution: {integrity: sha512-hBcWIOppZV14bi+eAmCZj8Elj8hVSUZJTpf1lgGBhVD85pervzQ1poM/qYfFUlPraYSZYP+ASg6To5BwYmUSGQ==}
|
||||
engines: {node: '>=16'}
|
||||
|
||||
'@vitejs/plugin-basic-ssl@2.2.0':
|
||||
resolution: {integrity: sha512-nmyQ1HGRkfUxjsv3jw0+hMhEdZdrtkvMTdkzRUaRWfiO6PCWw2V2Pz3gldCq96Tn9S8htcgdTxw/gmbLLEbfYw==}
|
||||
engines: {node: ^18.0.0 || ^20.0.0 || >=22.0.0}
|
||||
peerDependencies:
|
||||
vite: ^6.0.0 || ^7.0.0 || ^8.0.0
|
||||
|
||||
'@vitejs/plugin-react@5.2.0':
|
||||
resolution: {integrity: sha512-YmKkfhOAi3wsB1PhJq5Scj3GXMn3WvtQ/JC0xoopuHoXSdmtdStOpFrYaT1kie2YgFBcIe64ROzMYRjCrYOdYw==}
|
||||
engines: {node: ^20.19.0 || >=22.12.0}
|
||||
@@ -11039,10 +11030,6 @@ snapshots:
|
||||
'@resvg/resvg-wasm': 2.4.0
|
||||
satori: 0.16.0
|
||||
|
||||
'@vitejs/plugin-basic-ssl@2.2.0(@voidzero-dev/vite-plus-core@0.1.11(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.0)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))':
|
||||
dependencies:
|
||||
vite: '@voidzero-dev/vite-plus-core@0.1.11(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.0)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2)'
|
||||
|
||||
'@vitejs/plugin-react@5.2.0(@voidzero-dev/vite-plus-core@0.1.11(@types/node@25.5.0)(esbuild@0.27.2)(jiti@1.21.7)(sass@1.98.0)(terser@5.46.0)(tsx@4.21.0)(typescript@5.9.3)(yaml@2.8.2))':
|
||||
dependencies:
|
||||
'@babel/core': 7.29.0
|
||||
|
||||
@@ -1,10 +1,8 @@
|
||||
import { createSecureServer } from 'node:http2'
|
||||
import path from 'node:path'
|
||||
import { fileURLToPath } from 'node:url'
|
||||
import { serve } from '@hono/node-server'
|
||||
import { getCertificate } from '@vitejs/plugin-basic-ssl'
|
||||
import { loadEnv } from 'vite'
|
||||
import { createDevProxyApp, resolveDevProxyTargets, shouldUseHttpsForDevProxy } from '../plugins/dev-proxy/server'
|
||||
import { createDevProxyApp, resolveDevProxyTargets } from '../plugins/dev-proxy/server'
|
||||
|
||||
const projectRoot = path.resolve(path.dirname(fileURLToPath(import.meta.url)), '..')
|
||||
const mode = process.env.MODE || process.env.NODE_ENV || 'development'
|
||||
@@ -13,33 +11,11 @@ const env = loadEnv(mode, projectRoot, '')
|
||||
const host = env.HONO_PROXY_HOST || '127.0.0.1'
|
||||
const port = Number(env.HONO_PROXY_PORT || 5001)
|
||||
const app = createDevProxyApp(resolveDevProxyTargets(env))
|
||||
const useHttps = shouldUseHttpsForDevProxy(env)
|
||||
|
||||
if (useHttps) {
|
||||
const certificate = await getCertificate(
|
||||
path.join(projectRoot, 'node_modules/.vite/basic-ssl'),
|
||||
'localhost',
|
||||
Array.from(new Set(['localhost', '127.0.0.1', host])),
|
||||
)
|
||||
serve({
|
||||
fetch: app.fetch,
|
||||
hostname: host,
|
||||
port,
|
||||
})
|
||||
|
||||
serve({
|
||||
fetch: app.fetch,
|
||||
hostname: host,
|
||||
port,
|
||||
createServer: createSecureServer,
|
||||
serverOptions: {
|
||||
allowHTTP1: true,
|
||||
cert: certificate,
|
||||
key: certificate,
|
||||
},
|
||||
})
|
||||
}
|
||||
else {
|
||||
serve({
|
||||
fetch: app.fetch,
|
||||
hostname: host,
|
||||
port,
|
||||
})
|
||||
}
|
||||
|
||||
console.log(`[dev-hono-proxy] listening on ${useHttps ? 'https' : 'http'}://${host}:${port}`)
|
||||
console.log(`[dev-hono-proxy] listening on http://${host}:${port}`)
|
||||
|
||||
@@ -1,12 +1,9 @@
|
||||
import path from 'node:path'
|
||||
import { fileURLToPath } from 'node:url'
|
||||
import basicSsl from '@vitejs/plugin-basic-ssl'
|
||||
import react from '@vitejs/plugin-react'
|
||||
import vinext from 'vinext'
|
||||
import { loadEnv } from 'vite'
|
||||
import Inspect from 'vite-plugin-inspect'
|
||||
import { defineConfig } from 'vite-plus'
|
||||
import { shouldUseHttpsForDevProxy } from './plugins/dev-proxy/protocol'
|
||||
import { createCodeInspectorPlugin, createForceInspectorClientInjectionPlugin } from './plugins/vite/code-inspector'
|
||||
import { customI18nHmrPlugin } from './plugins/vite/custom-i18n-hmr'
|
||||
import { nextStaticImageTestPlugin } from './plugins/vite/next-static-image-test'
|
||||
@@ -24,8 +21,6 @@ export default defineConfig(({ mode }) => {
|
||||
const isTest = mode === 'test'
|
||||
const isStorybook = process.env.STORYBOOK === 'true'
|
||||
|| process.argv.some(arg => arg.toLowerCase().includes('storybook'))
|
||||
const env = loadEnv(mode, projectRoot, '')
|
||||
const useHttpsForDevServer = shouldUseHttpsForDevProxy(env)
|
||||
const isAppComponentsCoverage = coverageScope === 'app-components'
|
||||
const excludedComponentCoverageFiles = isAppComponentsCoverage
|
||||
? collectComponentCoverageExcludedFiles(path.join(projectRoot, 'app/components'), { pathPrefix: 'app/components' })
|
||||
@@ -62,7 +57,6 @@ export default defineConfig(({ mode }) => {
|
||||
react(),
|
||||
vinext({ react: false }),
|
||||
customI18nHmrPlugin({ injectTarget: browserInitializerInjectTarget }),
|
||||
...(useHttpsForDevServer ? [basicSsl()] : []),
|
||||
// reactGrabOpenFilePlugin({
|
||||
// injectTarget: browserInitializerInjectTarget,
|
||||
// projectRoot,
|
||||
|
||||
Reference in New Issue
Block a user