Compare commits

..

16 Commits

Author SHA1 Message Date
Yanli 盐粒
710ac3b90a fix(api): preserve typed loop array constants 2026-03-18 22:05:16 +08:00
Yanli 盐粒
8548498f25 fix(api): restore advanced chat refresh_model contract 2026-03-18 19:41:00 +08:00
Yanli 盐粒
d014f0b91a fix(api): address typing review feedback 2026-03-18 19:16:48 +08:00
Yanli 盐粒
cc5aac268a fix(api): support tool typed dicts on py311 2026-03-18 18:59:49 +08:00
Yanli 盐粒
4c1d27431b fix(api): restore workflow node compatibility 2026-03-18 18:43:35 +08:00
Yanli 盐粒
9a86f280eb fix(api): avoid recursive loop type adapters 2026-03-18 18:20:43 +08:00
Yanli 盐粒
c5920fb28a Merge remote-tracking branch 'origin/main' into yanli/phase3-code-scope 2026-03-18 17:52:03 +08:00
Yanli 盐粒
2f81d5dfdf fix(api): restore typedict py311 compatibility 2026-03-17 20:30:18 +08:00
Yanli 盐粒
7639d8e43f fix(api): reuse advanced chat refresh session 2026-03-17 20:18:21 +08:00
Yanli 盐粒
1dce81c604 refactor(api): type single node workflow helpers 2026-03-17 20:16:14 +08:00
Yanli 盐粒
f874ca183e chore(api): remove phase 3 pyrefly excludes 2026-03-17 20:04:55 +08:00
Yanli 盐粒
0d805e624e Type phase 3 loop values 2026-03-17 19:39:54 +08:00
Yanli 盐粒
61196180b8 Type phase 3 tool inputs 2026-03-17 19:31:00 +08:00
Yanli 盐粒
79433b0091 Refine phase 3 typing boundaries 2026-03-17 19:13:12 +08:00
Yanli 盐粒
c4aeaa35d4 Type phase 3 schema contracts 2026-03-17 18:56:22 +08:00
Yanli 盐粒
9f0d79b8b0 Tighten phase 3 runtime typing 2026-03-17 18:49:14 +08:00
65 changed files with 954 additions and 945 deletions

View File

@@ -5,7 +5,7 @@ import logging
import threading
import uuid
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
from typing import TYPE_CHECKING, Any, Literal, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@@ -47,7 +47,6 @@ from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.base import Base
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.workflow_draft_variable_service import (
@@ -522,8 +521,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# release database connection, because the following new thread operations may take a long time
with Session(bind=db.engine, expire_on_commit=False) as session:
workflow = _refresh_model(session, workflow)
message = _refresh_model(session, message)
workflow = _refresh_model(session=session, model=workflow)
message = _refresh_model(session=session, model=message)
assert message is not None
# workflow_ = session.get(Workflow, workflow.id)
# assert workflow_ is not None
# workflow = workflow_
@@ -690,11 +690,21 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
raise e
_T = TypeVar("_T", bound=Base)
@overload
def _refresh_model(*, session: Session | None = None, model: Workflow) -> Workflow: ...
def _refresh_model(session, model: _T) -> _T:
with Session(bind=db.engine, expire_on_commit=False) as session:
detach_model = session.get(type(model), model.id)
assert detach_model is not None
return detach_model
@overload
def _refresh_model(*, session: Session | None = None, model: Message) -> Message: ...
def _refresh_model(*, session: Session | None = None, model: Any) -> Any:
if session is not None:
detached_model = session.get(type(model), model.id)
assert detached_model is not None
return detached_model
with Session(bind=db.engine, expire_on_commit=False) as refresh_session:
detached_model = refresh_session.get(type(model), model.id)
assert detached_model is not None
return detached_model

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import Any, cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -56,8 +56,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, Any, None]:
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
:param stream_response: stream response
@@ -87,8 +87,8 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
) -> Generator[dict | str, Any, None]:
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.
:param stream_response: stream response

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -55,7 +55,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -86,7 +86,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -1,7 +1,7 @@
import logging
from abc import ABC, abstractmethod
from collections.abc import Generator, Mapping
from typing import Any, Union
from collections.abc import Generator, Iterator, Mapping
from typing import Any
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.task_entities import AppBlockingResponse, AppStreamResponse
@@ -16,24 +16,26 @@ class AppGenerateResponseConverter(ABC):
@classmethod
def convert(
cls, response: Union[AppBlockingResponse, Generator[AppStreamResponse, Any, None]], invoke_from: InvokeFrom
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
cls, response: AppBlockingResponse | Iterator[AppStreamResponse], invoke_from: InvokeFrom
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], None, None]:
if invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.SERVICE_API}:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_full_response(response)
else:
stream_response = response
def _generate_full_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_full_response(response)
def _generate_full_response() -> Generator[dict[str, Any] | str, None, None]:
yield from cls.convert_stream_full_response(stream_response)
return _generate_full_response()
else:
if isinstance(response, AppBlockingResponse):
return cls.convert_blocking_simple_response(response)
else:
stream_response = response
def _generate_simple_response() -> Generator[dict | str, Any, None]:
yield from cls.convert_stream_simple_response(response)
def _generate_simple_response() -> Generator[dict[str, Any] | str, None, None]:
yield from cls.convert_stream_simple_response(stream_response)
return _generate_simple_response()
@@ -50,14 +52,14 @@ class AppGenerateResponseConverter(ABC):
@classmethod
@abstractmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
raise NotImplementedError
@classmethod
@abstractmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
raise NotImplementedError

View File

@@ -224,6 +224,7 @@ class BaseAppGenerator:
def _get_draft_var_saver_factory(invoke_from: InvokeFrom, account: Account | EndUser) -> DraftVariableSaverFactory:
if invoke_from == InvokeFrom.DEBUGGER:
assert isinstance(account, Account)
debug_account = account
def draft_var_saver_factory(
session: Session,
@@ -240,7 +241,7 @@ class BaseAppGenerator:
node_type=node_type,
node_execution_id=node_execution_id,
enclosing_node_id=enclosing_node_id,
user=account,
user=debug_account,
)
else:

View File

