mirror of
https://github.com/langgenius/dify.git
synced 2026-02-05 15:43:59 +00:00
Compare commits
93 Commits
fix/api-to
...
zhsama/ass
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
8d643e4b85 | ||
|
|
4ee49552ce | ||
|
|
40caaaab23 | ||
|
|
1bc1c04be5 | ||
|
|
18abc66585 | ||
|
|
e85e31773a | ||
|
|
e5336a2d75 | ||
|
|
7222a896d8 | ||
|
|
b5712bf8b0 | ||
|
|
7bc2e33e83 | ||
|
|
a7826d9ea4 | ||
|
|
72eb29c01b | ||
|
|
0f3156dfbe | ||
|
|
b21875eaaf | ||
|
|
2591615a3c | ||
|
|
691554ad1c | ||
|
|
f43fde5797 | ||
|
|
f247ebfbe1 | ||
|
|
d641c845dd | ||
|
|
2e10d67610 | ||
|
|
e89d4e14ea | ||
|
|
5525f63032 | ||
|
|
8ee643e88d | ||
|
|
ccb337e8eb | ||
|
|
1ff677c300 | ||
|
|
04145b19a1 | ||
|
|
56e537786f | ||
|
|
810f9eaaad | ||
|
|
4828348532 | ||
|
|
c8c048c3a3 | ||
|
|
495d575ebc | ||
|
|
b9052bc244 | ||
|
|
b7025ad9d6 | ||
|
|
c5482c2503 | ||
|
|
d394adfaf7 | ||
|
|
bc771d9c50 | ||
|
|
96ec176b83 | ||
|
|
f57d2ef31f | ||
|
|
e80bc78780 | ||
|
|
ddbbddbd14 | ||
|
|
9b961fb41e | ||
|
|
4f79d09d7b | ||
|
|
dbed937fc6 | ||
|
|
969c96b070 | ||
|
|
03e0c4c617 | ||
|
|
47790b49d4 | ||
|
|
b25b069917 | ||
|
|
bb190f9610 | ||
|
|
d65ae68668 | ||
|
|
f625350439 | ||
|
|
f4e8f64bf7 | ||
|
|
d91087492d | ||
|
|
cab7cd37b8 | ||
|
|
f925266c1b | ||
|
|
6e2cf23a73 | ||
|
|
8b0bc6937d | ||
|
|
872fd98eda | ||
|
|
5bcd3b6fe6 | ||
|
|
1aed585a19 | ||
|
|
831eba8b1c | ||
|
|
8b8e521c4e | ||
|
|
88248ad2d3 | ||
|
|
760a739e91 | ||
|
|
d92c476388 | ||
|
|
9012dced6a | ||
|
|
50bed78d7a | ||
|
|
60250355cb | ||
|
|
75afc2dc0e | ||
|
|
225b13da93 | ||
|
|
37c748192d | ||
|
|
b7a2957340 | ||
|
|
a6ce6a249b | ||
|
|
8834e6e531 | ||
|
|
39010fd153 | ||
|
|
bd338a9043 | ||
|
|
39d6383474 | ||
|
|
add8980790 | ||
|
|
5157e1a96c | ||
|
|
4bb76acc37 | ||
|
|
b513933040 | ||
|
|
18ea9d3f18 | ||
|
|
7b660a9ebc | ||
|
|
783a49bd97 | ||
|
|
d3c6b09354 | ||
|
|
3d61496d25 | ||
|
|
16bff9e82f | ||
|
|
22f25731e8 | ||
|
|
035f51ad58 | ||
|
|
e9795bd772 | ||
|
|
93b516a4ec | ||
|
|
fc9d5b2a62 | ||
|
|
e3bfb95c52 | ||
|
|
752cb9e4f4 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -209,6 +209,7 @@ api/.vscode
|
||||
.history
|
||||
|
||||
.idea/
|
||||
web/migration/
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
|
||||
@@ -55,6 +55,35 @@ class InstructionTemplatePayload(BaseModel):
|
||||
type: str = Field(..., description="Instruction template type")
|
||||
|
||||
|
||||
class ContextGeneratePayload(BaseModel):
|
||||
"""Payload for generating extractor code node."""
|
||||
|
||||
workflow_id: str = Field(..., description="Workflow ID")
|
||||
node_id: str = Field(..., description="Current tool/llm node ID")
|
||||
parameter_name: str = Field(..., description="Parameter name to generate code for")
|
||||
language: str = Field(default="python3", description="Code language (python3/javascript)")
|
||||
prompt_messages: list[dict[str, Any]] = Field(
|
||||
..., description="Multi-turn conversation history, last message is the current instruction"
|
||||
)
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
|
||||
|
||||
class SuggestedQuestionsPayload(BaseModel):
|
||||
"""Payload for generating suggested questions."""
|
||||
|
||||
workflow_id: str = Field(..., description="Workflow ID")
|
||||
node_id: str = Field(..., description="Current tool/llm node ID")
|
||||
parameter_name: str = Field(..., description="Parameter name")
|
||||
language: str = Field(
|
||||
default="English", description="Language for generated questions (e.g. English, Chinese, Japanese)"
|
||||
)
|
||||
model_config_data: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
alias="model_config",
|
||||
description="Model configuration (optional, uses system default if not provided)",
|
||||
)
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
@@ -64,6 +93,8 @@ reg(RuleCodeGeneratePayload)
|
||||
reg(RuleStructuredOutputPayload)
|
||||
reg(InstructionGeneratePayload)
|
||||
reg(InstructionTemplatePayload)
|
||||
reg(ContextGeneratePayload)
|
||||
reg(SuggestedQuestionsPayload)
|
||||
|
||||
|
||||
@console_ns.route("/rule-generate")
|
||||
@@ -278,3 +309,74 @@ class InstructionGenerationTemplateApi(Resource):
|
||||
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
|
||||
case _:
|
||||
raise ValueError(f"Invalid type: {args.type}")
|
||||
|
||||
|
||||
@console_ns.route("/context-generate")
|
||||
class ContextGenerateApi(Resource):
|
||||
@console_ns.doc("generate_with_context")
|
||||
@console_ns.doc(description="Generate with multi-turn conversation context")
|
||||
@console_ns.expect(console_ns.models[ContextGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Content generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters or workflow not found")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
from core.llm_generator.utils import deserialize_prompt_messages
|
||||
|
||||
args = ContextGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
prompt_messages = deserialize_prompt_messages(args.prompt_messages)
|
||||
|
||||
try:
|
||||
return LLMGenerator.generate_with_context(
|
||||
tenant_id=current_tenant_id,
|
||||
workflow_id=args.workflow_id,
|
||||
node_id=args.node_id,
|
||||
parameter_name=args.parameter_name,
|
||||
language=args.language,
|
||||
prompt_messages=prompt_messages,
|
||||
model_config=args.model_config_data,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
|
||||
@console_ns.route("/context-generate/suggested-questions")
|
||||
class SuggestedQuestionsApi(Resource):
|
||||
@console_ns.doc("generate_suggested_questions")
|
||||
@console_ns.doc(description="Generate suggested questions for context generation")
|
||||
@console_ns.expect(console_ns.models[SuggestedQuestionsPayload.__name__])
|
||||
@console_ns.response(200, "Questions generated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = SuggestedQuestionsPayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return LLMGenerator.generate_suggested_questions(
|
||||
tenant_id=current_tenant_id,
|
||||
workflow_id=args.workflow_id,
|
||||
node_id=args.node_id,
|
||||
parameter_name=args.parameter_name,
|
||||
language=args.language,
|
||||
model_config=args.model_config_data,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
@@ -17,7 +17,7 @@ from controllers.console.wraps import account_initialization_required, edit_perm
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.file import helpers as file_helpers
|
||||
from core.variables.segment_group import SegmentGroup
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.variables.segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment
|
||||
from core.variables.types import SegmentType
|
||||
from core.workflow.constants import CONVERSATION_VARIABLE_NODE_ID, SYSTEM_VARIABLE_NODE_ID
|
||||
from extensions.ext_database import db
|
||||
@@ -58,6 +58,8 @@ def _convert_values_to_json_serializable_object(value: Segment):
|
||||
return value.value.model_dump()
|
||||
elif isinstance(value, ArrayFileSegment):
|
||||
return [i.model_dump() for i in value.value]
|
||||
elif isinstance(value, ArrayPromptMessageSegment):
|
||||
return value.to_object()
|
||||
elif isinstance(value, SegmentGroup):
|
||||
return [_convert_values_to_json_serializable_object(i) for i in value.value]
|
||||
else:
|
||||
|
||||
@@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
@@ -120,6 +120,6 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@@ -81,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@@ -109,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
@@ -117,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@@ -81,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@@ -109,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
@@ -117,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@@ -70,6 +70,8 @@ class _NodeSnapshot:
|
||||
"""Empty string means the node is not executing inside an iteration."""
|
||||
loop_id: str = ""
|
||||
"""Empty string means the node is not executing inside a loop."""
|
||||
mention_parent_id: str = ""
|
||||
"""Empty string means the node is not an extractor node."""
|
||||
|
||||
|
||||
class WorkflowResponseConverter:
|
||||
@@ -131,6 +133,7 @@ class WorkflowResponseConverter:
|
||||
start_at=event.start_at,
|
||||
iteration_id=event.in_iteration_id or "",
|
||||
loop_id=event.in_loop_id or "",
|
||||
mention_parent_id=event.in_mention_parent_id or "",
|
||||
)
|
||||
node_execution_id = NodeExecutionId(event.node_execution_id)
|
||||
self._node_snapshots[node_execution_id] = snapshot
|
||||
@@ -287,6 +290,7 @@ class WorkflowResponseConverter:
|
||||
created_at=int(snapshot.start_at.timestamp()),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
mention_parent_id=event.in_mention_parent_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
),
|
||||
)
|
||||
@@ -373,6 +377,7 @@ class WorkflowResponseConverter:
|
||||
files=self.fetch_files_from_node_outputs(event.outputs or {}),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
mention_parent_id=event.in_mention_parent_id,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -422,6 +427,7 @@ class WorkflowResponseConverter:
|
||||
files=self.fetch_files_from_node_outputs(event.outputs or {}),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
mention_parent_id=event.in_mention_parent_id,
|
||||
retry_index=event.retry_index,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -79,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@@ -106,7 +106,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
@@ -116,6 +116,6 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(cast(dict, data))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump())
|
||||
response_chunk.update(sub_stream_response.model_dump(exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump())
|
||||
response_chunk.update(sub_stream_response.model_dump(exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@@ -385,6 +385,7 @@ class WorkflowBasedAppRunner:
|
||||
start_at=event.start_at,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
@@ -405,6 +406,7 @@ class WorkflowBasedAppRunner:
|
||||
start_at=event.start_at,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
@@ -428,6 +430,7 @@ class WorkflowBasedAppRunner:
|
||||
execution_metadata=execution_metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
@@ -444,6 +447,7 @@ class WorkflowBasedAppRunner:
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunExceptionEvent):
|
||||
@@ -460,6 +464,7 @@ class WorkflowBasedAppRunner:
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
@@ -469,6 +474,7 @@ class WorkflowBasedAppRunner:
|
||||
from_variable_selector=list(event.selector),
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
@@ -477,6 +483,7 @@ class WorkflowBasedAppRunner:
|
||||
retriever_resources=event.retriever_resources,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunAgentLogEvent):
|
||||
|
||||
@@ -190,6 +190,8 @@ class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
@@ -229,6 +231,8 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
|
||||
|
||||
class QueueAnnotationReplyEvent(AppQueueEvent):
|
||||
@@ -306,6 +310,8 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
node_run_index: int = 1 # FIXME(-LAN-): may not used
|
||||
in_iteration_id: str | None = None
|
||||
in_loop_id: str | None = None
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
@@ -328,6 +334,8 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
@@ -383,6 +391,8 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
@@ -407,6 +417,8 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
@@ -262,6 +262,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
extras: dict[str, object] = Field(default_factory=dict)
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
mention_parent_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
@@ -285,6 +286,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
"extras": {},
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
"mention_parent_id": self.data.mention_parent_id,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -320,6 +322,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
mention_parent_id: str | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_FINISHED
|
||||
workflow_run_id: str
|
||||
@@ -349,6 +352,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
"files": [],
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
"mention_parent_id": self.data.mention_parent_id,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -384,6 +388,7 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
mention_parent_id: str | None = None
|
||||
retry_index: int = 0
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_RETRY
|
||||
@@ -414,6 +419,7 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
"files": [],
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
"mention_parent_id": self.data.mention_parent_id,
|
||||
"retry_index": self.data.retry_index,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import base64
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
|
||||
from configs import dify_config
|
||||
@@ -10,7 +11,10 @@ from core.model_runtime.entities import (
|
||||
TextPromptMessageContent,
|
||||
VideoPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessageContentUnionTypes,
|
||||
)
|
||||
from core.tools.signature import sign_tool_file
|
||||
from extensions.ext_storage import storage
|
||||
|
||||
@@ -18,6 +22,8 @@ from . import helpers
|
||||
from .enums import FileAttribute
|
||||
from .models import File, FileTransferMethod, FileType
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def get_attr(*, file: File, attr: FileAttribute):
|
||||
match attr:
|
||||
@@ -89,6 +95,8 @@ def to_prompt_message_content(
|
||||
"format": f.extension.removeprefix("."),
|
||||
"mime_type": f.mime_type,
|
||||
"filename": f.filename or "",
|
||||
# Encoded file reference for context restoration: "transfer_method:related_id" or "remote:url"
|
||||
"file_ref": _encode_file_ref(f),
|
||||
}
|
||||
if f.type == FileType.IMAGE:
|
||||
params["detail"] = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
@@ -96,6 +104,17 @@ def to_prompt_message_content(
|
||||
return prompt_class_map[f.type].model_validate(params)
|
||||
|
||||
|
||||
def _encode_file_ref(f: File) -> str | None:
|
||||
"""Encode file reference as 'transfer_method:id_or_url' string."""
|
||||
if f.transfer_method == FileTransferMethod.REMOTE_URL:
|
||||
return f"remote:{f.remote_url}" if f.remote_url else None
|
||||
elif f.transfer_method == FileTransferMethod.LOCAL_FILE:
|
||||
return f"local:{f.related_id}" if f.related_id else None
|
||||
elif f.transfer_method == FileTransferMethod.TOOL_FILE:
|
||||
return f"tool:{f.related_id}" if f.related_id else None
|
||||
return None
|
||||
|
||||
|
||||
def download(f: File, /):
|
||||
if f.transfer_method in (
|
||||
FileTransferMethod.TOOL_FILE,
|
||||
@@ -164,3 +183,128 @@ def _to_url(f: File, /):
|
||||
return sign_tool_file(tool_file_id=f.related_id, extension=f.extension)
|
||||
else:
|
||||
raise ValueError(f"Unsupported transfer method: {f.transfer_method}")
|
||||
|
||||
|
||||
def restore_multimodal_content(
|
||||
content: MultiModalPromptMessageContent,
|
||||
) -> MultiModalPromptMessageContent:
|
||||
"""
|
||||
Restore base64_data or url for multimodal content from file_ref.
|
||||
|
||||
file_ref format: "transfer_method:id_or_url" (e.g., "local:abc123", "remote:https://...")
|
||||
|
||||
Args:
|
||||
content: MultiModalPromptMessageContent with file_ref field
|
||||
|
||||
Returns:
|
||||
MultiModalPromptMessageContent with restored base64_data or url
|
||||
"""
|
||||
# Skip if no file reference or content already has data
|
||||
if not content.file_ref:
|
||||
return content
|
||||
if content.base64_data or content.url:
|
||||
return content
|
||||
|
||||
try:
|
||||
file = _build_file_from_ref(
|
||||
file_ref=content.file_ref,
|
||||
file_format=content.format,
|
||||
mime_type=content.mime_type,
|
||||
filename=content.filename,
|
||||
)
|
||||
if not file:
|
||||
return content
|
||||
|
||||
# Restore content based on config
|
||||
if dify_config.MULTIMODAL_SEND_FORMAT == "base64":
|
||||
restored_base64 = _get_encoded_string(file)
|
||||
return content.model_copy(update={"base64_data": restored_base64})
|
||||
else:
|
||||
restored_url = _to_url(file)
|
||||
return content.model_copy(update={"url": restored_url})
|
||||
|
||||
except Exception as e:
|
||||
logger.warning("Failed to restore multimodal content: %s", e)
|
||||
return content
|
||||
|
||||
|
||||
def _build_file_from_ref(
|
||||
file_ref: str,
|
||||
file_format: str | None,
|
||||
mime_type: str | None,
|
||||
filename: str | None,
|
||||
) -> File | None:
|
||||
"""
|
||||
Build a File object from encoded file_ref string.
|
||||
|
||||
Args:
|
||||
file_ref: Encoded reference "transfer_method:id_or_url"
|
||||
file_format: The file format/extension (without dot)
|
||||
mime_type: The mime type
|
||||
filename: The filename
|
||||
|
||||
Returns:
|
||||
File object with storage_key loaded, or None if not found
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.model import UploadFile
|
||||
from models.tools import ToolFile
|
||||
|
||||
# Parse file_ref: "method:value"
|
||||
if ":" not in file_ref:
|
||||
logger.warning("Invalid file_ref format: %s", file_ref)
|
||||
return None
|
||||
|
||||
method, value = file_ref.split(":", 1)
|
||||
extension = f".{file_format}" if file_format else None
|
||||
|
||||
if method == "remote":
|
||||
return File(
|
||||
tenant_id="",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url=value,
|
||||
extension=extension,
|
||||
mime_type=mime_type,
|
||||
filename=filename,
|
||||
storage_key="",
|
||||
)
|
||||
|
||||
# Query database for storage_key
|
||||
with Session(db.engine) as session:
|
||||
if method == "local":
|
||||
stmt = select(UploadFile).where(UploadFile.id == value)
|
||||
upload_file = session.scalar(stmt)
|
||||
if upload_file:
|
||||
return File(
|
||||
tenant_id=upload_file.tenant_id,
|
||||
type=FileType(upload_file.extension)
|
||||
if hasattr(FileType, upload_file.extension.upper())
|
||||
else FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id=value,
|
||||
extension=extension or ("." + upload_file.extension if upload_file.extension else None),
|
||||
mime_type=mime_type or upload_file.mime_type,
|
||||
filename=filename or upload_file.name,
|
||||
storage_key=upload_file.key,
|
||||
)
|
||||
elif method == "tool":
|
||||
stmt = select(ToolFile).where(ToolFile.id == value)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file:
|
||||
return File(
|
||||
tenant_id=tool_file.tenant_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id=value,
|
||||
extension=extension,
|
||||
mime_type=mime_type or tool_file.mimetype,
|
||||
filename=filename or tool_file.name,
|
||||
storage_key=tool_file.file_key,
|
||||
)
|
||||
|
||||
logger.warning("File not found for file_ref: %s", file_ref)
|
||||
return None
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
import json_repair
|
||||
|
||||
@@ -398,6 +398,488 @@ class LLMGenerator:
|
||||
logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
|
||||
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@classmethod
|
||||
def generate_with_context(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
parameter_name: str,
|
||||
language: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Generate extractor code node based on conversation context.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant/workspace ID
|
||||
workflow_id: Workflow ID
|
||||
node_id: Current tool/llm node ID
|
||||
parameter_name: Parameter name to generate code for
|
||||
language: Code language (python3/javascript)
|
||||
prompt_messages: Multi-turn conversation history (last message is instruction)
|
||||
model_config: Model configuration (provider, name, completion_params)
|
||||
|
||||
Returns:
|
||||
dict with CodeNodeData format:
|
||||
- variables: Input variable selectors
|
||||
- code_language: Code language
|
||||
- code: Generated code
|
||||
- outputs: Output definitions
|
||||
- message: Explanation
|
||||
- error: Error message if any
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
# Get workflow
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(App).where(App.id == workflow_id)
|
||||
app = session.scalar(stmt)
|
||||
if not app:
|
||||
return cls._error_response(f"App {workflow_id} not found")
|
||||
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return cls._error_response(f"Workflow for app {workflow_id} not found")
|
||||
|
||||
# Get upstream nodes via edge backtracking
|
||||
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
|
||||
|
||||
# Get current node info
|
||||
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
|
||||
if not current_node:
|
||||
return cls._error_response(f"Node {node_id} not found")
|
||||
|
||||
# Get parameter info
|
||||
parameter_info = cls._get_parameter_info(
|
||||
tenant_id=tenant_id,
|
||||
node_data=current_node.get("data", {}),
|
||||
parameter_name=parameter_name,
|
||||
)
|
||||
|
||||
# Build system prompt
|
||||
system_prompt = cls._build_extractor_system_prompt(
|
||||
upstream_nodes=upstream_nodes,
|
||||
current_node=current_node,
|
||||
parameter_info=parameter_info,
|
||||
language=language,
|
||||
)
|
||||
|
||||
# Construct complete prompt_messages with system prompt
|
||||
complete_messages: list[PromptMessage] = [
|
||||
SystemPromptMessage(content=system_prompt),
|
||||
*prompt_messages,
|
||||
]
|
||||
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
|
||||
# Get model instance and schema
|
||||
provider = model_config.get("provider", "")
|
||||
model_name = model_config.get("name", "")
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
)
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
|
||||
if not model_schema:
|
||||
return cls._error_response(f"Model schema not found for {model_name}")
|
||||
|
||||
model_parameters = model_config.get("completion_params", {})
|
||||
json_schema = cls._get_code_node_json_schema()
|
||||
|
||||
try:
|
||||
response = invoke_llm_with_structured_output(
|
||||
provider=provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=complete_messages,
|
||||
json_schema=json_schema,
|
||||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return cls._parse_code_node_output(
|
||||
response.structured_output, language, parameter_info.get("type", "string")
|
||||
)
|
||||
|
||||
except InvokeError as e:
|
||||
return cls._error_response(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate with context, model: %s", model_config.get("name"))
|
||||
return cls._error_response(f"An unexpected error occurred: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def _error_response(cls, error: str) -> dict:
|
||||
"""Return error response in CodeNodeData format."""
|
||||
return {
|
||||
"variables": [],
|
||||
"code_language": "python3",
|
||||
"code": "",
|
||||
"outputs": {},
|
||||
"message": "",
|
||||
"error": error,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def generate_suggested_questions(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
parameter_name: str,
|
||||
language: str,
|
||||
model_config: dict | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Generate suggested questions for context generation.
|
||||
|
||||
Returns dict with questions array and error field.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
# Get workflow context (reuse existing logic)
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(App).where(App.id == workflow_id)
|
||||
app = session.scalar(stmt)
|
||||
if not app:
|
||||
return {"questions": [], "error": f"App {workflow_id} not found"}
|
||||
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return {"questions": [], "error": f"Workflow for app {workflow_id} not found"}
|
||||
|
||||
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
|
||||
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
|
||||
if not current_node:
|
||||
return {"questions": [], "error": f"Node {node_id} not found"}
|
||||
|
||||
parameter_info = cls._get_parameter_info(
|
||||
tenant_id=tenant_id,
|
||||
node_data=current_node.get("data", {}),
|
||||
parameter_name=parameter_name,
|
||||
)
|
||||
|
||||
# Build prompt
|
||||
system_prompt = cls._build_suggested_questions_prompt(
|
||||
upstream_nodes=upstream_nodes,
|
||||
current_node=current_node,
|
||||
parameter_info=parameter_info,
|
||||
language=language,
|
||||
)
|
||||
|
||||
prompt_messages: list[PromptMessage] = [
|
||||
SystemPromptMessage(content=system_prompt),
|
||||
]
|
||||
|
||||
# Get model instance - use default if model_config not provided
|
||||
model_manager = ModelManager()
|
||||
if model_config:
|
||||
provider = model_config.get("provider", "")
|
||||
model_name = model_config.get("name", "")
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
)
|
||||
else:
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
model_name = model_instance.model
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
|
||||
if not model_schema:
|
||||
return {"questions": [], "error": f"Model schema not found for {model_name}"}
|
||||
|
||||
completion_params = model_config.get("completion_params", {}) if model_config else {}
|
||||
model_parameters = {**completion_params, "max_tokens": 256}
|
||||
json_schema = cls._get_suggested_questions_json_schema()
|
||||
|
||||
try:
|
||||
response = invoke_llm_with_structured_output(
|
||||
provider=model_instance.provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=json_schema,
|
||||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
questions = response.structured_output.get("questions", []) if response.structured_output else []
|
||||
return {"questions": questions, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
return {"questions": [], "error": str(e)}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate suggested questions, model: %s", model_name)
|
||||
return {"questions": [], "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@classmethod
|
||||
def _build_suggested_questions_prompt(
|
||||
cls,
|
||||
upstream_nodes: list[dict],
|
||||
current_node: dict,
|
||||
parameter_info: dict,
|
||||
language: str = "English",
|
||||
) -> str:
|
||||
"""Build minimal prompt for suggested questions generation."""
|
||||
# Simplify upstream nodes to reduce tokens
|
||||
sources = [f"{n['title']}({','.join(n.get('outputs', {}).keys())})" for n in upstream_nodes[:5]]
|
||||
param_type = parameter_info.get("type", "string")
|
||||
param_desc = parameter_info.get("description", "")[:100]
|
||||
|
||||
return f"""Suggest 3 code generation questions for extracting data.
|
||||
Sources: {", ".join(sources)}
|
||||
Target: {parameter_info.get("name")}({param_type}) - {param_desc}
|
||||
Output 3 short, practical questions in {language}."""
|
||||
|
||||
@classmethod
|
||||
def _get_suggested_questions_json_schema(cls) -> dict:
|
||||
"""Return JSON Schema for suggested questions."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"questions": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"minItems": 3,
|
||||
"maxItems": 3,
|
||||
"description": "3 suggested questions",
|
||||
},
|
||||
},
|
||||
"required": ["questions"],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_code_node_json_schema(cls) -> dict:
|
||||
"""Return JSON Schema for structured output."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"variables": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"variable": {"type": "string", "description": "Variable name in code"},
|
||||
"value_selector": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Path like [node_id, output_name]",
|
||||
},
|
||||
},
|
||||
"required": ["variable", "value_selector"],
|
||||
},
|
||||
},
|
||||
"code": {"type": "string", "description": "Generated code with main function"},
|
||||
"outputs": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"properties": {"type": {"type": "string"}},
|
||||
},
|
||||
"description": "Output definitions, key is output name",
|
||||
},
|
||||
"explanation": {"type": "string", "description": "Brief explanation of the code"},
|
||||
},
|
||||
"required": ["variables", "code", "outputs", "explanation"],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_upstream_nodes(cls, graph_dict: Mapping[str, Any], node_id: str) -> list[dict]:
|
||||
"""
|
||||
Get all upstream nodes via edge backtracking.
|
||||
|
||||
Traverses the graph backwards from node_id to collect all reachable nodes.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
nodes = {n["id"]: n for n in graph_dict.get("nodes", [])}
|
||||
edges = graph_dict.get("edges", [])
|
||||
|
||||
# Build reverse adjacency list
|
||||
reverse_adj: dict[str, list[str]] = defaultdict(list)
|
||||
for edge in edges:
|
||||
reverse_adj[edge["target"]].append(edge["source"])
|
||||
|
||||
# BFS to find all upstream nodes
|
||||
visited: set[str] = set()
|
||||
queue = [node_id]
|
||||
upstream: list[dict] = []
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
for source in reverse_adj.get(current, []):
|
||||
if source not in visited:
|
||||
visited.add(source)
|
||||
queue.append(source)
|
||||
if source in nodes:
|
||||
upstream.append(cls._extract_node_info(nodes[source]))
|
||||
|
||||
return upstream
|
||||
|
||||
@classmethod
|
||||
def _get_node_by_id(cls, graph_dict: Mapping[str, Any], node_id: str) -> dict | None:
|
||||
"""Get node by ID from graph."""
|
||||
for node in graph_dict.get("nodes", []):
|
||||
if node["id"] == node_id:
|
||||
return node
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_node_info(cls, node: dict) -> dict:
|
||||
"""Extract minimal node info with outputs based on node type."""
|
||||
node_type = node["data"]["type"]
|
||||
node_data = node.get("data", {})
|
||||
|
||||
# Build outputs based on node type (only type, no description to reduce tokens)
|
||||
outputs: dict[str, str] = {}
|
||||
match node_type:
|
||||
case "start":
|
||||
for var in node_data.get("variables", []):
|
||||
name = var.get("variable", var.get("name", ""))
|
||||
outputs[name] = var.get("type", "string")
|
||||
case "llm":
|
||||
outputs["text"] = "string"
|
||||
case "code":
|
||||
for name, output in node_data.get("outputs", {}).items():
|
||||
outputs[name] = output.get("type", "string")
|
||||
case "http-request":
|
||||
outputs = {"body": "string", "status_code": "number", "headers": "object"}
|
||||
case "knowledge-retrieval":
|
||||
outputs["result"] = "array[object]"
|
||||
case "tool":
|
||||
outputs = {"text": "string", "json": "object"}
|
||||
case _:
|
||||
outputs["output"] = "string"
|
||||
|
||||
info: dict = {
|
||||
"id": node["id"],
|
||||
"title": node_data.get("title", node["id"]),
|
||||
"outputs": outputs,
|
||||
}
|
||||
# Only include description if not empty
|
||||
desc = node_data.get("desc", "")
|
||||
if desc:
|
||||
info["desc"] = desc
|
||||
|
||||
return info
|
||||
|
||||
@classmethod
|
||||
def _get_parameter_info(cls, tenant_id: str, node_data: dict, parameter_name: str) -> dict:
|
||||
"""Get parameter info from tool schema using ToolManager."""
|
||||
default_info = {"name": parameter_name, "type": "string", "description": ""}
|
||||
|
||||
if node_data.get("type") != "tool":
|
||||
return default_info
|
||||
|
||||
try:
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
provider_type_str = node_data.get("provider_type", "")
|
||||
provider_type = ToolProviderType(provider_type_str) if provider_type_str else ToolProviderType.BUILT_IN
|
||||
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
provider_type=provider_type,
|
||||
provider_id=node_data.get("provider_id", ""),
|
||||
tool_name=node_data.get("tool_name", ""),
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
parameters = tool_runtime.get_merged_runtime_parameters()
|
||||
for param in parameters:
|
||||
if param.name == parameter_name:
|
||||
return {
|
||||
"name": param.name,
|
||||
"type": param.type.value if hasattr(param.type, "value") else str(param.type),
|
||||
"description": param.llm_description
|
||||
or (param.human_description.en_US if param.human_description else ""),
|
||||
"required": param.required,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug("Failed to get parameter info from ToolManager: %s", e)
|
||||
|
||||
return default_info
|
||||
|
||||
@classmethod
|
||||
def _build_extractor_system_prompt(
|
||||
cls,
|
||||
upstream_nodes: list[dict],
|
||||
current_node: dict,
|
||||
parameter_info: dict,
|
||||
language: str,
|
||||
) -> str:
|
||||
"""Build system prompt for extractor code generation."""
|
||||
upstream_json = json.dumps(upstream_nodes, indent=2, ensure_ascii=False)
|
||||
param_type = parameter_info.get("type", "string")
|
||||
return f"""You are a code generator for workflow automation.
|
||||
|
||||
Generate {language} code to extract/transform upstream node outputs for the target parameter.
|
||||
|
||||
## Upstream Nodes
|
||||
{upstream_json}
|
||||
|
||||
## Target
|
||||
Node: {current_node["data"].get("title", current_node["id"])}
|
||||
Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("description", "")}
|
||||
|
||||
## Requirements
|
||||
- Write a main function that returns type: {param_type}
|
||||
- Use value_selector format: ["node_id", "output_name"]
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _parse_code_node_output(cls, content: Mapping[str, Any] | None, language: str, parameter_type: str) -> dict:
|
||||
"""
|
||||
Parse structured output to CodeNodeData format.
|
||||
|
||||
Args:
|
||||
content: Structured output dict from invoke_llm_with_structured_output
|
||||
language: Code language
|
||||
parameter_type: Expected parameter type
|
||||
|
||||
Returns dict with variables, code_language, code, outputs, message, error.
|
||||
"""
|
||||
if content is None:
|
||||
return cls._error_response("Empty or invalid response from LLM")
|
||||
|
||||
# Validate and normalize variables
|
||||
variables = [
|
||||
{"variable": v.get("variable", ""), "value_selector": v.get("value_selector", [])}
|
||||
for v in content.get("variables", [])
|
||||
if isinstance(v, dict)
|
||||
]
|
||||
|
||||
outputs = content.get("outputs", {"result": {"type": parameter_type}})
|
||||
|
||||
return {
|
||||
"variables": variables,
|
||||
"code_language": language,
|
||||
"code": content.get("code", ""),
|
||||
"outputs": outputs,
|
||||
"message": content.get("explanation", ""),
|
||||
"error": "",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def instruction_modify_legacy(
|
||||
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
|
||||
|
||||
188
api/core/llm_generator/output_parser/file_ref.py
Normal file
188
api/core/llm_generator/output_parser/file_ref.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
File reference detection and conversion for structured output.
|
||||
|
||||
This module provides utilities to:
|
||||
1. Detect file reference fields in JSON Schema (format: "dify-file-ref")
|
||||
2. Convert file ID strings to File objects after LLM returns
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.file import File
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from factories.file_factory import build_from_mapping
|
||||
|
||||
FILE_REF_FORMAT = "dify-file-ref"
|
||||
|
||||
|
||||
def is_file_ref_property(schema: dict) -> bool:
|
||||
"""Check if a schema property is a file reference."""
|
||||
return schema.get("type") == "string" and schema.get("format") == FILE_REF_FORMAT
|
||||
|
||||
|
||||
def detect_file_ref_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
|
||||
"""
|
||||
Recursively detect file reference fields in schema.
|
||||
|
||||
Args:
|
||||
schema: JSON Schema to analyze
|
||||
path: Current path in the schema (used for recursion)
|
||||
|
||||
Returns:
|
||||
List of JSON paths containing file refs, e.g., ["image_id", "files[*]"]
|
||||
"""
|
||||
file_ref_paths: list[str] = []
|
||||
schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "object":
|
||||
for prop_name, prop_schema in schema.get("properties", {}).items():
|
||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||
|
||||
if is_file_ref_property(prop_schema):
|
||||
file_ref_paths.append(current_path)
|
||||
elif isinstance(prop_schema, dict):
|
||||
file_ref_paths.extend(detect_file_ref_fields(prop_schema, current_path))
|
||||
|
||||
elif schema_type == "array":
|
||||
items_schema = schema.get("items", {})
|
||||
array_path = f"{path}[*]" if path else "[*]"
|
||||
|
||||
if is_file_ref_property(items_schema):
|
||||
file_ref_paths.append(array_path)
|
||||
elif isinstance(items_schema, dict):
|
||||
file_ref_paths.extend(detect_file_ref_fields(items_schema, array_path))
|
||||
|
||||
return file_ref_paths
|
||||
|
||||
|
||||
def convert_file_refs_in_output(
|
||||
output: Mapping[str, Any],
|
||||
json_schema: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert file ID strings to File objects based on schema.
|
||||
|
||||
Args:
|
||||
output: The structured_output from LLM result
|
||||
json_schema: The original JSON schema (to detect file ref fields)
|
||||
tenant_id: Tenant ID for file lookup
|
||||
|
||||
Returns:
|
||||
Output with file references converted to File objects
|
||||
"""
|
||||
file_ref_paths = detect_file_ref_fields(json_schema)
|
||||
if not file_ref_paths:
|
||||
return dict(output)
|
||||
|
||||
result = _deep_copy_dict(output)
|
||||
|
||||
for path in file_ref_paths:
|
||||
_convert_path_in_place(result, path.split("."), tenant_id)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _deep_copy_dict(obj: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""Deep copy a mapping to a mutable dict."""
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in obj.items():
|
||||
if isinstance(value, Mapping):
|
||||
result[key] = _deep_copy_dict(value)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [_deep_copy_dict(item) if isinstance(item, Mapping) else item for item in value]
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def _convert_path_in_place(obj: dict, path_parts: list[str], tenant_id: str) -> None:
|
||||
"""Convert file refs at the given path in place, wrapping in Segment types."""
|
||||
if not path_parts:
|
||||
return
|
||||
|
||||
current = path_parts[0]
|
||||
remaining = path_parts[1:]
|
||||
|
||||
# Handle array notation like "files[*]"
|
||||
if current.endswith("[*]"):
|
||||
key = current[:-3] if current != "[*]" else None
|
||||
target = obj.get(key) if key else obj
|
||||
|
||||
if isinstance(target, list):
|
||||
if remaining:
|
||||
# Nested array with remaining path - recurse into each item
|
||||
for item in target:
|
||||
if isinstance(item, dict):
|
||||
_convert_path_in_place(item, remaining, tenant_id)
|
||||
else:
|
||||
# Array of file IDs - convert all and wrap in ArrayFileSegment
|
||||
files: list[File] = []
|
||||
for item in target:
|
||||
file = _convert_file_id(item, tenant_id)
|
||||
if file is not None:
|
||||
files.append(file)
|
||||
# Replace the array with ArrayFileSegment
|
||||
if key:
|
||||
obj[key] = ArrayFileSegment(value=files)
|
||||
return
|
||||
|
||||
if not remaining:
|
||||
# Leaf node - convert the value and wrap in FileSegment
|
||||
if current in obj:
|
||||
file = _convert_file_id(obj[current], tenant_id)
|
||||
if file is not None:
|
||||
obj[current] = FileSegment(value=file)
|
||||
else:
|
||||
obj[current] = None
|
||||
else:
|
||||
# Recurse into nested object
|
||||
if current in obj and isinstance(obj[current], dict):
|
||||
_convert_path_in_place(obj[current], remaining, tenant_id)
|
||||
|
||||
|
||||
def _convert_file_id(file_id: Any, tenant_id: str) -> File | None:
|
||||
"""
|
||||
Convert a file ID string to a File object.
|
||||
|
||||
Tries multiple file sources in order:
|
||||
1. ToolFile (files generated by tools/workflows)
|
||||
2. UploadFile (files uploaded by users)
|
||||
"""
|
||||
if not isinstance(file_id, str):
|
||||
return None
|
||||
|
||||
# Validate UUID format
|
||||
try:
|
||||
uuid.UUID(file_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
# Try ToolFile first (files generated by tools/workflows)
|
||||
try:
|
||||
return build_from_mapping(
|
||||
mapping={
|
||||
"transfer_method": "tool_file",
|
||||
"tool_file_id": file_id,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try UploadFile (files uploaded by users)
|
||||
try:
|
||||
return build_from_mapping(
|
||||
mapping={
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": file_id,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# File not found in any source
|
||||
return None
|
||||
@@ -8,6 +8,7 @@ import json_repair
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
|
||||
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
@@ -57,6 +58,7 @@ def invoke_llm_with_structured_output(
|
||||
stream: Literal[True],
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
@@ -72,6 +74,7 @@ def invoke_llm_with_structured_output(
|
||||
stream: Literal[False],
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
@@ -87,6 +90,7 @@ def invoke_llm_with_structured_output(
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
def invoke_llm_with_structured_output(
|
||||
*,
|
||||
@@ -101,20 +105,28 @@ def invoke_llm_with_structured_output(
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
"""
|
||||
Invoke large language model with structured output
|
||||
1. This method invokes model_instance.invoke_llm with json_schema
|
||||
2. Try to parse the result as structured output
|
||||
Invoke large language model with structured output.
|
||||
|
||||
This method invokes model_instance.invoke_llm with json_schema and parses
|
||||
the result as structured output.
|
||||
|
||||
:param provider: model provider name
|
||||
:param model_schema: model schema entity
|
||||
:param model_instance: model instance to invoke
|
||||
:param prompt_messages: prompt messages
|
||||
:param json_schema: json schema
|
||||
:param json_schema: json schema for structured output
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:param tenant_id: tenant ID for file reference conversion. When provided and
|
||||
json_schema contains file reference fields (format: "dify-file-ref"),
|
||||
file IDs in the output will be automatically converted to File objects.
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
@@ -153,8 +165,18 @@ def invoke_llm_with_structured_output(
|
||||
f"Failed to parse structured output, LLM result is not a string: {llm_result.message.content}"
|
||||
)
|
||||
|
||||
structured_output = _parse_structured_output(llm_result.message.content)
|
||||
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
structured_output = convert_file_refs_in_output(
|
||||
output=structured_output,
|
||||
json_schema=json_schema,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return LLMResultWithStructuredOutput(
|
||||
structured_output=_parse_structured_output(llm_result.message.content),
|
||||
structured_output=structured_output,
|
||||
model=llm_result.model,
|
||||
message=llm_result.message,
|
||||
usage=llm_result.usage,
|
||||
@@ -186,8 +208,18 @@ def invoke_llm_with_structured_output(
|
||||
delta=event.delta,
|
||||
)
|
||||
|
||||
structured_output = _parse_structured_output(result_text)
|
||||
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
structured_output = convert_file_refs_in_output(
|
||||
output=structured_output,
|
||||
json_schema=json_schema,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
structured_output=_parse_structured_output(result_text),
|
||||
structured_output=structured_output,
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
|
||||
45
api/core/llm_generator/utils.py
Normal file
45
api/core/llm_generator/utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Utility functions for LLM generator."""
|
||||
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
|
||||
def deserialize_prompt_messages(messages: list[dict]) -> list[PromptMessage]:
|
||||
"""
|
||||
Deserialize list of dicts to list[PromptMessage].
|
||||
|
||||
Expected format:
|
||||
[
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."},
|
||||
]
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
for msg in messages:
|
||||
role = PromptMessageRole.value_of(msg["role"])
|
||||
content = msg.get("content", "")
|
||||
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
result.append(UserPromptMessage(content=content))
|
||||
case PromptMessageRole.ASSISTANT:
|
||||
result.append(AssistantPromptMessage(content=content))
|
||||
case PromptMessageRole.SYSTEM:
|
||||
result.append(SystemPromptMessage(content=content))
|
||||
case PromptMessageRole.TOOL:
|
||||
result.append(ToolPromptMessage(content=content, tool_call_id=msg.get("tool_call_id", "")))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def serialize_prompt_messages(messages: list[PromptMessage]) -> list[dict]:
|
||||
"""
|
||||
Serialize list[PromptMessage] to list of dicts.
|
||||
"""
|
||||
return [{"role": msg.role.value, "content": msg.content} for msg in messages]
|
||||
267
api/core/memory/README.md
Normal file
267
api/core/memory/README.md
Normal file
@@ -0,0 +1,267 @@
|
||||
# Memory Module
|
||||
|
||||
This module provides memory management for LLM conversations, enabling context retention across dialogue turns.
|
||||
|
||||
## Overview
|
||||
|
||||
The memory module contains two types of memory implementations:
|
||||
|
||||
1. **TokenBufferMemory** - Conversation-level memory (existing)
|
||||
2. **NodeTokenBufferMemory** - Node-level memory (**Chatflow only**)
|
||||
|
||||
> **Note**: `NodeTokenBufferMemory` is only available in **Chatflow** (advanced-chat mode).
|
||||
> This is because it requires both `conversation_id` and `node_id`, which are only present in Chatflow.
|
||||
> Standard Workflow mode does not have `conversation_id` and therefore cannot use node-level memory.
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||||
│ Memory Architecture │
|
||||
├─────────────────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
|
||||
│ │ TokenBufferMemory │ │
|
||||
│ │ Scope: Conversation │ │
|
||||
│ │ Storage: Database (Message table) │ │
|
||||
│ │ Key: conversation_id │ │
|
||||
│ └─────────────────────────────────────────────────────────────────────-┘ │
|
||||
│ │
|
||||
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
|
||||
│ │ NodeTokenBufferMemory │ │
|
||||
│ │ Scope: Node within Conversation │ │
|
||||
│ │ Storage: WorkflowNodeExecutionModel.outputs["context"] │ │
|
||||
│ │ Key: (conversation_id, node_id, workflow_run_id) │ │
|
||||
│ └─────────────────────────────────────────────────────────────────────-┘ │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## TokenBufferMemory (Existing)
|
||||
|
||||
### Purpose
|
||||
|
||||
`TokenBufferMemory` retrieves conversation history from the `Message` table and converts it to `PromptMessage` objects for LLM context.
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Conversation-scoped**: All messages within a conversation are candidates
|
||||
- **Thread-aware**: Uses `parent_message_id` to extract only the current thread (supports regeneration scenarios)
|
||||
- **Token-limited**: Truncates history to fit within `max_token_limit`
|
||||
- **File support**: Handles `MessageFile` attachments (images, documents, etc.)
|
||||
|
||||
### Data Flow
|
||||
|
||||
```
|
||||
Message Table TokenBufferMemory LLM
|
||||
│ │ │
|
||||
│ SELECT * FROM messages │ │
|
||||
│ WHERE conversation_id = ? │ │
|
||||
│ ORDER BY created_at DESC │ │
|
||||
├─────────────────────────────────▶│ │
|
||||
│ │ │
|
||||
│ extract_thread_messages() │
|
||||
│ │ │
|
||||
│ build_prompt_message_with_files() │
|
||||
│ │ │
|
||||
│ truncate by max_token_limit │
|
||||
│ │ │
|
||||
│ │ Sequence[PromptMessage]
|
||||
│ ├───────────────────────▶│
|
||||
│ │ │
|
||||
```
|
||||
|
||||
### Thread Extraction
|
||||
|
||||
When a user regenerates a response, a new thread is created:
|
||||
|
||||
```
|
||||
Message A (user)
|
||||
└── Message A' (assistant)
|
||||
└── Message B (user)
|
||||
└── Message B' (assistant)
|
||||
└── Message A'' (assistant, regenerated) ← New thread
|
||||
└── Message C (user)
|
||||
└── Message C' (assistant)
|
||||
```
|
||||
|
||||
`extract_thread_messages()` traces back from the latest message using `parent_message_id` to get only the current thread: `[A, A'', C, C']`
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
history = memory.get_history_prompt_messages(max_token_limit=2000, message_limit=100)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## NodeTokenBufferMemory
|
||||
|
||||
### Purpose
|
||||
|
||||
`NodeTokenBufferMemory` provides **node-scoped memory** within a conversation. Each LLM node in a workflow can maintain its own independent conversation history.
|
||||
|
||||
### Use Cases
|
||||
|
||||
1. **Multi-LLM Workflows**: Different LLM nodes need separate context
|
||||
2. **Iterative Processing**: An LLM node in a loop needs to accumulate context across iterations
|
||||
3. **Specialized Agents**: Each agent node maintains its own dialogue history
|
||||
|
||||
### Design: Zero Extra Storage
|
||||
|
||||
**Key insight**: LLM node already saves complete context in `outputs["context"]`.
|
||||
|
||||
Each LLM node execution outputs:
|
||||
```python
|
||||
outputs = {
|
||||
"text": clean_text,
|
||||
"context": self._build_context(prompt_messages, clean_text), # Complete dialogue history!
|
||||
...
|
||||
}
|
||||
```
|
||||
|
||||
This `outputs["context"]` contains:
|
||||
- All previous user/assistant messages (excluding system prompt)
|
||||
- The current assistant response
|
||||
|
||||
**No separate storage needed** - we just read from the last execution's `outputs["context"]`.
|
||||
|
||||
### Benefits
|
||||
|
||||
| Aspect | Old Design (Object Storage) | New Design (outputs["context"]) |
|
||||
|--------|----------------------------|--------------------------------|
|
||||
| Storage | Separate JSON file | Already in WorkflowNodeExecutionModel |
|
||||
| Concurrency | Race condition risk | No issue (each execution is INSERT) |
|
||||
| Cleanup | Need separate cleanup task | Follows node execution lifecycle |
|
||||
| Migration | Required | None |
|
||||
| Complexity | High | Low |
|
||||
|
||||
### Data Flow
|
||||
|
||||
```
|
||||
WorkflowNodeExecutionModel NodeTokenBufferMemory LLM Node
|
||||
│ │ │
|
||||
│ │◀── get_history_prompt_messages()
|
||||
│ │ │
|
||||
│ SELECT outputs FROM │ │
|
||||
│ workflow_node_executions │ │
|
||||
│ WHERE workflow_run_id = ? │ │
|
||||
│ AND node_id = ? │ │
|
||||
│◀─────────────────────────────────┤ │
|
||||
│ │ │
|
||||
│ outputs["context"] │ │
|
||||
├─────────────────────────────────▶│ │
|
||||
│ │ │
|
||||
│ deserialize PromptMessages │
|
||||
│ │ │
|
||||
│ truncate by max_token_limit │
|
||||
│ │ │
|
||||
│ │ Sequence[PromptMessage] │
|
||||
│ ├──────────────────────────▶│
|
||||
│ │ │
|
||||
```
|
||||
|
||||
### Thread Tracking
|
||||
|
||||
Thread extraction still uses `Message` table's `parent_message_id` structure:
|
||||
|
||||
1. Query `Message` table for conversation → get thread's `workflow_run_ids`
|
||||
2. Get the last completed `workflow_run_id` in the thread
|
||||
3. Query `WorkflowNodeExecutionModel` for that execution's `outputs["context"]`
|
||||
|
||||
### API
|
||||
|
||||
```python
|
||||
class NodeTokenBufferMemory:
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
model_instance: ModelInstance,
|
||||
):
|
||||
"""Initialize node-level memory."""
|
||||
...
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Retrieve history as PromptMessage sequence.
|
||||
|
||||
Reads from last completed execution's outputs["context"].
|
||||
"""
|
||||
...
|
||||
|
||||
# Legacy methods (no-op, kept for compatibility)
|
||||
def add_messages(self, *args, **kwargs) -> None: pass
|
||||
def flush(self) -> None: pass
|
||||
def clear(self) -> None: pass
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
Add to `MemoryConfig` in `core/workflow/nodes/llm/entities.py`:
|
||||
|
||||
```python
|
||||
class MemoryMode(StrEnum):
|
||||
CONVERSATION = "conversation" # Use TokenBufferMemory (default)
|
||||
NODE = "node" # Use NodeTokenBufferMemory (Chatflow only)
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
role_prefix: RolePrefix | None = None
|
||||
window: MemoryWindowConfig | None = None
|
||||
query_prompt_template: str | None = None
|
||||
mode: MemoryMode = MemoryMode.CONVERSATION
|
||||
```
|
||||
|
||||
**Mode Behavior:**
|
||||
|
||||
| Mode | Memory Class | Scope | Availability |
|
||||
| -------------- | --------------------- | ------------------------ | ------------- |
|
||||
| `conversation` | TokenBufferMemory | Entire conversation | All app modes |
|
||||
| `node` | NodeTokenBufferMemory | Per-node in conversation | Chatflow only |
|
||||
|
||||
> When `mode=node` is used in a non-Chatflow context (no conversation_id), it falls back to no memory.
|
||||
|
||||
---
|
||||
|
||||
## Comparison
|
||||
|
||||
| Feature | TokenBufferMemory | NodeTokenBufferMemory |
|
||||
| -------------- | ------------------------ | ---------------------------------- |
|
||||
| Scope | Conversation | Node within Conversation |
|
||||
| Storage | Database (Message table) | WorkflowNodeExecutionModel.outputs |
|
||||
| Thread Support | Yes | Yes |
|
||||
| File Support | Yes (via MessageFile) | Yes (via context serialization) |
|
||||
| Token Limit | Yes | Yes |
|
||||
| Use Case | Standard chat apps | Complex workflows |
|
||||
|
||||
---
|
||||
|
||||
## Extending to Other Nodes
|
||||
|
||||
Currently only **LLM Node** outputs `context` in its outputs. To enable node memory for other nodes:
|
||||
|
||||
1. Add `outputs["context"] = self._build_context(prompt_messages, response)` in the node
|
||||
2. The `NodeTokenBufferMemory` will automatically pick it up
|
||||
|
||||
Nodes that could potentially support this:
|
||||
- `question_classifier`
|
||||
- `parameter_extractor`
|
||||
- `agent`
|
||||
|
||||
---
|
||||
|
||||
## Future Considerations
|
||||
|
||||
1. **Cleanup**: Node memory lifecycle follows `WorkflowNodeExecutionModel`, which already has cleanup mechanisms
|
||||
2. **Compression**: For very long conversations, consider summarization strategies
|
||||
3. **Extension**: Other nodes may benefit from node-level memory
|
||||
11
api/core/memory/__init__.py
Normal file
11
api/core/memory/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.node_token_buffer_memory import (
|
||||
NodeTokenBufferMemory,
|
||||
)
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
||||
__all__ = [
|
||||
"BaseMemory",
|
||||
"NodeTokenBufferMemory",
|
||||
"TokenBufferMemory",
|
||||
]
|
||||
83
api/core/memory/base.py
Normal file
83
api/core/memory/base.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
Base memory interfaces and types.
|
||||
|
||||
This module defines the common protocol for memory implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, PromptMessage
|
||||
|
||||
|
||||
class BaseMemory(ABC):
|
||||
"""
|
||||
Abstract base class for memory implementations.
|
||||
|
||||
Provides a common interface for both conversation-level and node-level memory.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: Maximum number of messages
|
||||
:return: Sequence of PromptMessage for LLM context
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt as formatted text.
|
||||
|
||||
:param human_prefix: Prefix for human messages
|
||||
:param ai_prefix: Prefix for assistant messages
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: Maximum number of messages
|
||||
:return: Formatted history text
|
||||
"""
|
||||
from core.model_runtime.entities import (
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
|
||||
prompt_messages = self.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=message_limit,
|
||||
)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
197
api/core/memory/node_token_buffer_memory.py
Normal file
197
api/core/memory/node_token_buffer_memory.py
Normal file
@@ -0,0 +1,197 @@
|
||||
"""
|
||||
Node-level Token Buffer Memory for Chatflow.
|
||||
|
||||
This module provides node-scoped memory within a conversation.
|
||||
Each LLM node in a workflow can maintain its own independent conversation history.
|
||||
|
||||
Note: This is only available in Chatflow (advanced-chat mode) because it requires
|
||||
both conversation_id and node_id.
|
||||
|
||||
Design:
|
||||
- History is read directly from WorkflowNodeExecutionModel.outputs["context"]
|
||||
- No separate storage needed - the context is already saved during node execution
|
||||
- Thread tracking leverages Message table's parent_message_id structure
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.file import file_manager
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeTokenBufferMemory(BaseMemory):
|
||||
"""
|
||||
Node-level Token Buffer Memory.
|
||||
|
||||
Provides node-scoped memory within a conversation. Each LLM node can maintain
|
||||
its own independent conversation history.
|
||||
|
||||
Key design: History is read directly from WorkflowNodeExecutionModel.outputs["context"],
|
||||
which is already saved during node execution. No separate storage needed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
model_instance: ModelInstance,
|
||||
):
|
||||
self.app_id = app_id
|
||||
self.conversation_id = conversation_id
|
||||
self.node_id = node_id
|
||||
self.tenant_id = tenant_id
|
||||
self.model_instance = model_instance
|
||||
|
||||
def _get_thread_workflow_run_ids(self) -> list[str]:
|
||||
"""
|
||||
Get workflow_run_ids for the current thread by querying Message table.
|
||||
Returns workflow_run_ids in chronological order (oldest first).
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(Message)
|
||||
.where(Message.conversation_id == self.conversation_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(500)
|
||||
)
|
||||
messages = list(session.scalars(stmt).all())
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# Extract thread messages using existing logic
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# For newly created message, its answer is temporarily empty, skip it
|
||||
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||
thread_messages.pop(0)
|
||||
|
||||
# Reverse to get chronological order, extract workflow_run_ids
|
||||
return [msg.workflow_run_id for msg in reversed(thread_messages) if msg.workflow_run_id]
|
||||
|
||||
def _deserialize_prompt_message(self, msg_dict: dict) -> PromptMessage:
|
||||
"""Deserialize a dict to PromptMessage based on role."""
|
||||
role = msg_dict.get("role")
|
||||
if role in (PromptMessageRole.USER, "user"):
|
||||
return UserPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.ASSISTANT, "assistant"):
|
||||
return AssistantPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.SYSTEM, "system"):
|
||||
return SystemPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.TOOL, "tool"):
|
||||
return ToolPromptMessage.model_validate(msg_dict)
|
||||
else:
|
||||
return PromptMessage.model_validate(msg_dict)
|
||||
|
||||
def _deserialize_context(self, context_data: list[dict]) -> list[PromptMessage]:
|
||||
"""Deserialize context data from outputs to list of PromptMessage."""
|
||||
messages = []
|
||||
for msg_dict in context_data:
|
||||
try:
|
||||
msg = self._deserialize_prompt_message(msg_dict)
|
||||
msg = self._restore_multimodal_content(msg)
|
||||
messages.append(msg)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to deserialize prompt message: %s", e)
|
||||
return messages
|
||||
|
||||
def _restore_multimodal_content(self, message: PromptMessage) -> PromptMessage:
|
||||
"""
|
||||
Restore multimodal content (base64 or url) from file_ref.
|
||||
|
||||
When context is saved, base64_data is cleared to save storage space.
|
||||
This method restores the content by parsing file_ref (format: "method:id_or_url").
|
||||
"""
|
||||
content = message.content
|
||||
if content is None or isinstance(content, str):
|
||||
return message
|
||||
|
||||
# Process list content, restoring multimodal data from file references
|
||||
restored_content: list[PromptMessageContentUnionTypes] = []
|
||||
for item in content:
|
||||
if isinstance(item, MultiModalPromptMessageContent):
|
||||
# restore_multimodal_content preserves the concrete subclass type
|
||||
restored_item = file_manager.restore_multimodal_content(item)
|
||||
restored_content.append(cast(PromptMessageContentUnionTypes, restored_item))
|
||||
else:
|
||||
restored_content.append(item)
|
||||
|
||||
return message.model_copy(update={"content": restored_content})
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Retrieve history as PromptMessage sequence.
|
||||
History is read directly from the last completed node execution's outputs["context"].
|
||||
"""
|
||||
_ = message_limit # unused, kept for interface compatibility
|
||||
|
||||
thread_workflow_run_ids = self._get_thread_workflow_run_ids()
|
||||
if not thread_workflow_run_ids:
|
||||
return []
|
||||
|
||||
# Get the last completed workflow_run_id (contains accumulated context)
|
||||
last_run_id = thread_workflow_run_ids[-1]
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == last_run_id,
|
||||
WorkflowNodeExecutionModel.node_id == self.node_id,
|
||||
WorkflowNodeExecutionModel.status == "succeeded",
|
||||
)
|
||||
execution = session.scalars(stmt).first()
|
||||
|
||||
if not execution:
|
||||
return []
|
||||
|
||||
outputs = execution.outputs_dict
|
||||
if not outputs:
|
||||
return []
|
||||
|
||||
context_data = outputs.get("context")
|
||||
|
||||
if not context_data or not isinstance(context_data, list):
|
||||
return []
|
||||
|
||||
prompt_messages = self._deserialize_context(context_data)
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
||||
# Truncate by token limit
|
||||
try:
|
||||
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
while current_tokens > max_token_limit and len(prompt_messages) > 1:
|
||||
prompt_messages.pop(0)
|
||||
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to count tokens for truncation: %s", e)
|
||||
|
||||
return prompt_messages
|
||||
@@ -5,12 +5,12 @@ from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
@@ -24,7 +24,7 @@ from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class TokenBufferMemory:
|
||||
class TokenBufferMemory(BaseMemory):
|
||||
def __init__(
|
||||
self,
|
||||
conversation: Conversation,
|
||||
@@ -115,10 +115,14 @@ class TokenBufferMemory:
|
||||
return AssistantPromptMessage(content=prompt_message_contents)
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: int | None = None
|
||||
self,
|
||||
*,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
"""
|
||||
@@ -200,44 +204,3 @@ class TokenBufferMemory:
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt text.
|
||||
:param human_prefix: human prefix
|
||||
:param ai_prefix: ai prefix
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
:return:
|
||||
"""
|
||||
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
|
||||
@@ -91,6 +91,9 @@ class MultiModalPromptMessageContent(PromptMessageContent):
|
||||
mime_type: str = Field(default=..., description="the mime type of multi-modal file")
|
||||
filename: str = Field(default="", description="the filename of multi-modal file")
|
||||
|
||||
# File reference for context restoration, format: "transfer_method:related_id" or "remote:url"
|
||||
file_ref: str | None = Field(default=None, description="Encoded file reference for restoration")
|
||||
|
||||
@property
|
||||
def data(self):
|
||||
return self.url or f"data:{self.mime_type};base64,{self.base64_data}"
|
||||
@@ -276,7 +279,5 @@ class ToolPromptMessage(PromptMessage):
|
||||
|
||||
:return: True if prompt message is empty, False otherwise
|
||||
"""
|
||||
if not super().is_empty() and not self.tool_call_id:
|
||||
return False
|
||||
|
||||
return True
|
||||
# ToolPromptMessage is not empty if it has content OR has a tool_call_id
|
||||
return super().is_empty() and not self.tool_call_id
|
||||
|
||||
@@ -5,7 +5,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||
from core.file import file_manager
|
||||
from core.file.models import File
|
||||
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
@@ -43,7 +43,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@@ -84,7 +84,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@@ -145,7 +145,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@@ -270,7 +270,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
def _set_histories_variable(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory: BaseMemory,
|
||||
memory_config: MemoryConfig,
|
||||
raw_prompt: str,
|
||||
role_prefix: MemoryConfig.RolePrefix,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -5,6 +6,13 @@ from pydantic import BaseModel
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
|
||||
|
||||
class MemoryMode(StrEnum):
|
||||
"""Memory mode for LLM nodes."""
|
||||
|
||||
CONVERSATION = "conversation" # Use TokenBufferMemory (default, existing behavior)
|
||||
NODE = "node" # Use NodeTokenBufferMemory (Chatflow only)
|
||||
|
||||
|
||||
class ChatModelMessage(BaseModel):
|
||||
"""
|
||||
Chat Message.
|
||||
@@ -48,3 +56,4 @@ class MemoryConfig(BaseModel):
|
||||
role_prefix: RolePrefix | None = None
|
||||
window: WindowConfig
|
||||
query_prompt_template: str | None = None
|
||||
mode: MemoryMode = MemoryMode.CONVERSATION
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
@@ -11,7 +11,7 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
class PromptTransform:
|
||||
def _append_chat_histories(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory: BaseMemory,
|
||||
memory_config: MemoryConfig,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
@@ -52,7 +52,7 @@ class PromptTransform:
|
||||
|
||||
def _get_history_messages_from_memory(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory: BaseMemory,
|
||||
memory_config: MemoryConfig,
|
||||
max_token_limit: int,
|
||||
human_prefix: str | None = None,
|
||||
@@ -73,7 +73,7 @@ class PromptTransform:
|
||||
return memory.get_history_prompt_text(**kwargs)
|
||||
|
||||
def _get_history_messages_list_from_memory(
|
||||
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
|
||||
self, memory: BaseMemory, memory_config: MemoryConfig, max_token_limit: int
|
||||
) -> list[PromptMessage]:
|
||||
"""Get memory messages."""
|
||||
return list(
|
||||
|
||||
@@ -1047,6 +1047,8 @@ class ToolManager:
|
||||
continue
|
||||
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
|
||||
if tool_input.type == "variable":
|
||||
if not isinstance(tool_input.value, list):
|
||||
raise ToolParameterError(f"Invalid variable selector for {parameter.name}")
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if variable is None:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
@@ -1056,6 +1058,11 @@ class ToolManager:
|
||||
elif tool_input.type == "mixed":
|
||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
||||
parameter_value = segment_group.text
|
||||
elif tool_input.type == "mention":
|
||||
# Mention type not supported in agent mode
|
||||
raise ToolParameterError(
|
||||
f"Mention type not supported in agent for parameter '{parameter.name}'"
|
||||
)
|
||||
else:
|
||||
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
|
||||
runtime_parameters[parameter.name] = parameter_value
|
||||
|
||||
@@ -4,6 +4,7 @@ from .segments import (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayPromptMessageSegment,
|
||||
ArraySegment,
|
||||
ArrayStringSegment,
|
||||
FileSegment,
|
||||
@@ -20,6 +21,7 @@ from .variables import (
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayPromptMessageVariable,
|
||||
ArrayStringVariable,
|
||||
ArrayVariable,
|
||||
FileVariable,
|
||||
@@ -42,6 +44,8 @@ __all__ = [
|
||||
"ArrayNumberVariable",
|
||||
"ArrayObjectSegment",
|
||||
"ArrayObjectVariable",
|
||||
"ArrayPromptMessageSegment",
|
||||
"ArrayPromptMessageVariable",
|
||||
"ArraySegment",
|
||||
"ArrayStringSegment",
|
||||
"ArrayStringVariable",
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Annotated, Any, TypeAlias
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
|
||||
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
|
||||
from .types import SegmentType
|
||||
|
||||
@@ -208,6 +209,15 @@ class ArrayBooleanSegment(ArraySegment):
|
||||
value: Sequence[bool]
|
||||
|
||||
|
||||
class ArrayPromptMessageSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_PROMPT_MESSAGE
|
||||
value: Sequence[PromptMessage]
|
||||
|
||||
def to_object(self):
|
||||
"""Convert to JSON-serializable format for database storage and frontend."""
|
||||
return [msg.model_dump() for msg in self.value]
|
||||
|
||||
|
||||
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
||||
if isinstance(v, Segment):
|
||||
return v.value_type
|
||||
@@ -248,6 +258,7 @@ SegmentUnion: TypeAlias = Annotated[
|
||||
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
|
||||
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
|
||||
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
|
||||
| Annotated[ArrayPromptMessageSegment, Tag(SegmentType.ARRAY_PROMPT_MESSAGE)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
]
|
||||
|
||||
@@ -45,6 +45,7 @@ class SegmentType(StrEnum):
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILE = "array[file]"
|
||||
ARRAY_BOOLEAN = "array[boolean]"
|
||||
ARRAY_PROMPT_MESSAGE = "array[message]"
|
||||
|
||||
NONE = "none"
|
||||
|
||||
|
||||
@@ -3,8 +3,10 @@ from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
|
||||
from .segment_group import SegmentGroup
|
||||
from .segments import ArrayFileSegment, FileSegment, Segment
|
||||
from .segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment
|
||||
|
||||
|
||||
def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
|
||||
@@ -16,7 +18,7 @@ def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[
|
||||
|
||||
def segment_orjson_default(o: Any):
|
||||
"""Default function for orjson serialization of Segment types"""
|
||||
if isinstance(o, ArrayFileSegment):
|
||||
if isinstance(o, (ArrayFileSegment, ArrayPromptMessageSegment)):
|
||||
return [v.model_dump() for v in o.value]
|
||||
elif isinstance(o, FileSegment):
|
||||
return o.value.model_dump()
|
||||
@@ -24,6 +26,8 @@ def segment_orjson_default(o: Any):
|
||||
return [segment_orjson_default(seg) for seg in o.value]
|
||||
elif isinstance(o, Segment):
|
||||
return o.value
|
||||
elif isinstance(o, PromptMessage):
|
||||
return o.model_dump()
|
||||
raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable")
|
||||
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from .segments import (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayPromptMessageSegment,
|
||||
ArraySegment,
|
||||
ArrayStringSegment,
|
||||
BooleanSegment,
|
||||
@@ -110,6 +111,10 @@ class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayPromptMessageVariable(ArrayPromptMessageSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class RAGPipelineVariable(BaseModel):
|
||||
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
|
||||
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
|
||||
@@ -160,6 +165,7 @@ Variable: TypeAlias = Annotated[
|
||||
| Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
|
||||
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
|
||||
| Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)]
|
||||
| Annotated[ArrayPromptMessageVariable, Tag(SegmentType.ARRAY_PROMPT_MESSAGE)]
|
||||
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
|
||||
1418
api/core/workflow/docs/variable_extraction_design.md
Normal file
1418
api/core/workflow/docs/variable_extraction_design.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -63,6 +63,7 @@ class NodeType(StrEnum):
|
||||
TRIGGER_SCHEDULE = "trigger-schedule"
|
||||
TRIGGER_PLUGIN = "trigger-plugin"
|
||||
HUMAN_INPUT = "human-input"
|
||||
GROUP = "group"
|
||||
|
||||
@property
|
||||
def is_trigger_node(self) -> bool:
|
||||
@@ -252,6 +253,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
COMPLETED_REASON = "completed_reason" # completed reason for loop node
|
||||
MENTION_PARENT_ID = "mention_parent_id" # parent node id for extractor nodes
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
|
||||
@@ -307,7 +307,14 @@ class Graph:
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
|
||||
# Filter out UI-only node types:
|
||||
# - custom-note: top-level type (node_config.type == "custom-note")
|
||||
# - group: data-level type (node_config.data.type == "group")
|
||||
node_configs = [
|
||||
node_config
|
||||
for node_config in node_configs
|
||||
if node_config.get("type", "") != "custom-note" and node_config.get("data", {}).get("type", "") != "group"
|
||||
]
|
||||
|
||||
# Parse node configurations
|
||||
node_configs_map = cls._parse_node_configs(node_configs)
|
||||
|
||||
@@ -93,8 +93,8 @@ class EventHandler:
|
||||
Args:
|
||||
event: The event to handle
|
||||
"""
|
||||
# Events in loops or iterations are always collected
|
||||
if event.in_loop_id or event.in_iteration_id:
|
||||
# Events in loops, iterations, or extractor groups are always collected
|
||||
if event.in_loop_id or event.in_iteration_id or event.in_mention_parent_id:
|
||||
self._event_collector.collect(event)
|
||||
return
|
||||
return self._dispatch(event)
|
||||
@@ -125,6 +125,11 @@ class EventHandler:
|
||||
Args:
|
||||
event: The node started event
|
||||
"""
|
||||
# Check if this is an extractor node (has parent_node_id)
|
||||
if self._is_extractor_node(event.node_id):
|
||||
self._handle_extractor_node_started(event)
|
||||
return
|
||||
|
||||
# Track execution in domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
is_initial_attempt = node_execution.retry_count == 0
|
||||
@@ -164,6 +169,11 @@ class EventHandler:
|
||||
Args:
|
||||
event: The node succeeded event
|
||||
"""
|
||||
# Check if this is an extractor node (has parent_node_id)
|
||||
if self._is_extractor_node(event.node_id):
|
||||
self._handle_extractor_node_success(event)
|
||||
return
|
||||
|
||||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
@@ -226,6 +236,11 @@ class EventHandler:
|
||||
Args:
|
||||
event: The node failed event
|
||||
"""
|
||||
# Check if this is an extractor node (has parent_node_id)
|
||||
if self._is_extractor_node(event.node_id):
|
||||
self._handle_extractor_node_failed(event)
|
||||
return
|
||||
|
||||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_failed(event.error)
|
||||
@@ -345,3 +360,57 @@ class EventHandler:
|
||||
self._graph_runtime_state.set_output("answer", value)
|
||||
else:
|
||||
self._graph_runtime_state.set_output(key, value)
|
||||
|
||||
def _is_extractor_node(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if node_id represents an extractor node (has parent_node_id).
|
||||
|
||||
Extractor nodes extract values from list[PromptMessage] for their parent node.
|
||||
They have a parent_node_id field pointing to their parent node.
|
||||
"""
|
||||
node = self._graph.nodes.get(node_id)
|
||||
if node is None:
|
||||
return False
|
||||
return node.node_data.is_extractor_node
|
||||
|
||||
def _handle_extractor_node_started(self, event: NodeRunStartedEvent) -> None:
|
||||
"""
|
||||
Handle extractor node started event.
|
||||
|
||||
Extractor nodes don't need full execution tracking, just collect the event.
|
||||
"""
|
||||
# Track in response coordinator for stream ordering
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
||||
def _handle_extractor_node_success(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
Handle extractor node success event.
|
||||
|
||||
Extractor nodes need special handling:
|
||||
- Store outputs in variable pool (for reference by other nodes)
|
||||
- Accumulate token usage
|
||||
- Collect the event for logging
|
||||
- Do NOT process edges or enqueue next nodes (parent node handles that)
|
||||
"""
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
||||
def _handle_extractor_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
Handle extractor node failed event.
|
||||
|
||||
Extractor node failures are collected for logging,
|
||||
but the parent node is responsible for handling the error.
|
||||
"""
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
# Collect the event for logging
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@@ -68,6 +68,7 @@ class _NodeRuntimeSnapshot:
|
||||
predecessor_node_id: str | None
|
||||
iteration_id: str | None
|
||||
loop_id: str | None
|
||||
mention_parent_id: str | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
@@ -230,6 +231,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
WorkflowNodeExecutionMetadataKey.MENTION_PARENT_ID: event.in_mention_parent_id,
|
||||
}
|
||||
|
||||
domain_execution = WorkflowNodeExecution(
|
||||
@@ -256,6 +258,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
mention_parent_id=event.in_mention_parent_id,
|
||||
created_at=event.start_at,
|
||||
)
|
||||
self._node_snapshots[event.id] = snapshot
|
||||
|
||||
@@ -21,6 +21,12 @@ class GraphNodeEventBase(GraphEngineEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""Parent node id if this is an extractor node event.
|
||||
|
||||
When set, indicates this event belongs to an extractor node that
|
||||
is extracting values for the specified parent node.
|
||||
"""
|
||||
|
||||
# The version of the node, or "1" if not specified.
|
||||
node_version: str = "1"
|
||||
|
||||
@@ -12,11 +12,20 @@ from sqlalchemy.orm import Session
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.node_token_buffer_memory import NodeTokenBufferMemory
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryMode
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
@@ -136,6 +145,9 @@ class AgentNode(Node[AgentNodeData]):
|
||||
)
|
||||
return
|
||||
|
||||
# Fetch memory for node memory saving
|
||||
memory = self._fetch_memory_for_save()
|
||||
|
||||
try:
|
||||
yield from self._transform_message(
|
||||
messages=message_stream,
|
||||
@@ -149,6 +161,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
memory=memory,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
@@ -395,8 +408,20 @@ class AgentNode(Node[AgentNodeData]):
|
||||
icon = None
|
||||
return icon
|
||||
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
|
||||
# get conversation id
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> BaseMemory | None:
|
||||
"""
|
||||
Fetch memory based on configuration mode.
|
||||
|
||||
Returns TokenBufferMemory for conversation mode (default),
|
||||
or NodeTokenBufferMemory for node mode (Chatflow only).
|
||||
"""
|
||||
node_data = self.node_data
|
||||
memory_config = node_data.memory
|
||||
|
||||
if not memory_config:
|
||||
return None
|
||||
|
||||
# get conversation id (required for both modes in Chatflow)
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
@@ -404,16 +429,26 @@ class AgentNode(Node[AgentNodeData]):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
return memory
|
||||
# Return appropriate memory type based on mode
|
||||
if memory_config.mode == MemoryMode.NODE:
|
||||
# Node-level memory (Chatflow only)
|
||||
return NodeTokenBufferMemory(
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=self._node_id,
|
||||
tenant_id=self.tenant_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
else:
|
||||
# Conversation-level memory (default)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == self.app_id, Conversation.id == conversation_id
|
||||
)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
provider_manager = ProviderManager()
|
||||
@@ -457,6 +492,136 @@ class AgentNode(Node[AgentNodeData]):
|
||||
else:
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||
|
||||
def _fetch_memory_for_save(self) -> BaseMemory | None:
|
||||
"""
|
||||
Fetch memory instance for saving node memory.
|
||||
This is a simplified version that doesn't require model_instance.
|
||||
"""
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
node_data = self.node_data
|
||||
if not node_data.memory:
|
||||
return None
|
||||
|
||||
# Get conversation_id
|
||||
conversation_id_var = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
if not isinstance(conversation_id_var, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_var.value
|
||||
|
||||
# Return appropriate memory type based on mode
|
||||
if node_data.memory.mode == MemoryMode.NODE:
|
||||
# For node memory, we need a model_instance for token counting
|
||||
# Use a simple default model for this purpose
|
||||
try:
|
||||
model_instance = ModelManager().get_default_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return NodeTokenBufferMemory(
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=self._node_id,
|
||||
tenant_id=self.tenant_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
else:
|
||||
# Conversation-level memory doesn't need saving here
|
||||
return None
|
||||
|
||||
def _build_context(
|
||||
self,
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_query: str,
|
||||
assistant_response: str,
|
||||
agent_logs: list[AgentLogEvent],
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Build context from user query, tool calls, and assistant response.
|
||||
Format: user -> assistant(with tool_calls) -> tool -> assistant
|
||||
|
||||
The context includes:
|
||||
- Current user query (always present, may be empty)
|
||||
- Assistant message with tool_calls (if tools were called)
|
||||
- Tool results
|
||||
- Assistant's final response
|
||||
"""
|
||||
context_messages: list[PromptMessage] = []
|
||||
|
||||
# Always add user query (even if empty, to maintain conversation structure)
|
||||
context_messages.append(UserPromptMessage(content=user_query or ""))
|
||||
|
||||
# Extract actual tool calls from agent logs
|
||||
# Only include logs with label starting with "CALL " - these are real tool invocations
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_results: list[tuple[str, str, str]] = [] # (tool_call_id, tool_name, result)
|
||||
|
||||
for log in agent_logs:
|
||||
if log.status == "success" and log.label and log.label.startswith("CALL "):
|
||||
# Extract tool name from label (format: "CALL tool_name")
|
||||
tool_name = log.label[5:] # Remove "CALL " prefix
|
||||
tool_call_id = log.message_id
|
||||
|
||||
# Parse tool response from data
|
||||
data = log.data or {}
|
||||
tool_response = ""
|
||||
|
||||
# Try to extract the actual tool response
|
||||
if "tool_response" in data:
|
||||
tool_response = data["tool_response"]
|
||||
elif "output" in data:
|
||||
tool_response = data["output"]
|
||||
elif "result" in data:
|
||||
tool_response = data["result"]
|
||||
|
||||
if isinstance(tool_response, dict):
|
||||
tool_response = str(tool_response)
|
||||
|
||||
# Get tool input for arguments
|
||||
tool_input = data.get("tool_call_input", {}) or data.get("input", {})
|
||||
if isinstance(tool_input, dict):
|
||||
import json
|
||||
|
||||
tool_input_str = json.dumps(tool_input, ensure_ascii=False)
|
||||
else:
|
||||
tool_input_str = str(tool_input) if tool_input else ""
|
||||
|
||||
if tool_response:
|
||||
tool_calls.append(
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_name,
|
||||
arguments=tool_input_str,
|
||||
),
|
||||
)
|
||||
)
|
||||
tool_results.append((tool_call_id, tool_name, str(tool_response)))
|
||||
|
||||
# Add assistant message with tool_calls if there were tool calls
|
||||
if tool_calls:
|
||||
context_messages.append(AssistantPromptMessage(content="", tool_calls=tool_calls))
|
||||
|
||||
# Add tool result messages
|
||||
for tool_call_id, tool_name, result in tool_results:
|
||||
context_messages.append(
|
||||
ToolPromptMessage(
|
||||
content=result,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
|
||||
# Add final assistant response
|
||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
|
||||
return context_messages
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
@@ -467,6 +632,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
memory: BaseMemory | None = None,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
@@ -711,6 +877,12 @@ class AgentNode(Node[AgentNodeData]):
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Get user query from parameters for building context
|
||||
user_query = parameters_for_log.get("query", "")
|
||||
|
||||
# Build context from history, user query, tool calls and assistant response
|
||||
context = self._build_context(parameters_for_log, user_query, text, agent_logs)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
@@ -719,6 +891,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
"usage": jsonable_encoder(llm_usage),
|
||||
"files": ArrayFileSegment(value=files),
|
||||
"json": json_output,
|
||||
"context": context,
|
||||
**variables,
|
||||
},
|
||||
metadata={
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from .entities import (
|
||||
BaseIterationNodeData,
|
||||
BaseIterationState,
|
||||
BaseLoopNodeData,
|
||||
BaseLoopState,
|
||||
BaseNodeData,
|
||||
)
|
||||
from .usage_tracking_mixin import LLMUsageTrackingMixin
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -175,6 +175,16 @@ class BaseNodeData(ABC, BaseModel):
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
# Parent node ID when this node is used as an extractor.
|
||||
# If set, this node is an "attached" extractor node that extracts values
|
||||
# from list[PromptMessage] for the parent node's parameters.
|
||||
parent_node_id: str | None = None
|
||||
|
||||
@property
|
||||
def is_extractor_node(self) -> bool:
|
||||
"""Check if this node is an extractor node (has parent_node_id)."""
|
||||
return self.parent_node_id is not None
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
if self.default_value:
|
||||
|
||||
@@ -270,10 +270,87 @@ class Node(Generic[NodeDataT]):
|
||||
"""Check if execution should be stopped."""
|
||||
return self.graph_runtime_state.stop_event.is_set()
|
||||
|
||||
def _find_extractor_node_configs(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Find all extractor node configurations that have parent_node_id == self._node_id.
|
||||
|
||||
Returns:
|
||||
List of node configuration dicts for extractor nodes
|
||||
"""
|
||||
nodes = self.graph_config.get("nodes", [])
|
||||
extractor_configs = []
|
||||
for node_config in nodes:
|
||||
node_data = node_config.get("data", {})
|
||||
if node_data.get("parent_node_id") == self._node_id:
|
||||
extractor_configs.append(node_config)
|
||||
return extractor_configs
|
||||
|
||||
def _execute_mention_nodes(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
"""
|
||||
Execute all extractor nodes associated with this node.
|
||||
|
||||
Extractor nodes are nodes with parent_node_id == self._node_id.
|
||||
They are executed before the main node to extract values from list[PromptMessage].
|
||||
"""
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
extractor_configs = self._find_extractor_node_configs()
|
||||
logger.debug("[Extractor] Found %d extractor nodes for parent '%s'", len(extractor_configs), self._node_id)
|
||||
if not extractor_configs:
|
||||
return
|
||||
|
||||
for config in extractor_configs:
|
||||
node_id = config.get("id")
|
||||
node_data = config.get("data", {})
|
||||
node_type_str = node_data.get("type")
|
||||
|
||||
if not node_id or not node_type_str:
|
||||
continue
|
||||
|
||||
# Get node class
|
||||
try:
|
||||
node_type = NodeType(node_type_str)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||
if not node_mapping:
|
||||
continue
|
||||
|
||||
node_version = str(node_data.get("version", "1"))
|
||||
node_cls = node_mapping.get(node_version) or node_mapping.get(LATEST_VERSION)
|
||||
if not node_cls:
|
||||
continue
|
||||
|
||||
# Instantiate and execute the extractor node
|
||||
extractor_node = node_cls(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=self._graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
|
||||
# Execute and process extractor node events
|
||||
for event in extractor_node.run():
|
||||
# Tag event with parent node id for stream ordering and history tracking
|
||||
if isinstance(event, GraphNodeEventBase):
|
||||
event.in_mention_parent_id = self._node_id
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
# Store extractor node outputs in variable pool
|
||||
outputs: Mapping[str, Any] = event.node_run_result.outputs
|
||||
for variable_name, variable_value in outputs.items():
|
||||
self.graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
|
||||
if not isinstance(event, NodeRunStreamChunkEvent):
|
||||
yield event
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
# Step 1: Execute associated extractor nodes before main node execution
|
||||
yield from self._execute_mention_nodes()
|
||||
|
||||
# Create and push start event with required fields
|
||||
start_event = NodeRunStartedEvent(
|
||||
id=execution_id,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
@@ -58,9 +58,28 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
jinja2_text: str | None = None
|
||||
|
||||
|
||||
class PromptMessageContext(BaseModel):
|
||||
"""Context variable reference in prompt template.
|
||||
|
||||
YAML/JSON format: { "$context": ["node_id", "variable_name"] }
|
||||
This will be expanded to list[PromptMessage] at runtime.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
value_selector: Sequence[str] = Field(alias="$context")
|
||||
|
||||
|
||||
# Union type for prompt template items (static message or context variable reference)
|
||||
PromptTemplateItem: TypeAlias = Annotated[
|
||||
LLMNodeChatModelMessage | PromptMessageContext,
|
||||
Field(discriminator=None),
|
||||
]
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_template: Sequence[PromptTemplateItem] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
|
||||
memory: MemoryConfig | None = None
|
||||
context: ContextConfig
|
||||
|
||||
@@ -8,12 +8,20 @@ from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.file.models import File
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.memory import NodeTokenBufferMemory, TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentUnionTypes,
|
||||
PromptMessageRole,
|
||||
)
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
@@ -86,25 +94,56 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
|
||||
|
||||
|
||||
def fetch_memory(
|
||||
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
|
||||
) -> TokenBufferMemory | None:
|
||||
variable_pool: VariablePool,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
node_data_memory: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
node_id: str = "",
|
||||
) -> BaseMemory | None:
|
||||
"""
|
||||
Fetch memory based on configuration mode.
|
||||
|
||||
Returns TokenBufferMemory for conversation mode (default),
|
||||
or NodeTokenBufferMemory for node mode (Chatflow only).
|
||||
|
||||
:param variable_pool: Variable pool containing system variables
|
||||
:param app_id: Application ID
|
||||
:param tenant_id: Tenant ID
|
||||
:param node_data_memory: Memory configuration
|
||||
:param model_instance: Model instance for token counting
|
||||
:param node_id: Node ID in the workflow (required for node mode)
|
||||
:return: Memory instance or None if not applicable
|
||||
"""
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
# Get conversation_id from variable pool (required for both modes in Chatflow)
|
||||
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
# Return appropriate memory type based on mode
|
||||
if node_data_memory.mode == MemoryMode.NODE:
|
||||
# Node-level memory (Chatflow only)
|
||||
if not node_id:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
return memory
|
||||
return NodeTokenBufferMemory(
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=node_id,
|
||||
tenant_id=tenant_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
else:
|
||||
# Conversation-level memory (default)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
|
||||
@@ -170,3 +209,87 @@ def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUs
|
||||
)
|
||||
session.execute(stmt)
|
||||
session.commit()
|
||||
|
||||
|
||||
def build_context(
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
assistant_response: str,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Build context from prompt messages and assistant response.
|
||||
Excludes system messages and includes the current LLM response.
|
||||
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
|
||||
|
||||
Note: Multi-modal content base64 data is truncated to avoid storing large data in context.
|
||||
"""
|
||||
context_messages: list[PromptMessage] = [
|
||||
_truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
|
||||
]
|
||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
return context_messages
|
||||
|
||||
|
||||
def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
|
||||
"""
|
||||
Truncate multi-modal content base64 data in a message to avoid storing large data.
|
||||
Preserves the PromptMessage structure for ArrayPromptMessageSegment compatibility.
|
||||
|
||||
If file_ref is present, clears base64_data and url (they can be restored later).
|
||||
Otherwise, truncates base64_data as fallback for legacy data.
|
||||
"""
|
||||
content = message.content
|
||||
if content is None or isinstance(content, str):
|
||||
return message
|
||||
|
||||
# Process list content, handling multi-modal data based on file_ref availability
|
||||
new_content: list[PromptMessageContentUnionTypes] = []
|
||||
for item in content:
|
||||
if isinstance(item, MultiModalPromptMessageContent):
|
||||
if item.file_ref:
|
||||
# Clear base64 and url, keep file_ref for later restoration
|
||||
new_content.append(item.model_copy(update={"base64_data": "", "url": ""}))
|
||||
else:
|
||||
# Fallback: truncate base64_data if no file_ref (legacy data)
|
||||
truncated_base64 = ""
|
||||
if item.base64_data:
|
||||
truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:]
|
||||
new_content.append(item.model_copy(update={"base64_data": truncated_base64}))
|
||||
else:
|
||||
new_content.append(item)
|
||||
|
||||
return message.model_copy(update={"content": new_content})
|
||||
|
||||
|
||||
def restore_multimodal_content_in_messages(messages: Sequence[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Restore multimodal content (base64 or url) in a list of PromptMessages.
|
||||
|
||||
When context is saved, base64_data is cleared to save storage space.
|
||||
This function restores the content by parsing file_ref in each MultiModalPromptMessageContent.
|
||||
|
||||
Args:
|
||||
messages: List of PromptMessages that may contain truncated multimodal content
|
||||
|
||||
Returns:
|
||||
List of PromptMessages with restored multimodal content
|
||||
"""
|
||||
from core.file import file_manager
|
||||
|
||||
return [_restore_message_content(msg, file_manager) for msg in messages]
|
||||
|
||||
|
||||
def _restore_message_content(message: PromptMessage, file_manager) -> PromptMessage:
|
||||
"""Restore multimodal content in a single PromptMessage."""
|
||||
content = message.content
|
||||
if content is None or isinstance(content, str):
|
||||
return message
|
||||
|
||||
restored_content: list[PromptMessageContentUnionTypes] = []
|
||||
for item in content:
|
||||
if isinstance(item, MultiModalPromptMessageContent):
|
||||
restored_item = file_manager.restore_multimodal_content(item)
|
||||
restored_content.append(cast(PromptMessageContentUnionTypes, restored_item))
|
||||
else:
|
||||
restored_content.append(item)
|
||||
|
||||
return message.model_copy(update={"content": restored_content})
|
||||
|
||||
@@ -7,7 +7,7 @@ import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -16,7 +16,7 @@ from core.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
@@ -51,6 +51,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.variables import (
|
||||
ArrayFileSegment,
|
||||
ArrayPromptMessageSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
NoneSegment,
|
||||
@@ -87,6 +88,7 @@ from .entities import (
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
PromptMessageContext,
|
||||
)
|
||||
from .exc import (
|
||||
InvalidContextStructureError,
|
||||
@@ -159,8 +161,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
try:
|
||||
# init messages template
|
||||
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
|
||||
# Parse prompt template to separate static messages and context references
|
||||
prompt_template = self.node_data.prompt_template
|
||||
static_messages, context_refs, template_order = self._parse_prompt_template()
|
||||
|
||||
# fetch variables and fetch values from variable pool
|
||||
inputs = self._fetch_inputs(node_data=self.node_data)
|
||||
@@ -208,8 +211,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_data_memory=self.node_data.memory,
|
||||
model_instance=model_instance,
|
||||
node_id=self._node_id,
|
||||
)
|
||||
|
||||
query: str | None = None
|
||||
@@ -220,21 +225,40 @@ class LLMNode(Node[LLMNodeData]):
|
||||
):
|
||||
query = query_variable.text
|
||||
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=self.node_data.prompt_template,
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
context_files=context_files,
|
||||
)
|
||||
# Get prompt messages
|
||||
prompt_messages: Sequence[PromptMessage]
|
||||
stop: Sequence[str] | None
|
||||
if isinstance(prompt_template, list) and context_refs:
|
||||
prompt_messages, stop = self._build_prompt_messages_with_context(
|
||||
context_refs=context_refs,
|
||||
template_order=template_order,
|
||||
static_messages=static_messages,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
context_files=context_files,
|
||||
)
|
||||
else:
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=cast(
|
||||
Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
self.node_data.prompt_template,
|
||||
),
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
generator = LLMNode.invoke_llm(
|
||||
@@ -250,6 +274,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
reasoning_format=self.node_data.reasoning_format,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
@@ -301,6 +326,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
"reasoning_content": reasoning_content,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
"context": llm_utils.build_context(prompt_messages, clean_text),
|
||||
}
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output.structured_output
|
||||
@@ -367,6 +393,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
node_data_model.name, model_instance.credentials
|
||||
@@ -390,6 +417,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
request_start_time = time.perf_counter()
|
||||
@@ -581,6 +609,212 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
return messages
|
||||
|
||||
def _parse_prompt_template(
|
||||
self,
|
||||
) -> tuple[list[LLMNodeChatModelMessage], list[PromptMessageContext], list[tuple[int, str]]]:
|
||||
"""
|
||||
Parse prompt_template to separate static messages and context references.
|
||||
|
||||
Returns:
|
||||
Tuple of (static_messages, context_refs, template_order)
|
||||
- static_messages: list of LLMNodeChatModelMessage
|
||||
- context_refs: list of PromptMessageContext
|
||||
- template_order: list of (index, type) tuples preserving original order
|
||||
"""
|
||||
prompt_template = self.node_data.prompt_template
|
||||
static_messages: list[LLMNodeChatModelMessage] = []
|
||||
context_refs: list[PromptMessageContext] = []
|
||||
template_order: list[tuple[int, str]] = []
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
for idx, item in enumerate(prompt_template):
|
||||
if isinstance(item, PromptMessageContext):
|
||||
context_refs.append(item)
|
||||
template_order.append((idx, "context"))
|
||||
else:
|
||||
static_messages.append(item)
|
||||
template_order.append((idx, "static"))
|
||||
# Transform static messages for jinja2
|
||||
if static_messages:
|
||||
self.node_data.prompt_template = self._transform_chat_messages(static_messages)
|
||||
|
||||
return static_messages, context_refs, template_order
|
||||
|
||||
def _build_prompt_messages_with_context(
|
||||
self,
|
||||
*,
|
||||
context_refs: list[PromptMessageContext],
|
||||
template_order: list[tuple[int, str]],
|
||||
static_messages: list[LLMNodeChatModelMessage],
|
||||
query: str | None,
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context_files: list[File],
|
||||
) -> tuple[list[PromptMessage], Sequence[str] | None]:
|
||||
"""
|
||||
Build prompt messages by combining static messages and context references in DSL order.
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_messages, stop_sequences)
|
||||
"""
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# Process messages in DSL order: iterate once and handle each type directly
|
||||
combined_messages: list[PromptMessage] = []
|
||||
context_idx = 0
|
||||
static_idx = 0
|
||||
|
||||
for _, type_ in template_order:
|
||||
if type_ == "context":
|
||||
# Handle context reference
|
||||
ctx_ref = context_refs[context_idx]
|
||||
ctx_var = variable_pool.get(ctx_ref.value_selector)
|
||||
if ctx_var is None:
|
||||
raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found")
|
||||
if not isinstance(ctx_var, ArrayPromptMessageSegment):
|
||||
raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]")
|
||||
# Restore multimodal content (base64/url) that was truncated when saving context
|
||||
restored_messages = llm_utils.restore_multimodal_content_in_messages(ctx_var.value)
|
||||
combined_messages.extend(restored_messages)
|
||||
context_idx += 1
|
||||
else:
|
||||
# Handle static message
|
||||
static_msg = static_messages[static_idx]
|
||||
processed_msgs = LLMNode.handle_list_messages(
|
||||
messages=[static_msg],
|
||||
context=context,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables or [],
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=self.node_data.vision.configs.detail,
|
||||
)
|
||||
combined_messages.extend(processed_msgs)
|
||||
static_idx += 1
|
||||
|
||||
# Append memory messages
|
||||
memory_messages = _handle_memory_chat_mode(
|
||||
memory=memory,
|
||||
memory_config=self.node_data.memory,
|
||||
model_config=model_config,
|
||||
)
|
||||
combined_messages.extend(memory_messages)
|
||||
|
||||
# Append current query if provided
|
||||
if query:
|
||||
query_message = LLMNodeChatModelMessage(
|
||||
text=query,
|
||||
role=PromptMessageRole.USER,
|
||||
edition_type="basic",
|
||||
)
|
||||
query_msgs = LLMNode.handle_list_messages(
|
||||
messages=[query_message],
|
||||
context="",
|
||||
jinja2_variables=[],
|
||||
variable_pool=variable_pool,
|
||||
vision_detail_config=self.node_data.vision.configs.detail,
|
||||
)
|
||||
combined_messages.extend(query_msgs)
|
||||
|
||||
# Handle files (sys_files and context_files)
|
||||
combined_messages = self._append_files_to_messages(
|
||||
messages=combined_messages,
|
||||
sys_files=files,
|
||||
context_files=context_files,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
# Filter empty messages and get stop sequences
|
||||
combined_messages = self._filter_messages(combined_messages, model_config)
|
||||
stop = self._get_stop_sequences(model_config)
|
||||
|
||||
return combined_messages, stop
|
||||
|
||||
def _append_files_to_messages(
|
||||
self,
|
||||
*,
|
||||
messages: list[PromptMessage],
|
||||
sys_files: Sequence[File],
|
||||
context_files: list[File],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> list[PromptMessage]:
|
||||
"""Append sys_files and context_files to messages."""
|
||||
vision_enabled = self.node_data.vision.enabled
|
||||
vision_detail = self.node_data.vision.configs.detail
|
||||
|
||||
# Handle sys_files (will be deprecated later)
|
||||
if vision_enabled and sys_files:
|
||||
file_prompts = [
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail) for file in sys_files
|
||||
]
|
||||
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
|
||||
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
|
||||
else:
|
||||
messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
# Handle context_files
|
||||
if vision_enabled and context_files:
|
||||
file_prompts = [
|
||||
file_manager.to_prompt_message_content(file, image_detail_config=vision_detail)
|
||||
for file in context_files
|
||||
]
|
||||
if messages and isinstance(messages[-1], UserPromptMessage) and isinstance(messages[-1].content, list):
|
||||
messages[-1] = UserPromptMessage(content=file_prompts + messages[-1].content)
|
||||
else:
|
||||
messages.append(UserPromptMessage(content=file_prompts))
|
||||
|
||||
return messages
|
||||
|
||||
def _filter_messages(
|
||||
self, messages: list[PromptMessage], model_config: ModelConfigWithCredentialsEntity
|
||||
) -> list[PromptMessage]:
|
||||
"""Filter empty messages and unsupported content types."""
|
||||
filtered_messages: list[PromptMessage] = []
|
||||
|
||||
for message in messages:
|
||||
if isinstance(message.content, list):
|
||||
filtered_content: list[PromptMessageContentUnionTypes] = []
|
||||
for content_item in message.content:
|
||||
# Skip non-text content if features are not defined
|
||||
if not model_config.model_schema.features:
|
||||
if content_item.type != PromptMessageContentType.TEXT:
|
||||
continue
|
||||
filtered_content.append(content_item)
|
||||
continue
|
||||
|
||||
# Skip content if corresponding feature is not supported
|
||||
feature_map = {
|
||||
PromptMessageContentType.IMAGE: ModelFeature.VISION,
|
||||
PromptMessageContentType.DOCUMENT: ModelFeature.DOCUMENT,
|
||||
PromptMessageContentType.VIDEO: ModelFeature.VIDEO,
|
||||
PromptMessageContentType.AUDIO: ModelFeature.AUDIO,
|
||||
}
|
||||
required_feature = feature_map.get(content_item.type)
|
||||
if required_feature and required_feature not in model_config.model_schema.features:
|
||||
continue
|
||||
filtered_content.append(content_item)
|
||||
|
||||
# Simplify single text content
|
||||
if len(filtered_content) == 1 and filtered_content[0].type == PromptMessageContentType.TEXT:
|
||||
message.content = filtered_content[0].data
|
||||
else:
|
||||
message.content = filtered_content
|
||||
|
||||
if not message.is_empty():
|
||||
filtered_messages.append(message)
|
||||
|
||||
if not filtered_messages:
|
||||
raise NoPromptFoundError(
|
||||
"No prompt found in the LLM configuration. "
|
||||
"Please ensure a prompt is properly configured before proceeding."
|
||||
)
|
||||
|
||||
return filtered_messages
|
||||
|
||||
def _get_stop_sequences(self, model_config: ModelConfigWithCredentialsEntity) -> Sequence[str] | None:
|
||||
"""Get stop sequences from model config."""
|
||||
return model_config.stop
|
||||
|
||||
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
@@ -778,7 +1012,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
sys_query: str | None = None,
|
||||
sys_files: Sequence[File],
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
memory: BaseMemory | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
memory_config: MemoryConfig | None = None,
|
||||
@@ -1337,7 +1571,7 @@ def _calculate_rest_token(
|
||||
|
||||
def _handle_memory_chat_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> Sequence[PromptMessage]:
|
||||
@@ -1354,7 +1588,7 @@ def _handle_memory_chat_mode(
|
||||
|
||||
def _handle_memory_completion_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> str:
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import File
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
@@ -145,8 +145,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
node_id=self._node_id,
|
||||
)
|
||||
|
||||
if (
|
||||
@@ -244,6 +246,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
# transform result into standard format
|
||||
result = self._transform_result(data=node_data, result=result or {})
|
||||
|
||||
# Build context from prompt messages and response
|
||||
assistant_response = json.dumps(result, ensure_ascii=False)
|
||||
context = llm_utils.build_context(prompt_messages, assistant_response)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
@@ -252,6 +258,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
"__is_success": 1 if not error else 0,
|
||||
"__reason": error,
|
||||
"__usage": jsonable_encoder(usage),
|
||||
"context": context,
|
||||
**result,
|
||||
},
|
||||
metadata={
|
||||
@@ -299,7 +306,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
|
||||
@@ -381,7 +388,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@@ -419,7 +426,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@@ -453,7 +460,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@@ -681,7 +688,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
) -> list[ChatModelMessage]:
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
@@ -708,7 +715,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@@ -96,8 +96,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
node_id=self._node_id,
|
||||
)
|
||||
# fetch instruction
|
||||
node_data.instruction = node_data.instruction or ""
|
||||
@@ -197,10 +199,15 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
"model_provider": model_config.provider,
|
||||
"model_name": model_config.model,
|
||||
}
|
||||
# Build context from prompt messages and response
|
||||
assistant_response = f"class_name: {category_name}, class_id: {category_id}"
|
||||
context = llm_utils.build_context(prompt_messages, assistant_response)
|
||||
|
||||
outputs = {
|
||||
"class_name": category_name,
|
||||
"class_id": category_id,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"context": context,
|
||||
}
|
||||
|
||||
return NodeRunResult(
|
||||
@@ -312,7 +319,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
|
||||
@@ -1,11 +1,63 @@
|
||||
from typing import Any, Literal, Union
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Self, Union
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
# Pattern to match mention value format: {{@node.context@}}instruction
|
||||
# The placeholder {{@node.context@}} must appear at the beginning
|
||||
# Format: {{@agent_node_id.context@}} where agent_node_id is dynamic, context is fixed
|
||||
MENTION_VALUE_PATTERN = re.compile(r"^\{\{@([a-zA-Z0-9_]+)\.context@\}\}(.*)$", re.DOTALL)
|
||||
|
||||
|
||||
def parse_mention_value(value: str) -> tuple[str, str]:
|
||||
"""Parse mention value into (node_id, instruction).
|
||||
|
||||
Args:
|
||||
value: The mention value string like "{{@llm.context@}}extract keywords"
|
||||
|
||||
Returns:
|
||||
Tuple of (node_id, instruction)
|
||||
|
||||
Raises:
|
||||
ValueError: If value format is invalid
|
||||
"""
|
||||
match = MENTION_VALUE_PATTERN.match(value)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
"For mention type, value must start with {{@node.context@}} placeholder, "
|
||||
"e.g., '{{@llm.context@}}extract keywords'"
|
||||
)
|
||||
return match.group(1), match.group(2)
|
||||
|
||||
|
||||
class MentionConfig(BaseModel):
|
||||
"""Configuration for extracting value from context variable.
|
||||
|
||||
Used when a tool parameter needs to be extracted from list[PromptMessage]
|
||||
context using an extractor LLM node.
|
||||
|
||||
Note: instruction is embedded in the value field as "{{@node.context@}}instruction"
|
||||
"""
|
||||
|
||||
# ID of the extractor LLM node
|
||||
extractor_node_id: str
|
||||
|
||||
# Output variable selector from extractor node
|
||||
# e.g., ["text"], ["structured_output", "query"]
|
||||
output_selector: Sequence[str]
|
||||
|
||||
# Strategy when output is None
|
||||
null_strategy: Literal["raise_error", "use_default"] = "raise_error"
|
||||
|
||||
# Default value when null_strategy is "use_default"
|
||||
# Type should match the parameter's expected type
|
||||
default_value: Any = None
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
provider_id: str
|
||||
@@ -35,7 +87,9 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
class ToolInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
type: Literal["mixed", "variable", "constant", "mention"]
|
||||
# Required config for mention type, extracting value from context variable
|
||||
mention_config: MentionConfig | None = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
@@ -48,6 +102,9 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
|
||||
if typ == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
elif typ == "mention":
|
||||
# Skip here, will be validated in model_validator
|
||||
pass
|
||||
elif typ == "variable":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("value must be a list")
|
||||
@@ -58,6 +115,26 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
raise ValueError("value must be a string, int, float, bool or dict")
|
||||
return typ
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_mention_type(self) -> Self:
|
||||
"""Validate mention type with mention_config."""
|
||||
if self.type != "mention":
|
||||
return self
|
||||
|
||||
value = self.value
|
||||
if value is None:
|
||||
return self
|
||||
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("value must be a string for mention type")
|
||||
# For mention type, value must match format: {{@node.context@}}instruction
|
||||
# This will raise ValueError if format is invalid
|
||||
parse_mention_value(value)
|
||||
# mention_config is required for mention type
|
||||
if self.mention_config is None:
|
||||
raise ValueError("mention_config is required for mention type")
|
||||
return self
|
||||
|
||||
tool_parameters: dict[str, ToolInput]
|
||||
# The version of the tool parameter.
|
||||
# If this value is None, it indicates this is a previous version
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
@@ -184,6 +187,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
for_log (bool): Whether to generate parameters for logging.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
@@ -199,14 +203,37 @@ class ToolNode(Node[ToolNodeData]):
|
||||
continue
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if not isinstance(tool_input.value, list):
|
||||
raise ToolParameterError(f"Invalid variable selector for parameter '{parameter_name}'")
|
||||
selector = tool_input.value
|
||||
variable = variable_pool.get(selector)
|
||||
if variable is None:
|
||||
if parameter.required:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
raise ToolParameterError(f"Variable {selector} does not exist")
|
||||
continue
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type == "mention":
|
||||
# Mention type: get value from extractor node's output
|
||||
if tool_input.mention_config is None:
|
||||
raise ToolParameterError(
|
||||
f"mention_config is required for mention type parameter '{parameter_name}'"
|
||||
)
|
||||
mention_config = tool_input.mention_config.model_dump()
|
||||
try:
|
||||
parameter_value, found = variable_pool.resolve_mention(
|
||||
mention_config, parameter_name=parameter_name
|
||||
)
|
||||
if not found and parameter.required:
|
||||
raise ToolParameterError(
|
||||
f"Extractor output not found for required parameter '{parameter_name}'"
|
||||
)
|
||||
if not found:
|
||||
continue
|
||||
except ValueError as e:
|
||||
raise ToolParameterError(str(e)) from e
|
||||
elif tool_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
||||
template = str(tool_input.value)
|
||||
segment_group = variable_pool.convert_template(template)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
|
||||
@@ -488,8 +515,12 @@ class ToolNode(Node[ToolNodeData]):
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
if isinstance(input.value, list):
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
elif input.type == "mention":
|
||||
# Mention type: value is handled by extractor node, no direct variable reference
|
||||
pass
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
|
||||
|
||||
@@ -268,6 +268,58 @@ class VariablePool(BaseModel):
|
||||
continue
|
||||
self.add(selector, value)
|
||||
|
||||
def resolve_mention(
|
||||
self,
|
||||
mention_config: Mapping[str, Any],
|
||||
/,
|
||||
*,
|
||||
parameter_name: str = "",
|
||||
) -> tuple[Any, bool]:
|
||||
"""
|
||||
Resolve a mention parameter value from an extractor node's output.
|
||||
|
||||
Mention parameters reference values extracted by an extractor LLM node
|
||||
from list[PromptMessage] context.
|
||||
|
||||
Args:
|
||||
mention_config: A dict containing:
|
||||
- extractor_node_id: ID of the extractor LLM node
|
||||
- output_selector: Selector path for the output variable (e.g., ["text"])
|
||||
- null_strategy: "raise_error" or "use_default"
|
||||
- default_value: Value to use when null_strategy is "use_default"
|
||||
parameter_name: Name of the parameter being resolved (for error messages)
|
||||
|
||||
Returns:
|
||||
Tuple of (resolved_value, found):
|
||||
- resolved_value: The extracted value, or default_value if not found
|
||||
- found: True if value was found, False if using default
|
||||
|
||||
Raises:
|
||||
ValueError: If extractor_node_id is missing, or if null_strategy is
|
||||
"raise_error" and the value is not found
|
||||
"""
|
||||
extractor_node_id = mention_config.get("extractor_node_id")
|
||||
if not extractor_node_id:
|
||||
raise ValueError(f"Missing extractor_node_id for mention parameter '{parameter_name}'")
|
||||
|
||||
output_selector = list(mention_config.get("output_selector", []))
|
||||
null_strategy = mention_config.get("null_strategy", "raise_error")
|
||||
default_value = mention_config.get("default_value")
|
||||
|
||||
# Build full selector: [extractor_node_id, ...output_selector]
|
||||
full_selector = [extractor_node_id] + output_selector
|
||||
variable = self.get(full_selector)
|
||||
|
||||
if variable is None:
|
||||
if null_strategy == "use_default":
|
||||
return default_value, False
|
||||
raise ValueError(
|
||||
f"Extractor node '{extractor_node_id}' output '{'.'.join(output_selector)}' "
|
||||
f"not found for parameter '{parameter_name}'"
|
||||
)
|
||||
|
||||
return variable.value, True
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> VariablePool:
|
||||
"""Create an empty variable pool."""
|
||||
|
||||
@@ -4,6 +4,7 @@ from uuid import uuid4
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
from core.variables.exc import VariableError
|
||||
from core.variables.segments import (
|
||||
ArrayAnySegment,
|
||||
@@ -11,6 +12,7 @@ from core.variables.segments import (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayPromptMessageSegment,
|
||||
ArraySegment,
|
||||
ArrayStringSegment,
|
||||
BooleanSegment,
|
||||
@@ -29,6 +31,7 @@ from core.variables.variables import (
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayPromptMessageVariable,
|
||||
ArrayStringVariable,
|
||||
BooleanVariable,
|
||||
FileVariable,
|
||||
@@ -61,6 +64,7 @@ SEGMENT_TO_VARIABLE_MAP = {
|
||||
ArrayFileSegment: ArrayFileVariable,
|
||||
ArrayNumberSegment: ArrayNumberVariable,
|
||||
ArrayObjectSegment: ArrayObjectVariable,
|
||||
ArrayPromptMessageSegment: ArrayPromptMessageVariable,
|
||||
ArrayStringSegment: ArrayStringVariable,
|
||||
BooleanSegment: BooleanVariable,
|
||||
FileSegment: FileVariable,
|
||||
@@ -156,7 +160,13 @@ def build_segment(value: Any, /) -> Segment:
|
||||
return ObjectSegment(value=value)
|
||||
if isinstance(value, File):
|
||||
return FileSegment(value=value)
|
||||
if isinstance(value, PromptMessage):
|
||||
# Single PromptMessage should be wrapped in a list
|
||||
return ArrayPromptMessageSegment(value=[value])
|
||||
if isinstance(value, list):
|
||||
# Check if all items are PromptMessage
|
||||
if value and all(isinstance(item, PromptMessage) for item in value):
|
||||
return ArrayPromptMessageSegment(value=value)
|
||||
items = [build_segment(item) for item in value]
|
||||
types = {item.value_type for item in items}
|
||||
if all(isinstance(item, ArraySegment) for item in items):
|
||||
@@ -200,6 +210,7 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
|
||||
SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
|
||||
SegmentType.ARRAY_FILE: ArrayFileSegment,
|
||||
SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment,
|
||||
SegmentType.ARRAY_PROMPT_MESSAGE: ArrayPromptMessageSegment,
|
||||
}
|
||||
|
||||
|
||||
@@ -274,6 +285,10 @@ def build_segment_with_type(segment_type: SegmentType, value: Any) -> Segment:
|
||||
):
|
||||
segment_class = _segment_factory[inferred_type]
|
||||
return segment_class(value_type=inferred_type, value=value)
|
||||
elif segment_type == SegmentType.ARRAY_PROMPT_MESSAGE and inferred_type == SegmentType.ARRAY_OBJECT:
|
||||
# PromptMessage serializes to dict, so ARRAY_OBJECT is compatible with ARRAY_PROMPT_MESSAGE
|
||||
segment_class = _segment_factory[segment_type]
|
||||
return segment_class(value_type=segment_type, value=value)
|
||||
else:
|
||||
raise TypeMismatchError(f"Type mismatch: expected {segment_type}, but got {inferred_type}, value={value}")
|
||||
|
||||
|
||||
@@ -1289,7 +1289,7 @@ class WorkflowDraftVariable(Base):
|
||||
# which may differ from the original value's type. Typically, they are the same,
|
||||
# but in cases where the structurally truncated value still exceeds the size limit,
|
||||
# text slicing is applied, and the `value_type` is converted to `STRING`.
|
||||
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20))
|
||||
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=21))
|
||||
|
||||
# The variable's value serialized as a JSON string
|
||||
#
|
||||
@@ -1663,7 +1663,7 @@ class WorkflowDraftVariableFile(Base):
|
||||
|
||||
# The `value_type` field records the type of the original value.
|
||||
value_type: Mapped[SegmentType] = mapped_column(
|
||||
EnumText(SegmentType, length=20),
|
||||
EnumText(SegmentType, length=21),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any, Generic, TypeAlias, TypeVar, overload
|
||||
|
||||
from configs import dify_config
|
||||
from core.file.models import File
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
from core.variables.segments import (
|
||||
ArrayFileSegment,
|
||||
ArraySegment,
|
||||
@@ -287,6 +288,10 @@ class VariableTruncator(BaseTruncator):
|
||||
if isinstance(item, File):
|
||||
truncated_value.append(item)
|
||||
continue
|
||||
# Handle PromptMessage types - convert to dict for truncation
|
||||
if isinstance(item, PromptMessage):
|
||||
truncated_value.append(item)
|
||||
continue
|
||||
if i >= target_length:
|
||||
return _PartResult(truncated_value, used_size, True)
|
||||
if i > 0:
|
||||
|
||||
181
api/tests/fixtures/file output schema.yml
vendored
Normal file
181
api/tests/fixtures/file output schema.yml
vendored
Normal file
@@ -0,0 +1,181 @@
|
||||
app:
|
||||
description: ''
|
||||
icon: 🤖
|
||||
icon_background: '#FFEAD5'
|
||||
mode: advanced-chat
|
||||
name: file output schema
|
||||
use_icon_as_answer_icon: false
|
||||
dependencies:
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: langgenius/openai:0.2.3@5a7f82fa86e28332ad51941d0b491c1e8a38ead539656442f7bf4c6129cd15fa
|
||||
version: null
|
||||
kind: app
|
||||
version: 0.5.0
|
||||
workflow:
|
||||
conversation_variables: []
|
||||
environment_variables: []
|
||||
features:
|
||||
file_upload:
|
||||
allowed_file_extensions:
|
||||
- .JPG
|
||||
- .JPEG
|
||||
- .PNG
|
||||
- .GIF
|
||||
- .WEBP
|
||||
- .SVG
|
||||
allowed_file_types:
|
||||
- image
|
||||
allowed_file_upload_methods:
|
||||
- remote_url
|
||||
- local_file
|
||||
enabled: true
|
||||
fileUploadConfig:
|
||||
attachment_image_file_size_limit: 2
|
||||
audio_file_size_limit: 50
|
||||
batch_count_limit: 5
|
||||
file_size_limit: 15
|
||||
file_upload_limit: 10
|
||||
image_file_batch_limit: 10
|
||||
image_file_size_limit: 10
|
||||
single_chunk_attachment_limit: 10
|
||||
video_file_size_limit: 100
|
||||
workflow_file_upload_limit: 10
|
||||
number_limits: 3
|
||||
opening_statement: ''
|
||||
retriever_resource:
|
||||
enabled: true
|
||||
sensitive_word_avoidance:
|
||||
enabled: false
|
||||
speech_to_text:
|
||||
enabled: false
|
||||
suggested_questions: []
|
||||
suggested_questions_after_answer:
|
||||
enabled: false
|
||||
text_to_speech:
|
||||
enabled: false
|
||||
language: ''
|
||||
voice: ''
|
||||
graph:
|
||||
edges:
|
||||
- data:
|
||||
sourceType: start
|
||||
targetType: llm
|
||||
id: 1768292241666-llm
|
||||
source: '1768292241666'
|
||||
sourceHandle: source
|
||||
target: llm
|
||||
targetHandle: target
|
||||
type: custom
|
||||
- data:
|
||||
sourceType: llm
|
||||
targetType: answer
|
||||
id: llm-answer
|
||||
source: llm
|
||||
sourceHandle: source
|
||||
target: answer
|
||||
targetHandle: target
|
||||
type: custom
|
||||
nodes:
|
||||
- data:
|
||||
selected: false
|
||||
title: User Input
|
||||
type: start
|
||||
variables: []
|
||||
height: 73
|
||||
id: '1768292241666'
|
||||
position:
|
||||
x: 80
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 80
|
||||
y: 282
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
context:
|
||||
enabled: false
|
||||
variable_selector: []
|
||||
memory:
|
||||
query_prompt_template: '{{#sys.query#}}
|
||||
|
||||
|
||||
{{#sys.files#}}'
|
||||
role_prefix:
|
||||
assistant: ''
|
||||
user: ''
|
||||
window:
|
||||
enabled: false
|
||||
size: 10
|
||||
model:
|
||||
completion_params:
|
||||
temperature: 0.7
|
||||
mode: chat
|
||||
name: gpt-4o-mini
|
||||
provider: langgenius/openai/openai
|
||||
prompt_template:
|
||||
- id: e30d75d7-7d85-49ec-be3c-3baf7f6d3c5a
|
||||
role: system
|
||||
text: ''
|
||||
selected: false
|
||||
structured_output:
|
||||
schema:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
image:
|
||||
description: File ID (UUID) of the selected image
|
||||
format: dify-file-ref
|
||||
type: string
|
||||
required:
|
||||
- image
|
||||
type: object
|
||||
structured_output_enabled: true
|
||||
title: LLM
|
||||
type: llm
|
||||
vision:
|
||||
configs:
|
||||
detail: high
|
||||
variable_selector:
|
||||
- sys
|
||||
- files
|
||||
enabled: true
|
||||
height: 88
|
||||
id: llm
|
||||
position:
|
||||
x: 380
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 380
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
answer: '{{#llm.structured_output.image#}}'
|
||||
selected: false
|
||||
title: Answer
|
||||
type: answer
|
||||
variables: []
|
||||
height: 103
|
||||
id: answer
|
||||
position:
|
||||
x: 680
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 680
|
||||
y: 282
|
||||
selected: true
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
viewport:
|
||||
x: -149
|
||||
y: 97.5
|
||||
zoom: 1
|
||||
rag_pipeline_variables: []
|
||||
307
api/tests/fixtures/pav-test-extraction.yml
vendored
Normal file
307
api/tests/fixtures/pav-test-extraction.yml
vendored
Normal file
@@ -0,0 +1,307 @@
|
||||
app:
|
||||
description: Test for variable extraction feature
|
||||
icon: 🤖
|
||||
icon_background: '#FFEAD5'
|
||||
mode: advanced-chat
|
||||
name: pav-test-extraction
|
||||
use_icon_as_answer_icon: false
|
||||
dependencies:
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
|
||||
version: null
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: langgenius/openai:0.2.3@5a7f82fa86e28332ad51941d0b491c1e8a38ead539656442f7bf4c6129cd15fa
|
||||
version: null
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: langgenius/tongyi:0.1.16@d8bffbe45418f0c117fb3393e5e40e61faee98f9a2183f062e5a280e74b15d21
|
||||
version: null
|
||||
kind: app
|
||||
version: 0.5.0
|
||||
workflow:
|
||||
conversation_variables: []
|
||||
environment_variables: []
|
||||
features:
|
||||
file_upload:
|
||||
allowed_file_extensions:
|
||||
- .JPG
|
||||
- .JPEG
|
||||
- .PNG
|
||||
- .GIF
|
||||
- .WEBP
|
||||
- .SVG
|
||||
allowed_file_types:
|
||||
- image
|
||||
allowed_file_upload_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
enabled: false
|
||||
image:
|
||||
enabled: false
|
||||
number_limits: 3
|
||||
transfer_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
number_limits: 3
|
||||
opening_statement: 你好!我是一个搜索助手,请告诉我你想搜索什么内容。
|
||||
retriever_resource:
|
||||
enabled: true
|
||||
sensitive_word_avoidance:
|
||||
enabled: false
|
||||
speech_to_text:
|
||||
enabled: false
|
||||
suggested_questions: []
|
||||
suggested_questions_after_answer:
|
||||
enabled: false
|
||||
text_to_speech:
|
||||
enabled: false
|
||||
language: ''
|
||||
voice: ''
|
||||
graph:
|
||||
edges:
|
||||
- data:
|
||||
sourceType: start
|
||||
targetType: llm
|
||||
id: 1767773675796-llm
|
||||
source: '1767773675796'
|
||||
sourceHandle: source
|
||||
target: llm
|
||||
targetHandle: target
|
||||
type: custom
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: llm
|
||||
targetType: tool
|
||||
id: llm-source-1767773709491-target
|
||||
source: llm
|
||||
sourceHandle: source
|
||||
target: '1767773709491'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: tool
|
||||
targetType: answer
|
||||
id: tool-source-answer-target
|
||||
source: '1767773709491'
|
||||
sourceHandle: source
|
||||
target: answer
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
nodes:
|
||||
- data:
|
||||
selected: false
|
||||
title: User Input
|
||||
type: start
|
||||
variables: []
|
||||
height: 73
|
||||
id: '1767773675796'
|
||||
position:
|
||||
x: 80
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 80
|
||||
y: 282
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
context:
|
||||
enabled: false
|
||||
variable_selector: []
|
||||
memory:
|
||||
mode: node
|
||||
query_prompt_template: '{{#sys.query#}}'
|
||||
role_prefix:
|
||||
assistant: ''
|
||||
user: ''
|
||||
window:
|
||||
enabled: true
|
||||
size: 10
|
||||
model:
|
||||
completion_params:
|
||||
temperature: 0.7
|
||||
mode: chat
|
||||
name: qwen-max
|
||||
provider: langgenius/tongyi/tongyi
|
||||
prompt_template:
|
||||
- id: 11d06d15-914a-4915-a5b1-0e35ab4fba51
|
||||
role: system
|
||||
text: '你是一个智能搜索助手。用户会告诉你他们想搜索的内容。
|
||||
|
||||
请与用户进行对话,了解他们的搜索需求。
|
||||
|
||||
当用户明确表达了想要搜索的内容后,你可以回复"好的,我来帮你搜索"。
|
||||
|
||||
'
|
||||
selected: false
|
||||
title: LLM
|
||||
type: llm
|
||||
vision:
|
||||
enabled: false
|
||||
height: 88
|
||||
id: llm
|
||||
position:
|
||||
x: 380
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 380
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
is_team_authorization: true
|
||||
paramSchemas:
|
||||
- auto_generate: null
|
||||
default: null
|
||||
form: llm
|
||||
human_description:
|
||||
en_US: used for searching
|
||||
ja_JP: used for searching
|
||||
pt_BR: used for searching
|
||||
zh_Hans: 用于搜索网页内容
|
||||
label:
|
||||
en_US: Query string
|
||||
ja_JP: Query string
|
||||
pt_BR: Query string
|
||||
zh_Hans: 查询语句
|
||||
llm_description: key words for searching
|
||||
max: null
|
||||
min: null
|
||||
name: query
|
||||
options: []
|
||||
placeholder: null
|
||||
precision: null
|
||||
required: true
|
||||
scope: null
|
||||
template: null
|
||||
type: string
|
||||
params:
|
||||
query: ''
|
||||
plugin_id: langgenius/google
|
||||
plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
|
||||
provider_icon: http://localhost:5001/console/api/workspaces/current/plugin/icon?tenant_id=7217e801-f6f5-49ec-8103-d7de97a4b98f&filename=1c5871163478957bac64c3fe33d72d003f767497d921c74b742aad27a8344a74.svg
|
||||
provider_id: langgenius/google/google
|
||||
provider_name: langgenius/google/google
|
||||
provider_type: builtin
|
||||
selected: false
|
||||
title: GoogleSearch
|
||||
tool_configurations: {}
|
||||
tool_description: A tool for performing a Google SERP search and extracting
|
||||
snippets and webpages.Input should be a search query.
|
||||
tool_label: GoogleSearch
|
||||
tool_name: google_search
|
||||
tool_node_version: '2'
|
||||
tool_parameters:
|
||||
query:
|
||||
type: mention
|
||||
value: '{{@llm.context@}}请从对话历史中提取用户想要搜索的关键词,只返回关键词本身'
|
||||
mention_config:
|
||||
extractor_node_id: 1767773709491_ext_query
|
||||
output_selector:
|
||||
- structured_output
|
||||
- query
|
||||
null_strategy: use_default
|
||||
default_value: ''
|
||||
type: tool
|
||||
height: 52
|
||||
id: '1767773709491'
|
||||
position:
|
||||
x: 682
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 682
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
context:
|
||||
enabled: false
|
||||
variable_selector: []
|
||||
model:
|
||||
completion_params:
|
||||
temperature: 0.7
|
||||
mode: chat
|
||||
name: gpt-4o-mini
|
||||
provider: langgenius/openai/openai
|
||||
parent_node_id: '1767773709491'
|
||||
prompt_template:
|
||||
- $context:
|
||||
- llm
|
||||
- context
|
||||
id: 75d58e22-dc59-40c8-ba6f-aeb28f4f305a
|
||||
- id: 18ba6710-77f5-47f4-b144-9191833bb547
|
||||
role: user
|
||||
text: 请从对话历史中提取用户想要搜索的关键词,只返回关键词本身,不要返回其他内容
|
||||
selected: false
|
||||
structured_output:
|
||||
schema:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
query:
|
||||
description: 搜索的关键词
|
||||
type: string
|
||||
required:
|
||||
- query
|
||||
type: object
|
||||
structured_output_enabled: true
|
||||
title: 提取搜索关键词
|
||||
type: llm
|
||||
vision:
|
||||
enabled: false
|
||||
height: 88
|
||||
id: 1767773709491_ext_query
|
||||
position:
|
||||
x: 531
|
||||
y: 382
|
||||
positionAbsolute:
|
||||
x: 531
|
||||
y: 382
|
||||
selected: true
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
answer: '搜索结果:
|
||||
|
||||
{{#1767773709491.text#}}
|
||||
|
||||
'
|
||||
selected: false
|
||||
title: Answer
|
||||
type: answer
|
||||
height: 103
|
||||
id: answer
|
||||
position:
|
||||
x: 984
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 984
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
viewport:
|
||||
x: -151
|
||||
y: 123
|
||||
zoom: 1
|
||||
rag_pipeline_variables: []
|
||||
182
api/tests/unit_tests/core/file/test_file_manager.py
Normal file
182
api/tests/unit_tests/core/file/test_file_manager.py
Normal file
@@ -0,0 +1,182 @@
|
||||
"""Tests for file_manager module, specifically multimodal content handling."""
|
||||
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.file.file_manager import (
|
||||
_encode_file_ref,
|
||||
restore_multimodal_content,
|
||||
to_prompt_message_content,
|
||||
)
|
||||
from core.model_runtime.entities.message_entities import ImagePromptMessageContent
|
||||
|
||||
|
||||
class TestEncodeFileRef:
|
||||
"""Tests for _encode_file_ref function."""
|
||||
|
||||
def test_encodes_local_file(self):
|
||||
"""Local file should be encoded as 'local:id'."""
|
||||
file = File(
|
||||
tenant_id="t",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="abc123",
|
||||
storage_key="key",
|
||||
)
|
||||
assert _encode_file_ref(file) == "local:abc123"
|
||||
|
||||
def test_encodes_tool_file(self):
|
||||
"""Tool file should be encoded as 'tool:id'."""
|
||||
file = File(
|
||||
tenant_id="t",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
related_id="xyz789",
|
||||
storage_key="key",
|
||||
)
|
||||
assert _encode_file_ref(file) == "tool:xyz789"
|
||||
|
||||
def test_encodes_remote_url(self):
|
||||
"""Remote URL should be encoded as 'remote:url'."""
|
||||
file = File(
|
||||
tenant_id="t",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
remote_url="https://example.com/image.png",
|
||||
storage_key="",
|
||||
)
|
||||
assert _encode_file_ref(file) == "remote:https://example.com/image.png"
|
||||
|
||||
|
||||
class TestToPromptMessageContent:
|
||||
"""Tests for to_prompt_message_content function with file_ref field."""
|
||||
|
||||
@patch("core.file.file_manager.dify_config")
|
||||
@patch("core.file.file_manager._get_encoded_string")
|
||||
def test_includes_file_ref(self, mock_get_encoded, mock_config):
|
||||
"""Generated content should include file_ref field."""
|
||||
mock_config.MULTIMODAL_SEND_FORMAT = "base64"
|
||||
mock_get_encoded.return_value = "base64data"
|
||||
|
||||
file = File(
|
||||
id="test-message-file-id",
|
||||
tenant_id="test-tenant",
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.LOCAL_FILE,
|
||||
related_id="test-related-id",
|
||||
remote_url=None,
|
||||
extension=".png",
|
||||
mime_type="image/png",
|
||||
filename="test.png",
|
||||
storage_key="test-key",
|
||||
)
|
||||
|
||||
result = to_prompt_message_content(file)
|
||||
|
||||
assert isinstance(result, ImagePromptMessageContent)
|
||||
assert result.file_ref == "local:test-related-id"
|
||||
assert result.base64_data == "base64data"
|
||||
|
||||
|
||||
class TestRestoreMultimodalContent:
|
||||
"""Tests for restore_multimodal_content function."""
|
||||
|
||||
def test_returns_content_unchanged_when_no_file_ref(self):
|
||||
"""Content without file_ref should pass through unchanged."""
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="existing-data",
|
||||
mime_type="image/png",
|
||||
file_ref=None,
|
||||
)
|
||||
|
||||
result = restore_multimodal_content(content)
|
||||
|
||||
assert result.base64_data == "existing-data"
|
||||
|
||||
def test_returns_content_unchanged_when_already_has_data(self):
|
||||
"""Content that already has base64_data should not be reloaded."""
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="existing-data",
|
||||
mime_type="image/png",
|
||||
file_ref="local:file-id",
|
||||
)
|
||||
|
||||
result = restore_multimodal_content(content)
|
||||
|
||||
assert result.base64_data == "existing-data"
|
||||
|
||||
def test_returns_content_unchanged_when_already_has_url(self):
|
||||
"""Content that already has url should not be reloaded."""
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
file_ref="local:file-id",
|
||||
)
|
||||
|
||||
result = restore_multimodal_content(content)
|
||||
|
||||
assert result.url == "https://example.com/image.png"
|
||||
|
||||
@patch("core.file.file_manager.dify_config")
|
||||
@patch("core.file.file_manager._build_file_from_ref")
|
||||
@patch("core.file.file_manager._to_url")
|
||||
def test_restores_url_from_file_ref(self, mock_to_url, mock_build_file, mock_config):
|
||||
"""Content should be restored from file_ref when url is empty (url mode)."""
|
||||
mock_config.MULTIMODAL_SEND_FORMAT = "url"
|
||||
mock_build_file.return_value = "mock_file"
|
||||
mock_to_url.return_value = "https://restored-url.com/image.png"
|
||||
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="",
|
||||
url="",
|
||||
mime_type="image/png",
|
||||
filename="test.png",
|
||||
file_ref="local:test-file-id",
|
||||
)
|
||||
|
||||
result = restore_multimodal_content(content)
|
||||
|
||||
assert result.url == "https://restored-url.com/image.png"
|
||||
mock_build_file.assert_called_once()
|
||||
|
||||
@patch("core.file.file_manager.dify_config")
|
||||
@patch("core.file.file_manager._build_file_from_ref")
|
||||
@patch("core.file.file_manager._get_encoded_string")
|
||||
def test_restores_base64_from_file_ref(self, mock_get_encoded, mock_build_file, mock_config):
|
||||
"""Content should be restored as base64 when in base64 mode."""
|
||||
mock_config.MULTIMODAL_SEND_FORMAT = "base64"
|
||||
mock_build_file.return_value = "mock_file"
|
||||
mock_get_encoded.return_value = "restored-base64-data"
|
||||
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="",
|
||||
url="",
|
||||
mime_type="image/png",
|
||||
filename="test.png",
|
||||
file_ref="local:test-file-id",
|
||||
)
|
||||
|
||||
result = restore_multimodal_content(content)
|
||||
|
||||
assert result.base64_data == "restored-base64-data"
|
||||
mock_build_file.assert_called_once()
|
||||
|
||||
def test_handles_invalid_file_ref_gracefully(self):
|
||||
"""Invalid file_ref format should be handled gracefully."""
|
||||
content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="",
|
||||
url="",
|
||||
mime_type="image/png",
|
||||
file_ref="invalid_format_no_colon",
|
||||
)
|
||||
|
||||
result = restore_multimodal_content(content)
|
||||
|
||||
# Should return unchanged on error
|
||||
assert result.base64_data == ""
|
||||
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Unit tests for file reference detection and conversion.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.llm_generator.output_parser.file_ref import (
|
||||
FILE_REF_FORMAT,
|
||||
convert_file_refs_in_output,
|
||||
detect_file_ref_fields,
|
||||
is_file_ref_property,
|
||||
)
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
|
||||
|
||||
class TestIsFileRefProperty:
|
||||
"""Tests for is_file_ref_property function."""
|
||||
|
||||
def test_valid_file_ref(self):
|
||||
schema = {"type": "string", "format": FILE_REF_FORMAT}
|
||||
assert is_file_ref_property(schema) is True
|
||||
|
||||
def test_invalid_type(self):
|
||||
schema = {"type": "number", "format": FILE_REF_FORMAT}
|
||||
assert is_file_ref_property(schema) is False
|
||||
|
||||
def test_missing_format(self):
|
||||
schema = {"type": "string"}
|
||||
assert is_file_ref_property(schema) is False
|
||||
|
||||
def test_wrong_format(self):
|
||||
schema = {"type": "string", "format": "uuid"}
|
||||
assert is_file_ref_property(schema) is False
|
||||
|
||||
|
||||
class TestDetectFileRefFields:
|
||||
"""Tests for detect_file_ref_fields function."""
|
||||
|
||||
def test_simple_file_ref(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["image"]
|
||||
|
||||
def test_multiple_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"document": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert set(paths) == {"image", "document"}
|
||||
|
||||
def test_array_of_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["files[*]"]
|
||||
|
||||
def test_nested_file_ref(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["data.image"]
|
||||
|
||||
def test_no_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == []
|
||||
|
||||
def test_empty_schema(self):
|
||||
schema = {}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == []
|
||||
|
||||
def test_mixed_schema(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"documents": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert set(paths) == {"image", "documents[*]"}
|
||||
|
||||
|
||||
class TestConvertFileRefsInOutput:
|
||||
"""Tests for convert_file_refs_in_output function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file(self):
|
||||
"""Create a mock File object with all required attributes."""
|
||||
file = MagicMock(spec=File)
|
||||
file.type = FileType.IMAGE
|
||||
file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||
file.related_id = "test-related-id"
|
||||
file.remote_url = None
|
||||
file.tenant_id = "tenant_123"
|
||||
file.id = None
|
||||
file.filename = "test.png"
|
||||
file.extension = ".png"
|
||||
file.mime_type = "image/png"
|
||||
file.size = 1024
|
||||
file.dify_model_identity = "__dify__file__"
|
||||
return file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_build_from_mapping(self, mock_file):
|
||||
"""Mock the build_from_mapping function."""
|
||||
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
|
||||
mock.return_value = mock_file
|
||||
yield mock
|
||||
|
||||
def test_convert_simple_file_ref(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"image": file_id}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
# Result should be wrapped in FileSegment
|
||||
assert isinstance(result["image"], FileSegment)
|
||||
assert result["image"].value == mock_file
|
||||
mock_build_from_mapping.assert_called_once_with(
|
||||
mapping={"transfer_method": "tool_file", "tool_file_id": file_id},
|
||||
tenant_id="tenant_123",
|
||||
)
|
||||
|
||||
def test_convert_array_of_file_refs(self, mock_build_from_mapping, mock_file):
|
||||
file_id1 = str(uuid.uuid4())
|
||||
file_id2 = str(uuid.uuid4())
|
||||
output = {"files": [file_id1, file_id2]}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
# Result should be wrapped in ArrayFileSegment
|
||||
assert isinstance(result["files"], ArrayFileSegment)
|
||||
assert list(result["files"].value) == [mock_file, mock_file]
|
||||
assert mock_build_from_mapping.call_count == 2
|
||||
|
||||
def test_no_conversion_without_file_refs(self):
|
||||
output = {"name": "test", "count": 5}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result == {"name": "test", "count": 5}
|
||||
|
||||
def test_invalid_uuid_returns_none(self):
|
||||
output = {"image": "not-a-valid-uuid"}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["image"] is None
|
||||
|
||||
def test_file_not_found_returns_none(self):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"image": file_id}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
|
||||
mock.side_effect = ValueError("File not found")
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["image"] is None
|
||||
|
||||
def test_preserves_non_file_fields(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"query": "search term", "image": file_id, "count": 10}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["query"] == "search term"
|
||||
assert isinstance(result["image"], FileSegment)
|
||||
assert result["image"].value == mock_file
|
||||
assert result["count"] == 10
|
||||
|
||||
def test_does_not_modify_original_output(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
original = {"image": file_id}
|
||||
output = dict(original)
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
# Original should still contain the string ID
|
||||
assert original["image"] == file_id
|
||||
@@ -25,6 +25,12 @@ class _StubErrorHandler:
|
||||
"""Minimal error handler stub for tests."""
|
||||
|
||||
|
||||
class _StubNodeData:
|
||||
"""Simple node data stub with is_extractor_node property."""
|
||||
|
||||
is_extractor_node = False
|
||||
|
||||
|
||||
class _StubNode:
|
||||
"""Simple node stub exposing the attributes needed by the state manager."""
|
||||
|
||||
@@ -36,6 +42,7 @@ class _StubNode:
|
||||
self.error_strategy = None
|
||||
self.retry_config = RetryConfig()
|
||||
self.retry = False
|
||||
self.node_data = _StubNodeData()
|
||||
|
||||
|
||||
def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]:
|
||||
|
||||
174
api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
Normal file
174
api/tests/unit_tests/core/workflow/nodes/llm/test_llm_utils.py
Normal file
@@ -0,0 +1,174 @@
|
||||
"""Tests for llm_utils module, specifically multimodal content handling."""
|
||||
|
||||
import string
|
||||
from unittest.mock import patch
|
||||
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
ImagePromptMessageContent,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.workflow.nodes.llm.llm_utils import (
|
||||
_truncate_multimodal_content,
|
||||
build_context,
|
||||
restore_multimodal_content_in_messages,
|
||||
)
|
||||
|
||||
|
||||
class TestTruncateMultimodalContent:
|
||||
"""Tests for _truncate_multimodal_content function."""
|
||||
|
||||
def test_returns_message_unchanged_for_string_content(self):
|
||||
"""String content should pass through unchanged."""
|
||||
message = UserPromptMessage(content="Hello, world!")
|
||||
result = _truncate_multimodal_content(message)
|
||||
assert result.content == "Hello, world!"
|
||||
|
||||
def test_returns_message_unchanged_for_none_content(self):
|
||||
"""None content should pass through unchanged."""
|
||||
message = UserPromptMessage(content=None)
|
||||
result = _truncate_multimodal_content(message)
|
||||
assert result.content is None
|
||||
|
||||
def test_clears_base64_when_file_ref_present(self):
|
||||
"""When file_ref is present, base64_data and url should be cleared."""
|
||||
image_content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data=string.ascii_lowercase,
|
||||
url="https://example.com/image.png",
|
||||
mime_type="image/png",
|
||||
filename="test.png",
|
||||
file_ref="local:test-file-id",
|
||||
)
|
||||
message = UserPromptMessage(content=[image_content])
|
||||
|
||||
result = _truncate_multimodal_content(message)
|
||||
|
||||
assert isinstance(result.content, list)
|
||||
assert len(result.content) == 1
|
||||
result_content = result.content[0]
|
||||
assert isinstance(result_content, ImagePromptMessageContent)
|
||||
assert result_content.base64_data == ""
|
||||
assert result_content.url == ""
|
||||
# file_ref should be preserved
|
||||
assert result_content.file_ref == "local:test-file-id"
|
||||
|
||||
def test_truncates_base64_when_no_file_ref(self):
|
||||
"""When file_ref is missing (legacy), base64_data should be truncated."""
|
||||
long_base64 = "a" * 100
|
||||
image_content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data=long_base64,
|
||||
mime_type="image/png",
|
||||
filename="test.png",
|
||||
file_ref=None,
|
||||
)
|
||||
message = UserPromptMessage(content=[image_content])
|
||||
|
||||
result = _truncate_multimodal_content(message)
|
||||
|
||||
assert isinstance(result.content, list)
|
||||
result_content = result.content[0]
|
||||
assert isinstance(result_content, ImagePromptMessageContent)
|
||||
# Should be truncated with marker
|
||||
assert "...[TRUNCATED]..." in result_content.base64_data
|
||||
assert len(result_content.base64_data) < len(long_base64)
|
||||
|
||||
def test_preserves_text_content(self):
|
||||
"""Text content should pass through unchanged."""
|
||||
text_content = TextPromptMessageContent(data="Hello!")
|
||||
image_content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="test123",
|
||||
mime_type="image/png",
|
||||
file_ref="local:file-id",
|
||||
)
|
||||
message = UserPromptMessage(content=[text_content, image_content])
|
||||
|
||||
result = _truncate_multimodal_content(message)
|
||||
|
||||
assert isinstance(result.content, list)
|
||||
assert len(result.content) == 2
|
||||
# Text content unchanged
|
||||
assert result.content[0].data == "Hello!"
|
||||
# Image content base64 cleared
|
||||
assert result.content[1].base64_data == ""
|
||||
|
||||
|
||||
class TestBuildContext:
|
||||
"""Tests for build_context function."""
|
||||
|
||||
def test_excludes_system_messages(self):
|
||||
"""System messages should be excluded from context."""
|
||||
from core.model_runtime.entities.message_entities import SystemPromptMessage
|
||||
|
||||
messages = [
|
||||
SystemPromptMessage(content="You are a helpful assistant."),
|
||||
UserPromptMessage(content="Hello!"),
|
||||
]
|
||||
|
||||
context = build_context(messages, "Hi there!")
|
||||
|
||||
# Should have user message + assistant response, no system message
|
||||
assert len(context) == 2
|
||||
assert context[0].content == "Hello!"
|
||||
assert context[1].content == "Hi there!"
|
||||
|
||||
def test_appends_assistant_response(self):
|
||||
"""Assistant response should be appended to context."""
|
||||
messages = [UserPromptMessage(content="What is 2+2?")]
|
||||
|
||||
context = build_context(messages, "The answer is 4.")
|
||||
|
||||
assert len(context) == 2
|
||||
assert context[1].content == "The answer is 4."
|
||||
|
||||
|
||||
class TestRestoreMultimodalContentInMessages:
|
||||
"""Tests for restore_multimodal_content_in_messages function."""
|
||||
|
||||
@patch("core.file.file_manager.restore_multimodal_content")
|
||||
def test_restores_multimodal_content(self, mock_restore):
|
||||
"""Should restore multimodal content in messages."""
|
||||
# Setup mock
|
||||
restored_content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="restored-base64",
|
||||
mime_type="image/png",
|
||||
file_ref="local:abc123",
|
||||
)
|
||||
mock_restore.return_value = restored_content
|
||||
|
||||
# Create message with truncated content
|
||||
truncated_content = ImagePromptMessageContent(
|
||||
format="png",
|
||||
base64_data="",
|
||||
mime_type="image/png",
|
||||
file_ref="local:abc123",
|
||||
)
|
||||
message = UserPromptMessage(content=[truncated_content])
|
||||
|
||||
result = restore_multimodal_content_in_messages([message])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content[0].base64_data == "restored-base64"
|
||||
mock_restore.assert_called_once()
|
||||
|
||||
def test_passes_through_string_content(self):
|
||||
"""String content should pass through unchanged."""
|
||||
message = UserPromptMessage(content="Hello!")
|
||||
|
||||
result = restore_multimodal_content_in_messages([message])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content == "Hello!"
|
||||
|
||||
def test_passes_through_text_content(self):
|
||||
"""TextPromptMessageContent should pass through unchanged."""
|
||||
text_content = TextPromptMessageContent(data="Hello!")
|
||||
message = UserPromptMessage(content=[text_content])
|
||||
|
||||
result = restore_multimodal_content_in_messages([message])
|
||||
|
||||
assert len(result) == 1
|
||||
assert result[0].content[0].data == "Hello!"
|
||||
@@ -14,7 +14,6 @@ import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useDocLink } from '@/context/i18n'
|
||||
import {
|
||||
|
||||
useAppTriggers,
|
||||
useInvalidateAppTriggers,
|
||||
useUpdateTriggerStatus,
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
<svg xmlns="http://www.w3.org/2000/svg" width="12" height="12" viewBox="0 0 12 12" fill="none">
|
||||
<path d="M2.91992 1.6875C3.23055 1.68754 3.48242 1.93937 3.48242 2.25C3.48242 2.56063 3.23055 2.81246 2.91992 2.8125C2.63855 2.8125 2.41064 3.04041 2.41064 3.32178V5.46436C2.41061 5.61344 2.35148 5.75637 2.24609 5.86182L2.10791 6L2.24609 6.13818C2.35148 6.24363 2.41061 6.38656 2.41064 6.53564V8.67822C2.41064 8.95959 2.63855 9.1875 2.91992 9.1875C3.23055 9.18754 3.48242 9.43937 3.48242 9.75C3.48242 10.0606 3.23055 10.3125 2.91992 10.3125C2.01723 10.3125 1.28564 9.58091 1.28564 8.67822V6.76855L0.914551 6.39795C0.809062 6.29246 0.75 6.14918 0.75 6C0.75 5.85082 0.809062 5.70754 0.914551 5.60205L1.28564 5.23145V3.32178C1.28564 2.41909 2.01723 1.6875 2.91992 1.6875Z" fill="currentColor"/>
|
||||
<path d="M9.08008 1.6875C9.98276 1.68751 10.7144 2.41909 10.7144 3.32178V5.23145L11.085 5.60205C11.1904 5.70754 11.25 5.85082 11.25 6C11.25 6.14918 11.1904 6.29246 11.085 6.39795L10.7144 6.76855V8.67822C10.7144 9.58107 9.98213 10.3125 9.08008 10.3125C8.76942 10.3125 8.51758 10.0607 8.51758 9.75C8.51758 9.43934 8.76942 9.1875 9.08008 9.1875C9.36113 9.18749 9.58936 8.95943 9.58936 8.67822V6.53564C9.58939 6.38654 9.64849 6.24363 9.75391 6.13818L9.89209 6L9.75391 5.86182C9.64849 5.75637 9.58939 5.61346 9.58936 5.46436V3.32178C9.58936 3.04041 9.36144 2.81251 9.08008 2.8125C8.76942 2.8125 8.51758 2.56066 8.51758 2.25C8.51758 1.93934 8.76942 1.6875 9.08008 1.6875Z" fill="currentColor"/>
|
||||
<path d="M5.24707 5.07715C5.36302 5.07715 5.46712 5.14866 5.50879 5.25684L5.8335 6.10059C5.88932 6.24563 6.00388 6.36018 6.14893 6.41602L6.99268 6.74072C7.10086 6.78238 7.17236 6.88648 7.17236 7.00244C7.17229 7.11832 7.10078 7.22202 6.99268 7.26367L6.14893 7.58838C6.00378 7.64424 5.88929 7.75912 5.8335 7.9043L5.50879 8.74756C5.46715 8.8558 5.36307 8.92725 5.24707 8.92725C5.13116 8.92717 5.02746 8.85572 4.98584 8.74756L4.66113 7.9043C4.60526 7.75904 4.49046 7.6442 4.34521 7.58838L3.50195 7.26367C3.39378 7.22205 3.32234 7.11835 3.32227 7.00244C3.32227 6.88645 3.39371 6.78236 3.50195 6.74072L4.34521 6.41602C4.49039 6.36022 4.60523 6.24573 4.66113 6.10059L4.98584 5.25684C5.02749 5.14874 5.13121 5.07723 5.24707 5.07715Z" fill="currentColor"/>
|
||||
<path d="M6.89746 2.87744C6.98013 2.87754 7.05427 2.92822 7.08398 3.00537L7.29053 3.54297C7.34635 3.68816 7.46125 3.80302 7.60645 3.85889L8.14404 4.06543C8.22123 4.0952 8.27246 4.16966 8.27246 4.25244C8.27236 4.33513 8.22116 4.40922 8.14404 4.43896L7.60645 4.64551C7.46125 4.70138 7.34635 4.81624 7.29053 4.96143L7.08398 5.49902C7.05428 5.57614 6.98014 5.62734 6.89746 5.62744C6.81468 5.62744 6.74019 5.57622 6.71045 5.49902L6.50391 4.96143C6.44808 4.81624 6.33318 4.70138 6.18799 4.64551L5.65039 4.43896C5.57328 4.40922 5.52256 4.33513 5.52246 4.25244C5.52246 4.16966 5.5732 4.0952 5.65039 4.06543L6.18799 3.85889C6.33318 3.80302 6.44808 3.68816 6.50391 3.54297L6.71045 3.00537C6.74019 2.92814 6.81469 2.87744 6.89746 2.87744Z" fill="currentColor"/>
|
||||
</svg>
|
||||
|
After Width: | Height: | Size: 2.9 KiB |
@@ -0,0 +1,53 @@
|
||||
{
|
||||
"icon": {
|
||||
"type": "element",
|
||||
"isRootNode": true,
|
||||
"name": "svg",
|
||||
"attributes": {
|
||||
"width": "12",
|
||||
"height": "12",
|
||||
"viewBox": "0 0 12 12",
|
||||
"fill": "none",
|
||||
"xmlns": "http://www.w3.org/2000/svg"
|
||||
},
|
||||
"children": [
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M2.91992 1.6875C3.23055 1.68754 3.48242 1.93937 3.48242 2.25C3.48242 2.56063 3.23055 2.81246 2.91992 2.8125C2.63855 2.8125 2.41064 3.04041 2.41064 3.32178V5.46436C2.41061 5.61344 2.35148 5.75637 2.24609 5.86182L2.10791 6L2.24609 6.13818C2.35148 6.24363 2.41061 6.38656 2.41064 6.53564V8.67822C2.41064 8.95959 2.63855 9.1875 2.91992 9.1875C3.23055 9.18754 3.48242 9.43937 3.48242 9.75C3.48242 10.0606 3.23055 10.3125 2.91992 10.3125C2.01723 10.3125 1.28564 9.58091 1.28564 8.67822V6.76855L0.914551 6.39795C0.809062 6.29246 0.75 6.14918 0.75 6C0.75 5.85082 0.809062 5.70754 0.914551 5.60205L1.28564 5.23145V3.32178C1.28564 2.41909 2.01723 1.6875 2.91992 1.6875Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M9.08008 1.6875C9.98276 1.68751 10.7144 2.41909 10.7144 3.32178V5.23145L11.085 5.60205C11.1904 5.70754 11.25 5.85082 11.25 6C11.25 6.14918 11.1904 6.29246 11.085 6.39795L10.7144 6.76855V8.67822C10.7144 9.58107 9.98213 10.3125 9.08008 10.3125C8.76942 10.3125 8.51758 10.0607 8.51758 9.75C8.51758 9.43934 8.76942 9.1875 9.08008 9.1875C9.36113 9.18749 9.58936 8.95943 9.58936 8.67822V6.53564C9.58939 6.38654 9.64849 6.24363 9.75391 6.13818L9.89209 6L9.75391 5.86182C9.64849 5.75637 9.58939 5.61346 9.58936 5.46436V3.32178C9.58936 3.04041 9.36144 2.81251 9.08008 2.8125C8.76942 2.8125 8.51758 2.56066 8.51758 2.25C8.51758 1.93934 8.76942 1.6875 9.08008 1.6875Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M5.24707 5.07715C5.36302 5.07715 5.46712 5.14866 5.50879 5.25684L5.8335 6.10059C5.88932 6.24563 6.00388 6.36018 6.14893 6.41602L6.99268 6.74072C7.10086 6.78238 7.17236 6.88648 7.17236 7.00244C7.17229 7.11832 7.10078 7.22202 6.99268 7.26367L6.14893 7.58838C6.00378 7.64424 5.88929 7.75912 5.8335 7.9043L5.50879 8.74756C5.46715 8.8558 5.36307 8.92725 5.24707 8.92725C5.13116 8.92717 5.02746 8.85572 4.98584 8.74756L4.66113 7.9043C4.60526 7.75904 4.49046 7.6442 4.34521 7.58838L3.50195 7.26367C3.39378 7.22205 3.32234 7.11835 3.32227 7.00244C3.32227 6.88645 3.39371 6.78236 3.50195 6.74072L4.34521 6.41602C4.49039 6.36022 4.60523 6.24573 4.66113 6.10059L4.98584 5.25684C5.02749 5.14874 5.13121 5.07723 5.24707 5.07715Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
},
|
||||
{
|
||||
"type": "element",
|
||||
"name": "path",
|
||||
"attributes": {
|
||||
"d": "M6.89746 2.87744C6.98013 2.87754 7.05427 2.92822 7.08398 3.00537L7.29053 3.54297C7.34635 3.68816 7.46125 3.80302 7.60645 3.85889L8.14404 4.06543C8.22123 4.0952 8.27246 4.16966 8.27246 4.25244C8.27236 4.33513 8.22116 4.40922 8.14404 4.43896L7.60645 4.64551C7.46125 4.70138 7.34635 4.81624 7.29053 4.96143L7.08398 5.49902C7.05428 5.57614 6.98014 5.62734 6.89746 5.62744C6.81468 5.62744 6.74019 5.57622 6.71045 5.49902L6.50391 4.96143C6.44808 4.81624 6.33318 4.70138 6.18799 4.64551L5.65039 4.43896C5.57328 4.40922 5.52256 4.33513 5.52246 4.25244C5.52246 4.16966 5.5732 4.0952 5.65039 4.06543L6.18799 3.85889C6.33318 3.80302 6.44808 3.68816 6.50391 3.54297L6.71045 3.00537C6.74019 2.92814 6.81469 2.87744 6.89746 2.87744Z",
|
||||
"fill": "currentColor"
|
||||
},
|
||||
"children": []
|
||||
}
|
||||
]
|
||||
},
|
||||
"name": "AssembleVariables"
|
||||
}
|
||||
@@ -0,0 +1,20 @@
|
||||
// GENERATE BY script
|
||||
// DON NOT EDIT IT MANUALLY
|
||||
|
||||
import type { IconData } from '@/app/components/base/icons/IconBase'
|
||||
import * as React from 'react'
|
||||
import IconBase from '@/app/components/base/icons/IconBase'
|
||||
import data from './AssembleVariables.json'
|
||||
|
||||
const Icon = (
|
||||
{
|
||||
ref,
|
||||
...props
|
||||
}: React.SVGProps<SVGSVGElement> & {
|
||||
ref?: React.RefObject<React.RefObject<HTMLOrSVGElement>>
|
||||
},
|
||||
) => <IconBase {...props} ref={ref} data={data as IconData} />
|
||||
|
||||
Icon.displayName = 'AssembleVariables'
|
||||
|
||||
export default Icon
|
||||
@@ -1,4 +1,5 @@
|
||||
export { default as AtSign } from './AtSign'
|
||||
export { default as AssembleVariables } from './AssembleVariables'
|
||||
export { default as Bookmark } from './Bookmark'
|
||||
export { default as Check } from './Check'
|
||||
export { default as CheckDone01 } from './CheckDone01'
|
||||
|
||||
@@ -38,13 +38,16 @@ export const getInputVars = (text: string): ValueSelector[] => {
|
||||
if (!text || typeof text !== 'string')
|
||||
return []
|
||||
|
||||
const allVars = text.match(/\{\{#([^#]*)#\}\}/g)
|
||||
const allVars = text.match(/\{\{[@#]([^@#]*)[@#]\}\}/g)
|
||||
if (allVars && allVars?.length > 0) {
|
||||
// {{#context#}}, {{#query#}} is not input vars
|
||||
const inputVars = allVars
|
||||
.filter(item => item.includes('.'))
|
||||
.map((item) => {
|
||||
const valueSelector = item.replace('{{#', '').replace('#}}', '').split('.')
|
||||
const valueSelector = item
|
||||
.replace(/^\{\{[@#]/, '')
|
||||
.replace(/[@#]\}\}$/, '')
|
||||
.split('.')
|
||||
if (valueSelector[1] === 'sys' && /^\d+$/.test(valueSelector[0]))
|
||||
return valueSelector.slice(1)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import type {
|
||||
} from 'lexical'
|
||||
import type { FC } from 'react'
|
||||
import type {
|
||||
AgentBlockType,
|
||||
ContextBlockType,
|
||||
CurrentBlockType,
|
||||
ErrorMessageBlockType,
|
||||
@@ -103,6 +104,7 @@ export type PromptEditorProps = {
|
||||
currentBlock?: CurrentBlockType
|
||||
errorMessageBlock?: ErrorMessageBlockType
|
||||
lastRunBlock?: LastRunBlockType
|
||||
agentBlock?: AgentBlockType
|
||||
isSupportFileVar?: boolean
|
||||
}
|
||||
|
||||
@@ -128,6 +130,7 @@ const PromptEditor: FC<PromptEditorProps> = ({
|
||||
currentBlock,
|
||||
errorMessageBlock,
|
||||
lastRunBlock,
|
||||
agentBlock,
|
||||
isSupportFileVar,
|
||||
}) => {
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
@@ -139,6 +142,7 @@ const PromptEditor: FC<PromptEditorProps> = ({
|
||||
{
|
||||
replace: TextNode,
|
||||
with: (node: TextNode) => new CustomTextNode(node.__text),
|
||||
withKlass: CustomTextNode,
|
||||
},
|
||||
ContextBlockNode,
|
||||
HistoryBlockNode,
|
||||
@@ -212,6 +216,22 @@ const PromptEditor: FC<PromptEditorProps> = ({
|
||||
lastRunBlock={lastRunBlock}
|
||||
isSupportFileVar={isSupportFileVar}
|
||||
/>
|
||||
{(!agentBlock || agentBlock.show) && (
|
||||
<ComponentPickerBlock
|
||||
triggerString="@"
|
||||
contextBlock={contextBlock}
|
||||
historyBlock={historyBlock}
|
||||
queryBlock={queryBlock}
|
||||
variableBlock={variableBlock}
|
||||
externalToolBlock={externalToolBlock}
|
||||
workflowVariableBlock={workflowVariableBlock}
|
||||
currentBlock={currentBlock}
|
||||
errorMessageBlock={errorMessageBlock}
|
||||
lastRunBlock={lastRunBlock}
|
||||
agentBlock={agentBlock}
|
||||
isSupportFileVar={isSupportFileVar}
|
||||
/>
|
||||
)}
|
||||
<ComponentPickerBlock
|
||||
triggerString="{"
|
||||
contextBlock={contextBlock}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { MenuRenderFn } from '@lexical/react/LexicalTypeaheadMenuPlugin'
|
||||
import type { TextNode } from 'lexical'
|
||||
import type {
|
||||
AgentBlockType,
|
||||
ContextBlockType,
|
||||
CurrentBlockType,
|
||||
ErrorMessageBlockType,
|
||||
@@ -20,7 +21,11 @@ import {
|
||||
} from '@floating-ui/react'
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
|
||||
import { LexicalTypeaheadMenuPlugin } from '@lexical/react/LexicalTypeaheadMenuPlugin'
|
||||
import { KEY_ESCAPE_COMMAND } from 'lexical'
|
||||
import {
|
||||
$getRoot,
|
||||
$insertNodes,
|
||||
KEY_ESCAPE_COMMAND,
|
||||
} from 'lexical'
|
||||
import {
|
||||
Fragment,
|
||||
memo,
|
||||
@@ -29,7 +34,9 @@ import {
|
||||
} from 'react'
|
||||
import ReactDOM from 'react-dom'
|
||||
import { GeneratorType } from '@/app/components/app/configuration/config/automatic/types'
|
||||
import AgentNodeList from '@/app/components/workflow/nodes/_base/components/agent-node-list'
|
||||
import VarReferenceVars from '@/app/components/workflow/nodes/_base/components/variable/var-reference-vars'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import { useBasicTypeaheadTriggerMatch } from '../../hooks'
|
||||
import { $splitNodeContainingQuery } from '../../utils'
|
||||
@@ -38,6 +45,7 @@ import { INSERT_ERROR_MESSAGE_BLOCK_COMMAND } from '../error-message-block'
|
||||
import { INSERT_LAST_RUN_BLOCK_COMMAND } from '../last-run-block'
|
||||
import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '../variable-block'
|
||||
import { INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND } from '../workflow-variable-block'
|
||||
import { $createWorkflowVariableBlockNode } from '../workflow-variable-block/node'
|
||||
import { useOptions } from './hooks'
|
||||
|
||||
type ComponentPickerProps = {
|
||||
@@ -51,6 +59,7 @@ type ComponentPickerProps = {
|
||||
currentBlock?: CurrentBlockType
|
||||
errorMessageBlock?: ErrorMessageBlockType
|
||||
lastRunBlock?: LastRunBlockType
|
||||
agentBlock?: AgentBlockType
|
||||
isSupportFileVar?: boolean
|
||||
}
|
||||
const ComponentPicker = ({
|
||||
@@ -64,6 +73,7 @@ const ComponentPicker = ({
|
||||
currentBlock,
|
||||
errorMessageBlock,
|
||||
lastRunBlock,
|
||||
agentBlock,
|
||||
isSupportFileVar,
|
||||
}: ComponentPickerProps) => {
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
@@ -151,12 +161,55 @@ const ComponentPicker = ({
|
||||
editor.dispatchCommand(KEY_ESCAPE_COMMAND, escapeEvent)
|
||||
}, [editor])
|
||||
|
||||
const handleSelectAssembleVariables = useCallback(() => {
|
||||
editor.update(() => {
|
||||
const match = checkForTriggerMatch(triggerString, editor)
|
||||
if (!match)
|
||||
return
|
||||
const needRemove = $splitNodeContainingQuery(match)
|
||||
if (needRemove)
|
||||
needRemove.remove()
|
||||
})
|
||||
workflowVariableBlock?.onAssembleVariables?.()
|
||||
handleClose()
|
||||
}, [editor, checkForTriggerMatch, triggerString, workflowVariableBlock, handleClose])
|
||||
|
||||
const handleSelectAgent = useCallback((agent: { id: string, title: string }) => {
|
||||
editor.update(() => {
|
||||
const needRemove = $splitNodeContainingQuery(checkForTriggerMatch(triggerString, editor)!)
|
||||
if (needRemove)
|
||||
needRemove.remove()
|
||||
|
||||
const root = $getRoot()
|
||||
const firstChild = root.getFirstChild()
|
||||
if (firstChild) {
|
||||
const selection = firstChild.selectStart()
|
||||
if (selection) {
|
||||
const workflowVariableBlockNode = $createWorkflowVariableBlockNode([agent.id, 'text'], {}, undefined)
|
||||
$insertNodes([workflowVariableBlockNode])
|
||||
}
|
||||
}
|
||||
})
|
||||
agentBlock?.onSelect?.(agent)
|
||||
handleClose()
|
||||
}, [editor, checkForTriggerMatch, triggerString, agentBlock, handleClose])
|
||||
|
||||
const isAgentTrigger = triggerString === '@' && agentBlock?.show
|
||||
const showAssembleVariables = triggerString === '/' && workflowVariableBlock?.showAssembleVariables
|
||||
const agentNodes = agentBlock?.agentNodes || []
|
||||
|
||||
const renderMenu = useCallback<MenuRenderFn<PickerBlockMenuOption>>((
|
||||
anchorElementRef,
|
||||
{ options, selectedIndex, selectOptionAndCleanUp, setHighlightedIndex },
|
||||
) => {
|
||||
if (!(anchorElementRef.current && (allFlattenOptions.length || workflowVariableBlock?.show)))
|
||||
return null
|
||||
if (isAgentTrigger) {
|
||||
if (!(anchorElementRef.current && agentNodes.length))
|
||||
return null
|
||||
}
|
||||
else {
|
||||
if (!(anchorElementRef.current && (allFlattenOptions.length || workflowVariableBlock?.show)))
|
||||
return null
|
||||
}
|
||||
|
||||
setTimeout(() => {
|
||||
if (anchorElementRef.current)
|
||||
@@ -167,9 +220,6 @@ const ComponentPicker = ({
|
||||
<>
|
||||
{
|
||||
ReactDOM.createPortal(
|
||||
// The `LexicalMenu` will try to calculate the position of the floating menu based on the first child.
|
||||
// Since we use floating ui, we need to wrap it with a div to prevent the position calculation being affected.
|
||||
// See https://github.com/facebook/lexical/blob/ac97dfa9e14a73ea2d6934ff566282d7f758e8bb/packages/lexical-react/src/shared/LexicalMenu.ts#L493
|
||||
<div className="h-0 w-0">
|
||||
<div
|
||||
className="w-[260px] rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg"
|
||||
@@ -179,56 +229,75 @@ const ComponentPicker = ({
|
||||
}}
|
||||
ref={refs.setFloating}
|
||||
>
|
||||
{
|
||||
workflowVariableBlock?.show && (
|
||||
<div className="p-1">
|
||||
<VarReferenceVars
|
||||
searchBoxClassName="mt-1"
|
||||
vars={workflowVariableOptions}
|
||||
onChange={(variables: string[]) => {
|
||||
handleSelectWorkflowVariable(variables)
|
||||
}}
|
||||
maxHeightClass="max-h-[34vh]"
|
||||
isSupportFileVar={isSupportFileVar}
|
||||
{isAgentTrigger
|
||||
? (
|
||||
<AgentNodeList
|
||||
nodes={agentNodes.map(node => ({
|
||||
...node,
|
||||
type: BlockEnum.Agent || BlockEnum.LLM,
|
||||
}))}
|
||||
onSelect={handleSelectAgent}
|
||||
onClose={handleClose}
|
||||
onBlur={handleClose}
|
||||
showManageInputField={workflowVariableBlock.showManageInputField}
|
||||
onManageInputField={workflowVariableBlock.onManageInputField}
|
||||
maxHeightClass="max-h-[34vh]"
|
||||
autoFocus={false}
|
||||
isInCodeGeneratorInstructionEditor={currentBlock?.generatorType === GeneratorType.code}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
workflowVariableBlock?.show && !!options.length && (
|
||||
<div className="my-1 h-px w-full -translate-x-1 bg-divider-subtle"></div>
|
||||
)
|
||||
}
|
||||
<div>
|
||||
{
|
||||
options.map((option, index) => (
|
||||
<Fragment key={option.key}>
|
||||
)
|
||||
: (
|
||||
<>
|
||||
{
|
||||
// Divider
|
||||
index !== 0 && options.at(index - 1)?.group !== option.group && (
|
||||
workflowVariableBlock?.show && (
|
||||
<div className="p-1">
|
||||
<VarReferenceVars
|
||||
searchBoxClassName="mt-1"
|
||||
vars={workflowVariableOptions}
|
||||
onChange={(variables: string[]) => {
|
||||
handleSelectWorkflowVariable(variables)
|
||||
}}
|
||||
maxHeightClass="max-h-[34vh]"
|
||||
isSupportFileVar={isSupportFileVar}
|
||||
onClose={handleClose}
|
||||
onBlur={handleClose}
|
||||
showManageInputField={workflowVariableBlock.showManageInputField}
|
||||
onManageInputField={workflowVariableBlock.onManageInputField}
|
||||
showAssembleVariables={showAssembleVariables}
|
||||
onAssembleVariables={showAssembleVariables ? handleSelectAssembleVariables : undefined}
|
||||
autoFocus={false}
|
||||
isInCodeGeneratorInstructionEditor={currentBlock?.generatorType === GeneratorType.code}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
workflowVariableBlock?.show && !!options.length && (
|
||||
<div className="my-1 h-px w-full -translate-x-1 bg-divider-subtle"></div>
|
||||
)
|
||||
}
|
||||
{option.renderMenuOption({
|
||||
queryString,
|
||||
isSelected: selectedIndex === index,
|
||||
onSelect: () => {
|
||||
selectOptionAndCleanUp(option)
|
||||
},
|
||||
onSetHighlight: () => {
|
||||
setHighlightedIndex(index)
|
||||
},
|
||||
})}
|
||||
</Fragment>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
<div>
|
||||
{
|
||||
options.map((option, index) => (
|
||||
<Fragment key={option.key}>
|
||||
{
|
||||
index !== 0 && options.at(index - 1)?.group !== option.group && (
|
||||
<div className="my-1 h-px w-full -translate-x-1 bg-divider-subtle"></div>
|
||||
)
|
||||
}
|
||||
{option.renderMenuOption({
|
||||
queryString,
|
||||
isSelected: selectedIndex === index,
|
||||
onSelect: () => {
|
||||
selectOptionAndCleanUp(option)
|
||||
},
|
||||
onSetHighlight: () => {
|
||||
setHighlightedIndex(index)
|
||||
},
|
||||
})}
|
||||
</Fragment>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>,
|
||||
anchorElementRef.current,
|
||||
@@ -236,7 +305,7 @@ const ComponentPicker = ({
|
||||
}
|
||||
</>
|
||||
)
|
||||
}, [allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, workflowVariableOptions, isSupportFileVar, handleClose, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField])
|
||||
}, [isAgentTrigger, agentNodes, allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, handleSelectAgent, handleClose, workflowVariableOptions, isSupportFileVar, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField, showAssembleVariables, handleSelectAssembleVariables])
|
||||
|
||||
return (
|
||||
<LexicalTypeaheadMenuPlugin
|
||||
|
||||
@@ -21,6 +21,7 @@ import {
|
||||
VariableLabelInEditor,
|
||||
} from '@/app/components/workflow/nodes/_base/components/variable/variable-label'
|
||||
import { Type } from '@/app/components/workflow/nodes/llm/types'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { isExceptionVariable } from '@/app/components/workflow/utils'
|
||||
import { useSelectOrDelete } from '../../hooks'
|
||||
import {
|
||||
@@ -66,6 +67,8 @@ const WorkflowVariableBlockComponent = ({
|
||||
)()
|
||||
const [localWorkflowNodesMap, setLocalWorkflowNodesMap] = useState<WorkflowNodesMap>(workflowNodesMap)
|
||||
const node = localWorkflowNodesMap![variables[isRagVar ? 1 : 0]]
|
||||
const isContextVariable = (node?.type === BlockEnum.Agent || node?.type === BlockEnum.LLM)
|
||||
&& variables[variablesLength - 1] === 'context'
|
||||
|
||||
const isException = isExceptionVariable(varName, node?.type)
|
||||
const variableValid = useMemo(() => {
|
||||
@@ -134,6 +137,9 @@ const WorkflowVariableBlockComponent = ({
|
||||
})
|
||||
}, [node, reactflow, store])
|
||||
|
||||
if (isContextVariable)
|
||||
return <span className="hidden" ref={ref} />
|
||||
|
||||
const Item = (
|
||||
<VariableLabelInEditor
|
||||
nodeType={node?.type}
|
||||
|
||||
@@ -2,6 +2,7 @@ import type { LexicalNode, NodeKey, SerializedLexicalNode } from 'lexical'
|
||||
import type { GetVarType, WorkflowVariableBlockType } from '../../types'
|
||||
import type { Var } from '@/app/components/workflow/types'
|
||||
import { DecoratorNode } from 'lexical'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import WorkflowVariableBlockComponent from './component'
|
||||
|
||||
export type WorkflowNodesMap = WorkflowVariableBlockType['workflowNodesMap']
|
||||
@@ -120,7 +121,12 @@ export class WorkflowVariableBlockNode extends DecoratorNode<React.JSX.Element>
|
||||
}
|
||||
|
||||
getTextContent(): string {
|
||||
return `{{#${this.getVariables().join('.')}#}}`
|
||||
const variables = this.getVariables()
|
||||
const node = this.getWorkflowNodesMap()?.[variables[0]]
|
||||
const isContextVariable = (node?.type === BlockEnum.Agent || node?.type === BlockEnum.LLM)
|
||||
&& variables[variables.length - 1] === 'context'
|
||||
const marker = isContextVariable ? '@' : '#'
|
||||
return `{{${marker}${variables.join('.')}${marker}}}`
|
||||
}
|
||||
}
|
||||
export function $createWorkflowVariableBlockNode(variables: string[], workflowNodesMap: WorkflowNodesMap, getVarType?: GetVarType, environmentVariables?: Var[], conversationVariables?: Var[], ragVariables?: Var[]): WorkflowVariableBlockNode {
|
||||
|
||||
@@ -71,6 +71,19 @@ export type WorkflowVariableBlockType = {
|
||||
getVarType?: GetVarType
|
||||
showManageInputField?: boolean
|
||||
onManageInputField?: () => void
|
||||
showAssembleVariables?: boolean
|
||||
onAssembleVariables?: () => void
|
||||
}
|
||||
|
||||
export type AgentNode = {
|
||||
id: string
|
||||
title: string
|
||||
}
|
||||
|
||||
export type AgentBlockType = {
|
||||
show?: boolean
|
||||
agentNodes?: AgentNode[]
|
||||
onSelect?: (agent: AgentNode) => void
|
||||
}
|
||||
|
||||
export type MenuTextMatch = {
|
||||
|
||||
194
web/app/components/sub-graph/components/config-panel.tsx
Normal file
194
web/app/components/sub-graph/components/config-panel.tsx
Normal file
@@ -0,0 +1,194 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import type { Item } from '@/app/components/base/select'
|
||||
import type { MentionConfig } from '@/app/components/workflow/nodes/_base/types'
|
||||
import type { Node, NodeOutPutVar, ValueSelector } from '@/app/components/workflow/types'
|
||||
import { RiCheckLine } from '@remixicon/react'
|
||||
import { memo, useCallback, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { SimpleSelect } from '@/app/components/base/select'
|
||||
import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor'
|
||||
import Field from '@/app/components/workflow/nodes/_base/components/field'
|
||||
import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker'
|
||||
import Tab, { TabType } from '@/app/components/workflow/nodes/_base/components/workflow-panel/tab'
|
||||
import { CodeLanguage } from '@/app/components/workflow/nodes/code/types'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type ConfigPanelProps = {
|
||||
agentName: string
|
||||
extractorNodeId: string
|
||||
mentionConfig: MentionConfig
|
||||
availableNodes: Node[]
|
||||
availableVars: NodeOutPutVar[]
|
||||
onMentionConfigChange: (config: MentionConfig) => void
|
||||
}
|
||||
|
||||
const ConfigPanel: FC<ConfigPanelProps> = ({
|
||||
agentName,
|
||||
extractorNodeId,
|
||||
mentionConfig,
|
||||
availableNodes,
|
||||
availableVars,
|
||||
onMentionConfigChange,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [tabType, setTabType] = useState<TabType>(TabType.settings)
|
||||
|
||||
const resolvedExtractorId = mentionConfig.extractor_node_id || extractorNodeId
|
||||
|
||||
const selectedOutput = useMemo<ValueSelector>(() => {
|
||||
if (!resolvedExtractorId || !mentionConfig.output_selector?.length)
|
||||
return []
|
||||
|
||||
return [resolvedExtractorId, ...(mentionConfig.output_selector || [])]
|
||||
}, [mentionConfig.output_selector, resolvedExtractorId])
|
||||
|
||||
const handleOutputVarChange = useCallback((value: ValueSelector | string) => {
|
||||
const selector = Array.isArray(value) ? value : []
|
||||
const nextExtractorId = selector[0] || resolvedExtractorId
|
||||
const nextOutputSelector = selector.length > 1 ? selector.slice(1) : []
|
||||
|
||||
onMentionConfigChange({
|
||||
...mentionConfig,
|
||||
extractor_node_id: nextExtractorId,
|
||||
output_selector: nextOutputSelector,
|
||||
})
|
||||
}, [mentionConfig, onMentionConfigChange, resolvedExtractorId])
|
||||
|
||||
const whenOutputNoneOptions = useMemo(() => ([
|
||||
{
|
||||
value: 'raise_error',
|
||||
name: t('subGraphModal.whenOutputNone.error', { ns: 'workflow' }),
|
||||
description: t('subGraphModal.whenOutputNone.errorDesc', { ns: 'workflow' }),
|
||||
},
|
||||
{
|
||||
value: 'use_default',
|
||||
name: t('subGraphModal.whenOutputNone.default', { ns: 'workflow' }),
|
||||
description: t('subGraphModal.whenOutputNone.defaultDesc', { ns: 'workflow' }),
|
||||
},
|
||||
]), [t])
|
||||
const selectedWhenOutputNoneOption = useMemo(() => (
|
||||
whenOutputNoneOptions.find(item => item.value === mentionConfig.null_strategy) ?? whenOutputNoneOptions[0]
|
||||
), [mentionConfig.null_strategy, whenOutputNoneOptions])
|
||||
|
||||
const handleNullStrategyChange = useCallback((item: Item) => {
|
||||
if (typeof item.value !== 'string')
|
||||
return
|
||||
onMentionConfigChange({
|
||||
...mentionConfig,
|
||||
null_strategy: item.value as MentionConfig['null_strategy'],
|
||||
})
|
||||
}, [mentionConfig, onMentionConfigChange])
|
||||
|
||||
const handleDefaultValueChange = useCallback((value: string) => {
|
||||
const trimmed = value.trim()
|
||||
let nextValue: unknown = value
|
||||
if ((trimmed.startsWith('{') && trimmed.endsWith('}')) || (trimmed.startsWith('[') && trimmed.endsWith(']'))) {
|
||||
try {
|
||||
nextValue = JSON.parse(trimmed)
|
||||
}
|
||||
catch {
|
||||
nextValue = value
|
||||
}
|
||||
}
|
||||
|
||||
onMentionConfigChange({
|
||||
...mentionConfig,
|
||||
default_value: nextValue,
|
||||
})
|
||||
}, [mentionConfig, onMentionConfigChange])
|
||||
const defaultValue = mentionConfig.default_value ?? ''
|
||||
const shouldFormatDefaultValue = typeof defaultValue !== 'string'
|
||||
|
||||
return (
|
||||
<div className="flex h-full flex-col">
|
||||
<div className="px-4 pb-2 pt-4">
|
||||
<div className="system-lg-semibold text-text-primary">
|
||||
{t('subGraphModal.internalStructure', { ns: 'workflow' })}
|
||||
</div>
|
||||
<div className="system-sm-regular text-text-tertiary">
|
||||
{t('subGraphModal.internalStructureDesc', { ns: 'workflow', name: agentName })}
|
||||
</div>
|
||||
</div>
|
||||
<div className="px-4 pb-2">
|
||||
<Tab value={tabType} onChange={setTabType} />
|
||||
</div>
|
||||
{tabType === TabType.lastRun && (
|
||||
<div className="flex flex-1 items-center justify-center p-4">
|
||||
<p className="system-sm-regular text-text-tertiary">
|
||||
{t('subGraphModal.noRunHistory', { ns: 'workflow' })}
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
{tabType === TabType.settings && (
|
||||
<div className="flex-1 overflow-y-auto">
|
||||
<div className="space-y-4 px-4 py-4">
|
||||
<Field title={t('subGraphModal.outputVariables', { ns: 'workflow' })}>
|
||||
<VarReferencePicker
|
||||
nodeId={extractorNodeId}
|
||||
readonly={false}
|
||||
isShowNodeName
|
||||
value={selectedOutput}
|
||||
onChange={handleOutputVarChange}
|
||||
availableNodes={availableNodes}
|
||||
availableVars={availableVars}
|
||||
/>
|
||||
</Field>
|
||||
</div>
|
||||
<div className="space-y-4 px-4 py-4">
|
||||
<Field
|
||||
title={t('subGraphModal.whenOutputIsNone', { ns: 'workflow' })}
|
||||
operations={(
|
||||
<div className="flex items-center">
|
||||
<SimpleSelect
|
||||
items={whenOutputNoneOptions}
|
||||
defaultValue={mentionConfig.null_strategy}
|
||||
allowSearch={false}
|
||||
notClearable
|
||||
wrapperClassName="min-w-[160px]"
|
||||
onSelect={handleNullStrategyChange}
|
||||
renderOption={({ item, selected }) => (
|
||||
<div className="flex items-start gap-2">
|
||||
<div className="mt-0.5 flex h-4 w-4 shrink-0 items-center justify-center">
|
||||
{selected && (
|
||||
<RiCheckLine className="h-4 w-4 text-[14px] text-text-accent" />
|
||||
)}
|
||||
</div>
|
||||
<div className="min-w-0">
|
||||
<div className="system-sm-medium text-text-secondary">{item.name}</div>
|
||||
<div className="system-xs-regular mt-0.5 text-text-tertiary">{item.description}</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
>
|
||||
<div className="space-y-2">
|
||||
{selectedWhenOutputNoneOption?.description && (
|
||||
<div className="system-xs-regular text-text-tertiary">
|
||||
{selectedWhenOutputNoneOption.description}
|
||||
</div>
|
||||
)}
|
||||
{mentionConfig.null_strategy === 'use_default' && (
|
||||
<div className={cn('overflow-hidden rounded-lg border border-components-input-border-active bg-components-input-bg-normal p-1')}>
|
||||
<CodeEditor
|
||||
noWrapper
|
||||
language={CodeLanguage.json}
|
||||
value={defaultValue}
|
||||
onChange={handleDefaultValueChange}
|
||||
isJSONStringifyBeauty={shouldFormatDefaultValue}
|
||||
className="min-h-[160px]"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</Field>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(ConfigPanel)
|
||||
@@ -0,0 +1,87 @@
|
||||
import type { FC } from 'react'
|
||||
import type { MentionConfig } from '@/app/components/workflow/nodes/_base/types'
|
||||
import type { NodeOutPutVar } from '@/app/components/workflow/types'
|
||||
import { memo, useMemo } from 'react'
|
||||
import { useStore as useReactFlowStore } from 'reactflow'
|
||||
import { useShallow } from 'zustand/react/shallow'
|
||||
import { useIsChatMode, useWorkflowVariables } from '@/app/components/workflow/hooks'
|
||||
import Panel from '@/app/components/workflow/panel'
|
||||
import { useStore } from '@/app/components/workflow/store'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import ConfigPanel from './config-panel'
|
||||
|
||||
type SubGraphChildrenProps = {
|
||||
agentName: string
|
||||
extractorNodeId: string
|
||||
mentionConfig: MentionConfig
|
||||
onMentionConfigChange: (config: MentionConfig) => void
|
||||
}
|
||||
|
||||
const SubGraphChildren: FC<SubGraphChildrenProps> = ({
|
||||
agentName,
|
||||
extractorNodeId,
|
||||
mentionConfig,
|
||||
onMentionConfigChange,
|
||||
}) => {
|
||||
const { getNodeAvailableVars } = useWorkflowVariables()
|
||||
const isChatMode = useIsChatMode()
|
||||
const nodePanelWidth = useStore(s => s.nodePanelWidth)
|
||||
|
||||
const selectedNode = useReactFlowStore(useShallow((s) => {
|
||||
return s.getNodes().find(node => node.data.selected)
|
||||
}))
|
||||
|
||||
const extractorNode = useReactFlowStore(useShallow((s) => {
|
||||
return s.getNodes().find(node => node.data.type === BlockEnum.LLM)
|
||||
}))
|
||||
|
||||
const availableNodes = useMemo(() => {
|
||||
return extractorNode ? [extractorNode] : []
|
||||
}, [extractorNode])
|
||||
|
||||
const availableVars = useMemo<NodeOutPutVar[]>(() => {
|
||||
if (!extractorNode)
|
||||
return []
|
||||
|
||||
const vars = getNodeAvailableVars({
|
||||
beforeNodes: [extractorNode],
|
||||
isChatMode,
|
||||
filterVar: () => true,
|
||||
})
|
||||
return vars.filter(item => item.nodeId === extractorNode.id)
|
||||
}, [extractorNode, getNodeAvailableVars, isChatMode])
|
||||
|
||||
const panelRight = useMemo(() => {
|
||||
if (selectedNode)
|
||||
return null
|
||||
|
||||
return (
|
||||
<div className="relative mr-1 h-full">
|
||||
<div
|
||||
className="flex h-full flex-col rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg"
|
||||
style={{ width: `${nodePanelWidth}px` }}
|
||||
>
|
||||
<ConfigPanel
|
||||
agentName={agentName}
|
||||
extractorNodeId={extractorNodeId}
|
||||
mentionConfig={mentionConfig}
|
||||
availableNodes={availableNodes}
|
||||
availableVars={availableVars}
|
||||
onMentionConfigChange={onMentionConfigChange}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}, [agentName, availableNodes, availableVars, extractorNodeId, mentionConfig, nodePanelWidth, onMentionConfigChange, selectedNode])
|
||||
|
||||
return (
|
||||
<Panel
|
||||
withHeader={false}
|
||||
components={{
|
||||
right: panelRight,
|
||||
}}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(SubGraphChildren)
|
||||
109
web/app/components/sub-graph/components/sub-graph-main.tsx
Normal file
109
web/app/components/sub-graph/components/sub-graph-main.tsx
Normal file
@@ -0,0 +1,109 @@
|
||||
import type { FC } from 'react'
|
||||
import type { Viewport } from 'reactflow'
|
||||
import type { SyncWorkflowDraft, SyncWorkflowDraftCallback } from '../types'
|
||||
import type { Shape as HooksStoreShape } from '@/app/components/workflow/hooks-store'
|
||||
import type { MentionConfig } from '@/app/components/workflow/nodes/_base/types'
|
||||
import type { Edge, Node } from '@/app/components/workflow/types'
|
||||
import { useCallback, useMemo } from 'react'
|
||||
import { useStoreApi } from 'reactflow'
|
||||
import { WorkflowWithInnerContext } from '@/app/components/workflow'
|
||||
import { useSetWorkflowVarsWithValue } from '@/app/components/workflow/hooks/use-fetch-workflow-inspect-vars'
|
||||
import { useInspectVarsCrudCommon } from '@/app/components/workflow/hooks/use-inspect-vars-crud-common'
|
||||
import { FlowType } from '@/types/common'
|
||||
import { useAvailableNodesMetaData } from '../hooks'
|
||||
import SubGraphChildren from './sub-graph-children'
|
||||
|
||||
type SubGraphMainProps = {
|
||||
nodes: Node[]
|
||||
edges: Edge[]
|
||||
viewport: Viewport
|
||||
agentName: string
|
||||
extractorNodeId: string
|
||||
configsMap?: HooksStoreShape['configsMap']
|
||||
mentionConfig: MentionConfig
|
||||
onMentionConfigChange: (config: MentionConfig) => void
|
||||
onSave?: (nodes: Node[], edges: Edge[]) => void
|
||||
onSyncWorkflowDraft?: SyncWorkflowDraft
|
||||
}
|
||||
|
||||
const SubGraphMain: FC<SubGraphMainProps> = ({
|
||||
nodes,
|
||||
edges,
|
||||
viewport,
|
||||
agentName,
|
||||
extractorNodeId,
|
||||
configsMap,
|
||||
mentionConfig,
|
||||
onMentionConfigChange,
|
||||
onSave,
|
||||
onSyncWorkflowDraft,
|
||||
}) => {
|
||||
const reactFlowStore = useStoreApi()
|
||||
const availableNodesMetaData = useAvailableNodesMetaData()
|
||||
const flowType = configsMap?.flowType ?? FlowType.appFlow
|
||||
const flowId = configsMap?.flowId ?? ''
|
||||
const { fetchInspectVars } = useSetWorkflowVarsWithValue({
|
||||
flowType,
|
||||
flowId,
|
||||
})
|
||||
const inspectVarsCrud = useInspectVarsCrudCommon({
|
||||
flowType,
|
||||
flowId,
|
||||
})
|
||||
|
||||
const handleSyncSubGraphDraft = useCallback(async () => {
|
||||
const { getNodes, edges } = reactFlowStore.getState()
|
||||
await onSave?.(getNodes() as Node[], edges as Edge[])
|
||||
}, [onSave, reactFlowStore])
|
||||
|
||||
const handleSyncWorkflowDraft = useCallback(async (
|
||||
notRefreshWhenSyncError?: boolean,
|
||||
callback?: SyncWorkflowDraftCallback,
|
||||
) => {
|
||||
try {
|
||||
await handleSyncSubGraphDraft()
|
||||
if (onSyncWorkflowDraft) {
|
||||
await onSyncWorkflowDraft(notRefreshWhenSyncError, callback)
|
||||
return
|
||||
}
|
||||
callback?.onSuccess?.()
|
||||
}
|
||||
catch {
|
||||
callback?.onError?.()
|
||||
}
|
||||
finally {
|
||||
callback?.onSettled?.()
|
||||
}
|
||||
}, [handleSyncSubGraphDraft, onSyncWorkflowDraft])
|
||||
|
||||
const hooksStore = useMemo(() => ({
|
||||
interactionMode: 'subgraph',
|
||||
availableNodesMetaData,
|
||||
configsMap,
|
||||
fetchInspectVars,
|
||||
...inspectVarsCrud,
|
||||
doSyncWorkflowDraft: handleSyncWorkflowDraft,
|
||||
syncWorkflowDraftWhenPageClose: handleSyncSubGraphDraft,
|
||||
}), [availableNodesMetaData, configsMap, fetchInspectVars, handleSyncSubGraphDraft, handleSyncWorkflowDraft, inspectVarsCrud])
|
||||
|
||||
return (
|
||||
<WorkflowWithInnerContext
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
viewport={viewport}
|
||||
hooksStore={hooksStore as any}
|
||||
allowSelectionWhenReadOnly
|
||||
canvasReadOnly
|
||||
interactionMode="subgraph"
|
||||
>
|
||||
<SubGraphChildren
|
||||
agentName={agentName}
|
||||
extractorNodeId={extractorNodeId}
|
||||
mentionConfig={mentionConfig}
|
||||
onMentionConfigChange={onMentionConfigChange}
|
||||
/>
|
||||
</WorkflowWithInnerContext>
|
||||
)
|
||||
}
|
||||
|
||||
export default SubGraphMain
|
||||
2
web/app/components/sub-graph/hooks/index.ts
Normal file
2
web/app/components/sub-graph/hooks/index.ts
Normal file
@@ -0,0 +1,2 @@
|
||||
export { useAvailableNodesMetaData } from './use-available-nodes-meta-data'
|
||||
export { useSubGraphNodes } from './use-sub-graph-nodes'
|
||||
@@ -0,0 +1,43 @@
|
||||
import type { AvailableNodesMetaData } from '@/app/components/workflow/hooks-store/store'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { WORKFLOW_COMMON_NODES } from '@/app/components/workflow/constants/node'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
|
||||
export const useAvailableNodesMetaData = () => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const availableNodesMetaData = useMemo(() => WORKFLOW_COMMON_NODES.map((node) => {
|
||||
const { metaData } = node
|
||||
const title = t(`blocks.${metaData.type}`, { ns: 'workflow' })
|
||||
const description = t(`blocksAbout.${metaData.type}`, { ns: 'workflow' })
|
||||
return {
|
||||
...node,
|
||||
metaData: {
|
||||
...metaData,
|
||||
title,
|
||||
description,
|
||||
},
|
||||
defaultValue: {
|
||||
...node.defaultValue,
|
||||
type: metaData.type,
|
||||
title,
|
||||
},
|
||||
}
|
||||
}), [t])
|
||||
|
||||
const availableNodesMetaDataMap = useMemo(() => availableNodesMetaData.reduce((acc, node) => {
|
||||
acc![node.metaData.type] = node
|
||||
return acc
|
||||
}, {} as AvailableNodesMetaData['nodesMap']), [availableNodesMetaData])
|
||||
|
||||
return useMemo(() => {
|
||||
return {
|
||||
nodes: availableNodesMetaData,
|
||||
nodesMap: {
|
||||
...availableNodesMetaDataMap,
|
||||
[BlockEnum.VariableAssigner]: availableNodesMetaDataMap?.[BlockEnum.VariableAggregator],
|
||||
},
|
||||
}
|
||||
}, [availableNodesMetaData, availableNodesMetaDataMap])
|
||||
}
|
||||
20
web/app/components/sub-graph/hooks/use-sub-graph-nodes.ts
Normal file
20
web/app/components/sub-graph/hooks/use-sub-graph-nodes.ts
Normal file
@@ -0,0 +1,20 @@
|
||||
import type { Edge, Node } from '@/app/components/workflow/types'
|
||||
import { useMemo } from 'react'
|
||||
import { initialEdges, initialNodes } from '@/app/components/workflow/utils'
|
||||
|
||||
export const useSubGraphNodes = (nodes: Node[], edges: Edge[]) => {
|
||||
const processedNodes = useMemo(
|
||||
() => initialNodes(nodes, edges),
|
||||
[nodes, edges],
|
||||
)
|
||||
|
||||
const processedEdges = useMemo(
|
||||
() => initialEdges(edges, nodes),
|
||||
[edges, nodes],
|
||||
)
|
||||
|
||||
return {
|
||||
nodes: processedNodes,
|
||||
edges: processedEdges,
|
||||
}
|
||||
}
|
||||
212
web/app/components/sub-graph/index.tsx
Normal file
212
web/app/components/sub-graph/index.tsx
Normal file
@@ -0,0 +1,212 @@
|
||||
import type { FC } from 'react'
|
||||
import type { Viewport } from 'reactflow'
|
||||
import type { SubGraphProps } from './types'
|
||||
import type { InjectWorkflowStoreSliceFn } from '@/app/components/workflow/store'
|
||||
import type { PromptItem, PromptTemplateItem } from '@/app/components/workflow/types'
|
||||
import { memo, useEffect, useMemo } from 'react'
|
||||
import WorkflowWithDefaultContext from '@/app/components/workflow'
|
||||
import { NODE_WIDTH_X_OFFSET, START_INITIAL_POSITION } from '@/app/components/workflow/constants'
|
||||
import { WorkflowContextProvider } from '@/app/components/workflow/context'
|
||||
import { useStore } from '@/app/components/workflow/store'
|
||||
import { BlockEnum, EditionType, isPromptMessageContext, PromptRole } from '@/app/components/workflow/types'
|
||||
import SubGraphMain from './components/sub-graph-main'
|
||||
import { useSubGraphNodes } from './hooks'
|
||||
import { createSubGraphSlice } from './store'
|
||||
|
||||
const SUB_GRAPH_EDGE_GAP = 160
|
||||
const SUB_GRAPH_ENTRY_POSITION = {
|
||||
x: START_INITIAL_POSITION.x,
|
||||
y: 150,
|
||||
}
|
||||
const SUB_GRAPH_LLM_POSITION = {
|
||||
x: SUB_GRAPH_ENTRY_POSITION.x + NODE_WIDTH_X_OFFSET - SUB_GRAPH_EDGE_GAP,
|
||||
y: SUB_GRAPH_ENTRY_POSITION.y,
|
||||
}
|
||||
|
||||
const defaultViewport: Viewport = {
|
||||
x: SUB_GRAPH_EDGE_GAP,
|
||||
y: 50,
|
||||
zoom: 1.3,
|
||||
}
|
||||
|
||||
const SubGraphContent: FC<SubGraphProps> = (props) => {
|
||||
const {
|
||||
toolNodeId,
|
||||
paramKey,
|
||||
agentName,
|
||||
agentNodeId,
|
||||
mentionConfig,
|
||||
onMentionConfigChange,
|
||||
extractorNode,
|
||||
toolParamValue,
|
||||
parentAvailableNodes,
|
||||
parentAvailableVars,
|
||||
configsMap,
|
||||
onSave,
|
||||
onSyncWorkflowDraft,
|
||||
} = props
|
||||
|
||||
const setParentAvailableVars = useStore(state => state.setParentAvailableVars)
|
||||
const setParentAvailableNodes = useStore(state => state.setParentAvailableNodes)
|
||||
|
||||
useEffect(() => {
|
||||
setParentAvailableVars?.(parentAvailableVars || [])
|
||||
setParentAvailableNodes?.(parentAvailableNodes || [])
|
||||
}, [parentAvailableNodes, parentAvailableVars, setParentAvailableNodes, setParentAvailableVars])
|
||||
|
||||
const promptText = useMemo(() => {
|
||||
if (!toolParamValue)
|
||||
return ''
|
||||
// Reason: escape agent id before building a regex pattern.
|
||||
const escapedAgentId = agentNodeId.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')
|
||||
const leadingPattern = new RegExp(`^\\{\\{[@#]${escapedAgentId}\\.context[@#]\\}\\}`)
|
||||
return toolParamValue.replace(leadingPattern, '')
|
||||
}, [agentNodeId, toolParamValue])
|
||||
|
||||
const startNode = useMemo(() => {
|
||||
return {
|
||||
id: 'subgraph-source',
|
||||
type: 'custom',
|
||||
position: SUB_GRAPH_ENTRY_POSITION,
|
||||
data: {
|
||||
type: BlockEnum.Start,
|
||||
title: agentName,
|
||||
desc: '',
|
||||
_connectedSourceHandleIds: ['source'],
|
||||
_connectedTargetHandleIds: [],
|
||||
_subGraphEntry: true,
|
||||
_iconTypeOverride: BlockEnum.Agent,
|
||||
selected: false,
|
||||
variables: [],
|
||||
},
|
||||
selected: false,
|
||||
selectable: false,
|
||||
draggable: false,
|
||||
connectable: false,
|
||||
focusable: false,
|
||||
deletable: false,
|
||||
}
|
||||
}, [agentName])
|
||||
|
||||
const extractorDisplayNode = useMemo(() => {
|
||||
if (!extractorNode)
|
||||
return null
|
||||
|
||||
const applyPromptText = (item: PromptItem) => {
|
||||
if (item.edition_type === EditionType.jinja2) {
|
||||
return {
|
||||
...item,
|
||||
text: promptText,
|
||||
jinja2_text: promptText,
|
||||
}
|
||||
}
|
||||
return { ...item, text: promptText }
|
||||
}
|
||||
|
||||
const nextPromptTemplate = (() => {
|
||||
const template = extractorNode.data.prompt_template
|
||||
if (!Array.isArray(template))
|
||||
return applyPromptText(template as PromptItem)
|
||||
|
||||
const userIndex = template.findIndex(
|
||||
item => !isPromptMessageContext(item) && (item as PromptItem).role === PromptRole.user,
|
||||
)
|
||||
if (userIndex >= 0) {
|
||||
return template.map((item, index) => {
|
||||
if (index !== userIndex)
|
||||
return item
|
||||
return applyPromptText(item as PromptItem)
|
||||
}) as PromptTemplateItem[]
|
||||
}
|
||||
|
||||
const useJinja = template.some(
|
||||
item => !isPromptMessageContext(item) && (item as PromptItem).edition_type === EditionType.jinja2,
|
||||
)
|
||||
const defaultUserPrompt: PromptItem = useJinja
|
||||
? {
|
||||
role: PromptRole.user,
|
||||
text: promptText,
|
||||
jinja2_text: promptText,
|
||||
edition_type: EditionType.jinja2,
|
||||
}
|
||||
: { role: PromptRole.user, text: promptText }
|
||||
return [...template, defaultUserPrompt] as PromptTemplateItem[]
|
||||
})()
|
||||
|
||||
return {
|
||||
...extractorNode,
|
||||
hidden: false,
|
||||
selected: false,
|
||||
position: SUB_GRAPH_LLM_POSITION,
|
||||
data: {
|
||||
...extractorNode.data,
|
||||
selected: false,
|
||||
prompt_template: nextPromptTemplate,
|
||||
},
|
||||
}
|
||||
}, [extractorNode, promptText])
|
||||
|
||||
const nodesSource = useMemo(() => {
|
||||
if (!extractorDisplayNode)
|
||||
return [startNode]
|
||||
|
||||
return [startNode, extractorDisplayNode]
|
||||
}, [extractorDisplayNode, startNode])
|
||||
|
||||
const edgesSource = useMemo(() => {
|
||||
if (!extractorDisplayNode)
|
||||
return []
|
||||
|
||||
return [
|
||||
{
|
||||
id: `${startNode.id}-${extractorDisplayNode.id}`,
|
||||
source: startNode.id,
|
||||
sourceHandle: 'source',
|
||||
target: extractorDisplayNode.id,
|
||||
targetHandle: 'target',
|
||||
type: 'custom',
|
||||
selectable: false,
|
||||
data: {
|
||||
sourceType: BlockEnum.Start,
|
||||
targetType: BlockEnum.LLM,
|
||||
_isTemp: true,
|
||||
_isSubGraphTemp: true,
|
||||
},
|
||||
},
|
||||
]
|
||||
}, [extractorDisplayNode, startNode])
|
||||
|
||||
const { nodes, edges } = useSubGraphNodes(nodesSource, edgesSource)
|
||||
|
||||
return (
|
||||
<WorkflowWithDefaultContext
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
>
|
||||
<SubGraphMain
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
viewport={defaultViewport}
|
||||
agentName={agentName}
|
||||
extractorNodeId={`${toolNodeId}_ext_${paramKey}`}
|
||||
configsMap={configsMap}
|
||||
mentionConfig={mentionConfig}
|
||||
onMentionConfigChange={onMentionConfigChange}
|
||||
onSave={onSave}
|
||||
onSyncWorkflowDraft={onSyncWorkflowDraft}
|
||||
/>
|
||||
</WorkflowWithDefaultContext>
|
||||
)
|
||||
}
|
||||
|
||||
const SubGraph: FC<SubGraphProps> = (props) => {
|
||||
return (
|
||||
<WorkflowContextProvider
|
||||
injectWorkflowStoreSliceFn={createSubGraphSlice as InjectWorkflowStoreSliceFn}
|
||||
>
|
||||
<SubGraphContent {...props} />
|
||||
</WorkflowContextProvider>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(SubGraph)
|
||||
12
web/app/components/sub-graph/store/index.ts
Normal file
12
web/app/components/sub-graph/store/index.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
import type { CreateSubGraphSlice, SubGraphSliceShape } from '../types'
|
||||
|
||||
const initialState: Omit<SubGraphSliceShape, 'setParentAvailableVars' | 'setParentAvailableNodes'> = {
|
||||
parentAvailableVars: [],
|
||||
parentAvailableNodes: [],
|
||||
}
|
||||
|
||||
export const createSubGraphSlice: CreateSubGraphSlice = set => ({
|
||||
...initialState,
|
||||
setParentAvailableVars: vars => set(() => ({ parentAvailableVars: vars })),
|
||||
setParentAvailableNodes: nodes => set(() => ({ parentAvailableNodes: nodes })),
|
||||
})
|
||||
42
web/app/components/sub-graph/types.ts
Normal file
42
web/app/components/sub-graph/types.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
import type { StateCreator } from 'zustand'
|
||||
import type { Shape as HooksStoreShape } from '@/app/components/workflow/hooks-store'
|
||||
import type { MentionConfig } from '@/app/components/workflow/nodes/_base/types'
|
||||
import type { LLMNodeType } from '@/app/components/workflow/nodes/llm/types'
|
||||
import type { Edge, Node, NodeOutPutVar, ValueSelector } from '@/app/components/workflow/types'
|
||||
|
||||
export type SyncWorkflowDraftCallback = {
|
||||
onSuccess?: () => void
|
||||
onError?: () => void
|
||||
onSettled?: () => void
|
||||
}
|
||||
|
||||
export type SyncWorkflowDraft = (
|
||||
notRefreshWhenSyncError?: boolean,
|
||||
callback?: SyncWorkflowDraftCallback,
|
||||
) => Promise<void>
|
||||
|
||||
export type SubGraphProps = {
|
||||
toolNodeId: string
|
||||
paramKey: string
|
||||
sourceVariable: ValueSelector
|
||||
agentNodeId: string
|
||||
agentName: string
|
||||
configsMap?: HooksStoreShape['configsMap']
|
||||
mentionConfig: MentionConfig
|
||||
onMentionConfigChange: (config: MentionConfig) => void
|
||||
extractorNode?: Node<LLMNodeType>
|
||||
toolParamValue?: string
|
||||
parentAvailableNodes?: Node[]
|
||||
parentAvailableVars?: NodeOutPutVar[]
|
||||
onSave?: (nodes: Node[], edges: Edge[]) => void
|
||||
onSyncWorkflowDraft?: SyncWorkflowDraft
|
||||
}
|
||||
|
||||
export type SubGraphSliceShape = {
|
||||
parentAvailableVars: NodeOutPutVar[]
|
||||
parentAvailableNodes: Node[]
|
||||
setParentAvailableVars: (vars: NodeOutPutVar[]) => void
|
||||
setParentAvailableNodes: (nodes: Node[]) => void
|
||||
}
|
||||
|
||||
export type CreateSubGraphSlice = StateCreator<SubGraphSliceShape>
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { FC } from 'react'
|
||||
import { memo } from 'react'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
import { Folder as FolderLine } from '@/app/components/base/icons/src/vender/line/files'
|
||||
import {
|
||||
Agent,
|
||||
Answer,
|
||||
@@ -54,6 +55,7 @@ const DEFAULT_ICON_MAP: Record<BlockEnum, React.ComponentType<{ className: strin
|
||||
[BlockEnum.TemplateTransform]: TemplatingTransform,
|
||||
[BlockEnum.VariableAssigner]: VariableX,
|
||||
[BlockEnum.VariableAggregator]: VariableX,
|
||||
[BlockEnum.Group]: FolderLine,
|
||||
[BlockEnum.Assigner]: Assigner,
|
||||
[BlockEnum.Tool]: VariableX,
|
||||
[BlockEnum.IterationStart]: VariableX,
|
||||
@@ -97,6 +99,7 @@ const ICON_CONTAINER_BG_COLOR_MAP: Record<string, string> = {
|
||||
[BlockEnum.VariableAssigner]: 'bg-util-colors-blue-blue-500',
|
||||
[BlockEnum.VariableAggregator]: 'bg-util-colors-blue-blue-500',
|
||||
[BlockEnum.Tool]: 'bg-util-colors-blue-blue-500',
|
||||
[BlockEnum.Group]: 'bg-util-colors-blue-blue-500',
|
||||
[BlockEnum.Assigner]: 'bg-util-colors-blue-blue-500',
|
||||
[BlockEnum.ParameterExtractor]: 'bg-util-colors-blue-blue-500',
|
||||
[BlockEnum.DocExtractor]: 'bg-util-colors-green-green-500',
|
||||
|
||||
@@ -131,6 +131,11 @@ export const SUPPORT_OUTPUT_VARS_NODE = [
|
||||
]
|
||||
|
||||
export const AGENT_OUTPUT_STRUCT: Var[] = [
|
||||
{
|
||||
variable: 'context',
|
||||
type: VarType.arrayObject,
|
||||
schemaType: 'List[promptMessage]',
|
||||
},
|
||||
{
|
||||
variable: 'usage',
|
||||
type: VarType.object,
|
||||
@@ -142,6 +147,11 @@ export const LLM_OUTPUT_STRUCT: Var[] = [
|
||||
variable: 'text',
|
||||
type: VarType.string,
|
||||
},
|
||||
{
|
||||
variable: 'context',
|
||||
type: VarType.arrayObject,
|
||||
schemaType: 'List[promptMessage]',
|
||||
},
|
||||
{
|
||||
variable: 'reasoning_content',
|
||||
type: VarType.string,
|
||||
|
||||
@@ -25,7 +25,8 @@ import {
|
||||
useAvailableBlocks,
|
||||
useNodesInteractions,
|
||||
} from './hooks'
|
||||
import { NodeRunningStatus } from './types'
|
||||
import { useHooksStore } from './hooks-store'
|
||||
import { BlockEnum, NodeRunningStatus } from './types'
|
||||
import { getEdgeColor } from './utils'
|
||||
|
||||
const CustomEdge = ({
|
||||
@@ -56,6 +57,8 @@ const CustomEdge = ({
|
||||
})
|
||||
const [open, setOpen] = useState(false)
|
||||
const { handleNodeAdd } = useNodesInteractions()
|
||||
const interactionMode = useHooksStore(s => s.interactionMode)
|
||||
const allowGraphActions = interactionMode !== 'subgraph'
|
||||
const { availablePrevBlocks } = useAvailableBlocks((data as Edge['data'])!.targetType, (data as Edge['data'])?.isInIteration || (data as Edge['data'])?.isInLoop)
|
||||
const { availableNextBlocks } = useAvailableBlocks((data as Edge['data'])!.sourceType, (data as Edge['data'])?.isInIteration || (data as Edge['data'])?.isInLoop)
|
||||
const {
|
||||
@@ -136,35 +139,37 @@ const CustomEdge = ({
|
||||
stroke,
|
||||
strokeWidth: 2,
|
||||
opacity: data._dimmed ? 0.3 : (data._waitingRun ? 0.7 : 1),
|
||||
strokeDasharray: data._isTemp ? '8 8' : undefined,
|
||||
strokeDasharray: (data._isTemp && !data._isSubGraphTemp && data.sourceType !== BlockEnum.Group && data.targetType !== BlockEnum.Group) ? '8 8' : undefined,
|
||||
}}
|
||||
/>
|
||||
<EdgeLabelRenderer>
|
||||
<div
|
||||
className={cn(
|
||||
'nopan nodrag hover:scale-125',
|
||||
data?._hovering ? 'block' : 'hidden',
|
||||
open && '!block',
|
||||
data.isInIteration && `z-[${ITERATION_CHILDREN_Z_INDEX}]`,
|
||||
data.isInLoop && `z-[${LOOP_CHILDREN_Z_INDEX}]`,
|
||||
)}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
transform: `translate(-50%, -50%) translate(${labelX}px, ${labelY}px)`,
|
||||
pointerEvents: 'all',
|
||||
opacity: data._waitingRun ? 0.7 : 1,
|
||||
}}
|
||||
>
|
||||
<BlockSelector
|
||||
open={open}
|
||||
onOpenChange={handleOpenChange}
|
||||
asChild
|
||||
onSelect={handleInsert}
|
||||
availableBlocksTypes={intersection(availablePrevBlocks, availableNextBlocks)}
|
||||
triggerClassName={() => 'hover:scale-150 transition-all'}
|
||||
/>
|
||||
</div>
|
||||
</EdgeLabelRenderer>
|
||||
{allowGraphActions && (
|
||||
<EdgeLabelRenderer>
|
||||
<div
|
||||
className={cn(
|
||||
'nopan nodrag hover:scale-125',
|
||||
data?._hovering ? 'block' : 'hidden',
|
||||
open && '!block',
|
||||
data.isInIteration && `z-[${ITERATION_CHILDREN_Z_INDEX}]`,
|
||||
data.isInLoop && `z-[${LOOP_CHILDREN_Z_INDEX}]`,
|
||||
)}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
transform: `translate(-50%, -50%) translate(${labelX}px, ${labelY}px)`,
|
||||
pointerEvents: 'all',
|
||||
opacity: data._waitingRun ? 0.7 : 1,
|
||||
}}
|
||||
>
|
||||
<BlockSelector
|
||||
open={open}
|
||||
onOpenChange={handleOpenChange}
|
||||
asChild
|
||||
onSelect={handleInsert}
|
||||
availableBlocksTypes={intersection(availablePrevBlocks, availableNextBlocks)}
|
||||
triggerClassName={() => 'hover:scale-150 transition-all'}
|
||||
/>
|
||||
</div>
|
||||
</EdgeLabelRenderer>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
11
web/app/components/workflow/custom-group-node/constants.ts
Normal file
11
web/app/components/workflow/custom-group-node/constants.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
export const CUSTOM_GROUP_NODE = 'custom-group'
|
||||
export const CUSTOM_GROUP_INPUT_NODE = 'custom-group-input'
|
||||
export const CUSTOM_GROUP_EXIT_PORT_NODE = 'custom-group-exit-port'
|
||||
|
||||
export const GROUP_CHILDREN_Z_INDEX = 1002
|
||||
|
||||
export const UI_ONLY_GROUP_NODE_TYPES = new Set([
|
||||
CUSTOM_GROUP_NODE,
|
||||
CUSTOM_GROUP_INPUT_NODE,
|
||||
CUSTOM_GROUP_EXIT_PORT_NODE,
|
||||
])
|
||||
@@ -0,0 +1,54 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { CustomGroupExitPortNodeData } from './types'
|
||||
import { memo } from 'react'
|
||||
import { Handle, Position } from 'reactflow'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type CustomGroupExitPortNodeProps = {
|
||||
id: string
|
||||
data: CustomGroupExitPortNodeData
|
||||
}
|
||||
|
||||
const CustomGroupExitPortNode: FC<CustomGroupExitPortNodeProps> = ({ id: _id, data }) => {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'flex items-center justify-center',
|
||||
'h-8 w-8 rounded-full',
|
||||
'bg-util-colors-green-green-500 shadow-md',
|
||||
data.selected && 'ring-2 ring-primary-400',
|
||||
)}
|
||||
>
|
||||
{/* Target handle - receives internal connections from leaf nodes */}
|
||||
<Handle
|
||||
id="target"
|
||||
type="target"
|
||||
position={Position.Left}
|
||||
className="!h-2 !w-2 !border-0 !bg-white"
|
||||
/>
|
||||
|
||||
{/* Source handle - connects to external nodes */}
|
||||
<Handle
|
||||
id="source"
|
||||
type="source"
|
||||
position={Position.Right}
|
||||
className="!h-2 !w-2 !border-0 !bg-white"
|
||||
/>
|
||||
|
||||
{/* Icon */}
|
||||
<svg
|
||||
className="h-4 w-4 text-white"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth={2}
|
||||
>
|
||||
<path d="M5 12h14M12 5l7 7-7 7" />
|
||||
</svg>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(CustomGroupExitPortNode)
|
||||
@@ -0,0 +1,55 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { CustomGroupInputNodeData } from './types'
|
||||
import { memo } from 'react'
|
||||
import { Handle, Position } from 'reactflow'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type CustomGroupInputNodeProps = {
|
||||
id: string
|
||||
data: CustomGroupInputNodeData
|
||||
}
|
||||
|
||||
const CustomGroupInputNode: FC<CustomGroupInputNodeProps> = ({ id: _id, data }) => {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'flex items-center justify-center',
|
||||
'h-8 w-8 rounded-full',
|
||||
'bg-util-colors-blue-blue-500 shadow-md',
|
||||
data.selected && 'ring-2 ring-primary-400',
|
||||
)}
|
||||
>
|
||||
{/* Target handle - receives external connections */}
|
||||
<Handle
|
||||
id="target"
|
||||
type="target"
|
||||
position={Position.Left}
|
||||
className="!h-2 !w-2 !border-0 !bg-white"
|
||||
/>
|
||||
|
||||
{/* Source handle - connects to entry nodes */}
|
||||
<Handle
|
||||
id="source"
|
||||
type="source"
|
||||
position={Position.Right}
|
||||
className="!h-2 !w-2 !border-0 !bg-white"
|
||||
/>
|
||||
|
||||
{/* Icon */}
|
||||
<svg
|
||||
className="h-4 w-4 text-white"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth={2}
|
||||
>
|
||||
<path d="M9 12l2 2 4-4" />
|
||||
<circle cx="12" cy="12" r="10" />
|
||||
</svg>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(CustomGroupInputNode)
|
||||
@@ -0,0 +1,94 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { CustomGroupNodeData } from './types'
|
||||
import { memo } from 'react'
|
||||
import { Handle, Position } from 'reactflow'
|
||||
import { Plus02 } from '@/app/components/base/icons/src/vender/line/general'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type CustomGroupNodeProps = {
|
||||
id: string
|
||||
data: CustomGroupNodeData
|
||||
}
|
||||
|
||||
const CustomGroupNode: FC<CustomGroupNodeProps> = ({ data }) => {
|
||||
const { group } = data
|
||||
const exitPorts = group.exitPorts ?? []
|
||||
const connectedSourceHandleIds = data._connectedSourceHandleIds ?? []
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'bg-workflow-block-parma-bg/50 group relative rounded-2xl border-2 border-dashed border-components-panel-border',
|
||||
data.selected && 'border-primary-400',
|
||||
)}
|
||||
style={{
|
||||
width: data.width || 280,
|
||||
height: data.height || 200,
|
||||
}}
|
||||
>
|
||||
{/* Group Header */}
|
||||
<div className="absolute -top-7 left-0 flex items-center gap-1 px-2">
|
||||
<span className="text-xs font-medium text-text-tertiary">
|
||||
{group.title}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Target handle for incoming connections */}
|
||||
<Handle
|
||||
id="target"
|
||||
type="target"
|
||||
position={Position.Left}
|
||||
className={cn(
|
||||
'!h-4 !w-4 !rounded-none !border-none !bg-transparent !outline-none',
|
||||
'after:absolute after:left-1.5 after:top-1 after:h-2 after:w-0.5 after:bg-workflow-link-line-handle',
|
||||
'transition-all hover:scale-125',
|
||||
)}
|
||||
style={{ top: '50%' }}
|
||||
/>
|
||||
|
||||
<div className="px-3 pt-3">
|
||||
{exitPorts.map((port, index) => {
|
||||
const connected = connectedSourceHandleIds.includes(port.portNodeId)
|
||||
|
||||
return (
|
||||
<div key={port.portNodeId} className="relative flex h-6 items-center px-1">
|
||||
<div className="w-full text-right text-xs font-semibold text-text-secondary">
|
||||
{port.name}
|
||||
</div>
|
||||
|
||||
<Handle
|
||||
id={port.portNodeId}
|
||||
type="source"
|
||||
position={Position.Right}
|
||||
className={cn(
|
||||
'group/handle z-[1] !h-4 !w-4 !rounded-none !border-none !bg-transparent !outline-none',
|
||||
'after:absolute after:right-1.5 after:top-1 after:h-2 after:w-0.5 after:bg-workflow-link-line-handle',
|
||||
'transition-all hover:scale-125',
|
||||
!connected && 'after:opacity-0',
|
||||
'!-right-[21px] !top-1/2 !-translate-y-1/2',
|
||||
)}
|
||||
isConnectable
|
||||
/>
|
||||
|
||||
{/* Visual "+" indicator (styling aligned with existing branch handles) */}
|
||||
<div
|
||||
className={cn(
|
||||
'pointer-events-none absolute z-10 hidden h-4 w-4 items-center justify-center rounded-full bg-components-button-primary-bg text-text-primary-on-surface',
|
||||
'-right-[21px] top-1/2 -translate-y-1/2',
|
||||
'group-hover:flex',
|
||||
data.selected && '!flex',
|
||||
)}
|
||||
>
|
||||
<Plus02 className="h-2.5 w-2.5" />
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(CustomGroupNode)
|
||||
19
web/app/components/workflow/custom-group-node/index.ts
Normal file
19
web/app/components/workflow/custom-group-node/index.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
export {
|
||||
CUSTOM_GROUP_EXIT_PORT_NODE,
|
||||
CUSTOM_GROUP_INPUT_NODE,
|
||||
CUSTOM_GROUP_NODE,
|
||||
GROUP_CHILDREN_Z_INDEX,
|
||||
UI_ONLY_GROUP_NODE_TYPES,
|
||||
} from './constants'
|
||||
|
||||
export { default as CustomGroupExitPortNode } from './custom-group-exit-port-node'
|
||||
|
||||
export { default as CustomGroupInputNode } from './custom-group-input-node'
|
||||
export { default as CustomGroupNode } from './custom-group-node'
|
||||
export type {
|
||||
CustomGroupExitPortNodeData,
|
||||
CustomGroupInputNodeData,
|
||||
CustomGroupNodeData,
|
||||
ExitPortInfo,
|
||||
GroupMember,
|
||||
} from './types'
|
||||
82
web/app/components/workflow/custom-group-node/types.ts
Normal file
82
web/app/components/workflow/custom-group-node/types.ts
Normal file
@@ -0,0 +1,82 @@
|
||||
import type { BlockEnum } from '../types'
|
||||
|
||||
/**
|
||||
* Exit port info stored in Group node
|
||||
*/
|
||||
export type ExitPortInfo = {
|
||||
portNodeId: string
|
||||
leafNodeId: string
|
||||
sourceHandle: string
|
||||
name: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Group node data structure
|
||||
* node.type = 'custom-group'
|
||||
* node.data.type = '' (empty string to bypass backend NodeType validation)
|
||||
*/
|
||||
export type CustomGroupNodeData = {
|
||||
type: '' // Empty string bypasses backend NodeType validation
|
||||
title: string
|
||||
desc?: string
|
||||
_connectedSourceHandleIds?: string[]
|
||||
_connectedTargetHandleIds?: string[]
|
||||
group: {
|
||||
groupId: string
|
||||
title: string
|
||||
memberNodeIds: string[]
|
||||
entryNodeIds: string[]
|
||||
inputNodeId: string
|
||||
exitPorts: ExitPortInfo[]
|
||||
collapsed: boolean
|
||||
}
|
||||
width?: number
|
||||
height?: number
|
||||
selected?: boolean
|
||||
_isTempNode?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Group Input node data structure
|
||||
* node.type = 'custom-group-input'
|
||||
* node.data.type = ''
|
||||
*/
|
||||
export type CustomGroupInputNodeData = {
|
||||
type: ''
|
||||
title: string
|
||||
desc?: string
|
||||
groupInput: {
|
||||
groupId: string
|
||||
title: string
|
||||
}
|
||||
selected?: boolean
|
||||
_isTempNode?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Exit Port node data structure
|
||||
* node.type = 'custom-group-exit-port'
|
||||
* node.data.type = ''
|
||||
*/
|
||||
export type CustomGroupExitPortNodeData = {
|
||||
type: ''
|
||||
title: string
|
||||
desc?: string
|
||||
exitPort: {
|
||||
groupId: string
|
||||
leafNodeId: string
|
||||
sourceHandle: string
|
||||
name: string
|
||||
}
|
||||
selected?: boolean
|
||||
_isTempNode?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Member node info for display
|
||||
*/
|
||||
export type GroupMember = {
|
||||
id: string
|
||||
type: BlockEnum
|
||||
label?: string
|
||||
}
|
||||
@@ -23,6 +23,7 @@ export type AvailableNodesMetaData = {
|
||||
nodesMap?: Record<BlockEnum, NodeDefault<any>>
|
||||
}
|
||||
export type CommonHooksFnMap = {
|
||||
interactionMode?: 'default' | 'subgraph'
|
||||
doSyncWorkflowDraft: (
|
||||
notRefreshWhenSyncError?: boolean,
|
||||
callback?: {
|
||||
@@ -76,6 +77,7 @@ export type Shape = {
|
||||
} & CommonHooksFnMap
|
||||
|
||||
export const createHooksStore = ({
|
||||
interactionMode = 'default',
|
||||
doSyncWorkflowDraft = async () => noop(),
|
||||
syncWorkflowDraftWhenPageClose = noop,
|
||||
handleRefreshWorkflowDraft = noop,
|
||||
@@ -118,6 +120,7 @@ export const createHooksStore = ({
|
||||
}: Partial<Shape>) => {
|
||||
return createStore<Shape>(set => ({
|
||||
refreshAll: props => set(state => ({ ...state, ...props })),
|
||||
interactionMode,
|
||||
doSyncWorkflowDraft,
|
||||
syncWorkflowDraftWhenPageClose,
|
||||
handleRefreshWorkflowDraft,
|
||||
|
||||
@@ -197,7 +197,8 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
// Start nodes and Trigger nodes should not show unConnected error if they have validation errors
|
||||
// or if they are valid start nodes (even without incoming connections)
|
||||
const isStartNodeMeta = nodesExtraData?.[node.data.type as BlockEnum]?.metaData.isStart ?? false
|
||||
const canSkipConnectionCheck = shouldCheckStartNode ? isStartNodeMeta : true
|
||||
const isSubGraphNode = Boolean((node.data as { parent_node_id?: string }).parent_node_id)
|
||||
const canSkipConnectionCheck = isSubGraphNode || (shouldCheckStartNode ? isStartNodeMeta : true)
|
||||
|
||||
const isUnconnected = !validNodes.find(n => n.id === node.id)
|
||||
const shouldShowError = errorMessage || (isUnconnected && !canSkipConnectionCheck)
|
||||
@@ -390,7 +391,8 @@ export const useChecklistBeforePublish = () => {
|
||||
}
|
||||
|
||||
const isStartNodeMeta = nodesExtraData?.[node.data.type as BlockEnum]?.metaData.isStart ?? false
|
||||
const canSkipConnectionCheck = shouldCheckStartNode ? isStartNodeMeta : true
|
||||
const isSubGraphNode = Boolean((node.data as { parent_node_id?: string }).parent_node_id)
|
||||
const canSkipConnectionCheck = isSubGraphNode || (shouldCheckStartNode ? isStartNodeMeta : true)
|
||||
const isUnconnected = !validNodes.find(n => n.id === node.id)
|
||||
|
||||
if (isUnconnected && !canSkipConnectionCheck) {
|
||||
|
||||
@@ -10,6 +10,7 @@ import { useCallback } from 'react'
|
||||
import {
|
||||
useStoreApi,
|
||||
} from 'reactflow'
|
||||
import { BlockEnum } from '../types'
|
||||
import { getNodesConnectedSourceOrTargetHandleIdsMap } from '../utils'
|
||||
import { useNodesSyncDraft } from './use-nodes-sync-draft'
|
||||
import { useNodesReadOnly } from './use-workflow'
|
||||
@@ -108,6 +109,50 @@ export const useEdgesInteractions = () => {
|
||||
return
|
||||
const currentEdge = edges[currentEdgeIndex]
|
||||
const nodes = getNodes()
|
||||
|
||||
// collect edges to delete (including corresponding real edges for temp edges)
|
||||
const edgesToDelete: Set<string> = new Set([currentEdge.id])
|
||||
|
||||
// if deleting a temp edge connected to a group, also delete the corresponding real hidden edge
|
||||
if (currentEdge.data?._isTemp) {
|
||||
const groupNode = nodes.find(n =>
|
||||
n.data.type === BlockEnum.Group
|
||||
&& (n.id === currentEdge.source || n.id === currentEdge.target),
|
||||
)
|
||||
|
||||
if (groupNode) {
|
||||
const memberIds = new Set((groupNode.data.members || []).map((m: { id: string }) => m.id))
|
||||
|
||||
if (currentEdge.target === groupNode.id) {
|
||||
// inbound temp edge: find real edge with same source, target is a head node
|
||||
edges.forEach((edge) => {
|
||||
if (edge.source === currentEdge.source
|
||||
&& memberIds.has(edge.target)
|
||||
&& edge.sourceHandle === currentEdge.sourceHandle) {
|
||||
edgesToDelete.add(edge.id)
|
||||
}
|
||||
})
|
||||
}
|
||||
else if (currentEdge.source === groupNode.id) {
|
||||
// outbound temp edge: sourceHandle format is "leafNodeId-originalHandle"
|
||||
const sourceHandle = currentEdge.sourceHandle || ''
|
||||
const lastDashIndex = sourceHandle.lastIndexOf('-')
|
||||
if (lastDashIndex > 0) {
|
||||
const leafNodeId = sourceHandle.substring(0, lastDashIndex)
|
||||
const originalHandle = sourceHandle.substring(lastDashIndex + 1)
|
||||
|
||||
edges.forEach((edge) => {
|
||||
if (edge.source === leafNodeId
|
||||
&& edge.target === currentEdge.target
|
||||
&& (edge.sourceHandle || 'source') === originalHandle) {
|
||||
edgesToDelete.add(edge.id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const nodesConnectedSourceOrTargetHandleIdsMap = getNodesConnectedSourceOrTargetHandleIdsMap(
|
||||
[
|
||||
{ type: 'remove', edge: currentEdge },
|
||||
@@ -126,7 +171,10 @@ export const useEdgesInteractions = () => {
|
||||
})
|
||||
setNodes(newNodes)
|
||||
const newEdges = produce(edges, (draft) => {
|
||||
draft.splice(currentEdgeIndex, 1)
|
||||
for (let i = draft.length - 1; i >= 0; i--) {
|
||||
if (edgesToDelete.has(draft[i].id))
|
||||
draft.splice(i, 1)
|
||||
}
|
||||
})
|
||||
setEdges(newEdges)
|
||||
handleSyncWorkflowDraft()
|
||||
|
||||
138
web/app/components/workflow/hooks/use-make-group.ts
Normal file
138
web/app/components/workflow/hooks/use-make-group.ts
Normal file
@@ -0,0 +1,138 @@
|
||||
import type { PredecessorHandle } from '../utils'
|
||||
import { useMemo } from 'react'
|
||||
import { useStore as useReactFlowStore } from 'reactflow'
|
||||
import { shallow } from 'zustand/shallow'
|
||||
import { BlockEnum } from '../types'
|
||||
import { getCommonPredecessorHandles } from '../utils'
|
||||
|
||||
export type MakeGroupAvailability = {
|
||||
canMakeGroup: boolean
|
||||
branchEntryNodeIds: string[]
|
||||
commonPredecessorHandle?: PredecessorHandle
|
||||
}
|
||||
|
||||
type MinimalEdge = {
|
||||
id: string
|
||||
source: string
|
||||
sourceHandle: string
|
||||
target: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Pure function to check if the selected nodes can be grouped.
|
||||
* Can be called both from React hooks and imperatively.
|
||||
*/
|
||||
export const checkMakeGroupAvailability = (
|
||||
selectedNodeIds: string[],
|
||||
edges: MinimalEdge[],
|
||||
hasGroupNode = false,
|
||||
): MakeGroupAvailability => {
|
||||
if (selectedNodeIds.length <= 1 || hasGroupNode) {
|
||||
return {
|
||||
canMakeGroup: false,
|
||||
branchEntryNodeIds: [],
|
||||
commonPredecessorHandle: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
const selectedNodeIdSet = new Set(selectedNodeIds)
|
||||
const inboundFromOutsideTargets = new Set<string>()
|
||||
const incomingEdgeCounts = new Map<string, number>()
|
||||
const incomingFromSelectedTargets = new Set<string>()
|
||||
|
||||
edges.forEach((edge) => {
|
||||
// Only consider edges whose target is inside the selected subgraph.
|
||||
if (!selectedNodeIdSet.has(edge.target))
|
||||
return
|
||||
|
||||
incomingEdgeCounts.set(edge.target, (incomingEdgeCounts.get(edge.target) ?? 0) + 1)
|
||||
|
||||
if (selectedNodeIdSet.has(edge.source))
|
||||
incomingFromSelectedTargets.add(edge.target)
|
||||
else
|
||||
inboundFromOutsideTargets.add(edge.target)
|
||||
})
|
||||
|
||||
// Branch head (entry) definition:
|
||||
// - has at least one incoming edge
|
||||
// - and all its incoming edges come from outside the selected subgraph
|
||||
const branchEntryNodeIds = selectedNodeIds.filter((nodeId) => {
|
||||
const incomingEdgeCount = incomingEdgeCounts.get(nodeId) ?? 0
|
||||
if (incomingEdgeCount === 0)
|
||||
return false
|
||||
|
||||
return !incomingFromSelectedTargets.has(nodeId)
|
||||
})
|
||||
|
||||
// No branch head means we cannot tell how many branches are represented by this selection.
|
||||
if (branchEntryNodeIds.length === 0) {
|
||||
return {
|
||||
canMakeGroup: false,
|
||||
branchEntryNodeIds,
|
||||
commonPredecessorHandle: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
// Guardrail: disallow side entrances into the selected subgraph.
|
||||
// If an outside node connects to a non-entry node inside the selection, the grouping boundary is ambiguous.
|
||||
const branchEntryNodeIdSet = new Set(branchEntryNodeIds)
|
||||
const hasInboundToNonEntryNode = Array.from(inboundFromOutsideTargets).some(nodeId => !branchEntryNodeIdSet.has(nodeId))
|
||||
|
||||
if (hasInboundToNonEntryNode) {
|
||||
return {
|
||||
canMakeGroup: false,
|
||||
branchEntryNodeIds,
|
||||
commonPredecessorHandle: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
// Compare the branch heads by their common predecessor "handler" (source node + sourceHandle).
|
||||
// This is required for multi-handle nodes like If-Else / Classifier where different branches use different handles.
|
||||
const commonPredecessorHandles = getCommonPredecessorHandles(
|
||||
branchEntryNodeIds,
|
||||
// Only look at edges coming from outside the selected subgraph when determining the "pre" handler.
|
||||
edges.filter(edge => !selectedNodeIdSet.has(edge.source)),
|
||||
)
|
||||
|
||||
if (commonPredecessorHandles.length !== 1) {
|
||||
return {
|
||||
canMakeGroup: false,
|
||||
branchEntryNodeIds,
|
||||
commonPredecessorHandle: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
canMakeGroup: true,
|
||||
branchEntryNodeIds,
|
||||
commonPredecessorHandle: commonPredecessorHandles[0],
|
||||
}
|
||||
}
|
||||
|
||||
export const useMakeGroupAvailability = (selectedNodeIds: string[]): MakeGroupAvailability => {
|
||||
const edgeKeys = useReactFlowStore((state) => {
|
||||
const delimiter = '\u0000'
|
||||
const keys = state.edges.map(edge => `${edge.source}${delimiter}${edge.sourceHandle || 'source'}${delimiter}${edge.target}`)
|
||||
keys.sort()
|
||||
return keys
|
||||
}, shallow)
|
||||
|
||||
const hasGroupNode = useReactFlowStore((state) => {
|
||||
return state.getNodes().some(node => node.selected && node.data.type === BlockEnum.Group)
|
||||
})
|
||||
|
||||
return useMemo(() => {
|
||||
const delimiter = '\u0000'
|
||||
const edges = edgeKeys.map((key) => {
|
||||
const [source, handleId, target] = key.split(delimiter)
|
||||
return {
|
||||
id: key,
|
||||
source,
|
||||
sourceHandle: handleId || 'source',
|
||||
target,
|
||||
}
|
||||
})
|
||||
|
||||
return checkMakeGroupAvailability(selectedNodeIds, edges, hasGroupNode)
|
||||
}, [edgeKeys, selectedNodeIds, hasGroupNode])
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import type {
|
||||
ResizeParamsWithDirection,
|
||||
} from 'reactflow'
|
||||
import type { PluginDefaultValue } from '../block-selector/types'
|
||||
import type { GroupHandler, GroupMember, GroupNodeData } from '../nodes/group/types'
|
||||
import type { IterationNodeType } from '../nodes/iteration/types'
|
||||
import type { LoopNodeType } from '../nodes/loop/types'
|
||||
import type { VariableAssignerNodeType } from '../nodes/variable-assigner/types'
|
||||
@@ -52,6 +53,7 @@ import { useWorkflowHistoryStore } from '../workflow-history-store'
|
||||
import { useAutoGenerateWebhookUrl } from './use-auto-generate-webhook-url'
|
||||
import { useHelpline } from './use-helpline'
|
||||
import useInspectVarsCrud from './use-inspect-vars-crud'
|
||||
import { checkMakeGroupAvailability } from './use-make-group'
|
||||
import { useNodesMetaData } from './use-nodes-meta-data'
|
||||
import { useNodesSyncDraft } from './use-nodes-sync-draft'
|
||||
import {
|
||||
@@ -73,6 +75,151 @@ const ENTRY_NODE_WRAPPER_OFFSET = {
|
||||
y: 21, // Adjusted based on visual testing feedback
|
||||
} as const
|
||||
|
||||
/**
|
||||
* Parse group handler id to get original node id and sourceHandle
|
||||
* Handler id format: `${nodeId}-${sourceHandle}`
|
||||
*/
|
||||
function parseGroupHandlerId(handlerId: string): { originalNodeId: string, originalSourceHandle: string } {
|
||||
const lastDashIndex = handlerId.lastIndexOf('-')
|
||||
return {
|
||||
originalNodeId: handlerId.substring(0, lastDashIndex),
|
||||
originalSourceHandle: handlerId.substring(lastDashIndex + 1),
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a pair of edges for group node connections:
|
||||
* - realEdge: hidden edge from original node to target (persisted to backend)
|
||||
* - uiEdge: visible temp edge from group to target (UI-only, not persisted)
|
||||
*/
|
||||
function createGroupEdgePair(params: {
|
||||
groupNodeId: string
|
||||
handlerId: string
|
||||
targetNodeId: string
|
||||
targetHandle: string
|
||||
nodes: Node[]
|
||||
baseEdgeData?: Partial<Edge['data']>
|
||||
zIndex?: number
|
||||
}): { realEdge: Edge, uiEdge: Edge } | null {
|
||||
const { groupNodeId, handlerId, targetNodeId, targetHandle, nodes, baseEdgeData = {}, zIndex = 0 } = params
|
||||
|
||||
const groupNode = nodes.find(node => node.id === groupNodeId)
|
||||
const groupData = groupNode?.data as GroupNodeData | undefined
|
||||
const handler = groupData?.handlers?.find(h => h.id === handlerId)
|
||||
|
||||
let originalNodeId: string
|
||||
let originalSourceHandle: string
|
||||
|
||||
if (handler?.nodeId && handler?.sourceHandle) {
|
||||
originalNodeId = handler.nodeId
|
||||
originalSourceHandle = handler.sourceHandle
|
||||
}
|
||||
else {
|
||||
const parsed = parseGroupHandlerId(handlerId)
|
||||
originalNodeId = parsed.originalNodeId
|
||||
originalSourceHandle = parsed.originalSourceHandle
|
||||
}
|
||||
|
||||
const originalNode = nodes.find(node => node.id === originalNodeId)
|
||||
const targetNode = nodes.find(node => node.id === targetNodeId)
|
||||
|
||||
if (!originalNode || !targetNode)
|
||||
return null
|
||||
|
||||
// Create the real edge (from original node to target) - hidden because original node is in group
|
||||
const realEdge: Edge = {
|
||||
id: `${originalNodeId}-${originalSourceHandle}-${targetNodeId}-${targetHandle}`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: originalNodeId,
|
||||
sourceHandle: originalSourceHandle,
|
||||
target: targetNodeId,
|
||||
targetHandle,
|
||||
hidden: true,
|
||||
data: {
|
||||
...baseEdgeData,
|
||||
sourceType: originalNode.data.type,
|
||||
targetType: targetNode.data.type,
|
||||
_hiddenInGroupId: groupNodeId,
|
||||
},
|
||||
zIndex,
|
||||
}
|
||||
|
||||
// Create the UI edge (from group to target) - temporary, not persisted to backend
|
||||
const uiEdge: Edge = {
|
||||
id: `${groupNodeId}-${handlerId}-${targetNodeId}-${targetHandle}`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: groupNodeId,
|
||||
sourceHandle: handlerId,
|
||||
target: targetNodeId,
|
||||
targetHandle,
|
||||
data: {
|
||||
...baseEdgeData,
|
||||
sourceType: BlockEnum.Group,
|
||||
targetType: targetNode.data.type,
|
||||
_isTemp: true,
|
||||
},
|
||||
zIndex,
|
||||
}
|
||||
|
||||
return { realEdge, uiEdge }
|
||||
}
|
||||
|
||||
function createGroupInboundEdges(params: {
|
||||
sourceNodeId: string
|
||||
sourceHandle: string
|
||||
groupNodeId: string
|
||||
groupData: GroupNodeData
|
||||
nodes: Node[]
|
||||
baseEdgeData?: Partial<Edge['data']>
|
||||
zIndex?: number
|
||||
}): { realEdges: Edge[], uiEdge: Edge } | null {
|
||||
const { sourceNodeId, sourceHandle, groupNodeId, groupData, nodes, baseEdgeData = {}, zIndex = 0 } = params
|
||||
|
||||
const sourceNode = nodes.find(node => node.id === sourceNodeId)
|
||||
const headNodeIds = groupData.headNodeIds || []
|
||||
|
||||
if (!sourceNode || headNodeIds.length === 0)
|
||||
return null
|
||||
|
||||
const realEdges: Edge[] = headNodeIds.map((headNodeId) => {
|
||||
const headNode = nodes.find(node => node.id === headNodeId)
|
||||
return {
|
||||
id: `${sourceNodeId}-${sourceHandle}-${headNodeId}-target`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: sourceNodeId,
|
||||
sourceHandle,
|
||||
target: headNodeId,
|
||||
targetHandle: 'target',
|
||||
hidden: true,
|
||||
data: {
|
||||
...baseEdgeData,
|
||||
sourceType: sourceNode.data.type,
|
||||
targetType: headNode?.data.type,
|
||||
_hiddenInGroupId: groupNodeId,
|
||||
},
|
||||
zIndex,
|
||||
} as Edge
|
||||
})
|
||||
|
||||
const uiEdge: Edge = {
|
||||
id: `${sourceNodeId}-${sourceHandle}-${groupNodeId}-target`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: sourceNodeId,
|
||||
sourceHandle,
|
||||
target: groupNodeId,
|
||||
targetHandle: 'target',
|
||||
data: {
|
||||
...baseEdgeData,
|
||||
sourceType: sourceNode.data.type,
|
||||
targetType: BlockEnum.Group,
|
||||
_isTemp: true,
|
||||
},
|
||||
zIndex,
|
||||
}
|
||||
|
||||
return { realEdges, uiEdge }
|
||||
}
|
||||
|
||||
export const useNodesInteractions = () => {
|
||||
const { t } = useTranslation()
|
||||
const store = useStoreApi()
|
||||
@@ -448,6 +595,146 @@ export const useNodesInteractions = () => {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if source is a group node - need special handling
|
||||
const isSourceGroup = sourceNode?.data.type === BlockEnum.Group
|
||||
|
||||
if (isSourceGroup && sourceHandle && target && targetHandle) {
|
||||
const { originalNodeId, originalSourceHandle } = parseGroupHandlerId(sourceHandle)
|
||||
|
||||
// Check if real edge already exists
|
||||
if (edges.find(edge =>
|
||||
edge.source === originalNodeId
|
||||
&& edge.sourceHandle === originalSourceHandle
|
||||
&& edge.target === target
|
||||
&& edge.targetHandle === targetHandle,
|
||||
)) {
|
||||
return
|
||||
}
|
||||
|
||||
const parentNode = nodes.find(node => node.id === targetNode?.parentId)
|
||||
const isInIteration = parentNode && parentNode.data.type === BlockEnum.Iteration
|
||||
const isInLoop = !!parentNode && parentNode.data.type === BlockEnum.Loop
|
||||
|
||||
const edgePair = createGroupEdgePair({
|
||||
groupNodeId: source!,
|
||||
handlerId: sourceHandle,
|
||||
targetNodeId: target,
|
||||
targetHandle,
|
||||
nodes,
|
||||
baseEdgeData: {
|
||||
isInIteration,
|
||||
iteration_id: isInIteration ? targetNode?.parentId : undefined,
|
||||
isInLoop,
|
||||
loop_id: isInLoop ? targetNode?.parentId : undefined,
|
||||
},
|
||||
})
|
||||
|
||||
if (!edgePair)
|
||||
return
|
||||
|
||||
const { realEdge, uiEdge } = edgePair
|
||||
|
||||
// Update connected handle ids for the original node
|
||||
const nodesConnectedSourceOrTargetHandleIdsMap
|
||||
= getNodesConnectedSourceOrTargetHandleIdsMap(
|
||||
[{ type: 'add', edge: realEdge }],
|
||||
nodes,
|
||||
)
|
||||
const newNodes = produce(nodes, (draft: Node[]) => {
|
||||
draft.forEach((node) => {
|
||||
if (nodesConnectedSourceOrTargetHandleIdsMap[node.id]) {
|
||||
node.data = {
|
||||
...node.data,
|
||||
...nodesConnectedSourceOrTargetHandleIdsMap[node.id],
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
const newEdges = produce(edges, (draft) => {
|
||||
draft.push(realEdge)
|
||||
draft.push(uiEdge)
|
||||
})
|
||||
|
||||
setNodes(newNodes)
|
||||
setEdges(newEdges)
|
||||
|
||||
handleSyncWorkflowDraft()
|
||||
saveStateToHistory(WorkflowHistoryEvent.NodeConnect, {
|
||||
nodeId: targetNode?.id,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
const isTargetGroup = targetNode?.data.type === BlockEnum.Group
|
||||
|
||||
if (isTargetGroup && source && sourceHandle) {
|
||||
const groupData = targetNode.data as GroupNodeData
|
||||
const headNodeIds = groupData.headNodeIds || []
|
||||
|
||||
if (edges.find(edge =>
|
||||
edge.source === source
|
||||
&& edge.sourceHandle === sourceHandle
|
||||
&& edge.target === target
|
||||
&& edge.targetHandle === targetHandle,
|
||||
)) {
|
||||
return
|
||||
}
|
||||
|
||||
const parentNode = nodes.find(node => node.id === sourceNode?.parentId)
|
||||
const isInIteration = parentNode && parentNode.data.type === BlockEnum.Iteration
|
||||
const isInLoop = !!parentNode && parentNode.data.type === BlockEnum.Loop
|
||||
|
||||
const inboundResult = createGroupInboundEdges({
|
||||
sourceNodeId: source,
|
||||
sourceHandle,
|
||||
groupNodeId: target!,
|
||||
groupData,
|
||||
nodes,
|
||||
baseEdgeData: {
|
||||
isInIteration,
|
||||
iteration_id: isInIteration ? sourceNode?.parentId : undefined,
|
||||
isInLoop,
|
||||
loop_id: isInLoop ? sourceNode?.parentId : undefined,
|
||||
},
|
||||
})
|
||||
|
||||
if (!inboundResult)
|
||||
return
|
||||
|
||||
const { realEdges, uiEdge } = inboundResult
|
||||
|
||||
const edgeChanges = realEdges.map(edge => ({ type: 'add' as const, edge }))
|
||||
const nodesConnectedSourceOrTargetHandleIdsMap
|
||||
= getNodesConnectedSourceOrTargetHandleIdsMap(edgeChanges, nodes)
|
||||
|
||||
const newNodes = produce(nodes, (draft: Node[]) => {
|
||||
draft.forEach((node) => {
|
||||
if (nodesConnectedSourceOrTargetHandleIdsMap[node.id]) {
|
||||
node.data = {
|
||||
...node.data,
|
||||
...nodesConnectedSourceOrTargetHandleIdsMap[node.id],
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
const newEdges = produce(edges, (draft) => {
|
||||
realEdges.forEach((edge) => {
|
||||
draft.push(edge)
|
||||
})
|
||||
draft.push(uiEdge)
|
||||
})
|
||||
|
||||
setNodes(newNodes)
|
||||
setEdges(newEdges)
|
||||
|
||||
handleSyncWorkflowDraft()
|
||||
saveStateToHistory(WorkflowHistoryEvent.NodeConnect, {
|
||||
nodeId: headNodeIds[0],
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if (
|
||||
edges.find(
|
||||
edge =>
|
||||
@@ -909,8 +1196,34 @@ export const useNodesInteractions = () => {
|
||||
}
|
||||
}
|
||||
|
||||
let newEdge = null
|
||||
if (nodeType !== BlockEnum.DataSource) {
|
||||
// Check if prevNode is a group node - need special handling
|
||||
const isPrevNodeGroup = prevNode.data.type === BlockEnum.Group
|
||||
let newEdge: Edge | null = null
|
||||
let newUiEdge: Edge | null = null
|
||||
|
||||
if (isPrevNodeGroup && prevNodeSourceHandle && nodeType !== BlockEnum.DataSource) {
|
||||
const edgePair = createGroupEdgePair({
|
||||
groupNodeId: prevNodeId,
|
||||
handlerId: prevNodeSourceHandle,
|
||||
targetNodeId: newNode.id,
|
||||
targetHandle,
|
||||
nodes: [...nodes, newNode],
|
||||
baseEdgeData: {
|
||||
isInIteration,
|
||||
isInLoop,
|
||||
iteration_id: isInIteration ? prevNode.parentId : undefined,
|
||||
loop_id: isInLoop ? prevNode.parentId : undefined,
|
||||
_connectedNodeIsSelected: true,
|
||||
},
|
||||
})
|
||||
|
||||
if (edgePair) {
|
||||
newEdge = edgePair.realEdge
|
||||
newUiEdge = edgePair.uiEdge
|
||||
}
|
||||
}
|
||||
else if (nodeType !== BlockEnum.DataSource) {
|
||||
// Normal case: prevNode is not a group
|
||||
newEdge = {
|
||||
id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`,
|
||||
type: CUSTOM_EDGE,
|
||||
@@ -935,9 +1248,10 @@ export const useNodesInteractions = () => {
|
||||
}
|
||||
}
|
||||
|
||||
const edgesToAdd = [newEdge, newUiEdge].filter(Boolean).map(edge => ({ type: 'add' as const, edge: edge! }))
|
||||
const nodesConnectedSourceOrTargetHandleIdsMap
|
||||
= getNodesConnectedSourceOrTargetHandleIdsMap(
|
||||
(newEdge ? [{ type: 'add', edge: newEdge }] : []),
|
||||
edgesToAdd,
|
||||
nodes,
|
||||
)
|
||||
const newNodes = produce(nodes, (draft: Node[]) => {
|
||||
@@ -1006,6 +1320,8 @@ export const useNodesInteractions = () => {
|
||||
})
|
||||
if (newEdge)
|
||||
draft.push(newEdge)
|
||||
if (newUiEdge)
|
||||
draft.push(newUiEdge)
|
||||
})
|
||||
|
||||
setNodes(newNodes)
|
||||
@@ -1090,7 +1406,7 @@ export const useNodesInteractions = () => {
|
||||
|
||||
const afterNodesInSameBranch = getAfterNodesInSameBranch(nextNodeId!)
|
||||
const afterNodesInSameBranchIds = afterNodesInSameBranch.map(
|
||||
node => node.id,
|
||||
(node: Node) => node.id,
|
||||
)
|
||||
const newNodes = produce(nodes, (draft) => {
|
||||
draft.forEach((node) => {
|
||||
@@ -1200,37 +1516,113 @@ export const useNodesInteractions = () => {
|
||||
}
|
||||
}
|
||||
|
||||
const currentEdgeIndex = edges.findIndex(
|
||||
edge => edge.source === prevNodeId && edge.target === nextNodeId,
|
||||
)
|
||||
let newPrevEdge = null
|
||||
// Check if prevNode is a group node - need special handling
|
||||
const isPrevNodeGroup = prevNode.data.type === BlockEnum.Group
|
||||
let newPrevEdge: Edge | null = null
|
||||
let newPrevUiEdge: Edge | null = null
|
||||
const edgesToRemove: string[] = []
|
||||
|
||||
if (nodeType !== BlockEnum.DataSource) {
|
||||
newPrevEdge = {
|
||||
id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: prevNodeId,
|
||||
sourceHandle: prevNodeSourceHandle,
|
||||
target: newNode.id,
|
||||
if (isPrevNodeGroup && prevNodeSourceHandle && nodeType !== BlockEnum.DataSource) {
|
||||
const { originalNodeId, originalSourceHandle } = parseGroupHandlerId(prevNodeSourceHandle)
|
||||
|
||||
// Find edges to remove: both hidden real edge and UI temp edge from group to nextNode
|
||||
const hiddenEdge = edges.find(
|
||||
edge => edge.source === originalNodeId
|
||||
&& edge.sourceHandle === originalSourceHandle
|
||||
&& edge.target === nextNodeId,
|
||||
)
|
||||
const uiTempEdge = edges.find(
|
||||
edge => edge.source === prevNodeId
|
||||
&& edge.sourceHandle === prevNodeSourceHandle
|
||||
&& edge.target === nextNodeId,
|
||||
)
|
||||
if (hiddenEdge)
|
||||
edgesToRemove.push(hiddenEdge.id)
|
||||
if (uiTempEdge)
|
||||
edgesToRemove.push(uiTempEdge.id)
|
||||
|
||||
const edgePair = createGroupEdgePair({
|
||||
groupNodeId: prevNodeId,
|
||||
handlerId: prevNodeSourceHandle,
|
||||
targetNodeId: newNode.id,
|
||||
targetHandle,
|
||||
data: {
|
||||
sourceType: prevNode.data.type,
|
||||
targetType: newNode.data.type,
|
||||
nodes: [...nodes, newNode],
|
||||
baseEdgeData: {
|
||||
isInIteration,
|
||||
isInLoop,
|
||||
iteration_id: isInIteration ? prevNode.parentId : undefined,
|
||||
loop_id: isInLoop ? prevNode.parentId : undefined,
|
||||
_connectedNodeIsSelected: true,
|
||||
},
|
||||
zIndex: prevNode.parentId
|
||||
? isInIteration
|
||||
? ITERATION_CHILDREN_Z_INDEX
|
||||
: LOOP_CHILDREN_Z_INDEX
|
||||
: 0,
|
||||
})
|
||||
|
||||
if (edgePair) {
|
||||
newPrevEdge = edgePair.realEdge
|
||||
newPrevUiEdge = edgePair.uiEdge
|
||||
}
|
||||
}
|
||||
else {
|
||||
const isNextNodeGroupForRemoval = nextNode.data.type === BlockEnum.Group
|
||||
|
||||
if (isNextNodeGroupForRemoval) {
|
||||
const groupData = nextNode.data as GroupNodeData
|
||||
const headNodeIds = groupData.headNodeIds || []
|
||||
|
||||
headNodeIds.forEach((headNodeId) => {
|
||||
const realEdge = edges.find(
|
||||
edge => edge.source === prevNodeId
|
||||
&& edge.sourceHandle === prevNodeSourceHandle
|
||||
&& edge.target === headNodeId,
|
||||
)
|
||||
if (realEdge)
|
||||
edgesToRemove.push(realEdge.id)
|
||||
})
|
||||
|
||||
const uiEdge = edges.find(
|
||||
edge => edge.source === prevNodeId
|
||||
&& edge.sourceHandle === prevNodeSourceHandle
|
||||
&& edge.target === nextNodeId,
|
||||
)
|
||||
if (uiEdge)
|
||||
edgesToRemove.push(uiEdge.id)
|
||||
}
|
||||
else {
|
||||
const currentEdge = edges.find(
|
||||
edge => edge.source === prevNodeId && edge.target === nextNodeId,
|
||||
)
|
||||
if (currentEdge)
|
||||
edgesToRemove.push(currentEdge.id)
|
||||
}
|
||||
|
||||
if (nodeType !== BlockEnum.DataSource) {
|
||||
newPrevEdge = {
|
||||
id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: prevNodeId,
|
||||
sourceHandle: prevNodeSourceHandle,
|
||||
target: newNode.id,
|
||||
targetHandle,
|
||||
data: {
|
||||
sourceType: prevNode.data.type,
|
||||
targetType: newNode.data.type,
|
||||
isInIteration,
|
||||
isInLoop,
|
||||
iteration_id: isInIteration ? prevNode.parentId : undefined,
|
||||
loop_id: isInLoop ? prevNode.parentId : undefined,
|
||||
_connectedNodeIsSelected: true,
|
||||
},
|
||||
zIndex: prevNode.parentId
|
||||
? isInIteration
|
||||
? ITERATION_CHILDREN_Z_INDEX
|
||||
: LOOP_CHILDREN_Z_INDEX
|
||||
: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let newNextEdge: Edge | null = null
|
||||
let newNextUiEdge: Edge | null = null
|
||||
const newNextRealEdges: Edge[] = []
|
||||
|
||||
const nextNodeParentNode
|
||||
= nodes.find(node => node.id === nextNode.parentId) || null
|
||||
@@ -1241,49 +1633,113 @@ export const useNodesInteractions = () => {
|
||||
= !!nextNodeParentNode
|
||||
&& nextNodeParentNode.data.type === BlockEnum.Loop
|
||||
|
||||
const isNextNodeGroup = nextNode.data.type === BlockEnum.Group
|
||||
|
||||
if (
|
||||
nodeType !== BlockEnum.IfElse
|
||||
&& nodeType !== BlockEnum.QuestionClassifier
|
||||
&& nodeType !== BlockEnum.LoopEnd
|
||||
) {
|
||||
newNextEdge = {
|
||||
id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: newNode.id,
|
||||
sourceHandle,
|
||||
target: nextNodeId,
|
||||
targetHandle: nextNodeTargetHandle,
|
||||
data: {
|
||||
sourceType: newNode.data.type,
|
||||
targetType: nextNode.data.type,
|
||||
isInIteration: isNextNodeInIteration,
|
||||
isInLoop: isNextNodeInLoop,
|
||||
iteration_id: isNextNodeInIteration
|
||||
? nextNode.parentId
|
||||
: undefined,
|
||||
loop_id: isNextNodeInLoop ? nextNode.parentId : undefined,
|
||||
_connectedNodeIsSelected: true,
|
||||
},
|
||||
zIndex: nextNode.parentId
|
||||
? isNextNodeInIteration
|
||||
? ITERATION_CHILDREN_Z_INDEX
|
||||
: LOOP_CHILDREN_Z_INDEX
|
||||
: 0,
|
||||
if (isNextNodeGroup) {
|
||||
const groupData = nextNode.data as GroupNodeData
|
||||
const headNodeIds = groupData.headNodeIds || []
|
||||
|
||||
headNodeIds.forEach((headNodeId) => {
|
||||
const headNode = nodes.find(node => node.id === headNodeId)
|
||||
newNextRealEdges.push({
|
||||
id: `${newNode.id}-${sourceHandle}-${headNodeId}-target`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: newNode.id,
|
||||
sourceHandle,
|
||||
target: headNodeId,
|
||||
targetHandle: 'target',
|
||||
hidden: true,
|
||||
data: {
|
||||
sourceType: newNode.data.type,
|
||||
targetType: headNode?.data.type,
|
||||
isInIteration: isNextNodeInIteration,
|
||||
isInLoop: isNextNodeInLoop,
|
||||
iteration_id: isNextNodeInIteration ? nextNode.parentId : undefined,
|
||||
loop_id: isNextNodeInLoop ? nextNode.parentId : undefined,
|
||||
_hiddenInGroupId: nextNodeId,
|
||||
_connectedNodeIsSelected: true,
|
||||
},
|
||||
zIndex: nextNode.parentId
|
||||
? isNextNodeInIteration
|
||||
? ITERATION_CHILDREN_Z_INDEX
|
||||
: LOOP_CHILDREN_Z_INDEX
|
||||
: 0,
|
||||
} as Edge)
|
||||
})
|
||||
|
||||
newNextUiEdge = {
|
||||
id: `${newNode.id}-${sourceHandle}-${nextNodeId}-target`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: newNode.id,
|
||||
sourceHandle,
|
||||
target: nextNodeId,
|
||||
targetHandle: 'target',
|
||||
data: {
|
||||
sourceType: newNode.data.type,
|
||||
targetType: BlockEnum.Group,
|
||||
isInIteration: isNextNodeInIteration,
|
||||
isInLoop: isNextNodeInLoop,
|
||||
iteration_id: isNextNodeInIteration ? nextNode.parentId : undefined,
|
||||
loop_id: isNextNodeInLoop ? nextNode.parentId : undefined,
|
||||
_isTemp: true,
|
||||
_connectedNodeIsSelected: true,
|
||||
},
|
||||
zIndex: nextNode.parentId
|
||||
? isNextNodeInIteration
|
||||
? ITERATION_CHILDREN_Z_INDEX
|
||||
: LOOP_CHILDREN_Z_INDEX
|
||||
: 0,
|
||||
}
|
||||
}
|
||||
else {
|
||||
newNextEdge = {
|
||||
id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: newNode.id,
|
||||
sourceHandle,
|
||||
target: nextNodeId,
|
||||
targetHandle: nextNodeTargetHandle,
|
||||
data: {
|
||||
sourceType: newNode.data.type,
|
||||
targetType: nextNode.data.type,
|
||||
isInIteration: isNextNodeInIteration,
|
||||
isInLoop: isNextNodeInLoop,
|
||||
iteration_id: isNextNodeInIteration
|
||||
? nextNode.parentId
|
||||
: undefined,
|
||||
loop_id: isNextNodeInLoop ? nextNode.parentId : undefined,
|
||||
_connectedNodeIsSelected: true,
|
||||
},
|
||||
zIndex: nextNode.parentId
|
||||
? isNextNodeInIteration
|
||||
? ITERATION_CHILDREN_Z_INDEX
|
||||
: LOOP_CHILDREN_Z_INDEX
|
||||
: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
const edgeChanges = [
|
||||
...edgesToRemove.map(id => ({ type: 'remove' as const, edge: edges.find(e => e.id === id)! })).filter(c => c.edge),
|
||||
...(newPrevEdge ? [{ type: 'add' as const, edge: newPrevEdge }] : []),
|
||||
...(newPrevUiEdge ? [{ type: 'add' as const, edge: newPrevUiEdge }] : []),
|
||||
...(newNextEdge ? [{ type: 'add' as const, edge: newNextEdge }] : []),
|
||||
...newNextRealEdges.map(edge => ({ type: 'add' as const, edge })),
|
||||
...(newNextUiEdge ? [{ type: 'add' as const, edge: newNextUiEdge }] : []),
|
||||
]
|
||||
const nodesConnectedSourceOrTargetHandleIdsMap
|
||||
= getNodesConnectedSourceOrTargetHandleIdsMap(
|
||||
[
|
||||
{ type: 'remove', edge: edges[currentEdgeIndex] },
|
||||
...(newPrevEdge ? [{ type: 'add', edge: newPrevEdge }] : []),
|
||||
...(newNextEdge ? [{ type: 'add', edge: newNextEdge }] : []),
|
||||
],
|
||||
edgeChanges,
|
||||
[...nodes, newNode],
|
||||
)
|
||||
|
||||
const afterNodesInSameBranch = getAfterNodesInSameBranch(nextNodeId!)
|
||||
const afterNodesInSameBranchIds = afterNodesInSameBranch.map(
|
||||
node => node.id,
|
||||
(node: Node) => node.id,
|
||||
)
|
||||
const newNodes = produce(nodes, (draft) => {
|
||||
draft.forEach((node) => {
|
||||
@@ -1342,7 +1798,10 @@ export const useNodesInteractions = () => {
|
||||
})
|
||||
}
|
||||
const newEdges = produce(edges, (draft) => {
|
||||
draft.splice(currentEdgeIndex, 1)
|
||||
const filteredDraft = draft.filter(edge => !edgesToRemove.includes(edge.id))
|
||||
draft.length = 0
|
||||
draft.push(...filteredDraft)
|
||||
|
||||
draft.forEach((item) => {
|
||||
item.data = {
|
||||
...item.data,
|
||||
@@ -1351,9 +1810,15 @@ export const useNodesInteractions = () => {
|
||||
})
|
||||
if (newPrevEdge)
|
||||
draft.push(newPrevEdge)
|
||||
|
||||
if (newPrevUiEdge)
|
||||
draft.push(newPrevUiEdge)
|
||||
if (newNextEdge)
|
||||
draft.push(newNextEdge)
|
||||
newNextRealEdges.forEach((edge) => {
|
||||
draft.push(edge)
|
||||
})
|
||||
if (newNextUiEdge)
|
||||
draft.push(newNextUiEdge)
|
||||
})
|
||||
setEdges(newEdges)
|
||||
}
|
||||
@@ -2087,6 +2552,302 @@ export const useNodesInteractions = () => {
|
||||
setEdges(newEdges)
|
||||
}, [store])
|
||||
|
||||
// Check if there are any nodes selected via box selection
|
||||
const hasBundledNodes = useCallback(() => {
|
||||
const { getNodes } = store.getState()
|
||||
const nodes = getNodes()
|
||||
return nodes.some(node => node.data._isBundled)
|
||||
}, [store])
|
||||
|
||||
const getCanMakeGroup = useCallback(() => {
|
||||
const { getNodes, edges } = store.getState()
|
||||
const nodes = getNodes()
|
||||
const bundledNodes = nodes.filter(node => node.data._isBundled)
|
||||
|
||||
if (bundledNodes.length <= 1)
|
||||
return false
|
||||
|
||||
const bundledNodeIds = bundledNodes.map(node => node.id)
|
||||
const minimalEdges = edges.map(edge => ({
|
||||
id: edge.id,
|
||||
source: edge.source,
|
||||
sourceHandle: edge.sourceHandle || 'source',
|
||||
target: edge.target,
|
||||
}))
|
||||
const hasGroupNode = bundledNodes.some(node => node.data.type === BlockEnum.Group)
|
||||
|
||||
const { canMakeGroup } = checkMakeGroupAvailability(bundledNodeIds, minimalEdges, hasGroupNode)
|
||||
return canMakeGroup
|
||||
}, [store])
|
||||
|
||||
const handleMakeGroup = useCallback(() => {
|
||||
const { getNodes, setNodes, edges, setEdges } = store.getState()
|
||||
const nodes = getNodes()
|
||||
const bundledNodes = nodes.filter(node => node.data._isBundled)
|
||||
|
||||
if (bundledNodes.length <= 1)
|
||||
return
|
||||
|
||||
const bundledNodeIds = bundledNodes.map(node => node.id)
|
||||
const minimalEdges = edges.map(edge => ({
|
||||
id: edge.id,
|
||||
source: edge.source,
|
||||
sourceHandle: edge.sourceHandle || 'source',
|
||||
target: edge.target,
|
||||
}))
|
||||
const hasGroupNode = bundledNodes.some(node => node.data.type === BlockEnum.Group)
|
||||
|
||||
const { canMakeGroup } = checkMakeGroupAvailability(bundledNodeIds, minimalEdges, hasGroupNode)
|
||||
if (!canMakeGroup)
|
||||
return
|
||||
|
||||
const bundledNodeIdSet = new Set(bundledNodeIds)
|
||||
const bundledNodeIdIsLeaf = new Set<string>()
|
||||
const inboundEdges = edges.filter(edge => !bundledNodeIdSet.has(edge.source) && bundledNodeIdSet.has(edge.target))
|
||||
const outboundEdges = edges.filter(edge => bundledNodeIdSet.has(edge.source) && !bundledNodeIdSet.has(edge.target))
|
||||
|
||||
// leaf node: no outbound edges to other nodes in the selection
|
||||
const handlers: GroupHandler[] = []
|
||||
const leafNodeIdSet = new Set<string>()
|
||||
|
||||
bundledNodes.forEach((node: Node) => {
|
||||
const targetBranches = node.data._targetBranches || [{ id: 'source', name: node.data.title }]
|
||||
targetBranches.forEach((branch) => {
|
||||
// A branch should be a handler if it's either:
|
||||
// 1. Connected to a node OUTSIDE the group
|
||||
// 2. NOT connected to any node INSIDE the group
|
||||
const isConnectedInside = edges.some(edge =>
|
||||
edge.source === node.id
|
||||
&& (edge.sourceHandle === branch.id || (!edge.sourceHandle && branch.id === 'source'))
|
||||
&& bundledNodeIdSet.has(edge.target),
|
||||
)
|
||||
const isConnectedOutside = edges.some(edge =>
|
||||
edge.source === node.id
|
||||
&& (edge.sourceHandle === branch.id || (!edge.sourceHandle && branch.id === 'source'))
|
||||
&& !bundledNodeIdSet.has(edge.target),
|
||||
)
|
||||
|
||||
if (isConnectedOutside || !isConnectedInside) {
|
||||
const handlerId = `${node.id}-${branch.id}`
|
||||
handlers.push({
|
||||
id: handlerId,
|
||||
label: branch.name || node.data.title || node.id,
|
||||
nodeId: node.id,
|
||||
sourceHandle: branch.id,
|
||||
})
|
||||
leafNodeIdSet.add(node.id)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
const leafNodeIds = Array.from(leafNodeIdSet)
|
||||
leafNodeIds.forEach(id => bundledNodeIdIsLeaf.add(id))
|
||||
|
||||
const members: GroupMember[] = bundledNodes.map((node) => {
|
||||
return {
|
||||
id: node.id,
|
||||
type: node.data.type,
|
||||
label: node.data.title,
|
||||
}
|
||||
})
|
||||
|
||||
// head nodes: nodes that receive input from outside the group
|
||||
const headNodeIds = [...new Set(inboundEdges.map(edge => edge.target))]
|
||||
|
||||
// put the group node at the top-left corner of the selection, slightly offset
|
||||
const { x: minX, y: minY } = getTopLeftNodePosition(bundledNodes)
|
||||
|
||||
const groupNodeData: GroupNodeData = {
|
||||
title: t('operator.makeGroup', { ns: 'workflow' }),
|
||||
desc: '',
|
||||
type: BlockEnum.Group,
|
||||
members,
|
||||
handlers,
|
||||
headNodeIds,
|
||||
leafNodeIds,
|
||||
selected: true,
|
||||
_targetBranches: handlers.map(handler => ({
|
||||
id: handler.id,
|
||||
name: handler.label || handler.id,
|
||||
})),
|
||||
}
|
||||
|
||||
const { newNode: groupNode } = generateNewNode({
|
||||
data: groupNodeData,
|
||||
position: {
|
||||
x: minX - 20,
|
||||
y: minY - 20,
|
||||
},
|
||||
})
|
||||
|
||||
const nodeTypeMap = new Map(nodes.map(node => [node.id, node.data.type]))
|
||||
|
||||
const newNodes = produce(nodes, (draft) => {
|
||||
draft.forEach((node) => {
|
||||
if (bundledNodeIdSet.has(node.id)) {
|
||||
node.data._isBundled = false
|
||||
node.selected = false
|
||||
node.hidden = true
|
||||
node.data._hiddenInGroupId = groupNode.id
|
||||
}
|
||||
else {
|
||||
node.data._isBundled = false
|
||||
}
|
||||
})
|
||||
draft.push(groupNode)
|
||||
})
|
||||
|
||||
const newEdges = produce(edges, (draft) => {
|
||||
draft.forEach((edge) => {
|
||||
if (bundledNodeIdSet.has(edge.source) || bundledNodeIdSet.has(edge.target)) {
|
||||
edge.hidden = true
|
||||
edge.data = {
|
||||
...edge.data,
|
||||
_hiddenInGroupId: groupNode.id,
|
||||
_isBundled: false,
|
||||
}
|
||||
}
|
||||
else if (edge.data?._isBundled) {
|
||||
edge.data._isBundled = false
|
||||
}
|
||||
})
|
||||
|
||||
// re-add the external inbound edges to the group node as UI-only edges (not persisted to backend)
|
||||
inboundEdges.forEach((edge) => {
|
||||
draft.push({
|
||||
id: `${edge.id}__to-${groupNode.id}`,
|
||||
type: edge.type || CUSTOM_EDGE,
|
||||
source: edge.source,
|
||||
target: groupNode.id,
|
||||
sourceHandle: edge.sourceHandle,
|
||||
targetHandle: 'target',
|
||||
data: {
|
||||
...edge.data,
|
||||
sourceType: nodeTypeMap.get(edge.source)!,
|
||||
targetType: BlockEnum.Group,
|
||||
_hiddenInGroupId: undefined,
|
||||
_isBundled: false,
|
||||
_isTemp: true, // UI-only edge, not persisted to backend
|
||||
},
|
||||
zIndex: edge.zIndex,
|
||||
})
|
||||
})
|
||||
|
||||
// outbound edges of the group node as UI-only edges (not persisted to backend)
|
||||
outboundEdges.forEach((edge) => {
|
||||
if (!bundledNodeIdIsLeaf.has(edge.source))
|
||||
return
|
||||
|
||||
// Use the same handler id format: nodeId-sourceHandle
|
||||
const originalSourceHandle = edge.sourceHandle || 'source'
|
||||
const handlerId = `${edge.source}-${originalSourceHandle}`
|
||||
|
||||
draft.push({
|
||||
id: `${groupNode.id}-${edge.target}-${edge.targetHandle || 'target'}-${handlerId}`,
|
||||
type: edge.type || CUSTOM_EDGE,
|
||||
source: groupNode.id,
|
||||
target: edge.target,
|
||||
sourceHandle: handlerId,
|
||||
targetHandle: edge.targetHandle,
|
||||
data: {
|
||||
...edge.data,
|
||||
sourceType: BlockEnum.Group,
|
||||
targetType: nodeTypeMap.get(edge.target)!,
|
||||
_hiddenInGroupId: undefined,
|
||||
_isBundled: false,
|
||||
_isTemp: true,
|
||||
},
|
||||
zIndex: edge.zIndex,
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
setNodes(newNodes)
|
||||
setEdges(newEdges)
|
||||
workflowStore.setState({
|
||||
selectionMenu: undefined,
|
||||
})
|
||||
handleSyncWorkflowDraft()
|
||||
saveStateToHistory(WorkflowHistoryEvent.NodeAdd, {
|
||||
nodeId: groupNode.id,
|
||||
})
|
||||
}, [handleSyncWorkflowDraft, saveStateToHistory, store, t, workflowStore])
|
||||
|
||||
// check if the current selection can be ungrouped (single selected Group node)
|
||||
const getCanUngroup = useCallback(() => {
|
||||
const { getNodes } = store.getState()
|
||||
const nodes = getNodes()
|
||||
const selectedNodes = nodes.filter(node => node.selected)
|
||||
|
||||
if (selectedNodes.length !== 1)
|
||||
return false
|
||||
|
||||
return selectedNodes[0].data.type === BlockEnum.Group
|
||||
}, [store])
|
||||
|
||||
// get the selected group node id for ungroup operation
|
||||
const getSelectedGroupId = useCallback(() => {
|
||||
const { getNodes } = store.getState()
|
||||
const nodes = getNodes()
|
||||
const selectedNodes = nodes.filter(node => node.selected)
|
||||
|
||||
if (selectedNodes.length === 1 && selectedNodes[0].data.type === BlockEnum.Group)
|
||||
return selectedNodes[0].id
|
||||
|
||||
return undefined
|
||||
}, [store])
|
||||
|
||||
const handleUngroup = useCallback((groupId: string) => {
|
||||
const { getNodes, setNodes, edges, setEdges } = store.getState()
|
||||
const nodes = getNodes()
|
||||
const groupNode = nodes.find(n => n.id === groupId)
|
||||
|
||||
if (!groupNode || groupNode.data.type !== BlockEnum.Group)
|
||||
return
|
||||
|
||||
const memberIds = new Set((groupNode.data.members || []).map((m: { id: string }) => m.id))
|
||||
|
||||
// restore hidden member nodes
|
||||
const newNodes = produce(nodes, (draft) => {
|
||||
draft.forEach((node) => {
|
||||
if (memberIds.has(node.id)) {
|
||||
node.hidden = false
|
||||
delete node.data._hiddenInGroupId
|
||||
}
|
||||
})
|
||||
// remove group node
|
||||
const groupIndex = draft.findIndex(n => n.id === groupId)
|
||||
if (groupIndex !== -1)
|
||||
draft.splice(groupIndex, 1)
|
||||
})
|
||||
|
||||
// restore hidden edges and remove temp edges in single pass O(E)
|
||||
const newEdges = produce(edges, (draft) => {
|
||||
const indicesToRemove: number[] = []
|
||||
|
||||
for (let i = 0; i < draft.length; i++) {
|
||||
const edge = draft[i]
|
||||
// restore hidden edges that involve member nodes
|
||||
if (edge.hidden && (memberIds.has(edge.source) || memberIds.has(edge.target)))
|
||||
edge.hidden = false
|
||||
// collect temp edges connected to group for removal
|
||||
if (edge.data?._isTemp && (edge.source === groupId || edge.target === groupId))
|
||||
indicesToRemove.push(i)
|
||||
}
|
||||
|
||||
// remove collected indices in reverse order to avoid index shift
|
||||
for (let i = indicesToRemove.length - 1; i >= 0; i--)
|
||||
draft.splice(indicesToRemove[i], 1)
|
||||
})
|
||||
|
||||
setNodes(newNodes)
|
||||
setEdges(newEdges)
|
||||
handleSyncWorkflowDraft()
|
||||
saveStateToHistory(WorkflowHistoryEvent.NodeDelete, {
|
||||
nodeId: groupId,
|
||||
})
|
||||
}, [handleSyncWorkflowDraft, saveStateToHistory, store])
|
||||
|
||||
return {
|
||||
handleNodeDragStart,
|
||||
handleNodeDrag,
|
||||
@@ -2107,11 +2868,17 @@ export const useNodesInteractions = () => {
|
||||
handleNodesPaste,
|
||||
handleNodesDuplicate,
|
||||
handleNodesDelete,
|
||||
handleMakeGroup,
|
||||
handleUngroup,
|
||||
handleNodeResize,
|
||||
handleNodeDisconnect,
|
||||
handleHistoryBack,
|
||||
handleHistoryForward,
|
||||
dimOtherNodes,
|
||||
undimAllNodes,
|
||||
hasBundledNodes,
|
||||
getCanMakeGroup,
|
||||
getCanUngroup,
|
||||
getSelectedGroupId,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import type { AvailableNodesMetaData } from '@/app/components/workflow/hooks-store'
|
||||
import type { Node } from '@/app/components/workflow/types'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { CollectionType } from '@/app/components/tools/types'
|
||||
import { useHooksStore } from '@/app/components/workflow/hooks-store'
|
||||
import GroupDefault from '@/app/components/workflow/nodes/group/default'
|
||||
import { useStore } from '@/app/components/workflow/store'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
@@ -25,6 +27,7 @@ export const useNodesMetaData = () => {
|
||||
}
|
||||
|
||||
export const useNodeMetaData = (node: Node) => {
|
||||
const { t } = useTranslation()
|
||||
const language = useGetLanguage()
|
||||
const { data: buildInTools } = useAllBuiltInTools()
|
||||
const { data: customTools } = useAllCustomTools()
|
||||
@@ -34,6 +37,9 @@ export const useNodeMetaData = (node: Node) => {
|
||||
const { data } = node
|
||||
const nodeMetaData = availableNodesMetaData.nodesMap?.[data.type]
|
||||
const author = useMemo(() => {
|
||||
if (data.type === BlockEnum.Group)
|
||||
return GroupDefault.metaData.author
|
||||
|
||||
if (data.type === BlockEnum.DataSource)
|
||||
return dataSourceList?.find(dataSource => dataSource.plugin_id === data.plugin_id)?.author
|
||||
|
||||
@@ -48,6 +54,9 @@ export const useNodeMetaData = (node: Node) => {
|
||||
}, [data, buildInTools, customTools, workflowTools, nodeMetaData, dataSourceList])
|
||||
|
||||
const description = useMemo(() => {
|
||||
if (data.type === BlockEnum.Group)
|
||||
return t('blocksAbout.group', { ns: 'workflow' })
|
||||
|
||||
if (data.type === BlockEnum.DataSource)
|
||||
return dataSourceList?.find(dataSource => dataSource.plugin_id === data.plugin_id)?.description[language]
|
||||
if (data.type === BlockEnum.Tool) {
|
||||
@@ -58,7 +67,7 @@ export const useNodeMetaData = (node: Node) => {
|
||||
return customTools?.find(toolWithProvider => toolWithProvider.id === data.provider_id)?.description[language]
|
||||
}
|
||||
return nodeMetaData?.metaData.description
|
||||
}, [data, buildInTools, customTools, workflowTools, nodeMetaData, dataSourceList, language])
|
||||
}, [data, buildInTools, customTools, workflowTools, nodeMetaData, dataSourceList, language, t])
|
||||
|
||||
return useMemo(() => {
|
||||
return {
|
||||
|
||||
@@ -17,7 +17,7 @@ import {
|
||||
} from '../utils'
|
||||
import { useWorkflowHistoryStore } from '../workflow-history-store'
|
||||
|
||||
export const useShortcuts = (): void => {
|
||||
export const useShortcuts = (enabled = true): void => {
|
||||
const {
|
||||
handleNodesCopy,
|
||||
handleNodesPaste,
|
||||
@@ -27,6 +27,12 @@ export const useShortcuts = (): void => {
|
||||
handleHistoryForward,
|
||||
dimOtherNodes,
|
||||
undimAllNodes,
|
||||
hasBundledNodes,
|
||||
getCanMakeGroup,
|
||||
handleMakeGroup,
|
||||
getCanUngroup,
|
||||
getSelectedGroupId,
|
||||
handleUngroup,
|
||||
} = useNodesInteractions()
|
||||
const { shortcutsEnabled: workflowHistoryShortcutsEnabled } = useWorkflowHistoryStore()
|
||||
const { handleSyncWorkflowDraft } = useNodesSyncDraft()
|
||||
@@ -60,13 +66,17 @@ export const useShortcuts = (): void => {
|
||||
}
|
||||
|
||||
const shouldHandleShortcut = useCallback((e: KeyboardEvent) => {
|
||||
if (!enabled)
|
||||
return false
|
||||
return !isEventTargetInputArea(e.target as HTMLElement)
|
||||
}, [])
|
||||
}, [enabled])
|
||||
|
||||
const shouldHandleCopy = useCallback(() => {
|
||||
if (!enabled)
|
||||
return false
|
||||
const selection = document.getSelection()
|
||||
return !selection || selection.isCollapsed
|
||||
}, [])
|
||||
}, [enabled])
|
||||
|
||||
useKeyPress(['delete', 'backspace'], (e) => {
|
||||
if (shouldHandleShortcut(e)) {
|
||||
@@ -78,7 +88,8 @@ export const useShortcuts = (): void => {
|
||||
|
||||
useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.c`, (e) => {
|
||||
const { showDebugAndPreviewPanel } = workflowStore.getState()
|
||||
if (shouldHandleShortcut(e) && shouldHandleCopy() && !showDebugAndPreviewPanel) {
|
||||
// Only intercept when nodes are selected via box selection
|
||||
if (shouldHandleShortcut(e) && shouldHandleCopy() && !showDebugAndPreviewPanel && hasBundledNodes()) {
|
||||
e.preventDefault()
|
||||
handleNodesCopy()
|
||||
}
|
||||
@@ -99,6 +110,26 @@ export const useShortcuts = (): void => {
|
||||
}
|
||||
}, { exactMatch: true, useCapture: true })
|
||||
|
||||
useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.g`, (e) => {
|
||||
// Only intercept when the selection can be grouped
|
||||
if (shouldHandleShortcut(e) && getCanMakeGroup()) {
|
||||
e.preventDefault()
|
||||
// Close selection context menu if open
|
||||
workflowStore.setState({ selectionMenu: undefined })
|
||||
handleMakeGroup()
|
||||
}
|
||||
}, { exactMatch: true, useCapture: true })
|
||||
|
||||
useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.shift.g`, (e) => {
|
||||
// Only intercept when the selection can be ungrouped
|
||||
if (shouldHandleShortcut(e) && getCanUngroup()) {
|
||||
e.preventDefault()
|
||||
const groupId = getSelectedGroupId()
|
||||
if (groupId)
|
||||
handleUngroup(groupId)
|
||||
}
|
||||
}, { exactMatch: true, useCapture: true })
|
||||
|
||||
useKeyPress(`${getKeyboardKeyCodeBySystem('alt')}.r`, (e) => {
|
||||
if (shouldHandleShortcut(e)) {
|
||||
e.preventDefault()
|
||||
@@ -255,6 +286,8 @@ export const useShortcuts = (): void => {
|
||||
|
||||
// Listen for zen toggle event from /zen command
|
||||
useEffect(() => {
|
||||
if (!enabled)
|
||||
return
|
||||
const handleZenToggle = () => {
|
||||
handleToggleMaximizeCanvas()
|
||||
}
|
||||
@@ -263,5 +296,5 @@ export const useShortcuts = (): void => {
|
||||
return () => {
|
||||
window.removeEventListener(ZEN_TOGGLE_EVENT, handleZenToggle)
|
||||
}
|
||||
}, [handleToggleMaximizeCanvas])
|
||||
}, [enabled, handleToggleMaximizeCanvas])
|
||||
}
|
||||
|
||||
@@ -37,7 +37,10 @@ export const useWorkflowNodeFinished = () => {
|
||||
}))
|
||||
|
||||
const newNodes = produce(nodes, (draft) => {
|
||||
const currentNode = draft.find(node => node.id === data.node_id)!
|
||||
const currentNode = draft.find(node => node.id === data.node_id)
|
||||
// Skip if node not found (e.g., virtual extraction nodes)
|
||||
if (!currentNode)
|
||||
return
|
||||
currentNode.data._runningStatus = data.status
|
||||
if (data.status === NodeRunningStatus.Exception) {
|
||||
if (data.execution_metadata?.error_strategy === ErrorHandleTypeEnum.failBranch)
|
||||
|
||||
@@ -45,6 +45,11 @@ export const useWorkflowNodeStarted = () => {
|
||||
} = reactflow
|
||||
const currentNodeIndex = nodes.findIndex(node => node.id === data.node_id)
|
||||
const currentNode = nodes[currentNodeIndex]
|
||||
|
||||
// Skip if node not found (e.g., virtual extraction nodes)
|
||||
if (!currentNode)
|
||||
return
|
||||
|
||||
const position = currentNode.position
|
||||
const zoom = transform[2]
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import type {
|
||||
Connection,
|
||||
} from 'reactflow'
|
||||
import type { GroupNodeData } from '../nodes/group/types'
|
||||
import type { IterationNodeType } from '../nodes/iteration/types'
|
||||
import type { LoopNodeType } from '../nodes/loop/types'
|
||||
import type {
|
||||
BlockEnum,
|
||||
Edge,
|
||||
Node,
|
||||
ValueSelector,
|
||||
@@ -28,14 +28,12 @@ import {
|
||||
} from '../constants'
|
||||
import { findUsedVarNodes, getNodeOutputVars, updateNodeVars } from '../nodes/_base/components/variable/utils'
|
||||
import { CUSTOM_NOTE_NODE } from '../note-node/constants'
|
||||
|
||||
import {
|
||||
useStore,
|
||||
useWorkflowStore,
|
||||
} from '../store'
|
||||
import {
|
||||
WorkflowRunningStatus,
|
||||
} from '../types'
|
||||
|
||||
import { BlockEnum, WorkflowRunningStatus } from '../types'
|
||||
import {
|
||||
getWorkflowEntryNode,
|
||||
isWorkflowEntryNode,
|
||||
@@ -381,7 +379,7 @@ export const useWorkflow = () => {
|
||||
return startNodes
|
||||
}, [nodesMap, getRootNodesById])
|
||||
|
||||
const isValidConnection = useCallback(({ source, sourceHandle: _sourceHandle, target }: Connection) => {
|
||||
const isValidConnection = useCallback(({ source, sourceHandle, target }: Connection) => {
|
||||
const {
|
||||
edges,
|
||||
getNodes,
|
||||
@@ -396,15 +394,42 @@ export const useWorkflow = () => {
|
||||
if (sourceNode.parentId !== targetNode.parentId)
|
||||
return false
|
||||
|
||||
// For Group nodes, use the leaf node's type for validation
|
||||
// sourceHandle format: "${leafNodeId}-${originalSourceHandle}"
|
||||
let actualSourceType = sourceNode.data.type
|
||||
if (sourceNode.data.type === BlockEnum.Group && sourceHandle) {
|
||||
const lastDashIndex = sourceHandle.lastIndexOf('-')
|
||||
if (lastDashIndex > 0) {
|
||||
const leafNodeId = sourceHandle.substring(0, lastDashIndex)
|
||||
const leafNode = nodes.find(node => node.id === leafNodeId)
|
||||
if (leafNode)
|
||||
actualSourceType = leafNode.data.type
|
||||
}
|
||||
}
|
||||
|
||||
if (sourceNode && targetNode) {
|
||||
const sourceNodeAvailableNextNodes = getAvailableBlocks(sourceNode.data.type, !!sourceNode.parentId).availableNextBlocks
|
||||
const sourceNodeAvailableNextNodes = getAvailableBlocks(actualSourceType, !!sourceNode.parentId).availableNextBlocks
|
||||
const targetNodeAvailablePrevNodes = getAvailableBlocks(targetNode.data.type, !!targetNode.parentId).availablePrevBlocks
|
||||
|
||||
if (!sourceNodeAvailableNextNodes.includes(targetNode.data.type))
|
||||
return false
|
||||
if (targetNode.data.type === BlockEnum.Group) {
|
||||
const groupData = targetNode.data as GroupNodeData
|
||||
const headNodeIds = groupData.headNodeIds || []
|
||||
if (headNodeIds.length > 0) {
|
||||
const headNode = nodes.find(node => node.id === headNodeIds[0])
|
||||
if (headNode) {
|
||||
const headNodeAvailablePrevNodes = getAvailableBlocks(headNode.data.type, !!targetNode.parentId).availablePrevBlocks
|
||||
if (!headNodeAvailablePrevNodes.includes(actualSourceType))
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (!sourceNodeAvailableNextNodes.includes(targetNode.data.type))
|
||||
return false
|
||||
|
||||
if (!targetNodeAvailablePrevNodes.includes(sourceNode.data.type))
|
||||
return false
|
||||
if (!targetNodeAvailablePrevNodes.includes(actualSourceType))
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
const hasCycle = (node: Node, visited = new Set()) => {
|
||||
@@ -473,13 +498,9 @@ export const useNodesReadOnly = () => {
|
||||
const isRestoring = useStore(s => s.isRestoring)
|
||||
|
||||
const getNodesReadOnly = useCallback((): boolean => {
|
||||
const {
|
||||
workflowRunningData,
|
||||
historyWorkflowData,
|
||||
isRestoring,
|
||||
} = workflowStore.getState()
|
||||
const state = workflowStore.getState()
|
||||
|
||||
return !!(workflowRunningData?.result.status === WorkflowRunningStatus.Running || historyWorkflowData || isRestoring)
|
||||
return !!(state.workflowRunningData?.result.status === WorkflowRunningStatus.Running || state.historyWorkflowData || state.isRestoring)
|
||||
}, [workflowStore])
|
||||
|
||||
return {
|
||||
@@ -525,6 +546,7 @@ export const useIsNodeInLoop = (loopId: string) => {
|
||||
return false
|
||||
|
||||
if (node.parentId === loopId)
|
||||
|
||||
return true
|
||||
|
||||
return false
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type {
|
||||
NodeMouseHandler,
|
||||
Viewport,
|
||||
} from 'reactflow'
|
||||
import type { Shape as HooksStoreShape } from './hooks-store'
|
||||
@@ -54,6 +55,14 @@ import {
|
||||
} from './constants'
|
||||
import CustomConnectionLine from './custom-connection-line'
|
||||
import CustomEdge from './custom-edge'
|
||||
import {
|
||||
CUSTOM_GROUP_EXIT_PORT_NODE,
|
||||
CUSTOM_GROUP_INPUT_NODE,
|
||||
CUSTOM_GROUP_NODE,
|
||||
CustomGroupExitPortNode,
|
||||
CustomGroupInputNode,
|
||||
CustomGroupNode,
|
||||
} from './custom-group-node'
|
||||
import DatasetsDetailProvider from './datasets-detail-store/provider'
|
||||
import HelpLine from './help-line'
|
||||
import {
|
||||
@@ -94,6 +103,7 @@ import {
|
||||
} from './store'
|
||||
import SyncingDataModal from './syncing-data-modal'
|
||||
import {
|
||||
BlockEnum,
|
||||
ControlMode,
|
||||
} from './types'
|
||||
import { setupScrollToNodeListener } from './utils/node-navigation'
|
||||
@@ -112,6 +122,9 @@ const nodeTypes = {
|
||||
[CUSTOM_ITERATION_START_NODE]: CustomIterationStartNode,
|
||||
[CUSTOM_LOOP_START_NODE]: CustomLoopStartNode,
|
||||
[CUSTOM_DATA_SOURCE_EMPTY_NODE]: CustomDataSourceEmptyNode,
|
||||
[CUSTOM_GROUP_NODE]: CustomGroupNode,
|
||||
[CUSTOM_GROUP_INPUT_NODE]: CustomGroupInputNode,
|
||||
[CUSTOM_GROUP_EXIT_PORT_NODE]: CustomGroupExitPortNode,
|
||||
}
|
||||
const edgeTypes = {
|
||||
[CUSTOM_EDGE]: CustomEdge,
|
||||
@@ -123,6 +136,9 @@ export type WorkflowProps = {
|
||||
viewport?: Viewport
|
||||
children?: React.ReactNode
|
||||
onWorkflowDataUpdate?: (v: any) => void
|
||||
allowSelectionWhenReadOnly?: boolean
|
||||
canvasReadOnly?: boolean
|
||||
interactionMode?: 'default' | 'subgraph'
|
||||
}
|
||||
export const Workflow: FC<WorkflowProps> = memo(({
|
||||
nodes: originalNodes,
|
||||
@@ -130,6 +146,9 @@ export const Workflow: FC<WorkflowProps> = memo(({
|
||||
viewport,
|
||||
children,
|
||||
onWorkflowDataUpdate,
|
||||
allowSelectionWhenReadOnly = false,
|
||||
canvasReadOnly = false,
|
||||
interactionMode = 'default',
|
||||
}) => {
|
||||
const workflowContainerRef = useRef<HTMLDivElement>(null)
|
||||
const workflowStore = useWorkflowStore()
|
||||
@@ -182,9 +201,10 @@ export const Workflow: FC<WorkflowProps> = memo(({
|
||||
id: node.id,
|
||||
data: node.data,
|
||||
}))
|
||||
if (!isEqual(oldData, nodesData))
|
||||
if (!isEqual(oldData, nodesData)) {
|
||||
setNodesInStore(nodes)
|
||||
}, [setNodesInStore, workflowStore])
|
||||
}
|
||||
}, [setNodesInStore])
|
||||
useEffect(() => {
|
||||
setNodesOnlyChangeWithData(currentNodes as Node[])
|
||||
}, [currentNodes, setNodesOnlyChangeWithData])
|
||||
@@ -316,7 +336,8 @@ export const Workflow: FC<WorkflowProps> = memo(({
|
||||
},
|
||||
})
|
||||
|
||||
useShortcuts()
|
||||
const isSubGraph = interactionMode === 'subgraph'
|
||||
useShortcuts(!isSubGraph)
|
||||
// Initialize workflow node search functionality
|
||||
useWorkflowSearch()
|
||||
|
||||
@@ -370,6 +391,16 @@ export const Workflow: FC<WorkflowProps> = memo(({
|
||||
}
|
||||
}
|
||||
|
||||
const handleNodeClickInMode = useCallback<NodeMouseHandler>(
|
||||
(event, node) => {
|
||||
if (isSubGraph && node.data.type !== BlockEnum.LLM)
|
||||
return
|
||||
|
||||
handleNodeClick(event, node)
|
||||
},
|
||||
[handleNodeClick, isSubGraph],
|
||||
)
|
||||
|
||||
return (
|
||||
<div
|
||||
id="workflow-container"
|
||||
@@ -381,18 +412,18 @@ export const Workflow: FC<WorkflowProps> = memo(({
|
||||
ref={workflowContainerRef}
|
||||
>
|
||||
<SyncingDataModal />
|
||||
<CandidateNode />
|
||||
{!isSubGraph && <CandidateNode />}
|
||||
<div
|
||||
className="pointer-events-none absolute left-0 top-0 z-10 flex w-12 items-center justify-center p-1 pl-2"
|
||||
style={{ height: controlHeight }}
|
||||
>
|
||||
<Control />
|
||||
{!isSubGraph && <Control />}
|
||||
</div>
|
||||
<Operator handleRedo={handleHistoryForward} handleUndo={handleHistoryBack} />
|
||||
<PanelContextmenu />
|
||||
<NodeContextmenu />
|
||||
<SelectionContextmenu />
|
||||
<HelpLine />
|
||||
{!isSubGraph && <PanelContextmenu />}
|
||||
{!isSubGraph && <NodeContextmenu />}
|
||||
{!isSubGraph && <SelectionContextmenu />}
|
||||
{!isSubGraph && <HelpLine />}
|
||||
{
|
||||
!!showConfirm && (
|
||||
<Confirm
|
||||
@@ -415,38 +446,38 @@ export const Workflow: FC<WorkflowProps> = memo(({
|
||||
onNodeDragStop={handleNodeDragStop}
|
||||
onNodeMouseEnter={handleNodeEnter}
|
||||
onNodeMouseLeave={handleNodeLeave}
|
||||
onNodeClick={handleNodeClick}
|
||||
onNodeContextMenu={handleNodeContextMenu}
|
||||
onConnect={handleNodeConnect}
|
||||
onConnectStart={handleNodeConnectStart}
|
||||
onConnectEnd={handleNodeConnectEnd}
|
||||
onNodeClick={handleNodeClickInMode}
|
||||
onNodeContextMenu={isSubGraph ? undefined : handleNodeContextMenu}
|
||||
onConnect={isSubGraph ? undefined : handleNodeConnect}
|
||||
onConnectStart={isSubGraph ? undefined : handleNodeConnectStart}
|
||||
onConnectEnd={isSubGraph ? undefined : handleNodeConnectEnd}
|
||||
onEdgeMouseEnter={handleEdgeEnter}
|
||||
onEdgeMouseLeave={handleEdgeLeave}
|
||||
onEdgesChange={handleEdgesChange}
|
||||
onSelectionStart={handleSelectionStart}
|
||||
onSelectionChange={handleSelectionChange}
|
||||
onSelectionDrag={handleSelectionDrag}
|
||||
onPaneContextMenu={handlePaneContextMenu}
|
||||
onSelectionContextMenu={handleSelectionContextMenu}
|
||||
onSelectionStart={isSubGraph ? undefined : handleSelectionStart}
|
||||
onSelectionChange={isSubGraph ? undefined : handleSelectionChange}
|
||||
onSelectionDrag={isSubGraph ? undefined : handleSelectionDrag}
|
||||
onPaneContextMenu={isSubGraph ? undefined : handlePaneContextMenu}
|
||||
onSelectionContextMenu={isSubGraph ? undefined : handleSelectionContextMenu}
|
||||
connectionLineComponent={CustomConnectionLine}
|
||||
// NOTE: For LOOP node, how to distinguish between ITERATION and LOOP here? Maybe both are the same?
|
||||
connectionLineContainerStyle={{ zIndex: ITERATION_CHILDREN_Z_INDEX }}
|
||||
defaultViewport={viewport}
|
||||
multiSelectionKeyCode={null}
|
||||
deleteKeyCode={null}
|
||||
nodesDraggable={!nodesReadOnly}
|
||||
nodesConnectable={!nodesReadOnly}
|
||||
nodesFocusable={!nodesReadOnly}
|
||||
edgesFocusable={!nodesReadOnly}
|
||||
panOnScroll={controlMode === ControlMode.Pointer && !workflowReadOnly}
|
||||
panOnDrag={controlMode === ControlMode.Hand || [1]}
|
||||
zoomOnPinch={true}
|
||||
zoomOnScroll={true}
|
||||
zoomOnDoubleClick={true}
|
||||
nodesDraggable={!(nodesReadOnly || canvasReadOnly || isSubGraph)}
|
||||
nodesConnectable={!(nodesReadOnly || canvasReadOnly || isSubGraph)}
|
||||
nodesFocusable={allowSelectionWhenReadOnly ? true : !nodesReadOnly}
|
||||
edgesFocusable={isSubGraph ? false : (allowSelectionWhenReadOnly ? true : !nodesReadOnly)}
|
||||
panOnScroll={!isSubGraph && controlMode === ControlMode.Pointer && !workflowReadOnly}
|
||||
panOnDrag={!isSubGraph && (controlMode === ControlMode.Hand || [1])}
|
||||
selectionOnDrag={!isSubGraph && controlMode === ControlMode.Pointer && !workflowReadOnly && !canvasReadOnly}
|
||||
zoomOnPinch={!isSubGraph}
|
||||
zoomOnScroll={!isSubGraph}
|
||||
zoomOnDoubleClick={!isSubGraph}
|
||||
isValidConnection={isValidConnection}
|
||||
selectionKeyCode={null}
|
||||
selectionMode={SelectionMode.Partial}
|
||||
selectionOnDrag={controlMode === ControlMode.Pointer && !workflowReadOnly}
|
||||
minZoom={0.25}
|
||||
>
|
||||
<Background
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user