@@ -166,15 +166,19 @@ class ChatAppGenerator(MessageBasedAppGenerator):
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
assert conversation is not None
assert message is not None
generated_conversation_id = str(conversation.id)
generated_message_id = str(message.id)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
conversation_id=generated_conversation_id,
app_mode=conversation.mode,
message_id=message.id,
message_id=generated_message_id,
)
# new thread with request context
@@ -184,8 +188,8 @@ class ChatAppGenerator(MessageBasedAppGenerator):
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
conversation_id=conversation.id,
message_id=message.id,
conversation_id=generated_conversation_id,
message_id=generated_message_id,
)
worker_thread = threading.Thread(target=worker_with_context)

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -55,7 +55,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -86,7 +86,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -149,6 +149,8 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
assert conversation is not None
assert message is not None
# init queue manager
queue_manager = MessageBasedAppQueueManager(
@@ -312,15 +314,19 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity)
assert conversation is not None
assert message is not None
conversation_id = str(conversation.id)
message_id = str(message.id)
# init queue manager
queue_manager = MessageBasedAppQueueManager(
task_id=application_generate_entity.task_id,
user_id=application_generate_entity.user_id,
invoke_from=application_generate_entity.invoke_from,
conversation_id=conversation.id,
conversation_id=conversation_id,
app_mode=conversation.mode,
message_id=message.id,
message_id=message_id,
)
# new thread with request context
@@ -330,7 +336,7 @@ class CompletionAppGenerator(MessageBasedAppGenerator):
flask_app=current_app._get_current_object(), # type: ignore
application_generate_entity=application_generate_entity,
queue_manager=queue_manager,
message_id=message.id,
message_id=message_id,
)
worker_thread = threading.Thread(target=worker_with_context)

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -54,7 +54,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -84,7 +84,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -1,4 +1,4 @@
from collections.abc import Generator
from collections.abc import Generator, Iterator
from typing import cast
from core.app.apps.base_app_generate_response_converter import AppGenerateResponseConverter
@@ -36,7 +36,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_full_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream full response.
@@ -65,7 +65,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
@classmethod
def convert_stream_simple_response(
cls, stream_response: Generator[AppStreamResponse, None, None]
cls, stream_response: Iterator[AppStreamResponse]
) -> Generator[dict | str, None, None]:
"""
Convert stream simple response.

View File

@@ -1,13 +1,17 @@
import logging
import time
from collections.abc import Mapping, Sequence
from typing import Any, cast
from typing import Protocol, TypeAlias
from pydantic import ValidationError
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
from core.app.entities.agent_strategy import AgentStrategyInfo
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom, build_dify_run_context
from core.app.entities.app_invoke_entities import (
InvokeFrom,
UserFrom,
build_dify_run_context,
)
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
@@ -36,7 +40,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.node_factory import DifyNodeFactory, get_default_root_node_id, resolve_workflow_node_class
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.entities import GraphInitParams
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.entities.graph_config import NodeConfigDict, NodeConfigDictAdapter
from dify_graph.entities.pause_reason import HumanInputRequired
from dify_graph.graph import Graph
from dify_graph.graph_engine.layers.base import GraphEngineLayer
@@ -75,6 +79,14 @@ from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
logger = logging.getLogger(__name__)
GraphConfigObject: TypeAlias = dict[str, object]
GraphConfigMapping: TypeAlias = Mapping[str, object]
class SingleNodeRunEntity(Protocol):
node_id: str
inputs: Mapping[str, object]
class WorkflowBasedAppRunner:
def __init__(
@@ -98,7 +110,7 @@ class WorkflowBasedAppRunner:
def _init_graph(
self,
graph_config: Mapping[str, Any],
graph_config: GraphConfigMapping,
graph_runtime_state: GraphRuntimeState,
user_from: UserFrom,
invoke_from: InvokeFrom,
@@ -154,8 +166,8 @@ class WorkflowBasedAppRunner:
def _prepare_single_node_execution(
self,
workflow: Workflow,
single_iteration_run: Any | None = None,
single_loop_run: Any | None = None,
single_iteration_run: SingleNodeRunEntity | None = None,
single_loop_run: SingleNodeRunEntity | None = None,
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
"""
Prepare graph, variable pool, and runtime state for single node execution
@@ -208,11 +220,88 @@ class WorkflowBasedAppRunner:
# This ensures all nodes in the graph reference the same GraphRuntimeState instance
return graph, variable_pool, graph_runtime_state
@staticmethod
def _get_graph_items(graph_config: GraphConfigMapping) -> tuple[list[GraphConfigMapping], list[GraphConfigMapping]]:
nodes = graph_config.get("nodes")
edges = graph_config.get("edges")
if not isinstance(nodes, list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(edges, list):
raise ValueError("edges in workflow graph must be a list")
validated_nodes: list[GraphConfigMapping] = []
for node in nodes:
if not isinstance(node, Mapping):
raise ValueError("nodes in workflow graph must be mappings")
validated_nodes.append(node)
validated_edges: list[GraphConfigMapping] = []
for edge in edges:
if not isinstance(edge, Mapping):
raise ValueError("edges in workflow graph must be mappings")
validated_edges.append(edge)
return validated_nodes, validated_edges
@staticmethod
def _extract_start_node_id(node_config: GraphConfigMapping | None) -> str | None:
if node_config is None:
return None
node_data = node_config.get("data")
if not isinstance(node_data, Mapping):
return None
start_node_id = node_data.get("start_node_id")
return start_node_id if isinstance(start_node_id, str) else None
@classmethod
def _build_single_node_graph_config(
cls,
*,
graph_config: GraphConfigMapping,
node_id: str,
node_type_filter_key: str,
) -> tuple[GraphConfigObject, NodeConfigDict]:
node_configs, edge_configs = cls._get_graph_items(graph_config)
main_node_config = next((node for node in node_configs if node.get("id") == node_id), None)
start_node_id = cls._extract_start_node_id(main_node_config)
filtered_node_configs = [
dict(node)
for node in node_configs
if node.get("id") == node_id
or (isinstance(node_data := node.get("data"), Mapping) and node_data.get(node_type_filter_key) == node_id)
or (start_node_id and node.get("id") == start_node_id)
]
if not filtered_node_configs:
raise ValueError(f"node id {node_id} not found in workflow graph")
filtered_node_ids = {
str(node_id_value) for node in filtered_node_configs if isinstance((node_id_value := node.get("id")), str)
}
filtered_edge_configs = [
dict(edge)
for edge in edge_configs
if (edge.get("source") is None or edge.get("source") in filtered_node_ids)
and (edge.get("target") is None or edge.get("target") in filtered_node_ids)
]
target_node_config = next((node for node in filtered_node_configs if node.get("id") == node_id), None)
if target_node_config is None:
raise ValueError(f"node id {node_id} not found in workflow graph")
return (
{
"nodes": filtered_node_configs,
"edges": filtered_edge_configs,
},
NodeConfigDictAdapter.validate_python(target_node_config),
)
def _get_graph_and_variable_pool_for_single_node_run(
self,
workflow: Workflow,
node_id: str,
user_inputs: dict[str, Any],
user_inputs: Mapping[str, object],
graph_runtime_state: GraphRuntimeState,
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
@@ -236,41 +325,14 @@ class WorkflowBasedAppRunner:
if not graph_config:
raise ValueError("workflow graph not found")
graph_config = cast(dict[str, Any], graph_config)
if "nodes" not in graph_config or "edges" not in graph_config:
raise ValueError("nodes or edges not found in workflow graph")
if not isinstance(graph_config.get("nodes"), list):
raise ValueError("nodes in workflow graph must be a list")
if not isinstance(graph_config.get("edges"), list):
raise ValueError("edges in workflow graph must be a list")
# filter nodes only in the specified node type (iteration or loop)
main_node_config = next((n for n in graph_config.get("nodes", []) if n.get("id") == node_id), None)
start_node_id = main_node_config.get("data", {}).get("start_node_id") if main_node_config else None
node_configs = [
node
for node in graph_config.get("nodes", [])
if node.get("id") == node_id
or node.get("data", {}).get(node_type_filter_key, "") == node_id
or (start_node_id and node.get("id") == start_node_id)
]
graph_config["nodes"] = node_configs
node_ids = [node.get("id") for node in node_configs]
# filter edges only in the specified node type
edge_configs = [
edge
for edge in graph_config.get("edges", [])
if (edge.get("source") is None or edge.get("source") in node_ids)
and (edge.get("target") is None or edge.get("target") in node_ids)
]
graph_config["edges"] = edge_configs
graph_config, target_node_config = self._build_single_node_graph_config(
graph_config=graph_config,
node_id=node_id,
node_type_filter_key=node_type_filter_key,
)
# Create required parameters for Graph.init
graph_init_params = GraphInitParams(
@@ -299,18 +361,6 @@ class WorkflowBasedAppRunner:
if not graph:
raise ValueError("graph not found in workflow")
# fetch node config from node id
target_node_config = None
for node in node_configs:
if node.get("id") == node_id:
target_node_config = node
break
if not target_node_config:
raise ValueError(f"{node_type_label} node id not found in workflow graph")
target_node_config = NodeConfigDictAdapter.validate_python(target_node_config)
# Get node class
node_type = target_node_config["data"].type
node_version = str(target_node_config["data"].version)

View File

@@ -213,7 +213,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""
node_id: str
inputs: Mapping
inputs: Mapping[str, object]
single_iteration_run: SingleIterationRunEntity | None = None
@@ -223,7 +223,7 @@ class AdvancedChatAppGenerateEntity(ConversationAppGenerateEntity):
"""
node_id: str
inputs: Mapping
inputs: Mapping[str, object]
single_loop_run: SingleLoopRunEntity | None = None
@@ -243,7 +243,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
node_id: str
inputs: dict
inputs: Mapping[str, object]
single_iteration_run: SingleIterationRunEntity | None = None
@@ -253,7 +253,7 @@ class WorkflowAppGenerateEntity(AppGenerateEntity):
"""
node_id: str
inputs: dict
inputs: Mapping[str, object]
single_loop_run: SingleLoopRunEntity | None = None

View File

@@ -88,7 +88,7 @@ class LindormVectorStore(BaseVector):
batch_size: int = 64,
timeout: int = 60,
**kwargs,
) -> list[str]:
):
logger.info("Total documents to add: %s", len(documents))
uuids = self._get_uuids(documents)
@@ -130,11 +130,8 @@ class LindormVectorStore(BaseVector):
Field.METADATA_KEY: documents[i].metadata,
}
if self._using_ugc:
routing = self._routing
if routing is None:
raise ValueError("UGC index should init vector with valid 'routing_value' parameter value")
action_header["index"]["routing"] = routing
action_values[ROUTING_FIELD] = routing
action_header["index"]["routing"] = self._routing
action_values[ROUTING_FIELD] = self._routing
actions.append(action_header)
actions.append(action_values)
@@ -150,8 +147,6 @@ class LindormVectorStore(BaseVector):
logger.exception("Failed to process batch %s", batch_num + 1)
raise
return uuids
def get_ids_by_metadata_field(self, key: str, value: str):
query: dict[str, Any] = {
"query": {"bool": {"must": [{"term": {f"{Field.METADATA_KEY}.{key}.keyword": value}}]}}
@@ -383,21 +378,18 @@ class LindormVectorStoreFactory(AbstractVectorFactory):
raise ValueError("LINDORM_USING_UGC is not set")
routing_value = None
if dataset.index_struct:
index_struct_dict = dataset.index_struct_dict
if index_struct_dict is None:
raise ValueError("dataset.index_struct_dict is missing")
# if an existed record's index_struct_dict doesn't contain using_ugc field,
# it actually stores in the normal index format
stored_in_ugc: bool = index_struct_dict.get("using_ugc", False)
stored_in_ugc: bool = dataset.index_struct_dict.get("using_ugc", False)
using_ugc = stored_in_ugc
if stored_in_ugc:
dimension = index_struct_dict["dimension"]
index_type = index_struct_dict["index_type"]
distance_type = index_struct_dict["distance_type"]
routing_value = index_struct_dict["vector_store"]["class_prefix"]
dimension = dataset.index_struct_dict["dimension"]
index_type = dataset.index_struct_dict["index_type"]
distance_type = dataset.index_struct_dict["distance_type"]
routing_value = dataset.index_struct_dict["vector_store"]["class_prefix"]
index_name = f"{UGC_INDEX_PREFIX}_{dimension}_{index_type}_{distance_type}".lower()
else:
index_name = index_struct_dict["vector_store"]["class_prefix"].lower()
index_name = dataset.index_struct_dict["vector_store"]["class_prefix"].lower()
else:
embedding_vector = embeddings.embed_query("hello word")
dimension = len(embedding_vector)

View File

@@ -7,9 +7,7 @@ from core.rag.models.document import Document
class BaseVector(ABC):
_collection_name: str
def __init__(self, collection_name: str) -> None:
def __init__(self, collection_name: str):
self._collection_name = collection_name
@abstractmethod
@@ -32,7 +30,7 @@ class BaseVector(ABC):
def delete_by_ids(self, ids: list[str]) -> None:
raise NotImplementedError
def get_ids_by_metadata_field(self, key: str, value: str) -> list[str] | None:
def get_ids_by_metadata_field(self, key: str, value: str):
raise NotImplementedError
@abstractmethod
@@ -65,5 +63,5 @@ class BaseVector(ABC):
return [text.metadata["doc_id"] for text in texts if text.metadata and "doc_id" in text.metadata]
@property
def collection_name(self) -> str:
def collection_name(self):
return self._collection_name

View File

@@ -2,8 +2,7 @@ import base64
import logging
import time
from abc import ABC, abstractmethod
from collections.abc import Sequence
from typing import Any, TypedDict
from typing import Any
from sqlalchemy import select
@@ -14,7 +13,7 @@ from core.rag.datasource.vdb.vector_type import VectorType
from core.rag.embedding.cached_embedding import CacheEmbedding
from core.rag.embedding.embedding_base import Embeddings
from core.rag.index_processor.constant.doc_type import DocType
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from core.rag.models.document import Document
from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from extensions.ext_redis import redis_client
@@ -25,34 +24,19 @@ from models.model import UploadFile
logger = logging.getLogger(__name__)
class VectorStoreIndexConfig(TypedDict):
class_prefix: str
class VectorIndexStructDict(TypedDict):
type: VectorType
vector_store: VectorStoreIndexConfig
VectorDocumentInput = Document | ChildDocument | AttachmentDocument
class AbstractVectorFactory(ABC):
@abstractmethod
def init_vector(self, dataset: Dataset, attributes: list[str], embeddings: Embeddings) -> BaseVector:
def init_vector(self, dataset: Dataset, attributes: list, embeddings: Embeddings) -> BaseVector:
raise NotImplementedError
@staticmethod
def gen_index_struct_dict(vector_type: VectorType, collection_name: str) -> VectorIndexStructDict:
index_struct_dict: VectorIndexStructDict = {
"type": vector_type,
"vector_store": {"class_prefix": collection_name},
}
def gen_index_struct_dict(vector_type: VectorType, collection_name: str):
index_struct_dict = {"type": vector_type, "vector_store": {"class_prefix": collection_name}}
return index_struct_dict
class Vector:
def __init__(self, dataset: Dataset, attributes: list[str] | None = None) -> None:
def __init__(self, dataset: Dataset, attributes: list | None = None):
if attributes is None:
attributes = ["doc_id", "dataset_id", "document_id", "doc_hash", "doc_type"]
self._dataset = dataset
@@ -214,12 +198,12 @@ class Vector:
case _:
raise ValueError(f"Vector store {vector_type} is not supported.")
def create(self, texts: Sequence[Document | ChildDocument] | None = None, **kwargs: Any) -> None:
def create(self, texts: list | None = None, **kwargs):
if texts:
start = time.time()
logger.info("start embedding %s texts %s", len(texts), start)
batch_size = 1000
total_batches = (len(texts) + batch_size - 1) // batch_size
total_batches = len(texts) + batch_size - 1
for i in range(0, len(texts), batch_size):
batch = texts[i : i + batch_size]
batch_start = time.time()
@@ -228,33 +212,29 @@ class Vector:
logger.info(
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
)
self._vector_processor.create(
texts=self._normalize_documents(batch), embeddings=batch_embeddings, **kwargs
)
self._vector_processor.create(texts=batch, embeddings=batch_embeddings, **kwargs)
logger.info("Embedding %s texts took %s s", len(texts), time.time() - start)
def create_multimodal(self, file_documents: list[AttachmentDocument] | None = None, **kwargs: Any) -> None:
def create_multimodal(self, file_documents: list | None = None, **kwargs):
if file_documents:
start = time.time()
logger.info("start embedding %s files %s", len(file_documents), start)
batch_size = 1000
total_batches = (len(file_documents) + batch_size - 1) // batch_size
total_batches = len(file_documents) + batch_size - 1
for i in range(0, len(file_documents), batch_size):
batch = file_documents[i : i + batch_size]
batch_start = time.time()
logger.info("Processing batch %s/%s (%s files)", i // batch_size + 1, total_batches, len(batch))
# Batch query all upload files to avoid N+1 queries
attachment_ids = [doc.metadata["doc_id"] for doc in batch if doc.metadata is not None]
attachment_ids = [doc.metadata["doc_id"] for doc in batch]
stmt = select(UploadFile).where(UploadFile.id.in_(attachment_ids))
upload_files = db.session.scalars(stmt).all()
upload_file_map = {str(f.id): f for f in upload_files}
file_base64_list: list[dict[str, str]] = []
real_batch: list[AttachmentDocument] = []
file_base64_list = []
real_batch = []
for document in batch:
if document.metadata is None:
continue
attachment_id = document.metadata["doc_id"]
doc_type = document.metadata["doc_type"]
upload_file = upload_file_map.get(attachment_id)
@@ -269,20 +249,14 @@ class Vector:
}
)
real_batch.append(document)
if not real_batch:
continue
batch_embeddings = self._embeddings.embed_multimodal_documents(file_base64_list)
logger.info(
"Embedding batch %s/%s took %s s", i // batch_size + 1, total_batches, time.time() - batch_start
)
self._vector_processor.create(
texts=self._normalize_documents(real_batch),
embeddings=batch_embeddings,
**kwargs,
)
self._vector_processor.create(texts=real_batch, embeddings=batch_embeddings, **kwargs)
logger.info("Embedding %s files took %s s", len(file_documents), time.time() - start)
def add_texts(self, documents: list[Document], **kwargs: Any) -> None:
def add_texts(self, documents: list[Document], **kwargs):
if kwargs.get("duplicate_check", False):
documents = self._filter_duplicate_texts(documents)
@@ -292,10 +266,10 @@ class Vector:
def text_exists(self, id: str) -> bool:
return self._vector_processor.text_exists(id)
def delete_by_ids(self, ids: list[str]) -> None:
def delete_by_ids(self, ids: list[str]):
self._vector_processor.delete_by_ids(ids)
def delete_by_metadata_field(self, key: str, value: str) -> None:
def delete_by_metadata_field(self, key: str, value: str):
self._vector_processor.delete_by_metadata_field(key, value)
def search_by_vector(self, query: str, **kwargs: Any) -> list[Document]:
@@ -321,7 +295,7 @@ class Vector:
def search_by_full_text(self, query: str, **kwargs: Any) -> list[Document]:
return self._vector_processor.search_by_full_text(query, **kwargs)
def delete(self) -> None:
def delete(self):
self._vector_processor.delete()
# delete collection redis cache
if self._vector_processor.collection_name:
@@ -351,26 +325,7 @@ class Vector:
return texts
@staticmethod
def _normalize_documents(documents: Sequence[VectorDocumentInput]) -> list[Document]:
normalized_documents: list[Document] = []
for document in documents:
if isinstance(document, Document):
normalized_documents.append(document)
continue
normalized_documents.append(
Document(
page_content=document.page_content,
vector=document.vector,
metadata=document.metadata,
provider=(document.provider or "dify") if isinstance(document, AttachmentDocument) else "dify",
)
)
return normalized_documents
def __getattr__(self, name: str) -> Any:
def __getattr__(self, name):
if self._vector_processor is not None:
method = getattr(self._vector_processor, name)
if callable(method):

View File

@@ -1,11 +1,7 @@
from typing import Literal
from pydantic import BaseModel, ConfigDict
from pydantic import BaseModel, ConfigDict, field_validator
from core.rag.extractor.entity.datasource_type import DatasourceType
from models.dataset import Document
from models.model import UploadFile
from services.auth.auth_type import AuthType
class NotionInfo(BaseModel):
@@ -16,7 +12,7 @@ class NotionInfo(BaseModel):
credential_id: str | None = None
notion_workspace_id: str | None = ""
notion_obj_id: str
notion_page_type: Literal["database", "page"]
notion_page_type: str
document: Document | None = None
tenant_id: str
model_config = ConfigDict(arbitrary_types_allowed=True)
@@ -29,27 +25,20 @@ class WebsiteInfo(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
provider: AuthType
provider: str
job_id: str
url: str
mode: Literal["crawl", "crawl_return_urls", "scrape"]
mode: str
tenant_id: str
only_main_content: bool = False
@field_validator("mode", mode="before")
@classmethod
def _normalize_legacy_mode(cls, value: str) -> str:
if value == "single":
return "crawl"
return value
class ExtractSetting(BaseModel):
"""
Model class for provider response.
"""
datasource_type: DatasourceType
datasource_type: str
upload_file: UploadFile | None = None
notion_info: NotionInfo | None = None
website_info: WebsiteInfo | None = None

View File

@@ -1,8 +1,7 @@
import os
import re
import tempfile
from pathlib import Path
from typing import TypeAlias
from typing import Union
from urllib.parse import unquote
from configs import dify_config
@@ -32,27 +31,19 @@ from core.rag.extractor.word_extractor import WordExtractor
from core.rag.models.document import Document
from extensions.ext_storage import storage
from models.model import UploadFile
from services.auth.auth_type import AuthType
SUPPORT_URL_CONTENT_TYPES = ["application/pdf", "text/plain", "application/json"]
USER_AGENT = (
"Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36 (KHTML, like Gecko) Chrome/91.0.4472.124"
" Safari/537.36"
)
ExtractProcessorOutput: TypeAlias = list[Document] | str
class ExtractProcessor:
@staticmethod
def _build_temp_file_path(temp_dir: str, suffix: str) -> str:
file_descriptor, file_path = tempfile.mkstemp(dir=temp_dir, suffix=suffix)
os.close(file_descriptor)
return file_path
@classmethod
def load_from_upload_file(
cls, upload_file: UploadFile, return_text: bool = False, is_automatic: bool = False
) -> ExtractProcessorOutput:
) -> Union[list[Document], str]:
extract_setting = ExtractSetting(
datasource_type=DatasourceType.FILE, upload_file=upload_file, document_model="text_model"
)
@@ -63,7 +54,7 @@ class ExtractProcessor:
return cls.extract(extract_setting, is_automatic)
@classmethod
def load_from_url(cls, url: str, return_text: bool = False) -> ExtractProcessorOutput:
def load_from_url(cls, url: str, return_text: bool = False) -> Union[list[Document], str]:
response = ssrf_proxy.get(url, headers={"User-Agent": USER_AGENT})
with tempfile.TemporaryDirectory() as temp_dir:
@@ -74,16 +65,17 @@ class ExtractProcessor:
suffix = "." + response.headers.get("Content-Type").split("/")[-1]
else:
content_disposition = response.headers.get("Content-Disposition")
if content_disposition:
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
if filename_match:
filename = unquote(filename_match.group(1))
match = re.search(r"\.(\w+)$", filename)
if match:
suffix = "." + match.group(1)
else:
suffix = ""
file_path = cls._build_temp_file_path(temp_dir, suffix)
filename_match = re.search(r'filename="([^"]+)"', content_disposition)
if filename_match:
filename = unquote(filename_match.group(1))
match = re.search(r"\.(\w+)$", filename)
if match:
suffix = "." + match.group(1)
else:
suffix = ""
# https://stackoverflow.com/questions/26541416/generate-temporary-file-names-without-creating-actual-file-in-python#comment90414256_26541521
# Generate a temporary filename under the created temp_dir and ensure the directory exists
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
Path(file_path).write_bytes(response.content)
extract_setting = ExtractSetting(datasource_type=DatasourceType.FILE, document_model="text_model")
if return_text:
@@ -102,13 +94,13 @@ class ExtractProcessor:
cls, extract_setting: ExtractSetting, is_automatic: bool = False, file_path: str | None = None
) -> list[Document]:
if extract_setting.datasource_type == DatasourceType.FILE:
upload_file = extract_setting.upload_file
with tempfile.TemporaryDirectory() as temp_dir:
if not file_path:
assert extract_setting.upload_file is not None, "upload_file is required"
upload_file = extract_setting.upload_file
upload_file: UploadFile = extract_setting.upload_file
suffix = Path(upload_file.key).suffix
file_path = cls._build_temp_file_path(temp_dir, suffix)
# FIXME mypy: Cannot determine type of 'tempfile._get_candidate_names' better not use it here
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file.key, file_path)
input_file = Path(file_path)
file_extension = input_file.suffix.lower()
@@ -121,11 +113,7 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
extractor = PdfExtractor(
file_path,
upload_file.tenant_id if upload_file else None,
upload_file.created_by if upload_file else None,
)
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = (
UnstructuredMarkdownExtractor(file_path, unstructured_api_url, unstructured_api_key)
@@ -135,11 +123,7 @@ class ExtractProcessor:
elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path)
elif file_extension == ".docx":
extractor = WordExtractor(
file_path,
upload_file.tenant_id if upload_file else None,
upload_file.created_by if upload_file else None,
)
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == ".doc":
extractor = UnstructuredWordExtractor(file_path, unstructured_api_url, unstructured_api_key)
elif file_extension == ".csv":
@@ -165,21 +149,13 @@ class ExtractProcessor:
if file_extension in {".xlsx", ".xls"}:
extractor = ExcelExtractor(file_path)
elif file_extension == ".pdf":
extractor = PdfExtractor(
file_path,
upload_file.tenant_id if upload_file else None,
upload_file.created_by if upload_file else None,
)
extractor = PdfExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension in {".md", ".markdown", ".mdx"}:
extractor = MarkdownExtractor(file_path, autodetect_encoding=True)
elif file_extension in {".htm", ".html"}:
extractor = HtmlExtractor(file_path)
elif file_extension == ".docx":
extractor = WordExtractor(
file_path,
upload_file.tenant_id if upload_file else None,
upload_file.created_by if upload_file else None,
)
extractor = WordExtractor(file_path, upload_file.tenant_id, upload_file.created_by)
elif file_extension == ".csv":
extractor = CSVExtractor(file_path, autodetect_encoding=True)
elif file_extension == ".epub":
@@ -201,7 +177,7 @@ class ExtractProcessor:
return extractor.extract()
elif extract_setting.datasource_type == DatasourceType.WEBSITE:
assert extract_setting.website_info is not None, "website_info is required"
if extract_setting.website_info.provider == AuthType.FIRECRAWL:
if extract_setting.website_info.provider == "firecrawl":
extractor = FirecrawlWebExtractor(
url=extract_setting.website_info.url,
job_id=extract_setting.website_info.job_id,
@@ -210,7 +186,7 @@ class ExtractProcessor:
only_main_content=extract_setting.website_info.only_main_content,
)
return extractor.extract()
elif extract_setting.website_info.provider == AuthType.WATERCRAWL:
elif extract_setting.website_info.provider == "watercrawl":
extractor = WaterCrawlWebExtractor(
url=extract_setting.website_info.url,
job_id=extract_setting.website_info.job_id,
@@ -219,7 +195,7 @@ class ExtractProcessor:
only_main_content=extract_setting.website_info.only_main_content,
)
return extractor.extract()
elif extract_setting.website_info.provider == AuthType.JINA:
elif extract_setting.website_info.provider == "jinareader":
extractor = JinaReaderWebExtractor(
url=extract_setting.website_info.url,
job_id=extract_setting.website_info.job_id,

View File

@@ -2,12 +2,10 @@
from abc import ABC, abstractmethod
from core.rag.models.document import Document
class BaseExtractor(ABC):
"""Interface for extract files."""
@abstractmethod
def extract(self) -> list[Document]:
def extract(self):
raise NotImplementedError

View File

@@ -30,7 +30,7 @@ def detect_file_encodings(file_path: str, timeout: int = 5, sample_size: int = 1
For large files, reading only a sample is sufficient and prevents timeout.
"""
def read_and_detect(filename: str) -> list[FileEncoding]:
def read_and_detect(filename: str):
rst = charset_normalizer.from_path(filename)
best = rst.best()
if best is None:

View File

@@ -28,8 +28,8 @@ class PdfExtractor(BaseExtractor):
Args:
file_path: Path to the PDF file.
tenant_id: Workspace ID used for extracted image persistence when available.
user_id: ID of the user performing the extraction when available.
tenant_id: Workspace ID.
user_id: ID of the user performing the extraction.
file_cache_key: Optional cache key for the extracted text.
"""
@@ -47,13 +47,7 @@ class PdfExtractor(BaseExtractor):
]
MAX_MAGIC_LEN = max(len(m) for m, _, _ in IMAGE_FORMATS)
def __init__(
self,
file_path: str,
tenant_id: str | None,
user_id: str | None,
file_cache_key: str | None = None,
):
def __init__(self, file_path: str, tenant_id: str, user_id: str, file_cache_key: str | None = None):
"""Initialize PdfExtractor."""
self._file_path = file_path
self._tenant_id = tenant_id
@@ -122,9 +116,6 @@ class PdfExtractor(BaseExtractor):
upload_files = []
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
if self._tenant_id is None or self._user_id is None:
return ""
try:
image_objects = page.get_objects(filter=(pdfium_c.FPDF_PAGEOBJ_IMAGE,))
for obj in image_objects:

View File

@@ -9,8 +9,6 @@ import os
import re
import tempfile
import uuid
from collections.abc import Iterable
from typing import cast
from urllib.parse import urlparse
from docx import Document as DocxDocument
@@ -37,7 +35,7 @@ class WordExtractor(BaseExtractor):
file_path: Path to the file to load.
"""
def __init__(self, file_path: str, tenant_id: str | None, user_id: str | None):
def __init__(self, file_path: str, tenant_id: str, user_id: str):
"""Initialize with file path."""
self.file_path = file_path
self.tenant_id = tenant_id
@@ -88,12 +86,9 @@ class WordExtractor(BaseExtractor):
def _extract_images_from_docx(self, doc):
image_count = 0
image_map: dict[object, str] = {}
image_map = {}
base_url = dify_config.INTERNAL_FILES_URL or dify_config.FILES_URL
if self.tenant_id is None or self.user_id is None:
return image_map
for r_id, rel in doc.part.rels.items():
if "image" in rel.target_ref:
image_count += 1
@@ -269,7 +264,7 @@ class WordExtractor(BaseExtractor):
def parse_docx(self, docx_path):
doc = DocxDocument(docx_path)
content: list[str] = []
content = []
image_map = self._extract_images_from_docx(doc)
@@ -367,10 +362,10 @@ class WordExtractor(BaseExtractor):
if link_text:
target_buffer.append(link_text)
paragraph_content: list[str] = []
paragraph_content = []
# State for legacy HYPERLINK fields
hyperlink_field_url: str | None = None
hyperlink_field_text_parts: list[str] = []
hyperlink_field_url = None
hyperlink_field_text_parts: list = []
is_collecting_field_text = False
# Iterate through paragraph elements in document order
for child in paragraph._element:
@@ -427,8 +422,7 @@ class WordExtractor(BaseExtractor):
paragraphs = doc.paragraphs.copy()
tables = doc.tables.copy()
body_elements = cast(Iterable[object], getattr(doc.element, "body", []))
for element in body_elements:
for element in doc.element.body:
if hasattr(element, "tag"):
if isinstance(element.tag, str) and element.tag.endswith("p"): # paragraph
para = paragraphs.pop(0)

View File

@@ -3,7 +3,7 @@ import datetime
import logging
import time
from collections.abc import Mapping
from typing import Any, Literal, TypedDict
from typing import Any
from flask import current_app
from sqlalchemy import delete, func, select
@@ -20,16 +20,6 @@ from .processor.paragraph_index_processor import ParagraphIndexProcessor
logger = logging.getLogger(__name__)
class IndexAndCleanResult(TypedDict):
dataset_id: str
dataset_name: str
batch: str
document_id: str
document_name: str
created_at: float
display_status: Literal["completed"]
class IndexProcessor:
def format_preview(self, chunk_structure: str, chunks: Any) -> Preview:
index_processor = IndexProcessorFactory(chunk_structure).init_index_processor()
@@ -61,9 +51,9 @@ class IndexProcessor:
document_id: str,
original_document_id: str,
chunks: Mapping[str, Any],
batch: str,
batch: Any,
summary_index_setting: SummaryIndexSettingDict | None = None,
) -> IndexAndCleanResult:
):
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
if not document:

View File

@@ -122,7 +122,6 @@ class BaseIndexProcessor(ABC):
"""
Get the NodeParser object according to the processing rule.
"""
character_splitter: TextSplitter
if processing_rule_mode in ["custom", "hierarchical"]:
# The user-defined segmentation rule
max_segmentation_tokens_length = dify_config.INDEXING_MAX_SEGMENTATION_TOKENS_LENGTH
@@ -148,7 +147,7 @@ class BaseIndexProcessor(ABC):
embedding_model_instance=embedding_model_instance,
)
return character_splitter
return character_splitter # type: ignore
def _get_content_files(self, document: Document, current_user: Account | None = None) -> list[AttachmentDocument]:
"""
@@ -159,7 +158,7 @@ class BaseIndexProcessor(ABC):
images = self._extract_markdown_images(text)
if not images:
return multi_model_documents
upload_file_id_list: list[str] = []
upload_file_id_list = []
for image in images:
# Collect all upload_file_ids including duplicates to preserve occurrence count

View File

@@ -10,7 +10,7 @@ from core.rag.index_processor.processor.qa_index_processor import QAIndexProcess
class IndexProcessorFactory:
"""IndexProcessorInit."""
def __init__(self, index_type: str | None) -> None:
def __init__(self, index_type: str | None):
self._index_type = index_type
def init_index_processor(self) -> BaseIndexProcessor:
@@ -19,12 +19,11 @@ class IndexProcessorFactory:
if not self._index_type:
raise ValueError("Index type must be specified.")
match self._index_type:
case IndexStructureType.PARAGRAPH_INDEX:
return ParagraphIndexProcessor()
case IndexStructureType.QA_INDEX:
return QAIndexProcessor()
case IndexStructureType.PARENT_CHILD_INDEX:
return ParentChildIndexProcessor()
case _:
raise ValueError(f"Index type {self._index_type} is not supported.")
if self._index_type == IndexStructureType.PARAGRAPH_INDEX:
return ParagraphIndexProcessor()
elif self._index_type == IndexStructureType.QA_INDEX:
return QAIndexProcessor()
elif self._index_type == IndexStructureType.PARENT_CHILD_INDEX:
return ParentChildIndexProcessor()
else:
raise ValueError(f"Index type {self._index_type} is not supported.")

View File

@@ -30,7 +30,7 @@ class EnhanceRecursiveCharacterTextSplitter(RecursiveCharacterTextSplitter):
allowed_special: Union[Literal["all"], Set[str]] = set(), # noqa: UP037
disallowed_special: Union[Literal["all"], Collection[str]] = "all", # noqa: UP037
**kwargs: Any,
) -> TS:
):
def _token_encoder(texts: list[str]) -> list[int]:
if not texts:
return []

View File

@@ -1045,9 +1045,10 @@ class ToolManager:
continue
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
variable_selector = tool_input.require_variable_selector()
variable = variable_pool.get(variable_selector)
if variable is None:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
raise ToolParameterError(f"Variable {variable_selector} does not exist")
parameter_value = variable.value
elif tool_input.type == "constant":
parameter_value = tool_input.value

View File

@@ -1,13 +1,24 @@
from enum import IntEnum, StrEnum, auto
from typing import Any, Literal, Union
from __future__ import annotations
from pydantic import BaseModel
from enum import IntEnum, StrEnum, auto
from typing import Literal, TypeAlias
from pydantic import BaseModel, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from core.tools.entities.tool_entities import ToolSelector
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
AgentInputConstantValue: TypeAlias = (
list[ToolSelector] | str | int | float | bool | dict[str, object] | list[object] | None
)
VariableSelector: TypeAlias = list[str]
_AGENT_INPUT_VALUE_ADAPTER: TypeAdapter[AgentInputConstantValue] = TypeAdapter(AgentInputConstantValue)
_AGENT_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
class AgentNodeData(BaseNodeData):
type: NodeType = BuiltinNodeTypes.AGENT
@@ -21,8 +32,20 @@ class AgentNodeData(BaseNodeData):
tool_node_version: str | None = None
class AgentInput(BaseModel):
value: Union[list[str], list[ToolSelector], Any]
type: Literal["mixed", "variable", "constant"]
value: AgentInputConstantValue | VariableSelector
@field_validator("value", mode="before")
@classmethod
def validate_value(
cls, value: object, validation_info: ValidationInfo
) -> AgentInputConstantValue | VariableSelector:
input_type = validation_info.data.get("type")
if input_type == "variable":
return _AGENT_VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if input_type in {"mixed", "constant"}:
return _AGENT_INPUT_VALUE_ADAPTER.validate_python(value)
raise ValueError(f"Unknown agent input type: {input_type}")
agent_parameters: dict[str, AgentInput]

View File

@@ -1,16 +1,17 @@
from __future__ import annotations
import json
from collections.abc import Sequence
from typing import Any, cast
from collections.abc import Mapping, Sequence
from typing import TypeAlias
from packaging.version import Version
from pydantic import ValidationError
from pydantic import TypeAdapter, ValidationError
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.app.entities.app_invoke_entities import InvokeFrom
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
from core.plugin.entities.request import InvokeCredentials
@@ -28,6 +29,14 @@ from .entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGen
from .exceptions import AgentInputTypeError, AgentVariableNotFoundError
from .strategy_protocols import ResolvedAgentStrategy
JsonObject: TypeAlias = dict[str, object]
JsonObjectList: TypeAlias = list[JsonObject]
VariableSelector: TypeAlias = list[str]
_JSON_OBJECT_ADAPTER = TypeAdapter(JsonObject)
_JSON_OBJECT_LIST_ADAPTER = TypeAdapter(JsonObjectList)
_VARIABLE_SELECTOR_ADAPTER = TypeAdapter(VariableSelector)
class AgentRuntimeSupport:
def build_parameters(
@@ -39,12 +48,12 @@ class AgentRuntimeSupport:
strategy: ResolvedAgentStrategy,
tenant_id: str,
app_id: str,
invoke_from: Any,
invoke_from: InvokeFrom,
for_log: bool = False,
) -> dict[str, Any]:
) -> dict[str, object]:
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
result: dict[str, Any] = {}
result: dict[str, object] = {}
for parameter_name in node_data.agent_parameters:
parameter = agent_parameters_dictionary.get(parameter_name)
if not parameter:
@@ -54,9 +63,10 @@ class AgentRuntimeSupport:
agent_input = node_data.agent_parameters[parameter_name]
match agent_input.type:
case "variable":
variable = variable_pool.get(agent_input.value) # type: ignore[arg-type]
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(agent_input.value)
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(agent_input.value))
raise AgentVariableNotFoundError(str(variable_selector))
parameter_value = variable.value
case "mixed" | "constant":
try:
@@ -79,60 +89,38 @@ class AgentRuntimeSupport:
value = parameter_value
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
value = [tool for tool in value if tool.get("enabled", False)]
value = self._filter_mcp_type_tool(strategy, value)
for tool in value:
if "schemas" in tool:
tool.pop("schemas")
parameters = tool.get("parameters", {})
if all(isinstance(v, dict) for _, v in parameters.items()):
params = {}
for key, param in parameters.items():
if param.get("auto", ParamsAutoGenerated.OPEN) in (
ParamsAutoGenerated.CLOSE,
0,
):
value_param = param.get("value", {})
if value_param and value_param.get("type", "") == "variable":
variable_selector = value_param.get("value")
if not variable_selector:
raise ValueError("Variable selector is missing for a variable-type parameter.")
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
params[key] = variable.value
else:
params[key] = value_param.get("value", "") if value_param is not None else None
else:
params[key] = None
parameters = params
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
tool["parameters"] = parameters
tool_payloads = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
value = self._normalize_tool_payloads(
strategy=strategy,
tools=tool_payloads,
variable_pool=variable_pool,
)
if not for_log:
if parameter.type == "array[tools]":
value = cast(list[dict[str, Any]], value)
value = _JSON_OBJECT_LIST_ADAPTER.validate_python(value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
provider_type = self._coerce_tool_provider_type(tool.get("type"))
setting_params = self._coerce_json_object(tool.get("settings")) or {}
parameters = self._coerce_json_object(tool.get("parameters")) or {}
manual_input_params = [key for key, value in parameters.items() if value is not None]
parameters = {**parameters, **setting_params}
provider_id = self._coerce_optional_string(tool.get("provider_name")) or ""
tool_name = self._coerce_optional_string(tool.get("tool_name")) or ""
plugin_unique_identifier = self._coerce_optional_string(tool.get("plugin_unique_identifier"))
credential_id = self._coerce_optional_string(tool.get("credential_id"))
entity = AgentToolEntity(
provider_id=tool.get("provider_name", ""),
provider_id=provider_id,
provider_type=provider_type,
tool_name=tool.get("tool_name", ""),
tool_name=tool_name,
tool_parameters=parameters,
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
credential_id=tool.get("credential_id", None),
plugin_unique_identifier=plugin_unique_identifier,
credential_id=credential_id,
)
extra = tool.get("extra", {})
extra = self._coerce_json_object(tool.get("extra")) or {}
runtime_variable_pool: VariablePool | None = None
if node_data.version != "1" or node_data.tool_node_version is not None:
@@ -145,8 +133,9 @@ class AgentRuntimeSupport:
runtime_variable_pool,
)
if tool_runtime.entity.description:
description_override = self._coerce_optional_string(extra.get("description"))
tool_runtime.entity.description.llm = (
extra.get("description", "") or tool_runtime.entity.description.llm
description_override or tool_runtime.entity.description.llm
)
for tool_runtime_params in tool_runtime.entity.parameters:
tool_runtime_params.form = (
@@ -167,13 +156,13 @@ class AgentRuntimeSupport:
{
**tool_runtime.entity.model_dump(mode="json"),
"runtime_parameters": runtime_parameters,
"credential_id": tool.get("credential_id", None),
"credential_id": credential_id,
"provider_type": provider_type.value,
}
)
value = tool_value
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
value = cast(dict[str, Any], value)
value = _JSON_OBJECT_ADAPTER.validate_python(value)
model_instance, model_schema = self.fetch_model(tenant_id=tenant_id, value=value)
history_prompt_messages = []
if node_data.memory:
@@ -199,17 +188,27 @@ class AgentRuntimeSupport:
return result
def build_credentials(self, *, parameters: dict[str, Any]) -> InvokeCredentials:
def build_credentials(self, *, parameters: Mapping[str, object]) -> InvokeCredentials:
credentials = InvokeCredentials()
credentials.tool_credentials = {}
for tool in parameters.get("tools", []):
tools = parameters.get("tools")
if not isinstance(tools, list):
return credentials
for raw_tool in tools:
tool = self._coerce_json_object(raw_tool)
if tool is None:
continue
if not tool.get("credential_id"):
continue
try:
identity = ToolIdentity.model_validate(tool.get("identity", {}))
except ValidationError:
continue
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
credential_id = self._coerce_optional_string(tool.get("credential_id"))
if credential_id is None:
continue
credentials.tool_credentials[identity.provider] = credential_id
return credentials
def fetch_memory(
@@ -232,14 +231,14 @@ class AgentRuntimeSupport:
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
def fetch_model(self, *, tenant_id: str, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
def fetch_model(self, *, tenant_id: str, value: Mapping[str, object]) -> tuple[ModelInstance, AIModelEntity | None]:
provider_manager = ProviderManager()
provider_model_bundle = provider_manager.get_provider_model_bundle(
tenant_id=tenant_id,
provider=value.get("provider", ""),
provider=str(value.get("provider", "")),
model_type=ModelType.LLM,
)
model_name = value.get("model", "")
model_name = str(value.get("model", ""))
model_credentials = provider_model_bundle.configuration.get_current_credentials(
model_type=ModelType.LLM,
model=model_name,
@@ -249,7 +248,7 @@ class AgentRuntimeSupport:
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
provider=provider_name,
model_type=ModelType(value.get("model_type", "")),
model_type=ModelType(str(value.get("model_type", ""))),
model=model_name,
)
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
@@ -268,9 +267,88 @@ class AgentRuntimeSupport:
@staticmethod
def _filter_mcp_type_tool(
strategy: ResolvedAgentStrategy,
tools: list[dict[str, Any]],
) -> list[dict[str, Any]]:
tools: JsonObjectList,
) -> JsonObjectList:
meta_version = strategy.meta_version
if meta_version and Version(meta_version) > Version("0.0.1"):
return tools
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
def _normalize_tool_payloads(
self,
*,
strategy: ResolvedAgentStrategy,
tools: JsonObjectList,
variable_pool: VariablePool,
) -> JsonObjectList:
enabled_tools = [dict(tool) for tool in tools if bool(tool.get("enabled", False))]
normalized_tools = self._filter_mcp_type_tool(strategy, enabled_tools)
for tool in normalized_tools:
tool.pop("schemas", None)
tool["parameters"] = self._resolve_tool_parameters(tool=tool, variable_pool=variable_pool)
tool["settings"] = self._resolve_tool_settings(tool)
return normalized_tools
def _resolve_tool_parameters(self, *, tool: Mapping[str, object], variable_pool: VariablePool) -> JsonObject:
parameter_configs = self._coerce_named_json_objects(tool.get("parameters"))
if parameter_configs is None:
raw_parameters = self._coerce_json_object(tool.get("parameters"))
return raw_parameters or {}
resolved_parameters: JsonObject = {}
for key, parameter_config in parameter_configs.items():
if parameter_config.get("auto", ParamsAutoGenerated.OPEN) in (ParamsAutoGenerated.CLOSE, 0):
value_param = self._coerce_json_object(parameter_config.get("value"))
if value_param and value_param.get("type") == "variable":
variable_selector = _VARIABLE_SELECTOR_ADAPTER.validate_python(value_param.get("value"))
variable = variable_pool.get(variable_selector)
if variable is None:
raise AgentVariableNotFoundError(str(variable_selector))
resolved_parameters[key] = variable.value
else:
resolved_parameters[key] = value_param.get("value", "") if value_param is not None else None
else:
resolved_parameters[key] = None
return resolved_parameters
@staticmethod
def _resolve_tool_settings(tool: Mapping[str, object]) -> JsonObject:
settings = AgentRuntimeSupport._coerce_named_json_objects(tool.get("settings"))
if settings is None:
return {}
return {key: setting.get("value") for key, setting in settings.items()}
@staticmethod
def _coerce_json_object(value: object) -> JsonObject | None:
try:
return _JSON_OBJECT_ADAPTER.validate_python(value)
except ValidationError:
return None
@staticmethod
def _coerce_optional_string(value: object) -> str | None:
return value if isinstance(value, str) else None
@staticmethod
def _coerce_tool_provider_type(value: object) -> ToolProviderType:
if isinstance(value, ToolProviderType):
return value
if isinstance(value, str):
return ToolProviderType(value)
return ToolProviderType.BUILT_IN
@classmethod
def _coerce_named_json_objects(cls, value: object) -> dict[str, JsonObject] | None:
if not isinstance(value, dict):
return None
coerced: dict[str, JsonObject] = {}
for key, item in value.items():
if not isinstance(key, str):
return None
json_object = cls._coerce_json_object(item)
if json_object is None:
return None
coerced[key] = json_object
return coerced

View File

@@ -1,7 +1,7 @@
import logging
import time
from collections.abc import Generator, Mapping, Sequence
from typing import Any, cast
from typing import Any, TypeAlias, cast
from configs import dify_config
from core.app.apps.exc import GenerateTaskStoppedError
@@ -32,6 +32,13 @@ from models.workflow import Workflow
logger = logging.getLogger(__name__)
SpecialValueScalar: TypeAlias = str | int | float | bool | None
SpecialValue: TypeAlias = SpecialValueScalar | File | Mapping[str, "SpecialValue"] | list["SpecialValue"]
SerializedSpecialValue: TypeAlias = (
SpecialValueScalar | dict[str, "SerializedSpecialValue"] | list["SerializedSpecialValue"]
)
SingleNodeGraphConfig: TypeAlias = dict[str, list[dict[str, object]]]
class _WorkflowChildEngineBuilder:
@staticmethod
@@ -276,10 +283,10 @@ class WorkflowEntry:
@staticmethod
def _create_single_node_graph(
node_id: str,
node_data: dict[str, Any],
node_data: Mapping[str, object],
node_width: int = 114,
node_height: int = 514,
) -> dict[str, Any]:
) -> SingleNodeGraphConfig:
"""
Create a minimal graph structure for testing a single node in isolation.
@@ -289,14 +296,14 @@ class WorkflowEntry:
:param node_height: height for UI layout (default: 100)
:return: graph dictionary with start node and target node
"""
node_config = {
node_config: dict[str, object] = {
"id": node_id,
"width": node_width,
"height": node_height,
"type": "custom",
"data": node_data,
"data": dict(node_data),
}
start_node_config = {
start_node_config: dict[str, object] = {
"id": "start",
"width": node_width,
"height": node_height,
@@ -321,7 +328,12 @@ class WorkflowEntry:
@classmethod
def run_free_node(
cls, node_data: dict[str, Any], node_id: str, tenant_id: str, user_id: str, user_inputs: dict[str, Any]
cls,
node_data: Mapping[str, object],
node_id: str,
tenant_id: str,
user_id: str,
user_inputs: Mapping[str, object],
) -> tuple[Node, Generator[GraphNodeEventBase, None, None]]:
"""
Run free node
@@ -339,6 +351,8 @@ class WorkflowEntry:
graph_dict = cls._create_single_node_graph(node_id, node_data)
node_type = node_data.get("type", "")
if not isinstance(node_type, str):
raise ValueError("Node type must be a string")
if node_type not in {BuiltinNodeTypes.PARAMETER_EXTRACTOR, BuiltinNodeTypes.QUESTION_CLASSIFIER}:
raise ValueError(f"Node type {node_type} not supported")
@@ -369,7 +383,7 @@ class WorkflowEntry:
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init workflow run state
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": node_data})
node_config = NodeConfigDictAdapter.validate_python({"id": node_id, "data": dict(node_data)})
node_factory = DifyNodeFactory(
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
@@ -405,30 +419,34 @@ class WorkflowEntry:
raise WorkflowNodeRunFailedError(node=node, err_msg=str(e))
@staticmethod
def handle_special_values(value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
def handle_special_values(value: Mapping[str, SpecialValue] | None) -> dict[str, SerializedSpecialValue] | None:
# NOTE(QuantumGhost): Avoid using this function in new code.
# Keep values structured as long as possible and only convert to dict
# immediately before serialization (e.g., JSON serialization) to maintain
# data integrity and type information.
result = WorkflowEntry._handle_special_values(value)
return result if isinstance(result, Mapping) or result is None else dict(result)
if result is None:
return None
if isinstance(result, dict):
return result
raise TypeError("handle_special_values expects a mapping input")
@staticmethod
def _handle_special_values(value: Any):
def _handle_special_values(value: SpecialValue) -> SerializedSpecialValue:
if value is None:
return value
if isinstance(value, dict):
res = {}
if isinstance(value, Mapping):
res: dict[str, SerializedSpecialValue] = {}
for k, v in value.items():
res[k] = WorkflowEntry._handle_special_values(v)
return res
if isinstance(value, list):
res_list = []
res_list: list[SerializedSpecialValue] = []
for item in value:
res_list.append(WorkflowEntry._handle_special_values(item))
return res_list
if isinstance(value, File):
return value.to_dict()
return dict(value.to_dict())
return value
@classmethod

View File

@@ -112,6 +112,8 @@ def _get_encoded_string(f: File, /) -> str:
data = _download_file_content(f.storage_key)
case FileTransferMethod.DATASOURCE_FILE:
data = _download_file_content(f.storage_key)
case _:
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
return base64.b64encode(data).decode("utf-8")

View File

@@ -133,6 +133,8 @@ class ExecutionLimitsLayer(GraphEngineLayer):
elif limit_type == LimitType.TIME_LIMIT:
elapsed_time = time.time() - self.start_time if self.start_time else 0
reason = f"Maximum execution time exceeded: {elapsed_time:.2f}s > {self.max_time}s"
else:
return
self.logger.warning("Execution limit exceeded: %s", reason)

View File

@@ -336,12 +336,7 @@ class Node(Generic[NodeDataT]):
def _restore_execution_id_from_runtime_state(self) -> str | None:
graph_execution = self.graph_runtime_state.graph_execution
try:
node_executions = graph_execution.node_executions
except AttributeError:
return None
if not isinstance(node_executions, dict):
return None
node_executions = graph_execution.node_executions
node_execution = node_executions.get(self._node_id)
if node_execution is None:
return None
@@ -395,8 +390,7 @@ class Node(Generic[NodeDataT]):
if isinstance(event, NodeEventBase): # pyright: ignore[reportUnnecessaryIsInstance]
yield self._dispatch(event)
elif isinstance(event, GraphNodeEventBase) and not event.in_iteration_id and not event.in_loop_id: # pyright: ignore[reportUnnecessaryIsInstance]
event.id = self.execution_id
yield event
yield event.model_copy(update={"id": self.execution_id})
else:
yield event
except Exception as e:

View File

@@ -443,7 +443,10 @@ def _extract_text_from_docx(file_content: bytes) -> str:
# Keep track of paragraph and table positions
content_items: list[tuple[int, str, Table | Paragraph]] = []
it = iter(doc.element.body)
doc_body = getattr(doc.element, "body", None)
if doc_body is None:
raise TextExtractionError("DOCX body not found")
it = iter(doc_body)
part = next(it, None)
i = 0
while part is not None:

View File

@@ -1,7 +1,8 @@
from collections.abc import Mapping, Sequence
from typing import Any, Literal
from typing import Literal, NotRequired
from pydantic import BaseModel, Field, field_validator
from typing_extensions import TypedDict
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
from dify_graph.entities.base_node_data import BaseNodeData
@@ -10,11 +11,17 @@ from dify_graph.model_runtime.entities import ImagePromptMessageContent, LLMMode
from dify_graph.nodes.base.entities import VariableSelector
class StructuredOutputConfig(TypedDict):
schema: Mapping[str, object]
name: NotRequired[str]
description: NotRequired[str]
class ModelConfig(BaseModel):
provider: str
name: str
mode: LLMMode
completion_params: dict[str, Any] = Field(default_factory=dict)
completion_params: dict[str, object] = Field(default_factory=dict)
class ContextConfig(BaseModel):
@@ -33,7 +40,7 @@ class VisionConfig(BaseModel):
@field_validator("configs", mode="before")
@classmethod
def convert_none_configs(cls, v: Any):
def convert_none_configs(cls, v: object):
if v is None:
return VisionConfigOptions()
return v
@@ -44,7 +51,7 @@ class PromptConfig(BaseModel):
@field_validator("jinja2_variables", mode="before")
@classmethod
def convert_none_jinja2_variables(cls, v: Any):
def convert_none_jinja2_variables(cls, v: object):
if v is None:
return []
return v
@@ -67,7 +74,7 @@ class LLMNodeData(BaseNodeData):
memory: MemoryConfig | None = None
context: ContextConfig
vision: VisionConfig = Field(default_factory=VisionConfig)
structured_output: Mapping[str, Any] | None = None
structured_output: StructuredOutputConfig | None = None
# We used 'structured_output_enabled' in the past, but it's not a good name.
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
reasoning_format: Literal["separated", "tagged"] = Field(
@@ -90,11 +97,30 @@ class LLMNodeData(BaseNodeData):
@field_validator("prompt_config", mode="before")
@classmethod
def convert_none_prompt_config(cls, v: Any):
def convert_none_prompt_config(cls, v: object):
if v is None:
return PromptConfig()
return v
@field_validator("structured_output", mode="before")
@classmethod
def convert_legacy_structured_output(cls, v: object) -> StructuredOutputConfig | None | object:
if not isinstance(v, Mapping):
return v
schema = v.get("schema")
if schema is None:
return None
normalized: StructuredOutputConfig = {"schema": schema}
name = v.get("name")
description = v.get("description")
if isinstance(name, str):
normalized["name"] = name
if isinstance(description, str):
normalized["description"] = description
return normalized
@property
def structured_output_enabled(self) -> bool:
return self.structured_output_switch_on and self.structured_output is not None

View File

@@ -9,6 +9,7 @@ import time
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal
from pydantic import TypeAdapter
from sqlalchemy import select
from core.llm_generator.output_parser.errors import OutputParserError
@@ -74,6 +75,7 @@ from .entities import (
LLMNodeChatModelMessage,
LLMNodeCompletionModelPromptTemplate,
LLMNodeData,
StructuredOutputConfig,
)
from .exc import (
InvalidContextStructureError,
@@ -88,6 +90,7 @@ if TYPE_CHECKING:
from dify_graph.runtime import GraphRuntimeState
logger = logging.getLogger(__name__)
_JSON_OBJECT_ADAPTER = TypeAdapter(dict[str, object])
class LLMNode(Node[LLMNodeData]):
@@ -354,7 +357,7 @@ class LLMNode(Node[LLMNodeData]):
stop: Sequence[str] | None = None,
user_id: str,
structured_output_enabled: bool,
structured_output: Mapping[str, Any] | None = None,
structured_output: StructuredOutputConfig | None = None,
file_saver: LLMFileSaver,
file_outputs: list[File],
node_id: str,
@@ -367,8 +370,10 @@ class LLMNode(Node[LLMNodeData]):
model_schema = llm_utils.fetch_model_schema(model_instance=model_instance)
if structured_output_enabled:
if structured_output is None:
raise LLMNodeError("Please provide a valid structured output schema")
output_schema = LLMNode.fetch_structured_output_schema(
structured_output=structured_output or {},
structured_output=structured_output,
)
request_start_time = time.perf_counter()
@@ -920,6 +925,12 @@ class LLMNode(Node[LLMNodeData]):
# Extract clean text and reasoning from <think> tags
clean_text, reasoning_content = LLMNode._split_reasoning(full_text, reasoning_format)
structured_output = (
dict(invoke_result.structured_output)
if isinstance(invoke_result, LLMResultWithStructuredOutput) and invoke_result.structured_output is not None
else None
)
event = ModelInvokeCompletedEvent(
# Use clean_text for separated mode, full_text for tagged mode
text=clean_text if reasoning_format == "separated" else full_text,
@@ -928,7 +939,7 @@ class LLMNode(Node[LLMNodeData]):
# Reasoning content for workflow variables and downstream nodes
reasoning_content=reasoning_content,
# Pass structured output if enabled
structured_output=getattr(invoke_result, "structured_output", None),
structured_output=structured_output,
)
if request_latency is not None:
event.usage.latency = round(request_latency, 3)
@@ -962,27 +973,18 @@ class LLMNode(Node[LLMNodeData]):
@staticmethod
def fetch_structured_output_schema(
*,
structured_output: Mapping[str, Any],
) -> dict[str, Any]:
structured_output: StructuredOutputConfig,
) -> dict[str, object]:
"""
Fetch the structured output schema from the node data.
Returns:
dict[str, Any]: The structured output schema
dict[str, object]: The structured output schema
"""
if not structured_output:
schema = structured_output.get("schema")
if not schema:
raise LLMNodeError("Please provide a valid structured output schema")
structured_output_schema = json.dumps(structured_output.get("schema", {}), ensure_ascii=False)
if not structured_output_schema:
raise LLMNodeError("Please provide a valid structured output schema")
try:
schema = json.loads(structured_output_schema)
if not isinstance(schema, dict):
raise LLMNodeError("structured_output_schema must be a JSON object")
return schema
except json.JSONDecodeError:
raise LLMNodeError("structured_output_schema is not valid JSON format")
return _JSON_OBJECT_ADAPTER.validate_python(schema)
@staticmethod
def _save_multimodal_output_and_convert_result_to_markdown(

View File

@@ -1,7 +1,10 @@
from enum import StrEnum
from typing import Annotated, Any, Literal
from __future__ import annotations
from pydantic import AfterValidator, BaseModel, Field, field_validator
from enum import StrEnum
from typing import Annotated, Any, Literal, TypeAlias, cast
from pydantic import AfterValidator, BaseModel, Field, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
@@ -9,6 +12,12 @@ from dify_graph.nodes.base import BaseLoopNodeData, BaseLoopState
from dify_graph.utils.condition.entities import Condition
from dify_graph.variables.types import SegmentType
LoopValue: TypeAlias = str | int | float | bool | None | dict[str, Any] | list[Any]
LoopValueMapping: TypeAlias = dict[str, LoopValue]
VariableSelector: TypeAlias = list[str]
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
_VALID_VAR_TYPE = frozenset(
[
SegmentType.STRING,
@@ -29,6 +38,36 @@ def _is_valid_var_type(seg_type: SegmentType) -> SegmentType:
return seg_type
def _validate_loop_value(value: object) -> LoopValue:
if value is None or isinstance(value, (str, int, float, bool)):
return cast(LoopValue, value)
if isinstance(value, list):
return [_validate_loop_value(item) for item in value]
if isinstance(value, dict):
normalized: dict[str, LoopValue] = {}
for key, item in value.items():
if not isinstance(key, str):
raise TypeError("Loop values only support string object keys")
normalized[key] = _validate_loop_value(item)
return normalized
raise TypeError("Loop values must be JSON-like primitives, arrays, or objects")
def _validate_loop_value_mapping(value: object) -> LoopValueMapping:
if not isinstance(value, dict):
raise TypeError("Loop outputs must be an object")
normalized: LoopValueMapping = {}
for key, item in value.items():
if not isinstance(key, str):
raise TypeError("Loop output keys must be strings")
normalized[key] = _validate_loop_value(item)
return normalized
class LoopVariableData(BaseModel):
"""
Loop Variable Data.
@@ -37,7 +76,29 @@ class LoopVariableData(BaseModel):
label: str
var_type: Annotated[SegmentType, AfterValidator(_is_valid_var_type)]
value_type: Literal["variable", "constant"]
value: Any | list[str] | None = None
value: LoopValue | VariableSelector | None = None
@field_validator("value", mode="before")
@classmethod
def validate_value(cls, value: object, validation_info: ValidationInfo) -> LoopValue | VariableSelector | None:
value_type = validation_info.data.get("value_type")
if value_type == "variable":
if value is None:
raise ValueError("Variable loop inputs require a selector")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if value_type == "constant":
return _validate_loop_value(value)
raise ValueError(f"Unknown loop variable value type: {value_type}")
def require_variable_selector(self) -> VariableSelector:
if self.value_type != "variable":
raise ValueError(f"Expected variable loop input, got {self.value_type}")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
def require_constant_value(self) -> LoopValue:
if self.value_type != "constant":
raise ValueError(f"Expected constant loop input, got {self.value_type}")
return _validate_loop_value(self.value)
class LoopNodeData(BaseLoopNodeData):
@@ -46,14 +107,14 @@ class LoopNodeData(BaseLoopNodeData):
break_conditions: list[Condition] # Conditions to break the loop
logical_operator: Literal["and", "or"]
loop_variables: list[LoopVariableData] | None = Field(default_factory=list[LoopVariableData])
outputs: dict[str, Any] = Field(default_factory=dict)
outputs: LoopValueMapping = Field(default_factory=dict)
@field_validator("outputs", mode="before")
@classmethod
def validate_outputs(cls, v):
if v is None:
def validate_outputs(cls, value: object) -> LoopValueMapping:
if value is None:
return {}
return v
return _validate_loop_value_mapping(value)
class LoopStartNodeData(BaseNodeData):
@@ -77,8 +138,8 @@ class LoopState(BaseLoopState):
Loop State.
"""
outputs: list[Any] = Field(default_factory=list)
current_output: Any = None
outputs: list[LoopValue] = Field(default_factory=list)
current_output: LoopValue | None = None
class MetaData(BaseLoopState.MetaData):
"""
@@ -87,7 +148,7 @@ class LoopState(BaseLoopState):
loop_length: int
def get_last_output(self) -> Any:
def get_last_output(self) -> LoopValue | None:
"""
Get last output.
"""
@@ -95,7 +156,7 @@ class LoopState(BaseLoopState):
return self.outputs[-1]
return None
def get_current_output(self) -> Any:
def get_current_output(self) -> LoopValue | None:
"""
Get current output.
"""

View File

@@ -3,7 +3,7 @@ import json
import logging
from collections.abc import Callable, Generator, Mapping, Sequence
from datetime import datetime
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import TYPE_CHECKING, Literal, cast
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.enums import (
@@ -29,7 +29,7 @@ from dify_graph.node_events import (
)
from dify_graph.nodes.base import LLMUsageTrackingMixin
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopVariableData
from dify_graph.nodes.loop.entities import LoopCompletedReason, LoopNodeData, LoopValue, LoopVariableData
from dify_graph.utils.condition.processor import ConditionProcessor
from dify_graph.variables import Segment, SegmentType
from factories.variable_factory import TypeMismatchError, build_segment_with_type, segment_to_variable
@@ -60,7 +60,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
break_conditions = self.node_data.break_conditions
logical_operator = self.node_data.logical_operator
inputs = {"loop_count": loop_count}
inputs: dict[str, object] = {"loop_count": loop_count}
if not self.node_data.start_node_id:
raise ValueError(f"field start_node_id in loop {self._node_id} not found")
@@ -68,12 +68,14 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
root_node_id = self.node_data.start_node_id
# Initialize loop variables in the original variable pool
loop_variable_selectors = {}
loop_variable_selectors: dict[str, list[str]] = {}
if self.node_data.loop_variables:
value_processor: dict[Literal["constant", "variable"], Callable[[LoopVariableData], Segment | None]] = {
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.value),
"constant": lambda var: self._get_segment_for_constant(var.var_type, var.require_constant_value()),
"variable": lambda var: (
self.graph_runtime_state.variable_pool.get(var.value) if isinstance(var.value, list) else None
self.graph_runtime_state.variable_pool.get(var.require_variable_selector())
if var.value is not None
else None
),
}
for loop_variable in self.node_data.loop_variables:
@@ -95,7 +97,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
condition_processor = ConditionProcessor()
loop_duration_map: dict[str, float] = {}
single_loop_variable_map: dict[str, dict[str, Any]] = {} # single loop variable output
single_loop_variable_map: dict[str, dict[str, LoopValue]] = {} # single loop variable output
loop_usage = LLMUsage.empty_usage()
loop_node_ids = self._extract_loop_node_ids_from_config(self.graph_config, self._node_id)
@@ -146,7 +148,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
loop_usage = self._merge_usage(loop_usage, graph_engine.graph_runtime_state.llm_usage)
# Collect loop variable values after iteration
single_loop_variable = {}
single_loop_variable: dict[str, LoopValue] = {}
for key, selector in loop_variable_selectors.items():
segment = self.graph_runtime_state.variable_pool.get(selector)
single_loop_variable[key] = segment.value if segment else None
@@ -297,20 +299,29 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
graph_config: Mapping[str, object],
node_id: str,
node_data: LoopNodeData,
) -> Mapping[str, Sequence[str]]:
variable_mapping = {}
variable_mapping: dict[str, Sequence[str]] = {}
# Extract loop node IDs statically from graph_config
loop_node_ids = cls._extract_loop_node_ids_from_config(graph_config, node_id)
# Get node configs from graph_config
node_configs = {node["id"]: node for node in graph_config.get("nodes", []) if "id" in node}
raw_nodes = graph_config.get("nodes")
node_configs: dict[str, Mapping[str, object]] = {}
if isinstance(raw_nodes, list):
for raw_node in raw_nodes:
if not isinstance(raw_node, dict):
continue
raw_node_id = raw_node.get("id")
if isinstance(raw_node_id, str):
node_configs[raw_node_id] = raw_node
for sub_node_id, sub_node_config in node_configs.items():
if sub_node_config.get("data", {}).get("loop_id") != node_id:
sub_node_data = sub_node_config.get("data")
if not isinstance(sub_node_data, dict) or sub_node_data.get("loop_id") != node_id:
continue
# variable selector to variable mapping
@@ -341,9 +352,8 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
for loop_variable in node_data.loop_variables or []:
if loop_variable.value_type == "variable":
assert loop_variable.value is not None, "Loop variable value must be provided for variable type"
# add loop variable to variable mapping
selector = loop_variable.value
selector = loop_variable.require_variable_selector()
variable_mapping[f"{node_id}.{loop_variable.label}"] = selector
# remove variable out from loop
@@ -352,7 +362,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
return variable_mapping
@classmethod
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, Any], loop_node_id: str) -> set[str]:
def _extract_loop_node_ids_from_config(cls, graph_config: Mapping[str, object], loop_node_id: str) -> set[str]:
"""
Extract node IDs that belong to a specific loop from graph configuration.
@@ -363,12 +373,19 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
:param loop_node_id: the ID of the loop node
:return: set of node IDs that belong to the loop
"""
loop_node_ids = set()
loop_node_ids: set[str] = set()
# Find all nodes that belong to this loop
nodes = graph_config.get("nodes", [])
for node in nodes:
node_data = node.get("data", {})
raw_nodes = graph_config.get("nodes")
if not isinstance(raw_nodes, list):
return loop_node_ids
for node in raw_nodes:
if not isinstance(node, dict):
continue
node_data = node.get("data")
if not isinstance(node_data, dict):
continue
if node_data.get("loop_id") == loop_node_id:
node_id = node.get("id")
if node_id:
@@ -377,7 +394,7 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
return loop_node_ids
@staticmethod
def _get_segment_for_constant(var_type: SegmentType, original_value: Any) -> Segment:
def _get_segment_for_constant(var_type: SegmentType, original_value: LoopValue | None) -> Segment:
"""Get the appropriate segment type for a constant value."""
# TODO: Refactor for maintainability:
# 1. Ensure type handling logic stays synchronized with _VALID_VAR_TYPE (entities.py)
@@ -389,11 +406,12 @@ class LoopNode(LLMUsageTrackingMixin, Node[LoopNodeData]):
SegmentType.ARRAY_OBJECT,
SegmentType.ARRAY_STRING,
]:
if original_value and isinstance(original_value, str):
value = json.loads(original_value)
# New typed payloads may already provide native lists, while legacy
# configs still serialize array constants as JSON strings.
if isinstance(original_value, str):
value = json.loads(original_value) if original_value else []
else:
logger.warning("unexpected value for LoopNode, value_type=%s, value=%s", original_value, var_type)
value = []
value = original_value
else:
raise AssertionError("this statement should be unreachable.")
try:

View File

@@ -1,4 +1,4 @@
from typing import Annotated, Any, Literal
from typing import Annotated, Literal
from pydantic import (
BaseModel,
@@ -6,6 +6,7 @@ from pydantic import (
Field,
field_validator,
)
from typing_extensions import TypedDict
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
from dify_graph.entities.base_node_data import BaseNodeData
@@ -55,7 +56,7 @@ class ParameterConfig(BaseModel):
@field_validator("name", mode="before")
@classmethod
def validate_name(cls, value) -> str:
def validate_name(cls, value: object) -> str:
if not value:
raise ValueError("Parameter name is required")
if value in {"__reason", "__is_success"}:
@@ -79,6 +80,23 @@ class ParameterConfig(BaseModel):
return element_type
class JsonSchemaArrayItems(TypedDict):
type: str
class ParameterJsonSchemaProperty(TypedDict, total=False):
description: str
type: str
items: JsonSchemaArrayItems
enum: list[str]
class ParameterJsonSchema(TypedDict):
type: Literal["object"]
properties: dict[str, ParameterJsonSchemaProperty]
required: list[str]
class ParameterExtractorNodeData(BaseNodeData):
"""
Parameter Extractor Node Data.
@@ -95,19 +113,19 @@ class ParameterExtractorNodeData(BaseNodeData):
@field_validator("reasoning_mode", mode="before")
@classmethod
def set_reasoning_mode(cls, v) -> str:
return v or "function_call"
def set_reasoning_mode(cls, v: object) -> str:
return str(v) if v else "function_call"
def get_parameter_json_schema(self):
def get_parameter_json_schema(self) -> ParameterJsonSchema:
"""
Get parameter json schema.
:return: parameter json schema
"""
parameters: dict[str, Any] = {"type": "object", "properties": {}, "required": []}
parameters: ParameterJsonSchema = {"type": "object", "properties": {}, "required": []}
for parameter in self.parameters:
parameter_schema: dict[str, Any] = {"description": parameter.description}
parameter_schema: ParameterJsonSchemaProperty = {"description": parameter.description}
if parameter.type == SegmentType.STRING:
parameter_schema["type"] = "string"
@@ -118,7 +136,7 @@ class ParameterExtractorNodeData(BaseNodeData):
raise AssertionError("element type should not be None.")
parameter_schema["items"] = {"type": element_type.value}
else:
parameter_schema["type"] = parameter.type
parameter_schema["type"] = parameter.type.value
if parameter.options:
parameter_schema["enum"] = parameter.options

View File

@@ -5,6 +5,8 @@ import uuid
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from pydantic import TypeAdapter
from core.model_manager import ModelInstance
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
@@ -63,6 +65,7 @@ from .prompts import (
)
logger = logging.getLogger(__name__)
_JSON_OBJECT_ADAPTER = TypeAdapter(dict[str, object])
if TYPE_CHECKING:
from dify_graph.entities import GraphInitParams
@@ -70,7 +73,7 @@ if TYPE_CHECKING:
from dify_graph.runtime import GraphRuntimeState
def extract_json(text):
def extract_json(text: str) -> str | None:
"""
From a given JSON started from '{' or '[' extract the complete JSON object.
"""
@@ -392,10 +395,15 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
)
# generate tool
parameter_schema = node_data.get_parameter_json_schema()
tool = PromptMessageTool(
name=FUNCTION_CALLING_EXTRACTOR_NAME,
description="Extract parameters from the natural language text",
parameters=node_data.get_parameter_json_schema(),
parameters={
"type": parameter_schema["type"],
"properties": dict(parameter_schema["properties"]),
"required": list(parameter_schema["required"]),
},
)
return prompt_messages, [tool]
@@ -602,19 +610,21 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
else:
return None
def _transform_result(self, data: ParameterExtractorNodeData, result: dict):
def _transform_result(self, data: ParameterExtractorNodeData, result: Mapping[str, object]) -> dict[str, object]:
"""
Transform result into standard format.
"""
transformed_result: dict[str, Any] = {}
transformed_result: dict[str, object] = {}
for parameter in data.parameters:
if parameter.name in result:
param_value = result[parameter.name]
# transform value
if parameter.type == SegmentType.NUMBER:
transformed = self._transform_number(param_value)
if transformed is not None:
transformed_result[parameter.name] = transformed
if isinstance(param_value, (bool, int, float, str)):
numeric_value: bool | int | float | str = param_value
transformed = self._transform_number(numeric_value)
if transformed is not None:
transformed_result[parameter.name] = transformed
elif parameter.type == SegmentType.BOOLEAN:
if isinstance(result[parameter.name], (bool, int)):
transformed_result[parameter.name] = bool(result[parameter.name])
@@ -661,7 +671,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
return transformed_result
def _extract_complete_json_response(self, result: str) -> dict | None:
def _extract_complete_json_response(self, result: str) -> dict[str, object] | None:
"""
Extract complete json response.
"""
@@ -672,11 +682,11 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
json_str = extract_json(result[idx:])
if json_str:
with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str))
return _JSON_OBJECT_ADAPTER.validate_python(json.loads(json_str))
logger.info("extra error: %s", result)
return None
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict | None:
def _extract_json_from_tool_call(self, tool_call: AssistantPromptMessage.ToolCall) -> dict[str, object] | None:
"""
Extract json from tool call.
"""
@@ -690,16 +700,16 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
json_str = extract_json(result[idx:])
if json_str:
with contextlib.suppress(Exception):
return cast(dict, json.loads(json_str))
return _JSON_OBJECT_ADAPTER.validate_python(json.loads(json_str))
logger.info("extra error: %s", result)
return None
def _generate_default_result(self, data: ParameterExtractorNodeData):
def _generate_default_result(self, data: ParameterExtractorNodeData) -> dict[str, object]:
"""
Generate default result.
"""
result: dict[str, Any] = {}
result: dict[str, object] = {}
for parameter in data.parameters:
if parameter.type == "number":
result[parameter.name] = 0

View File

@@ -1,12 +1,66 @@
from typing import Any, Literal, Union
from __future__ import annotations
from pydantic import BaseModel, field_validator
from typing import Literal, TypeAlias, cast
from pydantic import BaseModel, TypeAdapter, field_validator
from pydantic_core.core_schema import ValidationInfo
from typing_extensions import TypedDict
from core.tools.entities.tool_entities import ToolProviderType
from dify_graph.entities.base_node_data import BaseNodeData
from dify_graph.enums import BuiltinNodeTypes, NodeType
ToolConfigurationValue: TypeAlias = str | int | float | bool
ToolInputConstantValue: TypeAlias = str | int | float | bool | dict[str, object] | list[object] | None
VariableSelector: TypeAlias = list[str]
_TOOL_INPUT_MIXED_ADAPTER: TypeAdapter[str] = TypeAdapter(str)
_TOOL_INPUT_CONSTANT_ADAPTER: TypeAdapter[ToolInputConstantValue] = TypeAdapter(ToolInputConstantValue)
_VARIABLE_SELECTOR_ADAPTER: TypeAdapter[VariableSelector] = TypeAdapter(VariableSelector)
class WorkflowToolInputValue(TypedDict):
type: Literal["mixed", "variable", "constant"]
value: ToolInputConstantValue | VariableSelector
ToolConfigurationEntry: TypeAlias = ToolConfigurationValue | WorkflowToolInputValue
ToolConfigurations: TypeAlias = dict[str, ToolConfigurationEntry]
class ToolInputPayload(BaseModel):
type: Literal["mixed", "variable", "constant"]
value: ToolInputConstantValue | VariableSelector
@field_validator("value", mode="before")
@classmethod
def validate_value(
cls, value: object, validation_info: ValidationInfo
) -> ToolInputConstantValue | VariableSelector:
input_type = validation_info.data.get("type")
if input_type == "mixed":
return _TOOL_INPUT_MIXED_ADAPTER.validate_python(value)
if input_type == "variable":
return _VARIABLE_SELECTOR_ADAPTER.validate_python(value)
if input_type == "constant":
return _TOOL_INPUT_CONSTANT_ADAPTER.validate_python(value)
raise ValueError(f"Unknown tool input type: {input_type}")
def require_variable_selector(self) -> VariableSelector:
if self.type != "variable":
raise ValueError(f"Expected variable tool input, got {self.type}")
return _VARIABLE_SELECTOR_ADAPTER.validate_python(self.value)
def _validate_tool_configuration_entry(value: object) -> ToolConfigurationEntry:
if isinstance(value, (str, int, float, bool)):
return cast(ToolConfigurationEntry, value)
if isinstance(value, dict):
return cast(ToolConfigurationEntry, ToolInputPayload.model_validate(value).model_dump())
raise TypeError("Tool configuration values must be primitives or workflow tool input objects")
class ToolEntity(BaseModel):
provider_id: str
@@ -14,52 +68,29 @@ class ToolEntity(BaseModel):
provider_name: str # redundancy
tool_name: str
tool_label: str # redundancy
tool_configurations: dict[str, Any]
tool_configurations: ToolConfigurations
credential_id: str | None = None
plugin_unique_identifier: str | None = None # redundancy
@field_validator("tool_configurations", mode="before")
@classmethod
def validate_tool_configurations(cls, value, values: ValidationInfo):
def validate_tool_configurations(cls, value: object, _validation_info: ValidationInfo) -> ToolConfigurations:
if not isinstance(value, dict):
raise ValueError("tool_configurations must be a dictionary")
raise TypeError("tool_configurations must be a dictionary")
for key in values.data.get("tool_configurations", {}):
value = values.data.get("tool_configurations", {}).get(key)
if not isinstance(value, str | int | float | bool):
raise ValueError(f"{key} must be a string")
return value
normalized: ToolConfigurations = {}
for key, item in value.items():
if not isinstance(key, str):
raise TypeError("tool_configurations keys must be strings")
normalized[key] = _validate_tool_configuration_entry(item)
return normalized
class ToolNodeData(BaseNodeData, ToolEntity):
type: NodeType = BuiltinNodeTypes.TOOL
class ToolInput(BaseModel):
# TODO: check this type
value: Union[Any, list[str]]
type: Literal["mixed", "variable", "constant"]
@field_validator("type", mode="before")
@classmethod
def check_type(cls, value, validation_info: ValidationInfo):
typ = value
value = validation_info.data.get("value")
if value is None:
return typ
if typ == "mixed" and not isinstance(value, str):
raise ValueError("value must be a string")
elif typ == "variable":
if not isinstance(value, list):
raise ValueError("value must be a list")
for val in value:
if not isinstance(val, str):
raise ValueError("value must be a list of strings")
elif typ == "constant" and not isinstance(value, (allowed_types := (str, int, float, bool, dict, list))):
raise ValueError(f"value must be one of: {', '.join(t.__name__ for t in allowed_types)}")
return typ
class ToolInput(ToolInputPayload):
pass
tool_parameters: dict[str, ToolInput]
# The version of the tool parameter.
@@ -69,7 +100,7 @@ class ToolNodeData(BaseNodeData, ToolEntity):
@field_validator("tool_parameters", mode="before")
@classmethod
def filter_none_tool_inputs(cls, value):
def filter_none_tool_inputs(cls, value: object) -> object:
if not isinstance(value, dict):
return value
@@ -80,8 +111,10 @@ class ToolNodeData(BaseNodeData, ToolEntity):
}
@staticmethod
def _has_valid_value(tool_input):
def _has_valid_value(tool_input: object) -> bool:
"""Check if the value is valid"""
if isinstance(tool_input, dict):
return tool_input.get("value") is not None
return getattr(tool_input, "value", None) is not None
if isinstance(tool_input, ToolNodeData.ToolInput):
return tool_input.value is not None
return False

View File

@@ -225,10 +225,11 @@ class ToolNode(Node[ToolNodeData]):
continue
tool_input = node_data.tool_parameters[parameter_name]
if tool_input.type == "variable":
variable = variable_pool.get(tool_input.value)
variable_selector = tool_input.require_variable_selector()
variable = variable_pool.get(variable_selector)
if variable is None:
if parameter.required:
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
raise ToolParameterError(f"Variable {variable_selector} does not exist")
continue
parameter_value = variable.value
elif tool_input.type in {"mixed", "constant"}:
@@ -510,8 +511,9 @@ class ToolNode(Node[ToolNodeData]):
for selector in selectors:
result[selector.variable] = selector.value_selector
case "variable":
selector_key = ".".join(input.value)
result[f"#{selector_key}#"] = input.value
variable_selector = input.require_variable_selector()
selector_key = ".".join(variable_selector)
result[f"#{selector_key}#"] = variable_selector
case "constant":
pass

View File

@@ -9,7 +9,7 @@ from dify_graph.node_events import NodeRunResult
from dify_graph.nodes.base.node import Node
from dify_graph.nodes.variable_assigner.common import helpers as common_helpers
from dify_graph.nodes.variable_assigner.common.exc import VariableOperatorNodeError
from dify_graph.variables import SegmentType, VariableBase
from dify_graph.variables import Segment, SegmentType, VariableBase
from .node_data import VariableAssignerData, WriteMode
@@ -74,23 +74,29 @@ class VariableAssignerNode(Node[VariableAssignerData]):
if not isinstance(original_variable, VariableBase):
raise VariableOperatorNodeError("assigned variable not found")
income_value: Segment
updated_variable: VariableBase
match self.node_data.write_mode:
case WriteMode.OVER_WRITE:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if input_value is None:
raise VariableOperatorNodeError("input value not found")
income_value = input_value
updated_variable = original_variable.model_copy(update={"value": income_value.value})
case WriteMode.APPEND:
income_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if not income_value:
input_value = self.graph_runtime_state.variable_pool.get(self.node_data.input_variable_selector)
if input_value is None:
raise VariableOperatorNodeError("input value not found")
income_value = input_value
updated_value = original_variable.value + [income_value.value]
updated_variable = original_variable.model_copy(update={"value": updated_value})
case WriteMode.CLEAR:
income_value = SegmentType.get_zero_value(original_variable.value_type)
updated_variable = original_variable.model_copy(update={"value": income_value.to_object()})
case _:
raise VariableOperatorNodeError(f"unsupported write mode: {self.node_data.write_mode}")
# Over write the variable.
self.graph_runtime_state.variable_pool.add(assigned_variable_selector, updated_variable)

View File

@@ -66,6 +66,11 @@ class GraphExecutionProtocol(Protocol):
exceptions_count: int
pause_reasons: list[PauseReason]
@property
def node_executions(self) -> Mapping[str, NodeExecutionProtocol]:
"""Return node execution state keyed by node id for resume support."""
...
def start(self) -> None:
"""Transition execution into the running state."""
...
@@ -91,6 +96,12 @@ class GraphExecutionProtocol(Protocol):
...
class NodeExecutionProtocol(Protocol):
"""Structural interface for per-node execution state used during resume."""
execution_id: str | None
class ResponseStreamCoordinatorProtocol(Protocol):
"""Structural interface for response stream coordinator."""

View File

@@ -8,7 +8,7 @@ class BaseStorage(ABC):
"""Interface for file storage."""
@abstractmethod
def save(self, filename: str, data: bytes) -> None:
def save(self, filename: str, data: bytes):
raise NotImplementedError
@abstractmethod
@@ -16,7 +16,7 @@ class BaseStorage(ABC):
raise NotImplementedError
@abstractmethod
def load_stream(self, filename: str) -> Generator[bytes, None, None]:
def load_stream(self, filename: str) -> Generator:
raise NotImplementedError
@abstractmethod
@@ -28,10 +28,10 @@ class BaseStorage(ABC):
raise NotImplementedError
@abstractmethod
def delete(self, filename: str) -> None:
def delete(self, filename: str):
raise NotImplementedError
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
def scan(self, path, files=True, directories=False) -> list[str]:
"""
Scan files and directories in the given path.
This method is implemented only in some storage backends.

View File

@@ -30,8 +30,8 @@ class GitHubEmailRecord(TypedDict, total=False):
class GitHubRawUserInfo(TypedDict):
id: int | str
login: str
name: NotRequired[str | None]
email: NotRequired[str | None]
name: NotRequired[str]
email: NotRequired[str]
class GoogleRawUserInfo(TypedDict):
@@ -138,7 +138,7 @@ class GitHubOAuth(OAuth):
email = payload.get("email")
if not email:
email = f"{payload['id']}+{payload['login']}@users.noreply.github.com"
return OAuthUserInfo(id=str(payload["id"]), name=payload.get("name") or "", email=email)
return OAuthUserInfo(id=str(payload["id"]), name=str(payload.get("name", "")), email=email)
class GoogleOAuth(OAuth):

View File

@@ -20,7 +20,7 @@ else:
class NotionPageSummary(TypedDict):
page_id: str
page_name: str
page_icon: dict[str, object] | None
page_icon: dict[str, str] | None
parent_id: str
type: Literal["page", "database"]

View File

@@ -13,21 +13,6 @@ controllers/console/workspace/trigger_providers.py
controllers/service_api/app/annotation.py
controllers/web/workflow_events.py
core/agent/fc_agent_runner.py
core/app/apps/advanced_chat/app_generator.py
core/app/apps/advanced_chat/app_runner.py
core/app/apps/advanced_chat/generate_task_pipeline.py
core/app/apps/agent_chat/app_generator.py
core/app/apps/base_app_generate_response_converter.py
core/app/apps/base_app_generator.py
core/app/apps/chat/app_generator.py
core/app/apps/common/workflow_response_converter.py
core/app/apps/completion/app_generator.py
core/app/apps/pipeline/pipeline_generator.py
core/app/apps/pipeline/pipeline_runner.py
core/app/apps/workflow/app_generator.py
core/app/apps/workflow/app_runner.py
core/app/apps/workflow/generate_task_pipeline.py
core/app/apps/workflow_app_runner.py
core/app/task_pipeline/easy_ui_based_generate_task_pipeline.py
core/datasource/datasource_manager.py
core/external_data_tool/api/api.py
@@ -43,6 +28,58 @@ core/ops/tencent_trace/utils.py
core/plugin/backwards_invocation/base.py
core/plugin/backwards_invocation/model.py
core/prompt/utils/extract_thread_messages.py
core/rag/datasource/keyword/jieba/jieba.py
core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
core/rag/datasource/vdb/baidu/baidu_vector.py
core/rag/datasource/vdb/chroma/chroma_vector.py
core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
core/rag/datasource/vdb/couchbase/couchbase_vector.py
core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
core/rag/datasource/vdb/lindorm/lindorm_vector.py
core/rag/datasource/vdb/matrixone/matrixone_vector.py
core/rag/datasource/vdb/milvus/milvus_vector.py
core/rag/datasource/vdb/myscale/myscale_vector.py
core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
core/rag/datasource/vdb/opensearch/opensearch_vector.py
core/rag/datasource/vdb/oracle/oraclevector.py
core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
core/rag/datasource/vdb/relyt/relyt_vector.py
core/rag/datasource/vdb/tablestore/tablestore_vector.py
core/rag/datasource/vdb/tencent/tencent_vector.py
core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
core/rag/datasource/vdb/tidb_vector/tidb_vector.py
core/rag/datasource/vdb/upstash/upstash_vector.py
core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
core/rag/datasource/vdb/weaviate/weaviate_vector.py
core/rag/extractor/csv_extractor.py
core/rag/extractor/excel_extractor.py
core/rag/extractor/firecrawl/firecrawl_app.py
core/rag/extractor/firecrawl/firecrawl_web_extractor.py
core/rag/extractor/html_extractor.py
core/rag/extractor/jina_reader_extractor.py
core/rag/extractor/markdown_extractor.py
core/rag/extractor/notion_extractor.py
core/rag/extractor/pdf_extractor.py
core/rag/extractor/text_extractor.py
core/rag/extractor/unstructured/unstructured_doc_extractor.py
core/rag/extractor/unstructured/unstructured_eml_extractor.py
core/rag/extractor/unstructured/unstructured_epub_extractor.py
core/rag/extractor/unstructured/unstructured_markdown_extractor.py
core/rag/extractor/unstructured/unstructured_msg_extractor.py
core/rag/extractor/unstructured/unstructured_ppt_extractor.py
core/rag/extractor/unstructured/unstructured_pptx_extractor.py
core/rag/extractor/unstructured/unstructured_xml_extractor.py
core/rag/extractor/watercrawl/client.py
core/rag/extractor/watercrawl/extractor.py
core/rag/extractor/watercrawl/provider.py
core/rag/extractor/word_extractor.py
core/rag/index_processor/processor/paragraph_index_processor.py
core/rag/index_processor/processor/parent_child_index_processor.py
core/rag/index_processor/processor/qa_index_processor.py
core/rag/retrieval/router/multi_dataset_function_call_router.py
core/rag/summary_index/summary_index.py
core/repositories/sqlalchemy_workflow_execution_repository.py
@@ -56,42 +93,30 @@ core/tools/workflow_as_tool/provider.py
core/trigger/debug/event_selectors.py
core/trigger/entities/entities.py
core/trigger/provider.py
core/workflow/workflow_entry.py
dify_graph/entities/workflow_execution.py
dify_graph/file/file_manager.py
dify_graph/graph_engine/error_handler.py
dify_graph/graph_engine/layers/execution_limits.py
dify_graph/nodes/agent/agent_node.py
dify_graph/nodes/base/node.py
dify_graph/nodes/code/code_node.py
dify_graph/nodes/datasource/datasource_node.py
dify_graph/nodes/document_extractor/node.py
dify_graph/nodes/human_input/human_input_node.py
dify_graph/nodes/if_else/if_else_node.py
dify_graph/nodes/iteration/iteration_node.py
dify_graph/nodes/knowledge_index/knowledge_index_node.py
core/workflow/nodes/knowledge_retrieval/knowledge_retrieval_node.py
dify_graph/nodes/list_operator/node.py
dify_graph/nodes/llm/node.py
dify_graph/nodes/loop/loop_node.py
dify_graph/nodes/parameter_extractor/parameter_extractor_node.py
dify_graph/nodes/question_classifier/question_classifier_node.py
dify_graph/nodes/start/start_node.py
dify_graph/nodes/template_transform/template_transform_node.py
dify_graph/nodes/tool/tool_node.py
dify_graph/nodes/trigger_plugin/trigger_event_node.py
dify_graph/nodes/trigger_schedule/trigger_schedule_node.py
dify_graph/nodes/trigger_webhook/node.py
dify_graph/nodes/variable_aggregator/variable_aggregator_node.py
dify_graph/nodes/variable_assigner/v1/node.py
dify_graph/nodes/variable_assigner/v2/node.py
extensions/logstore/repositories/logstore_api_workflow_run_repository.py
extensions/otel/instrumentation.py
extensions/otel/runtime.py
extensions/storage/aliyun_oss_storage.py
extensions/storage/aws_s3_storage.py
extensions/storage/azure_blob_storage.py
extensions/storage/baidu_obs_storage.py
extensions/storage/clickzetta_volume/clickzetta_volume_storage.py
extensions/storage/clickzetta_volume/file_lifecycle.py
extensions/storage/google_cloud_storage.py
extensions/storage/huawei_obs_storage.py
extensions/storage/opendal_storage.py
extensions/storage/oracle_oci_storage.py
extensions/storage/supabase_storage.py
extensions/storage/tencent_cos_storage.py
extensions/storage/volcengine_tos_storage.py
libs/gmpy2_pkcs10aep_cipher.py
schedule/queue_monitor_task.py
services/account_service.py
services/audio_service.py
services/auth/firecrawl/firecrawl.py
services/auth/jina.py
services/auth/jina/jina.py
services/auth/watercrawl/watercrawl.py
services/conversation_service.py
services/dataset_service.py
services/document_indexing_proxy/document_indexing_task_proxy.py
@@ -119,75 +144,3 @@ tasks/disable_segment_from_index_task.py
tasks/enable_segment_to_index_task.py
tasks/remove_document_from_index_task.py
tasks/workflow_execution_tasks.py
# no need to fix for now: storage adapters
extensions/storage/aliyun_oss_storage.py
extensions/storage/aws_s3_storage.py
extensions/storage/azure_blob_storage.py
extensions/storage/baidu_obs_storage.py
extensions/storage/clickzetta_volume/clickzetta_volume_storage.py
extensions/storage/clickzetta_volume/file_lifecycle.py
extensions/storage/google_cloud_storage.py
extensions/storage/huawei_obs_storage.py
extensions/storage/opendal_storage.py
extensions/storage/oracle_oci_storage.py
extensions/storage/supabase_storage.py
extensions/storage/tencent_cos_storage.py
extensions/storage/volcengine_tos_storage.py
# no need to fix for now: keyword adapters
core/rag/datasource/keyword/jieba/jieba.py
core/rag/datasource/keyword/jieba/jieba_keyword_table_handler.py
# no need to fix for now: vector db adapters
core/rag/datasource/vdb/analyticdb/analyticdb_vector.py
core/rag/datasource/vdb/analyticdb/analyticdb_vector_openapi.py
core/rag/datasource/vdb/baidu/baidu_vector.py
core/rag/datasource/vdb/chroma/chroma_vector.py
core/rag/datasource/vdb/clickzetta/clickzetta_vector.py
core/rag/datasource/vdb/couchbase/couchbase_vector.py
core/rag/datasource/vdb/elasticsearch/elasticsearch_vector.py
core/rag/datasource/vdb/huawei/huawei_cloud_vector.py
core/rag/datasource/vdb/matrixone/matrixone_vector.py
core/rag/datasource/vdb/milvus/milvus_vector.py
core/rag/datasource/vdb/myscale/myscale_vector.py
core/rag/datasource/vdb/oceanbase/oceanbase_vector.py
core/rag/datasource/vdb/opensearch/opensearch_vector.py
core/rag/datasource/vdb/oracle/oraclevector.py
core/rag/datasource/vdb/pgvecto_rs/pgvecto_rs.py
core/rag/datasource/vdb/relyt/relyt_vector.py
core/rag/datasource/vdb/tablestore/tablestore_vector.py
core/rag/datasource/vdb/tencent/tencent_vector.py
core/rag/datasource/vdb/tidb_on_qdrant/tidb_on_qdrant_vector.py
core/rag/datasource/vdb/tidb_on_qdrant/tidb_service.py
core/rag/datasource/vdb/tidb_vector/tidb_vector.py
core/rag/datasource/vdb/upstash/upstash_vector.py
core/rag/datasource/vdb/vikingdb/vikingdb_vector.py
core/rag/datasource/vdb/weaviate/weaviate_vector.py
# no need to fix for now: extractors
core/rag/extractor/csv_extractor.py
core/rag/extractor/excel_extractor.py
core/rag/extractor/firecrawl/firecrawl_app.py
core/rag/extractor/firecrawl/firecrawl_web_extractor.py
core/rag/extractor/html_extractor.py
core/rag/extractor/jina_reader_extractor.py
core/rag/extractor/markdown_extractor.py
core/rag/extractor/notion_extractor.py
core/rag/extractor/text_extractor.py
core/rag/extractor/unstructured/unstructured_doc_extractor.py
core/rag/extractor/unstructured/unstructured_eml_extractor.py
core/rag/extractor/unstructured/unstructured_epub_extractor.py
core/rag/extractor/unstructured/unstructured_markdown_extractor.py
core/rag/extractor/unstructured/unstructured_msg_extractor.py
core/rag/extractor/unstructured/unstructured_ppt_extractor.py
core/rag/extractor/unstructured/unstructured_pptx_extractor.py
core/rag/extractor/unstructured/unstructured_xml_extractor.py
core/rag/extractor/watercrawl/client.py
core/rag/extractor/watercrawl/extractor.py
core/rag/extractor/watercrawl/provider.py
# no need to fix for now: index processors
core/rag/index_processor/processor/paragraph_index_processor.py
core/rag/index_processor/processor/parent_child_index_processor.py
core/rag/index_processor/processor/qa_index_processor.py

View File

@@ -1,24 +1,10 @@
from abc import ABC, abstractmethod
from typing_extensions import TypedDict
class ApiKeyAuthConfig(TypedDict, total=False):
api_key: str
base_url: str
class ApiKeyAuthCredentials(TypedDict):
auth_type: object
config: ApiKeyAuthConfig
class ApiKeyAuthBase(ABC):
credentials: ApiKeyAuthCredentials
def __init__(self, credentials: ApiKeyAuthCredentials) -> None:
def __init__(self, credentials: dict):
self.credentials = credentials
@abstractmethod
def validate_credentials(self) -> bool:
def validate_credentials(self):
raise NotImplementedError

View File

@@ -1,20 +1,18 @@
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
from services.auth.auth_type import AuthProvider, AuthType
from services.auth.api_key_auth_base import ApiKeyAuthBase
from services.auth.auth_type import AuthType
class ApiKeyAuthFactory:
auth: ApiKeyAuthBase
def __init__(self, provider: AuthProvider, credentials: ApiKeyAuthCredentials) -> None:
def __init__(self, provider: str, credentials: dict):
auth_factory = self.get_apikey_auth_factory(provider)
self.auth = auth_factory(credentials)
def validate_credentials(self) -> bool:
def validate_credentials(self):
return self.auth.validate_credentials()
@staticmethod
def get_apikey_auth_factory(provider: AuthProvider) -> type[ApiKeyAuthBase]:
match ApiKeyAuthFactory._normalize_provider(provider):
def get_apikey_auth_factory(provider: str) -> type[ApiKeyAuthBase]:
match provider:
case AuthType.FIRECRAWL:
from services.auth.firecrawl.firecrawl import FirecrawlAuth
@@ -29,13 +27,3 @@ class ApiKeyAuthFactory:
return JinaAuth
case _:
raise ValueError("Invalid provider")
@staticmethod
def _normalize_provider(provider: AuthProvider) -> AuthType | str:
if isinstance(provider, AuthType):
return provider
try:
return AuthType(provider)
except ValueError:
return provider

View File

@@ -1,63 +1,40 @@
import json
from typing import cast
from pydantic import TypeAdapter
from sqlalchemy import select
from typing_extensions import TypedDict
from core.helper import encrypter
from extensions.ext_database import db
from models.source import DataSourceApiKeyAuthBinding
from services.auth.api_key_auth_base import ApiKeyAuthCredentials
from services.auth.api_key_auth_factory import ApiKeyAuthFactory
class ApiKeyAuthCreateArgs(TypedDict):
category: str
provider: str
credentials: ApiKeyAuthCredentials
AUTH_CREATE_ARGS_ADAPTER = TypeAdapter(ApiKeyAuthCreateArgs)
AUTH_CREDENTIALS_ADAPTER = TypeAdapter(dict[str, object])
class ApiKeyAuthService:
@staticmethod
def get_provider_auth_list(tenant_id: str) -> list[DataSourceApiKeyAuthBinding]:
def get_provider_auth_list(tenant_id: str):
data_source_api_key_bindings = db.session.scalars(
select(DataSourceApiKeyAuthBinding).where(
DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.disabled.is_(False)
)
).all()
return list(data_source_api_key_bindings)
return data_source_api_key_bindings
@staticmethod
def create_provider_auth(tenant_id: str, args: dict[str, object]) -> None:
validated_args = ApiKeyAuthService.validate_api_key_auth_args(args)
raw_credentials = ApiKeyAuthService._get_credentials_dict(args)
auth_result = ApiKeyAuthFactory(
validated_args["provider"], validated_args["credentials"]
).validate_credentials()
def create_provider_auth(tenant_id: str, args: dict):
auth_result = ApiKeyAuthFactory(args["provider"], args["credentials"]).validate_credentials()
if auth_result:
api_key_value = validated_args["credentials"]["config"].get("api_key")
if api_key_value is None:
raise KeyError("api_key")
api_key = encrypter.encrypt_token(tenant_id, api_key_value)
raw_config = ApiKeyAuthService._get_config_dict(raw_credentials)
raw_config["api_key"] = api_key
# Encrypt the api key
api_key = encrypter.encrypt_token(tenant_id, args["credentials"]["config"]["api_key"])
args["credentials"]["config"]["api_key"] = api_key
data_source_api_key_binding = DataSourceApiKeyAuthBinding(
tenant_id=tenant_id,
category=validated_args["category"],
provider=validated_args["provider"],
tenant_id=tenant_id, category=args["category"], provider=args["provider"]
)
data_source_api_key_binding.credentials = json.dumps(raw_credentials, ensure_ascii=False)
data_source_api_key_binding.credentials = json.dumps(args["credentials"], ensure_ascii=False)
db.session.add(data_source_api_key_binding)
db.session.commit()
@staticmethod
def get_auth_credentials(tenant_id: str, category: str, provider: str) -> dict[str, object] | None:
def get_auth_credentials(tenant_id: str, category: str, provider: str):
data_source_api_key_bindings = (
db.session.query(DataSourceApiKeyAuthBinding)
.where(
@@ -73,10 +50,10 @@ class ApiKeyAuthService:
if not data_source_api_key_bindings.credentials:
return None
credentials = json.loads(data_source_api_key_bindings.credentials)
return AUTH_CREDENTIALS_ADAPTER.validate_python(credentials)
return credentials
@staticmethod
def delete_provider_auth(tenant_id: str, binding_id: str) -> None:
def delete_provider_auth(tenant_id: str, binding_id: str):
data_source_api_key_binding = (
db.session.query(DataSourceApiKeyAuthBinding)
.where(DataSourceApiKeyAuthBinding.tenant_id == tenant_id, DataSourceApiKeyAuthBinding.id == binding_id)
@@ -86,10 +63,8 @@ class ApiKeyAuthService:
db.session.delete(data_source_api_key_binding)
db.session.commit()
@staticmethod
def validate_api_key_auth_args(args: dict[str, object] | None) -> ApiKeyAuthCreateArgs:
if args is None:
raise TypeError("argument of type 'NoneType' is not iterable")
@classmethod
def validate_api_key_auth_args(cls, args):
if "category" not in args or not args["category"]:
raise ValueError("category is required")
if "provider" not in args or not args["provider"]:
@@ -100,18 +75,3 @@ class ApiKeyAuthService:
raise ValueError("credentials must be a dictionary")
if "auth_type" not in args["credentials"] or not args["credentials"]["auth_type"]:
raise ValueError("auth_type is required")
return AUTH_CREATE_ARGS_ADAPTER.validate_python(args)
@staticmethod
def _get_credentials_dict(args: dict[str, object]) -> dict[str, object]:
credentials = args["credentials"]
if not isinstance(credentials, dict):
raise ValueError("credentials must be a dictionary")
return cast(dict[str, object], credentials)
@staticmethod
def _get_config_dict(credentials: dict[str, object]) -> dict[str, object]:
config = credentials["config"]
if not isinstance(config, dict):
raise TypeError(f"credentials['config'] must be a dictionary, got {type(config).__name__}")
return cast(dict[str, object], config)

View File

@@ -5,6 +5,3 @@ class AuthType(StrEnum):
FIRECRAWL = "firecrawl"
WATERCRAWL = "watercrawl"
JINA = "jinareader"
AuthProvider = AuthType | str

View File

@@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
from services.auth.api_key_auth_base import ApiKeyAuthBase
class FirecrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: ApiKeyAuthCredentials):
def __init__(self, credentials: dict):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
from services.auth.api_key_auth_base import ApiKeyAuthBase
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: ApiKeyAuthCredentials):
def __init__(self, credentials: dict):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@@ -2,11 +2,11 @@ import json
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
from services.auth.api_key_auth_base import ApiKeyAuthBase
class JinaAuth(ApiKeyAuthBase):
def __init__(self, credentials: ApiKeyAuthCredentials):
def __init__(self, credentials: dict):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "bearer":

View File

@@ -3,11 +3,11 @@ from urllib.parse import urljoin
import httpx
from services.auth.api_key_auth_base import ApiKeyAuthBase, ApiKeyAuthCredentials
from services.auth.api_key_auth_base import ApiKeyAuthBase
class WatercrawlAuth(ApiKeyAuthBase):
def __init__(self, credentials: ApiKeyAuthCredentials):
def __init__(self, credentials: dict):
super().__init__(credentials)
auth_type = credentials.get("auth_type")
if auth_type != "x-api-key":

View File

@@ -1013,7 +1013,7 @@ class TestAdvancedChatAppGeneratorInternals:
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session)
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.db", SimpleNamespace(engine=object()))
refreshed = _refresh_model(session=SimpleNamespace(), model=source_model)
refreshed = _refresh_model(session=None, model=source_model)
assert refreshed is detached_model

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any
from copy import deepcopy
from unittest.mock import MagicMock, patch
import pytest
@@ -33,8 +33,8 @@ def _make_graph_state():
],
)
def test_run_uses_single_node_execution_branch(
single_iteration_run: Any,
single_loop_run: Any,
single_iteration_run: WorkflowAppGenerateEntity.SingleIterationRunEntity | None,
single_loop_run: WorkflowAppGenerateEntity.SingleLoopRunEntity | None,
) -> None:
app_config = MagicMock()
app_config.app_id = "app"
@@ -130,10 +130,23 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
"break_conditions": [],
"logical_operator": "and",
},
},
{
"id": "other-node",
"data": {
"type": "answer",
"title": "Answer",
},
},
],
"edges": [
{
"source": "other-node",
"target": "loop-node",
}
],
"edges": [],
}
original_graph_dict = deepcopy(workflow.graph_dict)
_, _, graph_runtime_state = _make_graph_state()
seen_configs: list[object] = []
@@ -143,13 +156,19 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
seen_configs.append(value)
return original_validate_python(value)
class FakeNodeClass:
@staticmethod
def extract_variable_selector_to_variable_mapping(**_kwargs):
return {}
monkeypatch.setattr(NodeConfigDictAdapter, "validate_python", record_validate_python)
with (
patch("core.app.apps.workflow_app_runner.DifyNodeFactory"),
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()),
patch("core.app.apps.workflow_app_runner.Graph.init", return_value=MagicMock()) as graph_init,
patch("core.app.apps.workflow_app_runner.load_into_variable_pool"),
patch("core.app.apps.workflow_app_runner.WorkflowEntry.mapping_user_inputs_to_variable_pool"),
patch("core.app.apps.workflow_app_runner.resolve_workflow_node_class", return_value=FakeNodeClass),
):
runner._get_graph_and_variable_pool_for_single_node_run(
workflow=workflow,
@@ -161,3 +180,8 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
)
assert seen_configs == [workflow.graph_dict["nodes"][0]]
assert workflow.graph_dict == original_graph_dict
graph_config = graph_init.call_args.kwargs["graph_config"]
assert graph_config is not workflow.graph_dict
assert graph_config["nodes"] == [workflow.graph_dict["nodes"][0]]
assert graph_config["edges"] == []

View File

@@ -200,29 +200,6 @@ class TestExtractProcessorFileRouting:
with pytest.raises(AssertionError, match="upload_file is required"):
ExtractProcessor.extract(setting)
@pytest.mark.parametrize(
("extension", "etl_type", "expected_extractor"),
[
(".pdf", "Unstructured", "PdfExtractor"),
(".docx", "Unstructured", "WordExtractor"),
(".pdf", "SelfHosted", "PdfExtractor"),
(".docx", "SelfHosted", "WordExtractor"),
],
)
def test_extract_allows_url_file_paths_without_upload_context(
self, monkeypatch, extension: str, etl_type: str, expected_extractor: str
):
factory = _patch_all_extractors(monkeypatch)
monkeypatch.setattr(processor_module.dify_config, "ETL_TYPE", etl_type)
setting = SimpleNamespace(datasource_type=DatasourceType.FILE, upload_file=None)
docs = ExtractProcessor.extract(setting, file_path=f"/tmp/example{extension}")
assert docs[0].page_content == f"extracted-by-{expected_extractor}"
assert factory.calls[-1][0] == expected_extractor
assert factory.calls[-1][1] == (f"/tmp/example{extension}", None, None)
class TestExtractProcessorDatasourceRouting:
def test_extract_routes_notion_datasource(self, monkeypatch):

View File

@@ -184,21 +184,3 @@ def test_extract_images_failures(mock_dependencies):
assert len(saves) == 1
assert saves[0][1] == jpeg_bytes
assert db_stub.session.committed is True
def test_extract_images_skips_persistence_without_upload_context(mock_dependencies):
mock_page = MagicMock()
mock_image_obj = MagicMock()
mock_image_obj.extract.side_effect = lambda buf, fb_format=None: buf.write(b"\xff\xd8\xff image")
mock_page.get_objects.return_value = [mock_image_obj]
extractor = pe.PdfExtractor(file_path="test.pdf", tenant_id=None, user_id=None)
with patch("pypdfium2.raw", autospec=True) as mock_raw:
mock_raw.FPDF_PAGEOBJ_IMAGE = 1
result = extractor._extract_images(mock_page)
assert result == ""
assert mock_dependencies.saves == []
assert mock_dependencies.db.session.added == []
assert mock_dependencies.db.session.committed is False

View File

@@ -179,27 +179,6 @@ def test_extract_images_from_docx(monkeypatch):
assert db_stub.session.committed is True
def test_extract_images_from_docx_skips_persistence_without_upload_context(monkeypatch):
saves: list[tuple[str, bytes]] = []
monkeypatch.setattr(we, "storage", SimpleNamespace(save=lambda key, data: saves.append((key, data))))
db_stub = SimpleNamespace(session=SimpleNamespace(add=lambda obj: None, commit=lambda: None))
monkeypatch.setattr(we, "db", db_stub)
rel_ext = SimpleNamespace(is_external=True, target_ref="https://example.com/image.png")
doc = SimpleNamespace(part=SimpleNamespace(rels={"rId1": rel_ext}))
extractor = object.__new__(WordExtractor)
extractor.tenant_id = None
extractor.user_id = None
image_map = extractor._extract_images_from_docx(doc)
assert image_map == {}
assert saves == []
def test_extract_images_from_docx_uses_internal_files_url():
"""Test that INTERNAL_FILES_URL takes precedence over FILES_URL for plugin access."""
# Test the URL generation logic directly

View File

@@ -1,6 +1,9 @@
import pytest
from dify_graph.entities.graph_config import NodeConfigDictAdapter
from dify_graph.nodes.loop.entities import LoopNodeData
from dify_graph.nodes.loop.entities import LoopNodeData, LoopValue
from dify_graph.nodes.loop.loop_node import LoopNode
from dify_graph.variables.types import SegmentType
def test_extract_variable_selector_to_variable_mapping_validates_child_node_configs(monkeypatch) -> None:
@@ -50,3 +53,21 @@ def test_extract_variable_selector_to_variable_mapping_validates_child_node_conf
)
assert seen_configs == [child_node_config]
@pytest.mark.parametrize(
("var_type", "original_value", "expected_value"),
[
(SegmentType.ARRAY_STRING, ["alpha", "beta"], ["alpha", "beta"]),
(SegmentType.ARRAY_NUMBER, [1, 2.5], [1, 2.5]),
(SegmentType.ARRAY_OBJECT, [{"name": "item"}], [{"name": "item"}]),
(SegmentType.ARRAY_STRING, '["legacy", "json"]', ["legacy", "json"]),
],
)
def test_get_segment_for_constant_accepts_native_array_values(
var_type: SegmentType, original_value: LoopValue, expected_value: LoopValue
) -> None:
segment = LoopNode._get_segment_for_constant(var_type, original_value)
assert segment.value_type == var_type
assert segment.value == expected_value

View File

@@ -13,11 +13,8 @@ class TestApiKeyAuthFactory:
("provider", "auth_class_path"),
[
(AuthType.FIRECRAWL, "services.auth.firecrawl.firecrawl.FirecrawlAuth"),
(AuthType.FIRECRAWL.value, "services.auth.firecrawl.firecrawl.FirecrawlAuth"),
(AuthType.WATERCRAWL, "services.auth.watercrawl.watercrawl.WatercrawlAuth"),
(AuthType.WATERCRAWL.value, "services.auth.watercrawl.watercrawl.WatercrawlAuth"),
(AuthType.JINA, "services.auth.jina.jina.JinaAuth"),
(AuthType.JINA.value, "services.auth.jina.jina.JinaAuth"),
],
)
def test_get_apikey_auth_factory_valid_providers(self, provider, auth_class_path):

View File

@@ -1,5 +1,4 @@
import json
from copy import deepcopy
from unittest.mock import Mock, patch
import pytest
@@ -69,16 +68,7 @@ class TestApiKeyAuthService:
# Mock successful auth validation
mock_auth_instance = Mock()
mock_auth_instance.validate_credentials.return_value = True
captured_provider = None
captured_credentials = None
def factory_side_effect(provider, credentials):
nonlocal captured_provider, captured_credentials
captured_provider = provider
captured_credentials = deepcopy(credentials)
return mock_auth_instance
mock_factory.side_effect = factory_side_effect
mock_factory.return_value = mock_auth_instance
# Mock encryption
encrypted_key = "encrypted_test_key_123"
@@ -87,14 +77,11 @@ class TestApiKeyAuthService:
# Mock database operations
mock_session.add = Mock()
mock_session.commit = Mock()
expected_credentials = deepcopy(self.mock_credentials)
ApiKeyAuthService.create_provider_auth(self.tenant_id, self.mock_args)
# Verify factory class calls
assert mock_factory.call_count == 1
assert captured_provider == self.provider
assert captured_credentials == expected_credentials
mock_factory.assert_called_once_with(self.provider, self.mock_credentials)
mock_auth_instance.validate_credentials.assert_called_once()
# Verify encryption calls

View File

@@ -4,13 +4,11 @@ from __future__ import annotations
from dataclasses import dataclass
from typing import Any
from unittest.mock import MagicMock, Mock, patch
from unittest.mock import MagicMock
import pytest
import services.vector_service as vector_service_module
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import AttachmentDocument, ChildDocument, Document
from services.vector_service import VectorService
@@ -704,105 +702,3 @@ def test_update_multimodel_vector_rolls_back_and_reraises_on_error(monkeypatch:
logger_mock.exception.assert_called_once()
db_mock.session.rollback.assert_called_once()
@patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector")
@patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings")
def test_vector_create_normalizes_child_documents(mock_get_embeddings: Mock, mock_init_vector: Mock) -> None:
dataset = _make_dataset()
documents = [ChildDocument(page_content="Child content", metadata={"doc_id": "child-1", "dataset_id": "dataset-1"})]
mock_embeddings = Mock()
mock_embeddings.embed_documents.return_value = [[0.1] * 1536]
mock_get_embeddings.return_value = mock_embeddings
mock_vector_processor = Mock()
mock_init_vector.return_value = mock_vector_processor
vector = Vector(dataset=dataset)
vector.create(texts=documents)
normalized_document = mock_vector_processor.create.call_args.kwargs["texts"][0]
assert isinstance(normalized_document, Document)
assert normalized_document.page_content == "Child content"
assert normalized_document.metadata["doc_id"] == "child-1"
@patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector")
@patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings")
@patch("core.rag.datasource.vdb.vector_factory.storage")
@patch("core.rag.datasource.vdb.vector_factory.db.session")
def test_vector_create_multimodal_normalizes_attachment_documents(
mock_session: Mock,
mock_storage: Mock,
mock_get_embeddings: Mock,
mock_init_vector: Mock,
) -> None:
dataset = _make_dataset()
file_document = AttachmentDocument(
page_content="Attachment content",
provider="custom-provider",
metadata={"doc_id": "file-1", "doc_type": "image/png"},
)
upload_file = Mock(id="file-1", key="upload-key")
mock_scalars = Mock()
mock_scalars.all.return_value = [upload_file]
mock_session.scalars.return_value = mock_scalars
mock_storage.load_once.return_value = b"binary-content"
mock_embeddings = Mock()
mock_embeddings.embed_multimodal_documents.return_value = [[0.2] * 1536]
mock_get_embeddings.return_value = mock_embeddings
mock_vector_processor = Mock()
mock_init_vector.return_value = mock_vector_processor
vector = Vector(dataset=dataset)
vector.create_multimodal(file_documents=[file_document])
normalized_document = mock_vector_processor.create.call_args.kwargs["texts"][0]
assert isinstance(normalized_document, Document)
assert normalized_document.provider == "custom-provider"
assert normalized_document.metadata["doc_id"] == "file-1"
@patch("core.rag.datasource.vdb.vector_factory.Vector._init_vector")
@patch("core.rag.datasource.vdb.vector_factory.Vector._get_embeddings")
@patch("core.rag.datasource.vdb.vector_factory.storage")
@patch("core.rag.datasource.vdb.vector_factory.db.session")
def test_vector_create_multimodal_falls_back_to_dify_provider_when_attachment_provider_is_none(
mock_session: Mock,
mock_storage: Mock,
mock_get_embeddings: Mock,
mock_init_vector: Mock,
) -> None:
dataset = _make_dataset()
file_document = AttachmentDocument(
page_content="Attachment content",
provider=None,
metadata={"doc_id": "file-1", "doc_type": "image/png"},
)
upload_file = Mock(id="file-1", key="upload-key")
mock_scalars = Mock()
mock_scalars.all.return_value = [upload_file]
mock_session.scalars.return_value = mock_scalars
mock_storage.load_once.return_value = b"binary-content"
mock_embeddings = Mock()
mock_embeddings.embed_multimodal_documents.return_value = [[0.2] * 1536]
mock_get_embeddings.return_value = mock_embeddings
mock_vector_processor = Mock()
mock_init_vector.return_value = mock_vector_processor
vector = Vector(dataset=dataset)
vector.create_multimodal(file_documents=[file_document])
normalized_document = mock_vector_processor.create.call_args.kwargs["texts"][0]
assert isinstance(normalized_document, Document)
assert normalized_document.provider == "dify"