mirror of
https://github.com/langgenius/dify.git
synced 2026-02-14 04:34:02 +00:00
Compare commits
79 Commits
refactor/r
...
zhsama/age
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
0f3156dfbe | ||
|
|
b21875eaaf | ||
|
|
691554ad1c | ||
|
|
f247ebfbe1 | ||
|
|
d641c845dd | ||
|
|
2e10d67610 | ||
|
|
e89d4e14ea | ||
|
|
5525f63032 | ||
|
|
8ee643e88d | ||
|
|
ccb337e8eb | ||
|
|
1ff677c300 | ||
|
|
04145b19a1 | ||
|
|
56e537786f | ||
|
|
810f9eaaad | ||
|
|
4828348532 | ||
|
|
c8c048c3a3 | ||
|
|
495d575ebc | ||
|
|
b9052bc244 | ||
|
|
b7025ad9d6 | ||
|
|
c5482c2503 | ||
|
|
d394adfaf7 | ||
|
|
bc771d9c50 | ||
|
|
96ec176b83 | ||
|
|
f57d2ef31f | ||
|
|
e80bc78780 | ||
|
|
ddbbddbd14 | ||
|
|
9b961fb41e | ||
|
|
4f79d09d7b | ||
|
|
dbed937fc6 | ||
|
|
969c96b070 | ||
|
|
03e0c4c617 | ||
|
|
47790b49d4 | ||
|
|
b25b069917 | ||
|
|
bb190f9610 | ||
|
|
d65ae68668 | ||
|
|
f625350439 | ||
|
|
f4e8f64bf7 | ||
|
|
d91087492d | ||
|
|
cab7cd37b8 | ||
|
|
f925266c1b | ||
|
|
6e2cf23a73 | ||
|
|
8b0bc6937d | ||
|
|
872fd98eda | ||
|
|
5bcd3b6fe6 | ||
|
|
1aed585a19 | ||
|
|
831eba8b1c | ||
|
|
8b8e521c4e | ||
|
|
88248ad2d3 | ||
|
|
760a739e91 | ||
|
|
d92c476388 | ||
|
|
9012dced6a | ||
|
|
50bed78d7a | ||
|
|
60250355cb | ||
|
|
75afc2dc0e | ||
|
|
225b13da93 | ||
|
|
37c748192d | ||
|
|
b7a2957340 | ||
|
|
a6ce6a249b | ||
|
|
8834e6e531 | ||
|
|
39010fd153 | ||
|
|
bd338a9043 | ||
|
|
39d6383474 | ||
|
|
add8980790 | ||
|
|
5157e1a96c | ||
|
|
4bb76acc37 | ||
|
|
b513933040 | ||
|
|
18ea9d3f18 | ||
|
|
7b660a9ebc | ||
|
|
783a49bd97 | ||
|
|
d3c6b09354 | ||
|
|
3d61496d25 | ||
|
|
16bff9e82f | ||
|
|
22f25731e8 | ||
|
|
035f51ad58 | ||
|
|
e9795bd772 | ||
|
|
93b516a4ec | ||
|
|
fc9d5b2a62 | ||
|
|
e3bfb95c52 | ||
|
|
752cb9e4f4 |
1
.gitignore
vendored
1
.gitignore
vendored
@@ -209,6 +209,7 @@ api/.vscode
|
||||
.history
|
||||
|
||||
.idea/
|
||||
web/migration/
|
||||
|
||||
# pnpm
|
||||
/.pnpm-store
|
||||
|
||||
@@ -55,6 +55,35 @@ class InstructionTemplatePayload(BaseModel):
|
||||
type: str = Field(..., description="Instruction template type")
|
||||
|
||||
|
||||
class ContextGeneratePayload(BaseModel):
|
||||
"""Payload for generating extractor code node."""
|
||||
|
||||
workflow_id: str = Field(..., description="Workflow ID")
|
||||
node_id: str = Field(..., description="Current tool/llm node ID")
|
||||
parameter_name: str = Field(..., description="Parameter name to generate code for")
|
||||
language: str = Field(default="python3", description="Code language (python3/javascript)")
|
||||
prompt_messages: list[dict[str, Any]] = Field(
|
||||
..., description="Multi-turn conversation history, last message is the current instruction"
|
||||
)
|
||||
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
|
||||
|
||||
|
||||
class SuggestedQuestionsPayload(BaseModel):
|
||||
"""Payload for generating suggested questions."""
|
||||
|
||||
workflow_id: str = Field(..., description="Workflow ID")
|
||||
node_id: str = Field(..., description="Current tool/llm node ID")
|
||||
parameter_name: str = Field(..., description="Parameter name")
|
||||
language: str = Field(
|
||||
default="English", description="Language for generated questions (e.g. English, Chinese, Japanese)"
|
||||
)
|
||||
model_config_data: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
alias="model_config",
|
||||
description="Model configuration (optional, uses system default if not provided)",
|
||||
)
|
||||
|
||||
|
||||
def reg(cls: type[BaseModel]):
|
||||
console_ns.schema_model(cls.__name__, cls.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
@@ -64,6 +93,8 @@ reg(RuleCodeGeneratePayload)
|
||||
reg(RuleStructuredOutputPayload)
|
||||
reg(InstructionGeneratePayload)
|
||||
reg(InstructionTemplatePayload)
|
||||
reg(ContextGeneratePayload)
|
||||
reg(SuggestedQuestionsPayload)
|
||||
|
||||
|
||||
@console_ns.route("/rule-generate")
|
||||
@@ -278,3 +309,74 @@ class InstructionGenerationTemplateApi(Resource):
|
||||
return {"data": INSTRUCTION_GENERATE_TEMPLATE_CODE}
|
||||
case _:
|
||||
raise ValueError(f"Invalid type: {args.type}")
|
||||
|
||||
|
||||
@console_ns.route("/context-generate")
|
||||
class ContextGenerateApi(Resource):
|
||||
@console_ns.doc("generate_with_context")
|
||||
@console_ns.doc(description="Generate with multi-turn conversation context")
|
||||
@console_ns.expect(console_ns.models[ContextGeneratePayload.__name__])
|
||||
@console_ns.response(200, "Content generated successfully")
|
||||
@console_ns.response(400, "Invalid request parameters or workflow not found")
|
||||
@console_ns.response(402, "Provider quota exceeded")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
from core.llm_generator.utils import deserialize_prompt_messages
|
||||
|
||||
args = ContextGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
prompt_messages = deserialize_prompt_messages(args.prompt_messages)
|
||||
|
||||
try:
|
||||
return LLMGenerator.generate_with_context(
|
||||
tenant_id=current_tenant_id,
|
||||
workflow_id=args.workflow_id,
|
||||
node_id=args.node_id,
|
||||
parameter_name=args.parameter_name,
|
||||
language=args.language,
|
||||
prompt_messages=prompt_messages,
|
||||
model_config=args.model_config_data,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
|
||||
@console_ns.route("/context-generate/suggested-questions")
|
||||
class SuggestedQuestionsApi(Resource):
|
||||
@console_ns.doc("generate_suggested_questions")
|
||||
@console_ns.doc(description="Generate suggested questions for context generation")
|
||||
@console_ns.expect(console_ns.models[SuggestedQuestionsPayload.__name__])
|
||||
@console_ns.response(200, "Questions generated successfully")
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = SuggestedQuestionsPayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
return LLMGenerator.generate_suggested_questions(
|
||||
tenant_id=current_tenant_id,
|
||||
workflow_id=args.workflow_id,
|
||||
node_id=args.node_id,
|
||||
parameter_name=args.parameter_name,
|
||||
language=args.language,
|
||||
model_config=args.model_config_data,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
raise ProviderQuotaExceededError()
|
||||
except ModelCurrentlyNotSupportError:
|
||||
raise ProviderModelCurrentlyNotSupportError()
|
||||
except InvokeError as e:
|
||||
raise CompletionRequestError(e.description)
|
||||
|
||||
@@ -82,7 +82,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@@ -110,7 +110,7 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
@@ -120,6 +120,6 @@ class AdvancedChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@@ -81,7 +81,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@@ -109,7 +109,7 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
@@ -117,6 +117,6 @@ class AgentChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@@ -81,7 +81,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@@ -109,7 +109,7 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
sub_stream_response_dict["metadata"] = cls._get_simple_metadata(metadata)
|
||||
response_chunk.update(sub_stream_response_dict)
|
||||
@@ -117,6 +117,6 @@ class ChatAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@@ -70,6 +70,8 @@ class _NodeSnapshot:
|
||||
"""Empty string means the node is not executing inside an iteration."""
|
||||
loop_id: str = ""
|
||||
"""Empty string means the node is not executing inside a loop."""
|
||||
mention_parent_id: str = ""
|
||||
"""Empty string means the node is not an extractor node."""
|
||||
|
||||
|
||||
class WorkflowResponseConverter:
|
||||
@@ -131,6 +133,7 @@ class WorkflowResponseConverter:
|
||||
start_at=event.start_at,
|
||||
iteration_id=event.in_iteration_id or "",
|
||||
loop_id=event.in_loop_id or "",
|
||||
mention_parent_id=event.in_mention_parent_id or "",
|
||||
)
|
||||
node_execution_id = NodeExecutionId(event.node_execution_id)
|
||||
self._node_snapshots[node_execution_id] = snapshot
|
||||
@@ -287,6 +290,7 @@ class WorkflowResponseConverter:
|
||||
created_at=int(snapshot.start_at.timestamp()),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
mention_parent_id=event.in_mention_parent_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
),
|
||||
)
|
||||
@@ -373,6 +377,7 @@ class WorkflowResponseConverter:
|
||||
files=self.fetch_files_from_node_outputs(event.outputs or {}),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
mention_parent_id=event.in_mention_parent_id,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -422,6 +427,7 @@ class WorkflowResponseConverter:
|
||||
files=self.fetch_files_from_node_outputs(event.outputs or {}),
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
mention_parent_id=event.in_mention_parent_id,
|
||||
retry_index=event.retry_index,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -79,7 +79,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@@ -106,7 +106,7 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
}
|
||||
|
||||
if isinstance(sub_stream_response, MessageEndStreamResponse):
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json")
|
||||
sub_stream_response_dict = sub_stream_response.model_dump(mode="json", exclude_none=True)
|
||||
metadata = sub_stream_response_dict.get("metadata", {})
|
||||
if not isinstance(metadata, dict):
|
||||
metadata = {}
|
||||
@@ -116,6 +116,6 @@ class CompletionAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
|
||||
yield response_chunk
|
||||
|
||||
@@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(cast(dict, data))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump())
|
||||
response_chunk.update(sub_stream_response.model_dump(exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(cast(dict, sub_stream_response.to_ignore_detail_dict()))
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump())
|
||||
response_chunk.update(sub_stream_response.model_dump(exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@@ -60,7 +60,7 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
data = cls._error_to_stream_response(sub_stream_response.err)
|
||||
response_chunk.update(data)
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@classmethod
|
||||
@@ -91,5 +91,5 @@ class WorkflowAppGenerateResponseConverter(AppGenerateResponseConverter):
|
||||
elif isinstance(sub_stream_response, NodeStartStreamResponse | NodeFinishStreamResponse):
|
||||
response_chunk.update(sub_stream_response.to_ignore_detail_dict())
|
||||
else:
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json"))
|
||||
response_chunk.update(sub_stream_response.model_dump(mode="json", exclude_none=True))
|
||||
yield response_chunk
|
||||
|
||||
@@ -385,6 +385,7 @@ class WorkflowBasedAppRunner:
|
||||
start_at=event.start_at,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
inputs=inputs,
|
||||
process_data=process_data,
|
||||
outputs=outputs,
|
||||
@@ -405,6 +406,7 @@ class WorkflowBasedAppRunner:
|
||||
start_at=event.start_at,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
agent_strategy=event.agent_strategy,
|
||||
provider_type=event.provider_type,
|
||||
provider_id=event.provider_id,
|
||||
@@ -428,6 +430,7 @@ class WorkflowBasedAppRunner:
|
||||
execution_metadata=execution_metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunFailedEvent):
|
||||
@@ -444,6 +447,7 @@ class WorkflowBasedAppRunner:
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunExceptionEvent):
|
||||
@@ -460,6 +464,7 @@ class WorkflowBasedAppRunner:
|
||||
execution_metadata=event.node_run_result.metadata,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunStreamChunkEvent):
|
||||
@@ -469,6 +474,7 @@ class WorkflowBasedAppRunner:
|
||||
from_variable_selector=list(event.selector),
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetrieverResourceEvent):
|
||||
@@ -477,6 +483,7 @@ class WorkflowBasedAppRunner:
|
||||
retriever_resources=event.retriever_resources,
|
||||
in_iteration_id=event.in_iteration_id,
|
||||
in_loop_id=event.in_loop_id,
|
||||
in_mention_parent_id=event.in_mention_parent_id,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunAgentLogEvent):
|
||||
|
||||
@@ -190,6 +190,8 @@ class QueueTextChunkEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
|
||||
|
||||
class QueueAgentMessageEvent(AppQueueEvent):
|
||||
@@ -229,6 +231,8 @@ class QueueRetrieverResourcesEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
|
||||
|
||||
class QueueAnnotationReplyEvent(AppQueueEvent):
|
||||
@@ -306,6 +310,8 @@ class QueueNodeStartedEvent(AppQueueEvent):
|
||||
node_run_index: int = 1 # FIXME(-LAN-): may not used
|
||||
in_iteration_id: str | None = None
|
||||
in_loop_id: str | None = None
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
@@ -328,6 +334,8 @@ class QueueNodeSucceededEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
@@ -383,6 +391,8 @@ class QueueNodeExceptionEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
@@ -407,6 +417,8 @@ class QueueNodeFailedEvent(AppQueueEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""parent node id if this is an extractor node event"""
|
||||
start_at: datetime
|
||||
|
||||
inputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
|
||||
@@ -262,6 +262,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
extras: dict[str, object] = Field(default_factory=dict)
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
mention_parent_id: str | None = None
|
||||
agent_strategy: AgentNodeStrategyInit | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_STARTED
|
||||
@@ -285,6 +286,7 @@ class NodeStartStreamResponse(StreamResponse):
|
||||
"extras": {},
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
"mention_parent_id": self.data.mention_parent_id,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -320,6 +322,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
mention_parent_id: str | None = None
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_FINISHED
|
||||
workflow_run_id: str
|
||||
@@ -349,6 +352,7 @@ class NodeFinishStreamResponse(StreamResponse):
|
||||
"files": [],
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
"mention_parent_id": self.data.mention_parent_id,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -384,6 +388,7 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
iteration_id: str | None = None
|
||||
loop_id: str | None = None
|
||||
mention_parent_id: str | None = None
|
||||
retry_index: int = 0
|
||||
|
||||
event: StreamEvent = StreamEvent.NODE_RETRY
|
||||
@@ -414,6 +419,7 @@ class NodeRetryStreamResponse(StreamResponse):
|
||||
"files": [],
|
||||
"iteration_id": self.data.iteration_id,
|
||||
"loop_id": self.data.loop_id,
|
||||
"mention_parent_id": self.data.mention_parent_id,
|
||||
"retry_index": self.data.retry_index,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Protocol, cast
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Protocol, cast
|
||||
|
||||
import json_repair
|
||||
|
||||
@@ -398,6 +398,488 @@ class LLMGenerator:
|
||||
logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
|
||||
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@classmethod
|
||||
def generate_with_context(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
parameter_name: str,
|
||||
language: str,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: dict,
|
||||
) -> dict:
|
||||
"""
|
||||
Generate extractor code node based on conversation context.
|
||||
|
||||
Args:
|
||||
tenant_id: Tenant/workspace ID
|
||||
workflow_id: Workflow ID
|
||||
node_id: Current tool/llm node ID
|
||||
parameter_name: Parameter name to generate code for
|
||||
language: Code language (python3/javascript)
|
||||
prompt_messages: Multi-turn conversation history (last message is instruction)
|
||||
model_config: Model configuration (provider, name, completion_params)
|
||||
|
||||
Returns:
|
||||
dict with CodeNodeData format:
|
||||
- variables: Input variable selectors
|
||||
- code_language: Code language
|
||||
- code: Generated code
|
||||
- outputs: Output definitions
|
||||
- message: Explanation
|
||||
- error: Error message if any
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
# Get workflow
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(App).where(App.id == workflow_id)
|
||||
app = session.scalar(stmt)
|
||||
if not app:
|
||||
return cls._error_response(f"App {workflow_id} not found")
|
||||
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return cls._error_response(f"Workflow for app {workflow_id} not found")
|
||||
|
||||
# Get upstream nodes via edge backtracking
|
||||
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
|
||||
|
||||
# Get current node info
|
||||
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
|
||||
if not current_node:
|
||||
return cls._error_response(f"Node {node_id} not found")
|
||||
|
||||
# Get parameter info
|
||||
parameter_info = cls._get_parameter_info(
|
||||
tenant_id=tenant_id,
|
||||
node_data=current_node.get("data", {}),
|
||||
parameter_name=parameter_name,
|
||||
)
|
||||
|
||||
# Build system prompt
|
||||
system_prompt = cls._build_extractor_system_prompt(
|
||||
upstream_nodes=upstream_nodes,
|
||||
current_node=current_node,
|
||||
parameter_info=parameter_info,
|
||||
language=language,
|
||||
)
|
||||
|
||||
# Construct complete prompt_messages with system prompt
|
||||
complete_messages: list[PromptMessage] = [
|
||||
SystemPromptMessage(content=system_prompt),
|
||||
*prompt_messages,
|
||||
]
|
||||
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
|
||||
# Get model instance and schema
|
||||
provider = model_config.get("provider", "")
|
||||
model_name = model_config.get("name", "")
|
||||
model_instance = ModelManager().get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
)
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
|
||||
if not model_schema:
|
||||
return cls._error_response(f"Model schema not found for {model_name}")
|
||||
|
||||
model_parameters = model_config.get("completion_params", {})
|
||||
json_schema = cls._get_code_node_json_schema()
|
||||
|
||||
try:
|
||||
response = invoke_llm_with_structured_output(
|
||||
provider=provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=complete_messages,
|
||||
json_schema=json_schema,
|
||||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return cls._parse_code_node_output(
|
||||
response.structured_output, language, parameter_info.get("type", "string")
|
||||
)
|
||||
|
||||
except InvokeError as e:
|
||||
return cls._error_response(str(e))
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate with context, model: %s", model_config.get("name"))
|
||||
return cls._error_response(f"An unexpected error occurred: {str(e)}")
|
||||
|
||||
@classmethod
|
||||
def _error_response(cls, error: str) -> dict:
|
||||
"""Return error response in CodeNodeData format."""
|
||||
return {
|
||||
"variables": [],
|
||||
"code_language": "python3",
|
||||
"code": "",
|
||||
"outputs": {},
|
||||
"message": "",
|
||||
"error": error,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def generate_suggested_questions(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
workflow_id: str,
|
||||
node_id: str,
|
||||
parameter_name: str,
|
||||
language: str,
|
||||
model_config: dict | None = None,
|
||||
) -> dict:
|
||||
"""
|
||||
Generate suggested questions for context generation.
|
||||
|
||||
Returns dict with questions array and error field.
|
||||
"""
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
# Get workflow context (reuse existing logic)
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(App).where(App.id == workflow_id)
|
||||
app = session.scalar(stmt)
|
||||
if not app:
|
||||
return {"questions": [], "error": f"App {workflow_id} not found"}
|
||||
|
||||
workflow = WorkflowService().get_draft_workflow(app_model=app)
|
||||
if not workflow:
|
||||
return {"questions": [], "error": f"Workflow for app {workflow_id} not found"}
|
||||
|
||||
upstream_nodes = cls._get_upstream_nodes(workflow.graph_dict, node_id)
|
||||
current_node = cls._get_node_by_id(workflow.graph_dict, node_id)
|
||||
if not current_node:
|
||||
return {"questions": [], "error": f"Node {node_id} not found"}
|
||||
|
||||
parameter_info = cls._get_parameter_info(
|
||||
tenant_id=tenant_id,
|
||||
node_data=current_node.get("data", {}),
|
||||
parameter_name=parameter_name,
|
||||
)
|
||||
|
||||
# Build prompt
|
||||
system_prompt = cls._build_suggested_questions_prompt(
|
||||
upstream_nodes=upstream_nodes,
|
||||
current_node=current_node,
|
||||
parameter_info=parameter_info,
|
||||
language=language,
|
||||
)
|
||||
|
||||
prompt_messages: list[PromptMessage] = [
|
||||
SystemPromptMessage(content=system_prompt),
|
||||
]
|
||||
|
||||
# Get model instance - use default if model_config not provided
|
||||
model_manager = ModelManager()
|
||||
if model_config:
|
||||
provider = model_config.get("provider", "")
|
||||
model_name = model_config.get("name", "")
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=provider,
|
||||
model=model_name,
|
||||
)
|
||||
else:
|
||||
model_instance = model_manager.get_default_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
model_name = model_instance.model
|
||||
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(model_name, model_instance.credentials)
|
||||
if not model_schema:
|
||||
return {"questions": [], "error": f"Model schema not found for {model_name}"}
|
||||
|
||||
completion_params = model_config.get("completion_params", {}) if model_config else {}
|
||||
model_parameters = {**completion_params, "max_tokens": 256}
|
||||
json_schema = cls._get_suggested_questions_json_schema()
|
||||
|
||||
try:
|
||||
response = invoke_llm_with_structured_output(
|
||||
provider=model_instance.provider,
|
||||
model_schema=model_schema,
|
||||
model_instance=model_instance,
|
||||
prompt_messages=prompt_messages,
|
||||
json_schema=json_schema,
|
||||
model_parameters=model_parameters,
|
||||
stream=False,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
questions = response.structured_output.get("questions", []) if response.structured_output else []
|
||||
return {"questions": questions, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
return {"questions": [], "error": str(e)}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate suggested questions, model: %s", model_name)
|
||||
return {"questions": [], "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
@classmethod
|
||||
def _build_suggested_questions_prompt(
|
||||
cls,
|
||||
upstream_nodes: list[dict],
|
||||
current_node: dict,
|
||||
parameter_info: dict,
|
||||
language: str = "English",
|
||||
) -> str:
|
||||
"""Build minimal prompt for suggested questions generation."""
|
||||
# Simplify upstream nodes to reduce tokens
|
||||
sources = [f"{n['title']}({','.join(n.get('outputs', {}).keys())})" for n in upstream_nodes[:5]]
|
||||
param_type = parameter_info.get("type", "string")
|
||||
param_desc = parameter_info.get("description", "")[:100]
|
||||
|
||||
return f"""Suggest 3 code generation questions for extracting data.
|
||||
Sources: {", ".join(sources)}
|
||||
Target: {parameter_info.get("name")}({param_type}) - {param_desc}
|
||||
Output 3 short, practical questions in {language}."""
|
||||
|
||||
@classmethod
|
||||
def _get_suggested_questions_json_schema(cls) -> dict:
|
||||
"""Return JSON Schema for suggested questions."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"questions": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"minItems": 3,
|
||||
"maxItems": 3,
|
||||
"description": "3 suggested questions",
|
||||
},
|
||||
},
|
||||
"required": ["questions"],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_code_node_json_schema(cls) -> dict:
|
||||
"""Return JSON Schema for structured output."""
|
||||
return {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"variables": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"variable": {"type": "string", "description": "Variable name in code"},
|
||||
"value_selector": {
|
||||
"type": "array",
|
||||
"items": {"type": "string"},
|
||||
"description": "Path like [node_id, output_name]",
|
||||
},
|
||||
},
|
||||
"required": ["variable", "value_selector"],
|
||||
},
|
||||
},
|
||||
"code": {"type": "string", "description": "Generated code with main function"},
|
||||
"outputs": {
|
||||
"type": "object",
|
||||
"additionalProperties": {
|
||||
"type": "object",
|
||||
"properties": {"type": {"type": "string"}},
|
||||
},
|
||||
"description": "Output definitions, key is output name",
|
||||
},
|
||||
"explanation": {"type": "string", "description": "Brief explanation of the code"},
|
||||
},
|
||||
"required": ["variables", "code", "outputs", "explanation"],
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def _get_upstream_nodes(cls, graph_dict: Mapping[str, Any], node_id: str) -> list[dict]:
|
||||
"""
|
||||
Get all upstream nodes via edge backtracking.
|
||||
|
||||
Traverses the graph backwards from node_id to collect all reachable nodes.
|
||||
"""
|
||||
from collections import defaultdict
|
||||
|
||||
nodes = {n["id"]: n for n in graph_dict.get("nodes", [])}
|
||||
edges = graph_dict.get("edges", [])
|
||||
|
||||
# Build reverse adjacency list
|
||||
reverse_adj: dict[str, list[str]] = defaultdict(list)
|
||||
for edge in edges:
|
||||
reverse_adj[edge["target"]].append(edge["source"])
|
||||
|
||||
# BFS to find all upstream nodes
|
||||
visited: set[str] = set()
|
||||
queue = [node_id]
|
||||
upstream: list[dict] = []
|
||||
|
||||
while queue:
|
||||
current = queue.pop(0)
|
||||
for source in reverse_adj.get(current, []):
|
||||
if source not in visited:
|
||||
visited.add(source)
|
||||
queue.append(source)
|
||||
if source in nodes:
|
||||
upstream.append(cls._extract_node_info(nodes[source]))
|
||||
|
||||
return upstream
|
||||
|
||||
@classmethod
|
||||
def _get_node_by_id(cls, graph_dict: Mapping[str, Any], node_id: str) -> dict | None:
|
||||
"""Get node by ID from graph."""
|
||||
for node in graph_dict.get("nodes", []):
|
||||
if node["id"] == node_id:
|
||||
return node
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _extract_node_info(cls, node: dict) -> dict:
|
||||
"""Extract minimal node info with outputs based on node type."""
|
||||
node_type = node["data"]["type"]
|
||||
node_data = node.get("data", {})
|
||||
|
||||
# Build outputs based on node type (only type, no description to reduce tokens)
|
||||
outputs: dict[str, str] = {}
|
||||
match node_type:
|
||||
case "start":
|
||||
for var in node_data.get("variables", []):
|
||||
name = var.get("variable", var.get("name", ""))
|
||||
outputs[name] = var.get("type", "string")
|
||||
case "llm":
|
||||
outputs["text"] = "string"
|
||||
case "code":
|
||||
for name, output in node_data.get("outputs", {}).items():
|
||||
outputs[name] = output.get("type", "string")
|
||||
case "http-request":
|
||||
outputs = {"body": "string", "status_code": "number", "headers": "object"}
|
||||
case "knowledge-retrieval":
|
||||
outputs["result"] = "array[object]"
|
||||
case "tool":
|
||||
outputs = {"text": "string", "json": "object"}
|
||||
case _:
|
||||
outputs["output"] = "string"
|
||||
|
||||
info: dict = {
|
||||
"id": node["id"],
|
||||
"title": node_data.get("title", node["id"]),
|
||||
"outputs": outputs,
|
||||
}
|
||||
# Only include description if not empty
|
||||
desc = node_data.get("desc", "")
|
||||
if desc:
|
||||
info["desc"] = desc
|
||||
|
||||
return info
|
||||
|
||||
@classmethod
|
||||
def _get_parameter_info(cls, tenant_id: str, node_data: dict, parameter_name: str) -> dict:
|
||||
"""Get parameter info from tool schema using ToolManager."""
|
||||
default_info = {"name": parameter_name, "type": "string", "description": ""}
|
||||
|
||||
if node_data.get("type") != "tool":
|
||||
return default_info
|
||||
|
||||
try:
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
provider_type_str = node_data.get("provider_type", "")
|
||||
provider_type = ToolProviderType(provider_type_str) if provider_type_str else ToolProviderType.BUILT_IN
|
||||
|
||||
tool_runtime = ToolManager.get_tool_runtime(
|
||||
provider_type=provider_type,
|
||||
provider_id=node_data.get("provider_id", ""),
|
||||
tool_name=node_data.get("tool_name", ""),
|
||||
tenant_id=tenant_id,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
)
|
||||
|
||||
parameters = tool_runtime.get_merged_runtime_parameters()
|
||||
for param in parameters:
|
||||
if param.name == parameter_name:
|
||||
return {
|
||||
"name": param.name,
|
||||
"type": param.type.value if hasattr(param.type, "value") else str(param.type),
|
||||
"description": param.llm_description
|
||||
or (param.human_description.en_US if param.human_description else ""),
|
||||
"required": param.required,
|
||||
}
|
||||
except Exception as e:
|
||||
logger.debug("Failed to get parameter info from ToolManager: %s", e)
|
||||
|
||||
return default_info
|
||||
|
||||
@classmethod
|
||||
def _build_extractor_system_prompt(
|
||||
cls,
|
||||
upstream_nodes: list[dict],
|
||||
current_node: dict,
|
||||
parameter_info: dict,
|
||||
language: str,
|
||||
) -> str:
|
||||
"""Build system prompt for extractor code generation."""
|
||||
upstream_json = json.dumps(upstream_nodes, indent=2, ensure_ascii=False)
|
||||
param_type = parameter_info.get("type", "string")
|
||||
return f"""You are a code generator for workflow automation.
|
||||
|
||||
Generate {language} code to extract/transform upstream node outputs for the target parameter.
|
||||
|
||||
## Upstream Nodes
|
||||
{upstream_json}
|
||||
|
||||
## Target
|
||||
Node: {current_node["data"].get("title", current_node["id"])}
|
||||
Parameter: {parameter_info.get("name")} ({param_type}) - {parameter_info.get("description", "")}
|
||||
|
||||
## Requirements
|
||||
- Write a main function that returns type: {param_type}
|
||||
- Use value_selector format: ["node_id", "output_name"]
|
||||
"""
|
||||
|
||||
@classmethod
|
||||
def _parse_code_node_output(cls, content: Mapping[str, Any] | None, language: str, parameter_type: str) -> dict:
|
||||
"""
|
||||
Parse structured output to CodeNodeData format.
|
||||
|
||||
Args:
|
||||
content: Structured output dict from invoke_llm_with_structured_output
|
||||
language: Code language
|
||||
parameter_type: Expected parameter type
|
||||
|
||||
Returns dict with variables, code_language, code, outputs, message, error.
|
||||
"""
|
||||
if content is None:
|
||||
return cls._error_response("Empty or invalid response from LLM")
|
||||
|
||||
# Validate and normalize variables
|
||||
variables = [
|
||||
{"variable": v.get("variable", ""), "value_selector": v.get("value_selector", [])}
|
||||
for v in content.get("variables", [])
|
||||
if isinstance(v, dict)
|
||||
]
|
||||
|
||||
outputs = content.get("outputs", {"result": {"type": parameter_type}})
|
||||
|
||||
return {
|
||||
"variables": variables,
|
||||
"code_language": language,
|
||||
"code": content.get("code", ""),
|
||||
"outputs": outputs,
|
||||
"message": content.get("explanation", ""),
|
||||
"error": "",
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def instruction_modify_legacy(
|
||||
tenant_id: str, flow_id: str, current: str, instruction: str, model_config: dict, ideal_output: str | None
|
||||
|
||||
188
api/core/llm_generator/output_parser/file_ref.py
Normal file
188
api/core/llm_generator/output_parser/file_ref.py
Normal file
@@ -0,0 +1,188 @@
|
||||
"""
|
||||
File reference detection and conversion for structured output.
|
||||
|
||||
This module provides utilities to:
|
||||
1. Detect file reference fields in JSON Schema (format: "dify-file-ref")
|
||||
2. Convert file ID strings to File objects after LLM returns
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.file import File
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
from factories.file_factory import build_from_mapping
|
||||
|
||||
FILE_REF_FORMAT = "dify-file-ref"
|
||||
|
||||
|
||||
def is_file_ref_property(schema: dict) -> bool:
|
||||
"""Check if a schema property is a file reference."""
|
||||
return schema.get("type") == "string" and schema.get("format") == FILE_REF_FORMAT
|
||||
|
||||
|
||||
def detect_file_ref_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
|
||||
"""
|
||||
Recursively detect file reference fields in schema.
|
||||
|
||||
Args:
|
||||
schema: JSON Schema to analyze
|
||||
path: Current path in the schema (used for recursion)
|
||||
|
||||
Returns:
|
||||
List of JSON paths containing file refs, e.g., ["image_id", "files[*]"]
|
||||
"""
|
||||
file_ref_paths: list[str] = []
|
||||
schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "object":
|
||||
for prop_name, prop_schema in schema.get("properties", {}).items():
|
||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||
|
||||
if is_file_ref_property(prop_schema):
|
||||
file_ref_paths.append(current_path)
|
||||
elif isinstance(prop_schema, dict):
|
||||
file_ref_paths.extend(detect_file_ref_fields(prop_schema, current_path))
|
||||
|
||||
elif schema_type == "array":
|
||||
items_schema = schema.get("items", {})
|
||||
array_path = f"{path}[*]" if path else "[*]"
|
||||
|
||||
if is_file_ref_property(items_schema):
|
||||
file_ref_paths.append(array_path)
|
||||
elif isinstance(items_schema, dict):
|
||||
file_ref_paths.extend(detect_file_ref_fields(items_schema, array_path))
|
||||
|
||||
return file_ref_paths
|
||||
|
||||
|
||||
def convert_file_refs_in_output(
|
||||
output: Mapping[str, Any],
|
||||
json_schema: Mapping[str, Any],
|
||||
tenant_id: str,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Convert file ID strings to File objects based on schema.
|
||||
|
||||
Args:
|
||||
output: The structured_output from LLM result
|
||||
json_schema: The original JSON schema (to detect file ref fields)
|
||||
tenant_id: Tenant ID for file lookup
|
||||
|
||||
Returns:
|
||||
Output with file references converted to File objects
|
||||
"""
|
||||
file_ref_paths = detect_file_ref_fields(json_schema)
|
||||
if not file_ref_paths:
|
||||
return dict(output)
|
||||
|
||||
result = _deep_copy_dict(output)
|
||||
|
||||
for path in file_ref_paths:
|
||||
_convert_path_in_place(result, path.split("."), tenant_id)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _deep_copy_dict(obj: Mapping[str, Any]) -> dict[str, Any]:
|
||||
"""Deep copy a mapping to a mutable dict."""
|
||||
result: dict[str, Any] = {}
|
||||
for key, value in obj.items():
|
||||
if isinstance(value, Mapping):
|
||||
result[key] = _deep_copy_dict(value)
|
||||
elif isinstance(value, list):
|
||||
result[key] = [_deep_copy_dict(item) if isinstance(item, Mapping) else item for item in value]
|
||||
else:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
|
||||
def _convert_path_in_place(obj: dict, path_parts: list[str], tenant_id: str) -> None:
|
||||
"""Convert file refs at the given path in place, wrapping in Segment types."""
|
||||
if not path_parts:
|
||||
return
|
||||
|
||||
current = path_parts[0]
|
||||
remaining = path_parts[1:]
|
||||
|
||||
# Handle array notation like "files[*]"
|
||||
if current.endswith("[*]"):
|
||||
key = current[:-3] if current != "[*]" else None
|
||||
target = obj.get(key) if key else obj
|
||||
|
||||
if isinstance(target, list):
|
||||
if remaining:
|
||||
# Nested array with remaining path - recurse into each item
|
||||
for item in target:
|
||||
if isinstance(item, dict):
|
||||
_convert_path_in_place(item, remaining, tenant_id)
|
||||
else:
|
||||
# Array of file IDs - convert all and wrap in ArrayFileSegment
|
||||
files: list[File] = []
|
||||
for item in target:
|
||||
file = _convert_file_id(item, tenant_id)
|
||||
if file is not None:
|
||||
files.append(file)
|
||||
# Replace the array with ArrayFileSegment
|
||||
if key:
|
||||
obj[key] = ArrayFileSegment(value=files)
|
||||
return
|
||||
|
||||
if not remaining:
|
||||
# Leaf node - convert the value and wrap in FileSegment
|
||||
if current in obj:
|
||||
file = _convert_file_id(obj[current], tenant_id)
|
||||
if file is not None:
|
||||
obj[current] = FileSegment(value=file)
|
||||
else:
|
||||
obj[current] = None
|
||||
else:
|
||||
# Recurse into nested object
|
||||
if current in obj and isinstance(obj[current], dict):
|
||||
_convert_path_in_place(obj[current], remaining, tenant_id)
|
||||
|
||||
|
||||
def _convert_file_id(file_id: Any, tenant_id: str) -> File | None:
|
||||
"""
|
||||
Convert a file ID string to a File object.
|
||||
|
||||
Tries multiple file sources in order:
|
||||
1. ToolFile (files generated by tools/workflows)
|
||||
2. UploadFile (files uploaded by users)
|
||||
"""
|
||||
if not isinstance(file_id, str):
|
||||
return None
|
||||
|
||||
# Validate UUID format
|
||||
try:
|
||||
uuid.UUID(file_id)
|
||||
except ValueError:
|
||||
return None
|
||||
|
||||
# Try ToolFile first (files generated by tools/workflows)
|
||||
try:
|
||||
return build_from_mapping(
|
||||
mapping={
|
||||
"transfer_method": "tool_file",
|
||||
"tool_file_id": file_id,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Try UploadFile (files uploaded by users)
|
||||
try:
|
||||
return build_from_mapping(
|
||||
mapping={
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": file_id,
|
||||
},
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# File not found in any source
|
||||
return None
|
||||
@@ -8,6 +8,7 @@ import json_repair
|
||||
from pydantic import TypeAdapter, ValidationError
|
||||
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.file_ref import convert_file_refs_in_output
|
||||
from core.llm_generator.prompts import STRUCTURED_OUTPUT_PROMPT
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.callbacks.base_callback import Callback
|
||||
@@ -57,6 +58,7 @@ def invoke_llm_with_structured_output(
|
||||
stream: Literal[True],
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
@@ -72,6 +74,7 @@ def invoke_llm_with_structured_output(
|
||||
stream: Literal[False],
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput: ...
|
||||
@overload
|
||||
def invoke_llm_with_structured_output(
|
||||
@@ -87,6 +90,7 @@ def invoke_llm_with_structured_output(
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]: ...
|
||||
def invoke_llm_with_structured_output(
|
||||
*,
|
||||
@@ -101,20 +105,28 @@ def invoke_llm_with_structured_output(
|
||||
stream: bool = True,
|
||||
user: str | None = None,
|
||||
callbacks: list[Callback] | None = None,
|
||||
tenant_id: str | None = None,
|
||||
) -> LLMResultWithStructuredOutput | Generator[LLMResultChunkWithStructuredOutput, None, None]:
|
||||
"""
|
||||
Invoke large language model with structured output
|
||||
1. This method invokes model_instance.invoke_llm with json_schema
|
||||
2. Try to parse the result as structured output
|
||||
Invoke large language model with structured output.
|
||||
|
||||
This method invokes model_instance.invoke_llm with json_schema and parses
|
||||
the result as structured output.
|
||||
|
||||
:param provider: model provider name
|
||||
:param model_schema: model schema entity
|
||||
:param model_instance: model instance to invoke
|
||||
:param prompt_messages: prompt messages
|
||||
:param json_schema: json schema
|
||||
:param json_schema: json schema for structured output
|
||||
:param model_parameters: model parameters
|
||||
:param tools: tools for tool calling
|
||||
:param stop: stop words
|
||||
:param stream: is stream response
|
||||
:param user: unique user id
|
||||
:param callbacks: callbacks
|
||||
:param tenant_id: tenant ID for file reference conversion. When provided and
|
||||
json_schema contains file reference fields (format: "dify-file-ref"),
|
||||
file IDs in the output will be automatically converted to File objects.
|
||||
:return: full response or stream response chunk generator result
|
||||
"""
|
||||
|
||||
@@ -153,8 +165,18 @@ def invoke_llm_with_structured_output(
|
||||
f"Failed to parse structured output, LLM result is not a string: {llm_result.message.content}"
|
||||
)
|
||||
|
||||
structured_output = _parse_structured_output(llm_result.message.content)
|
||||
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
structured_output = convert_file_refs_in_output(
|
||||
output=structured_output,
|
||||
json_schema=json_schema,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
return LLMResultWithStructuredOutput(
|
||||
structured_output=_parse_structured_output(llm_result.message.content),
|
||||
structured_output=structured_output,
|
||||
model=llm_result.model,
|
||||
message=llm_result.message,
|
||||
usage=llm_result.usage,
|
||||
@@ -186,8 +208,18 @@ def invoke_llm_with_structured_output(
|
||||
delta=event.delta,
|
||||
)
|
||||
|
||||
structured_output = _parse_structured_output(result_text)
|
||||
|
||||
# Convert file references if tenant_id is provided
|
||||
if tenant_id is not None:
|
||||
structured_output = convert_file_refs_in_output(
|
||||
output=structured_output,
|
||||
json_schema=json_schema,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
|
||||
yield LLMResultChunkWithStructuredOutput(
|
||||
structured_output=_parse_structured_output(result_text),
|
||||
structured_output=structured_output,
|
||||
model=model_schema.model,
|
||||
prompt_messages=prompt_messages,
|
||||
system_fingerprint=system_fingerprint,
|
||||
|
||||
45
api/core/llm_generator/utils.py
Normal file
45
api/core/llm_generator/utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Utility functions for LLM generator."""
|
||||
|
||||
from core.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
|
||||
def deserialize_prompt_messages(messages: list[dict]) -> list[PromptMessage]:
|
||||
"""
|
||||
Deserialize list of dicts to list[PromptMessage].
|
||||
|
||||
Expected format:
|
||||
[
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."},
|
||||
]
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
for msg in messages:
|
||||
role = PromptMessageRole.value_of(msg["role"])
|
||||
content = msg.get("content", "")
|
||||
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
result.append(UserPromptMessage(content=content))
|
||||
case PromptMessageRole.ASSISTANT:
|
||||
result.append(AssistantPromptMessage(content=content))
|
||||
case PromptMessageRole.SYSTEM:
|
||||
result.append(SystemPromptMessage(content=content))
|
||||
case PromptMessageRole.TOOL:
|
||||
result.append(ToolPromptMessage(content=content, tool_call_id=msg.get("tool_call_id", "")))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def serialize_prompt_messages(messages: list[PromptMessage]) -> list[dict]:
|
||||
"""
|
||||
Serialize list[PromptMessage] to list of dicts.
|
||||
"""
|
||||
return [{"role": msg.role.value, "content": msg.content} for msg in messages]
|
||||
434
api/core/memory/README.md
Normal file
434
api/core/memory/README.md
Normal file
@@ -0,0 +1,434 @@
|
||||
# Memory Module
|
||||
|
||||
This module provides memory management for LLM conversations, enabling context retention across dialogue turns.
|
||||
|
||||
## Overview
|
||||
|
||||
The memory module contains two types of memory implementations:
|
||||
|
||||
1. **TokenBufferMemory** - Conversation-level memory (existing)
|
||||
2. **NodeTokenBufferMemory** - Node-level memory (to be implemented, **Chatflow only**)
|
||||
|
||||
> **Note**: `NodeTokenBufferMemory` is only available in **Chatflow** (advanced-chat mode).
|
||||
> This is because it requires both `conversation_id` and `node_id`, which are only present in Chatflow.
|
||||
> Standard Workflow mode does not have `conversation_id` and therefore cannot use node-level memory.
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────────────────────┐
|
||||
│ Memory Architecture │
|
||||
├─────────────────────────────────────────────────────────────────────────────┤
|
||||
│ │
|
||||
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
|
||||
│ │ TokenBufferMemory │ │
|
||||
│ │ Scope: Conversation │ │
|
||||
│ │ Storage: Database (Message table) │ │
|
||||
│ │ Key: conversation_id │ │
|
||||
│ └─────────────────────────────────────────────────────────────────────-┘ │
|
||||
│ │
|
||||
│ ┌─────────────────────────────────────────────────────────────────────-┐ │
|
||||
│ │ NodeTokenBufferMemory │ │
|
||||
│ │ Scope: Node within Conversation │ │
|
||||
│ │ Storage: Object Storage (JSON file) │ │
|
||||
│ │ Key: (app_id, conversation_id, node_id) │ │
|
||||
│ └─────────────────────────────────────────────────────────────────────-┘ │
|
||||
│ │
|
||||
└─────────────────────────────────────────────────────────────────────────────┘
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## TokenBufferMemory (Existing)
|
||||
|
||||
### Purpose
|
||||
|
||||
`TokenBufferMemory` retrieves conversation history from the `Message` table and converts it to `PromptMessage` objects for LLM context.
|
||||
|
||||
### Key Features
|
||||
|
||||
- **Conversation-scoped**: All messages within a conversation are candidates
|
||||
- **Thread-aware**: Uses `parent_message_id` to extract only the current thread (supports regeneration scenarios)
|
||||
- **Token-limited**: Truncates history to fit within `max_token_limit`
|
||||
- **File support**: Handles `MessageFile` attachments (images, documents, etc.)
|
||||
|
||||
### Data Flow
|
||||
|
||||
```
|
||||
Message Table TokenBufferMemory LLM
|
||||
│ │ │
|
||||
│ SELECT * FROM messages │ │
|
||||
│ WHERE conversation_id = ? │ │
|
||||
│ ORDER BY created_at DESC │ │
|
||||
├─────────────────────────────────▶│ │
|
||||
│ │ │
|
||||
│ extract_thread_messages() │
|
||||
│ │ │
|
||||
│ build_prompt_message_with_files() │
|
||||
│ │ │
|
||||
│ truncate by max_token_limit │
|
||||
│ │ │
|
||||
│ │ Sequence[PromptMessage]
|
||||
│ ├───────────────────────▶│
|
||||
│ │ │
|
||||
```
|
||||
|
||||
### Thread Extraction
|
||||
|
||||
When a user regenerates a response, a new thread is created:
|
||||
|
||||
```
|
||||
Message A (user)
|
||||
└── Message A' (assistant)
|
||||
└── Message B (user)
|
||||
└── Message B' (assistant)
|
||||
└── Message A'' (assistant, regenerated) ← New thread
|
||||
└── Message C (user)
|
||||
└── Message C' (assistant)
|
||||
```
|
||||
|
||||
`extract_thread_messages()` traces back from the latest message using `parent_message_id` to get only the current thread: `[A, A'', C, C']`
|
||||
|
||||
### Usage
|
||||
|
||||
```python
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
history = memory.get_history_prompt_messages(max_token_limit=2000, message_limit=100)
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## NodeTokenBufferMemory (To Be Implemented)
|
||||
|
||||
### Purpose
|
||||
|
||||
`NodeTokenBufferMemory` provides **node-scoped memory** within a conversation. Each LLM node in a workflow can maintain its own independent conversation history.
|
||||
|
||||
### Use Cases
|
||||
|
||||
1. **Multi-LLM Workflows**: Different LLM nodes need separate context
|
||||
2. **Iterative Processing**: An LLM node in a loop needs to accumulate context across iterations
|
||||
3. **Specialized Agents**: Each agent node maintains its own dialogue history
|
||||
|
||||
### Design Decisions
|
||||
|
||||
#### Storage: Object Storage for Messages (No New Database Table)
|
||||
|
||||
| Aspect | Database | Object Storage |
|
||||
| ------------------------- | -------------------- | ------------------ |
|
||||
| Cost | High | Low |
|
||||
| Query Flexibility | High | Low |
|
||||
| Schema Changes | Migration required | None |
|
||||
| Consistency with existing | ConversationVariable | File uploads, logs |
|
||||
|
||||
**Decision**: Store message data in object storage, but still use existing database tables for file metadata.
|
||||
|
||||
**What is stored in Object Storage:**
|
||||
|
||||
- Message content (text)
|
||||
- Message metadata (role, token_count, created_at)
|
||||
- File references (upload_file_id, tool_file_id, etc.)
|
||||
- Thread relationships (message_id, parent_message_id)
|
||||
|
||||
**What still requires Database queries:**
|
||||
|
||||
- File reconstruction: When reading node memory, file references are used to query
|
||||
`UploadFile` / `ToolFile` tables via `file_factory.build_from_mapping()` to rebuild
|
||||
complete `File` objects with storage_key, mime_type, etc.
|
||||
|
||||
**Why this hybrid approach:**
|
||||
|
||||
- No database migration required (no new tables)
|
||||
- Message data may be large, object storage is cost-effective
|
||||
- File metadata is already in database, no need to duplicate
|
||||
- Aligns with existing storage patterns (file uploads, logs)
|
||||
|
||||
#### Storage Key Format
|
||||
|
||||
```
|
||||
node_memory/{app_id}/{conversation_id}/{node_id}.json
|
||||
```
|
||||
|
||||
#### Data Structure
|
||||
|
||||
```json
|
||||
{
|
||||
"version": 1,
|
||||
"messages": [
|
||||
{
|
||||
"message_id": "msg-001",
|
||||
"parent_message_id": null,
|
||||
"role": "user",
|
||||
"content": "Analyze this image",
|
||||
"files": [
|
||||
{
|
||||
"type": "image",
|
||||
"transfer_method": "local_file",
|
||||
"upload_file_id": "file-uuid-123",
|
||||
"belongs_to": "user"
|
||||
}
|
||||
],
|
||||
"token_count": 15,
|
||||
"created_at": "2026-01-07T10:00:00Z"
|
||||
},
|
||||
{
|
||||
"message_id": "msg-002",
|
||||
"parent_message_id": "msg-001",
|
||||
"role": "assistant",
|
||||
"content": "This is a landscape image...",
|
||||
"files": [],
|
||||
"token_count": 50,
|
||||
"created_at": "2026-01-07T10:00:01Z"
|
||||
}
|
||||
]
|
||||
}
|
||||
```
|
||||
|
||||
### Thread Support
|
||||
|
||||
Node memory also supports thread extraction (for regeneration scenarios):
|
||||
|
||||
```python
|
||||
def _extract_thread(
|
||||
self,
|
||||
messages: list[NodeMemoryMessage],
|
||||
current_message_id: str
|
||||
) -> list[NodeMemoryMessage]:
|
||||
"""
|
||||
Extract messages belonging to the thread of current_message_id.
|
||||
Similar to extract_thread_messages() in TokenBufferMemory.
|
||||
"""
|
||||
...
|
||||
```
|
||||
|
||||
### File Handling
|
||||
|
||||
Files are stored as references (not full metadata):
|
||||
|
||||
```python
|
||||
class NodeMemoryFile(BaseModel):
|
||||
type: str # image, audio, video, document, custom
|
||||
transfer_method: str # local_file, remote_url, tool_file
|
||||
upload_file_id: str | None # for local_file
|
||||
tool_file_id: str | None # for tool_file
|
||||
url: str | None # for remote_url
|
||||
belongs_to: str # user / assistant
|
||||
```
|
||||
|
||||
When reading, files are rebuilt using `file_factory.build_from_mapping()`.
|
||||
|
||||
### API Design
|
||||
|
||||
```python
|
||||
class NodeTokenBufferMemory:
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
node_id: str,
|
||||
model_instance: ModelInstance,
|
||||
):
|
||||
"""
|
||||
Initialize node-level memory.
|
||||
|
||||
:param app_id: Application ID
|
||||
:param conversation_id: Conversation ID
|
||||
:param node_id: Node ID in the workflow
|
||||
:param model_instance: Model instance for token counting
|
||||
"""
|
||||
...
|
||||
|
||||
def add_messages(
|
||||
self,
|
||||
message_id: str,
|
||||
parent_message_id: str | None,
|
||||
user_content: str,
|
||||
user_files: Sequence[File],
|
||||
assistant_content: str,
|
||||
assistant_files: Sequence[File],
|
||||
) -> None:
|
||||
"""
|
||||
Append a dialogue turn (user + assistant) to node memory.
|
||||
Call this after LLM node execution completes.
|
||||
|
||||
:param message_id: Current message ID (from Message table)
|
||||
:param parent_message_id: Parent message ID (for thread tracking)
|
||||
:param user_content: User's text input
|
||||
:param user_files: Files attached by user
|
||||
:param assistant_content: Assistant's text response
|
||||
:param assistant_files: Files generated by assistant
|
||||
"""
|
||||
...
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
current_message_id: str,
|
||||
tenant_id: str,
|
||||
max_token_limit: int = 2000,
|
||||
file_upload_config: FileUploadConfig | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Retrieve history as PromptMessage sequence.
|
||||
|
||||
:param current_message_id: Current message ID (for thread extraction)
|
||||
:param tenant_id: Tenant ID (for file reconstruction)
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param file_upload_config: File upload configuration
|
||||
:return: Sequence of PromptMessage for LLM context
|
||||
"""
|
||||
...
|
||||
|
||||
def flush(self) -> None:
|
||||
"""
|
||||
Persist buffered changes to object storage.
|
||||
Call this at the end of node execution.
|
||||
"""
|
||||
...
|
||||
|
||||
def clear(self) -> None:
|
||||
"""
|
||||
Clear all messages in this node's memory.
|
||||
"""
|
||||
...
|
||||
```
|
||||
|
||||
### Data Flow
|
||||
|
||||
```
|
||||
Object Storage NodeTokenBufferMemory LLM Node
|
||||
│ │ │
|
||||
│ │◀── get_history_prompt_messages()
|
||||
│ storage.load(key) │ │
|
||||
│◀─────────────────────────────────┤ │
|
||||
│ │ │
|
||||
│ JSON data │ │
|
||||
├─────────────────────────────────▶│ │
|
||||
│ │ │
|
||||
│ _extract_thread() │
|
||||
│ │ │
|
||||
│ _rebuild_files() via file_factory │
|
||||
│ │ │
|
||||
│ _build_prompt_messages() │
|
||||
│ │ │
|
||||
│ _truncate_by_tokens() │
|
||||
│ │ │
|
||||
│ │ Sequence[PromptMessage] │
|
||||
│ ├──────────────────────────▶│
|
||||
│ │ │
|
||||
│ │◀── LLM execution complete │
|
||||
│ │ │
|
||||
│ │◀── add_messages() │
|
||||
│ │ │
|
||||
│ storage.save(key, data) │ │
|
||||
│◀─────────────────────────────────┤ │
|
||||
│ │ │
|
||||
```
|
||||
|
||||
### Integration with LLM Node
|
||||
|
||||
```python
|
||||
# In LLM Node execution
|
||||
|
||||
# 1. Fetch memory based on mode
|
||||
if node_data.memory and node_data.memory.mode == MemoryMode.NODE:
|
||||
# Node-level memory (Chatflow only)
|
||||
memory = fetch_node_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=app_id,
|
||||
node_id=self.node_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
elif node_data.memory and node_data.memory.mode == MemoryMode.CONVERSATION:
|
||||
# Conversation-level memory (existing behavior)
|
||||
memory = fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=app_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
else:
|
||||
memory = None
|
||||
|
||||
# 2. Get history for context
|
||||
if memory:
|
||||
if isinstance(memory, NodeTokenBufferMemory):
|
||||
history = memory.get_history_prompt_messages(
|
||||
current_message_id=current_message_id,
|
||||
tenant_id=tenant_id,
|
||||
max_token_limit=max_token_limit,
|
||||
)
|
||||
else: # TokenBufferMemory
|
||||
history = memory.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
)
|
||||
prompt_messages = [*history, *current_messages]
|
||||
else:
|
||||
prompt_messages = current_messages
|
||||
|
||||
# 3. Call LLM
|
||||
response = model_instance.invoke(prompt_messages)
|
||||
|
||||
# 4. Append to node memory (only for NodeTokenBufferMemory)
|
||||
if isinstance(memory, NodeTokenBufferMemory):
|
||||
memory.add_messages(
|
||||
message_id=message_id,
|
||||
parent_message_id=parent_message_id,
|
||||
user_content=user_input,
|
||||
user_files=user_files,
|
||||
assistant_content=response.content,
|
||||
assistant_files=response_files,
|
||||
)
|
||||
memory.flush()
|
||||
```
|
||||
|
||||
### Configuration
|
||||
|
||||
Add to `MemoryConfig` in `core/workflow/nodes/llm/entities.py`:
|
||||
|
||||
```python
|
||||
class MemoryMode(StrEnum):
|
||||
CONVERSATION = "conversation" # Use TokenBufferMemory (default, existing behavior)
|
||||
NODE = "node" # Use NodeTokenBufferMemory (new, Chatflow only)
|
||||
|
||||
class MemoryConfig(BaseModel):
|
||||
# Existing fields
|
||||
role_prefix: RolePrefix | None = None
|
||||
window: MemoryWindowConfig | None = None
|
||||
query_prompt_template: str | None = None
|
||||
|
||||
# Memory mode (new)
|
||||
mode: MemoryMode = MemoryMode.CONVERSATION
|
||||
```
|
||||
|
||||
**Mode Behavior:**
|
||||
|
||||
| Mode | Memory Class | Scope | Availability |
|
||||
| -------------- | --------------------- | ------------------------ | ------------- |
|
||||
| `conversation` | TokenBufferMemory | Entire conversation | All app modes |
|
||||
| `node` | NodeTokenBufferMemory | Per-node in conversation | Chatflow only |
|
||||
|
||||
> When `mode=node` is used in a non-Chatflow context (no conversation_id), it should
|
||||
> fall back to no memory or raise a configuration error.
|
||||
|
||||
---
|
||||
|
||||
## Comparison
|
||||
|
||||
| Feature | TokenBufferMemory | NodeTokenBufferMemory |
|
||||
| -------------- | ------------------------ | ------------------------- |
|
||||
| Scope | Conversation | Node within Conversation |
|
||||
| Storage | Database (Message table) | Object Storage (JSON) |
|
||||
| Thread Support | Yes | Yes |
|
||||
| File Support | Yes (via MessageFile) | Yes (via file references) |
|
||||
| Token Limit | Yes | Yes |
|
||||
| Use Case | Standard chat apps | Complex workflows |
|
||||
|
||||
---
|
||||
|
||||
## Future Considerations
|
||||
|
||||
1. **Cleanup Task**: Add a Celery task to clean up old node memory files
|
||||
2. **Concurrency**: Consider Redis lock for concurrent node executions
|
||||
3. **Compression**: Compress large memory files to reduce storage costs
|
||||
4. **Extension**: Other nodes (Agent, Tool) may also benefit from node-level memory
|
||||
15
api/core/memory/__init__.py
Normal file
15
api/core/memory/__init__.py
Normal file
@@ -0,0 +1,15 @@
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.node_token_buffer_memory import (
|
||||
NodeMemoryData,
|
||||
NodeMemoryFile,
|
||||
NodeTokenBufferMemory,
|
||||
)
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
||||
__all__ = [
|
||||
"BaseMemory",
|
||||
"NodeMemoryData",
|
||||
"NodeMemoryFile",
|
||||
"NodeTokenBufferMemory",
|
||||
"TokenBufferMemory",
|
||||
]
|
||||
83
api/core/memory/base.py
Normal file
83
api/core/memory/base.py
Normal file
@@ -0,0 +1,83 @@
|
||||
"""
|
||||
Base memory interfaces and types.
|
||||
|
||||
This module defines the common protocol for memory implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, PromptMessage
|
||||
|
||||
|
||||
class BaseMemory(ABC):
|
||||
"""
|
||||
Abstract base class for memory implementations.
|
||||
|
||||
Provides a common interface for both conversation-level and node-level memory.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: Maximum number of messages
|
||||
:return: Sequence of PromptMessage for LLM context
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt as formatted text.
|
||||
|
||||
:param human_prefix: Prefix for human messages
|
||||
:param ai_prefix: Prefix for assistant messages
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: Maximum number of messages
|
||||
:return: Formatted history text
|
||||
"""
|
||||
from core.model_runtime.entities import (
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
|
||||
prompt_messages = self.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=message_limit,
|
||||
)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
353
api/core/memory/node_token_buffer_memory.py
Normal file
353
api/core/memory/node_token_buffer_memory.py
Normal file
@@ -0,0 +1,353 @@
|
||||
"""
|
||||
Node-level Token Buffer Memory for Chatflow.
|
||||
|
||||
This module provides node-scoped memory within a conversation.
|
||||
Each LLM node in a workflow can maintain its own independent conversation history.
|
||||
|
||||
Note: This is only available in Chatflow (advanced-chat mode) because it requires
|
||||
both conversation_id and node_id.
|
||||
|
||||
Design:
|
||||
- Storage is indexed by workflow_run_id (each execution stores one turn)
|
||||
- Thread tracking leverages Message table's parent_message_id structure
|
||||
- On read: query Message table for current thread, then filter Node Memory by workflow_run_ids
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeMemoryFile(BaseModel):
|
||||
"""File reference stored in node memory."""
|
||||
|
||||
type: str # image, audio, video, document, custom
|
||||
transfer_method: str # local_file, remote_url, tool_file
|
||||
upload_file_id: str | None = None
|
||||
tool_file_id: str | None = None
|
||||
url: str | None = None
|
||||
|
||||
|
||||
class NodeMemoryTurn(BaseModel):
|
||||
"""A single dialogue turn (user + assistant) in node memory."""
|
||||
|
||||
user_content: str = ""
|
||||
user_files: list[NodeMemoryFile] = []
|
||||
assistant_content: str = ""
|
||||
assistant_files: list[NodeMemoryFile] = []
|
||||
|
||||
|
||||
class NodeMemoryData(BaseModel):
|
||||
"""Root data structure for node memory storage."""
|
||||
|
||||
version: int = 1
|
||||
# Key: workflow_run_id, Value: dialogue turn
|
||||
turns: dict[str, NodeMemoryTurn] = {}
|
||||
|
||||
|
||||
class NodeTokenBufferMemory(BaseMemory):
|
||||
"""
|
||||
Node-level Token Buffer Memory.
|
||||
|
||||
Provides node-scoped memory within a conversation. Each LLM node can maintain
|
||||
its own independent conversation history, stored in object storage.
|
||||
|
||||
Key design: Thread tracking is delegated to Message table's parent_message_id.
|
||||
Storage is indexed by workflow_run_id for easy filtering.
|
||||
|
||||
Storage key format: node_memory/{app_id}/{conversation_id}/{node_id}.json
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
model_instance: ModelInstance,
|
||||
):
|
||||
"""
|
||||
Initialize node-level memory.
|
||||
|
||||
:param app_id: Application ID
|
||||
:param conversation_id: Conversation ID
|
||||
:param node_id: Node ID in the workflow
|
||||
:param tenant_id: Tenant ID for file reconstruction
|
||||
:param model_instance: Model instance for token counting
|
||||
"""
|
||||
self.app_id = app_id
|
||||
self.conversation_id = conversation_id
|
||||
self.node_id = node_id
|
||||
self.tenant_id = tenant_id
|
||||
self.model_instance = model_instance
|
||||
self._storage_key = f"node_memory/{app_id}/{conversation_id}/{node_id}.json"
|
||||
self._data: NodeMemoryData | None = None
|
||||
self._dirty = False
|
||||
|
||||
def _load(self) -> NodeMemoryData:
|
||||
"""Load data from object storage."""
|
||||
if self._data is not None:
|
||||
return self._data
|
||||
|
||||
try:
|
||||
raw = storage.load_once(self._storage_key)
|
||||
self._data = NodeMemoryData.model_validate_json(raw)
|
||||
except Exception:
|
||||
# File not found or parse error, start fresh
|
||||
self._data = NodeMemoryData()
|
||||
|
||||
return self._data
|
||||
|
||||
def _save(self) -> None:
|
||||
"""Save data to object storage."""
|
||||
if self._data is not None:
|
||||
storage.save(self._storage_key, self._data.model_dump_json().encode("utf-8"))
|
||||
self._dirty = False
|
||||
|
||||
def _file_to_memory_file(self, file: File) -> NodeMemoryFile:
|
||||
"""Convert File object to NodeMemoryFile reference."""
|
||||
return NodeMemoryFile(
|
||||
type=file.type.value if hasattr(file.type, "value") else str(file.type),
|
||||
transfer_method=(
|
||||
file.transfer_method.value if hasattr(file.transfer_method, "value") else str(file.transfer_method)
|
||||
),
|
||||
upload_file_id=file.related_id if file.transfer_method == FileTransferMethod.LOCAL_FILE else None,
|
||||
tool_file_id=file.related_id if file.transfer_method == FileTransferMethod.TOOL_FILE else None,
|
||||
url=file.remote_url if file.transfer_method == FileTransferMethod.REMOTE_URL else None,
|
||||
)
|
||||
|
||||
def _memory_file_to_mapping(self, memory_file: NodeMemoryFile) -> dict:
|
||||
"""Convert NodeMemoryFile to mapping for file_factory."""
|
||||
mapping: dict = {
|
||||
"type": memory_file.type,
|
||||
"transfer_method": memory_file.transfer_method,
|
||||
}
|
||||
if memory_file.upload_file_id:
|
||||
mapping["upload_file_id"] = memory_file.upload_file_id
|
||||
if memory_file.tool_file_id:
|
||||
mapping["tool_file_id"] = memory_file.tool_file_id
|
||||
if memory_file.url:
|
||||
mapping["url"] = memory_file.url
|
||||
return mapping
|
||||
|
||||
def _rebuild_files(self, memory_files: list[NodeMemoryFile]) -> list[File]:
|
||||
"""Rebuild File objects from NodeMemoryFile references."""
|
||||
if not memory_files:
|
||||
return []
|
||||
|
||||
from factories import file_factory
|
||||
|
||||
files = []
|
||||
for mf in memory_files:
|
||||
try:
|
||||
mapping = self._memory_file_to_mapping(mf)
|
||||
file = file_factory.build_from_mapping(mapping=mapping, tenant_id=self.tenant_id)
|
||||
files.append(file)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to rebuild file from memory: %s", e)
|
||||
continue
|
||||
return files
|
||||
|
||||
def _build_prompt_message(
|
||||
self,
|
||||
role: str,
|
||||
content: str,
|
||||
files: list[File],
|
||||
detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH,
|
||||
) -> PromptMessage:
|
||||
"""Build PromptMessage from content and files."""
|
||||
from core.file import file_manager
|
||||
|
||||
if not files:
|
||||
if role == "user":
|
||||
return UserPromptMessage(content=content)
|
||||
else:
|
||||
return AssistantPromptMessage(content=content)
|
||||
|
||||
# Build multimodal content
|
||||
prompt_contents: list = []
|
||||
for file in files:
|
||||
try:
|
||||
prompt_content = file_manager.to_prompt_message_content(file, image_detail_config=detail)
|
||||
prompt_contents.append(prompt_content)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to convert file to prompt content: %s", e)
|
||||
continue
|
||||
|
||||
prompt_contents.append(TextPromptMessageContent(data=content))
|
||||
|
||||
if role == "user":
|
||||
return UserPromptMessage(content=prompt_contents)
|
||||
else:
|
||||
return AssistantPromptMessage(content=prompt_contents)
|
||||
|
||||
def _get_thread_workflow_run_ids(self) -> list[str]:
|
||||
"""
|
||||
Get workflow_run_ids for the current thread by querying Message table.
|
||||
|
||||
Returns workflow_run_ids in chronological order (oldest first).
|
||||
"""
|
||||
# Query messages for this conversation
|
||||
stmt = (
|
||||
select(Message).where(Message.conversation_id == self.conversation_id).order_by(Message.created_at.desc())
|
||||
)
|
||||
messages = db.session.scalars(stmt.limit(500)).all()
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# Extract thread messages using existing logic
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# For newly created message, its answer is temporarily empty, skip it
|
||||
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||
thread_messages.pop(0)
|
||||
|
||||
# Reverse to get chronological order, extract workflow_run_ids
|
||||
workflow_run_ids = []
|
||||
for msg in reversed(thread_messages):
|
||||
if msg.workflow_run_id:
|
||||
workflow_run_ids.append(msg.workflow_run_id)
|
||||
|
||||
return workflow_run_ids
|
||||
|
||||
def add_messages(
|
||||
self,
|
||||
workflow_run_id: str,
|
||||
user_content: str,
|
||||
user_files: Sequence[File] | None = None,
|
||||
assistant_content: str = "",
|
||||
assistant_files: Sequence[File] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Add a dialogue turn to node memory.
|
||||
Call this after LLM node execution completes.
|
||||
|
||||
:param workflow_run_id: Current workflow execution ID
|
||||
:param user_content: User's text input
|
||||
:param user_files: Files attached by user
|
||||
:param assistant_content: Assistant's text response
|
||||
:param assistant_files: Files generated by assistant
|
||||
"""
|
||||
data = self._load()
|
||||
|
||||
# Convert files to memory file references
|
||||
user_memory_files = [self._file_to_memory_file(f) for f in (user_files or [])]
|
||||
assistant_memory_files = [self._file_to_memory_file(f) for f in (assistant_files or [])]
|
||||
|
||||
# Store the turn indexed by workflow_run_id
|
||||
data.turns[workflow_run_id] = NodeMemoryTurn(
|
||||
user_content=user_content,
|
||||
user_files=user_memory_files,
|
||||
assistant_content=assistant_content,
|
||||
assistant_files=assistant_memory_files,
|
||||
)
|
||||
|
||||
self._dirty = True
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
*,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Retrieve history as PromptMessage sequence.
|
||||
|
||||
Thread tracking is handled by querying Message table's parent_message_id structure.
|
||||
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: unused, for interface compatibility
|
||||
:return: Sequence of PromptMessage for LLM context
|
||||
"""
|
||||
# message_limit is unused in NodeTokenBufferMemory (uses token limit instead)
|
||||
_ = message_limit
|
||||
detail = ImagePromptMessageContent.DETAIL.HIGH
|
||||
data = self._load()
|
||||
|
||||
if not data.turns:
|
||||
return []
|
||||
|
||||
# Get workflow_run_ids for current thread from Message table
|
||||
thread_workflow_run_ids = self._get_thread_workflow_run_ids()
|
||||
|
||||
if not thread_workflow_run_ids:
|
||||
return []
|
||||
|
||||
# Build prompt messages in thread order
|
||||
prompt_messages: list[PromptMessage] = []
|
||||
for wf_run_id in thread_workflow_run_ids:
|
||||
turn = data.turns.get(wf_run_id)
|
||||
if not turn:
|
||||
# This workflow execution didn't have node memory stored
|
||||
continue
|
||||
|
||||
# Build user message
|
||||
user_files = self._rebuild_files(turn.user_files) if turn.user_files else []
|
||||
user_msg = self._build_prompt_message(
|
||||
role="user",
|
||||
content=turn.user_content,
|
||||
files=user_files,
|
||||
detail=detail,
|
||||
)
|
||||
prompt_messages.append(user_msg)
|
||||
|
||||
# Build assistant message
|
||||
assistant_files = self._rebuild_files(turn.assistant_files) if turn.assistant_files else []
|
||||
assistant_msg = self._build_prompt_message(
|
||||
role="assistant",
|
||||
content=turn.assistant_content,
|
||||
files=assistant_files,
|
||||
detail=detail,
|
||||
)
|
||||
prompt_messages.append(assistant_msg)
|
||||
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
||||
# Truncate by token limit
|
||||
try:
|
||||
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
while current_tokens > max_token_limit and len(prompt_messages) > 1:
|
||||
prompt_messages.pop(0)
|
||||
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to count tokens for truncation: %s", e)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def flush(self) -> None:
|
||||
"""
|
||||
Persist buffered changes to object storage.
|
||||
Call this at the end of node execution.
|
||||
"""
|
||||
if self._dirty:
|
||||
self._save()
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Clear all messages in this node's memory."""
|
||||
self._data = NodeMemoryData()
|
||||
self._save()
|
||||
|
||||
def exists(self) -> bool:
|
||||
"""Check if node memory exists in storage."""
|
||||
return storage.exists(self._storage_key)
|
||||
@@ -5,12 +5,12 @@ from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.app.app_config.features.file_upload.manager import FileUploadConfigManager
|
||||
from core.file import file_manager
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
ImagePromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
@@ -24,7 +24,7 @@ from repositories.api_workflow_run_repository import APIWorkflowRunRepository
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
|
||||
class TokenBufferMemory:
|
||||
class TokenBufferMemory(BaseMemory):
|
||||
def __init__(
|
||||
self,
|
||||
conversation: Conversation,
|
||||
@@ -115,10 +115,14 @@ class TokenBufferMemory:
|
||||
return AssistantPromptMessage(content=prompt_message_contents)
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self, max_token_limit: int = 2000, message_limit: int | None = None
|
||||
self,
|
||||
*,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
"""
|
||||
@@ -200,44 +204,3 @@ class TokenBufferMemory:
|
||||
curr_message_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt text.
|
||||
:param human_prefix: human prefix
|
||||
:param ai_prefix: ai prefix
|
||||
:param max_token_limit: max token limit
|
||||
:param message_limit: message limit
|
||||
:return:
|
||||
"""
|
||||
prompt_messages = self.get_history_prompt_messages(max_token_limit=max_token_limit, message_limit=message_limit)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
|
||||
@@ -5,7 +5,7 @@ from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEnti
|
||||
from core.file import file_manager
|
||||
from core.file.models import File
|
||||
from core.helper.code_executor.jinja2.jinja2_formatter import Jinja2Formatter
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
@@ -43,7 +43,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@@ -84,7 +84,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@@ -145,7 +145,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
image_detail_config: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@@ -270,7 +270,7 @@ class AdvancedPromptTransform(PromptTransform):
|
||||
|
||||
def _set_histories_variable(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory: BaseMemory,
|
||||
memory_config: MemoryConfig,
|
||||
raw_prompt: str,
|
||||
role_prefix: MemoryConfig.RolePrefix,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -5,6 +6,13 @@ from pydantic import BaseModel
|
||||
from core.model_runtime.entities.message_entities import PromptMessageRole
|
||||
|
||||
|
||||
class MemoryMode(StrEnum):
|
||||
"""Memory mode for LLM nodes."""
|
||||
|
||||
CONVERSATION = "conversation" # Use TokenBufferMemory (default, existing behavior)
|
||||
NODE = "node" # Use NodeTokenBufferMemory (Chatflow only)
|
||||
|
||||
|
||||
class ChatModelMessage(BaseModel):
|
||||
"""
|
||||
Chat Message.
|
||||
@@ -48,3 +56,4 @@ class MemoryConfig(BaseModel):
|
||||
role_prefix: RolePrefix | None = None
|
||||
window: WindowConfig
|
||||
query_prompt_template: str | None = None
|
||||
mode: MemoryMode = MemoryMode.CONVERSATION
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities.message_entities import PromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
@@ -11,7 +11,7 @@ from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
class PromptTransform:
|
||||
def _append_chat_histories(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory: BaseMemory,
|
||||
memory_config: MemoryConfig,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
@@ -52,7 +52,7 @@ class PromptTransform:
|
||||
|
||||
def _get_history_messages_from_memory(
|
||||
self,
|
||||
memory: TokenBufferMemory,
|
||||
memory: BaseMemory,
|
||||
memory_config: MemoryConfig,
|
||||
max_token_limit: int,
|
||||
human_prefix: str | None = None,
|
||||
@@ -73,7 +73,7 @@ class PromptTransform:
|
||||
return memory.get_history_prompt_text(**kwargs)
|
||||
|
||||
def _get_history_messages_list_from_memory(
|
||||
self, memory: TokenBufferMemory, memory_config: MemoryConfig, max_token_limit: int
|
||||
self, memory: BaseMemory, memory_config: MemoryConfig, max_token_limit: int
|
||||
) -> list[PromptMessage]:
|
||||
"""Get memory messages."""
|
||||
return list(
|
||||
|
||||
@@ -1047,6 +1047,8 @@ class ToolManager:
|
||||
continue
|
||||
tool_input = ToolNodeData.ToolInput.model_validate(tool_configurations.get(parameter.name, {}))
|
||||
if tool_input.type == "variable":
|
||||
if not isinstance(tool_input.value, list):
|
||||
raise ToolParameterError(f"Invalid variable selector for {parameter.name}")
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if variable is None:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
@@ -1056,6 +1058,11 @@ class ToolManager:
|
||||
elif tool_input.type == "mixed":
|
||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
||||
parameter_value = segment_group.text
|
||||
elif tool_input.type == "mention":
|
||||
# Mention type not supported in agent mode
|
||||
raise ToolParameterError(
|
||||
f"Mention type not supported in agent for parameter '{parameter.name}'"
|
||||
)
|
||||
else:
|
||||
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
|
||||
runtime_parameters[parameter.name] = parameter_value
|
||||
|
||||
@@ -4,6 +4,7 @@ from .segments import (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayPromptMessageSegment,
|
||||
ArraySegment,
|
||||
ArrayStringSegment,
|
||||
FileSegment,
|
||||
@@ -20,6 +21,7 @@ from .variables import (
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayPromptMessageVariable,
|
||||
ArrayStringVariable,
|
||||
ArrayVariable,
|
||||
FileVariable,
|
||||
@@ -42,6 +44,8 @@ __all__ = [
|
||||
"ArrayNumberVariable",
|
||||
"ArrayObjectSegment",
|
||||
"ArrayObjectVariable",
|
||||
"ArrayPromptMessageSegment",
|
||||
"ArrayPromptMessageVariable",
|
||||
"ArraySegment",
|
||||
"ArrayStringSegment",
|
||||
"ArrayStringVariable",
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import Annotated, Any, TypeAlias
|
||||
from pydantic import BaseModel, ConfigDict, Discriminator, Tag, field_validator
|
||||
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
|
||||
from .types import SegmentType
|
||||
|
||||
@@ -208,6 +209,15 @@ class ArrayBooleanSegment(ArraySegment):
|
||||
value: Sequence[bool]
|
||||
|
||||
|
||||
class ArrayPromptMessageSegment(ArraySegment):
|
||||
value_type: SegmentType = SegmentType.ARRAY_PROMPT_MESSAGE
|
||||
value: Sequence[PromptMessage]
|
||||
|
||||
def to_object(self):
|
||||
"""Convert to JSON-serializable format for database storage and frontend."""
|
||||
return [msg.model_dump() for msg in self.value]
|
||||
|
||||
|
||||
def get_segment_discriminator(v: Any) -> SegmentType | None:
|
||||
if isinstance(v, Segment):
|
||||
return v.value_type
|
||||
@@ -248,6 +258,7 @@ SegmentUnion: TypeAlias = Annotated[
|
||||
| Annotated[ArrayObjectSegment, Tag(SegmentType.ARRAY_OBJECT)]
|
||||
| Annotated[ArrayFileSegment, Tag(SegmentType.ARRAY_FILE)]
|
||||
| Annotated[ArrayBooleanSegment, Tag(SegmentType.ARRAY_BOOLEAN)]
|
||||
| Annotated[ArrayPromptMessageSegment, Tag(SegmentType.ARRAY_PROMPT_MESSAGE)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
]
|
||||
|
||||
@@ -45,6 +45,7 @@ class SegmentType(StrEnum):
|
||||
ARRAY_OBJECT = "array[object]"
|
||||
ARRAY_FILE = "array[file]"
|
||||
ARRAY_BOOLEAN = "array[boolean]"
|
||||
ARRAY_PROMPT_MESSAGE = "array[message]"
|
||||
|
||||
NONE = "none"
|
||||
|
||||
|
||||
@@ -3,8 +3,10 @@ from typing import Any
|
||||
|
||||
import orjson
|
||||
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
|
||||
from .segment_group import SegmentGroup
|
||||
from .segments import ArrayFileSegment, FileSegment, Segment
|
||||
from .segments import ArrayFileSegment, ArrayPromptMessageSegment, FileSegment, Segment
|
||||
|
||||
|
||||
def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[str]:
|
||||
@@ -16,7 +18,7 @@ def to_selector(node_id: str, name: str, paths: Iterable[str] = ()) -> Sequence[
|
||||
|
||||
def segment_orjson_default(o: Any):
|
||||
"""Default function for orjson serialization of Segment types"""
|
||||
if isinstance(o, ArrayFileSegment):
|
||||
if isinstance(o, (ArrayFileSegment, ArrayPromptMessageSegment)):
|
||||
return [v.model_dump() for v in o.value]
|
||||
elif isinstance(o, FileSegment):
|
||||
return o.value.model_dump()
|
||||
@@ -24,6 +26,8 @@ def segment_orjson_default(o: Any):
|
||||
return [segment_orjson_default(seg) for seg in o.value]
|
||||
elif isinstance(o, Segment):
|
||||
return o.value
|
||||
elif isinstance(o, PromptMessage):
|
||||
return o.model_dump()
|
||||
raise TypeError(f"Object of type {type(o).__name__} is not JSON serializable")
|
||||
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from .segments import (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayPromptMessageSegment,
|
||||
ArraySegment,
|
||||
ArrayStringSegment,
|
||||
BooleanSegment,
|
||||
@@ -110,6 +111,10 @@ class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class ArrayPromptMessageVariable(ArrayPromptMessageSegment, ArrayVariable):
|
||||
pass
|
||||
|
||||
|
||||
class RAGPipelineVariable(BaseModel):
|
||||
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
|
||||
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
|
||||
@@ -160,6 +165,7 @@ Variable: TypeAlias = Annotated[
|
||||
| Annotated[ArrayObjectVariable, Tag(SegmentType.ARRAY_OBJECT)]
|
||||
| Annotated[ArrayFileVariable, Tag(SegmentType.ARRAY_FILE)]
|
||||
| Annotated[ArrayBooleanVariable, Tag(SegmentType.ARRAY_BOOLEAN)]
|
||||
| Annotated[ArrayPromptMessageVariable, Tag(SegmentType.ARRAY_PROMPT_MESSAGE)]
|
||||
| Annotated[SecretVariable, Tag(SegmentType.SECRET)]
|
||||
),
|
||||
Discriminator(get_segment_discriminator),
|
||||
|
||||
1418
api/core/workflow/docs/variable_extraction_design.md
Normal file
1418
api/core/workflow/docs/variable_extraction_design.md
Normal file
File diff suppressed because it is too large
Load Diff
@@ -63,6 +63,7 @@ class NodeType(StrEnum):
|
||||
TRIGGER_SCHEDULE = "trigger-schedule"
|
||||
TRIGGER_PLUGIN = "trigger-plugin"
|
||||
HUMAN_INPUT = "human-input"
|
||||
GROUP = "group"
|
||||
|
||||
@property
|
||||
def is_trigger_node(self) -> bool:
|
||||
@@ -252,6 +253,7 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
|
||||
LOOP_VARIABLE_MAP = "loop_variable_map" # single loop variable output
|
||||
DATASOURCE_INFO = "datasource_info"
|
||||
COMPLETED_REASON = "completed_reason" # completed reason for loop node
|
||||
MENTION_PARENT_ID = "mention_parent_id" # parent node id for extractor nodes
|
||||
|
||||
|
||||
class WorkflowNodeExecutionStatus(StrEnum):
|
||||
|
||||
@@ -307,7 +307,14 @@ class Graph:
|
||||
if not node_configs:
|
||||
raise ValueError("Graph must have at least one node")
|
||||
|
||||
node_configs = [node_config for node_config in node_configs if node_config.get("type", "") != "custom-note"]
|
||||
# Filter out UI-only node types:
|
||||
# - custom-note: top-level type (node_config.type == "custom-note")
|
||||
# - group: data-level type (node_config.data.type == "group")
|
||||
node_configs = [
|
||||
node_config
|
||||
for node_config in node_configs
|
||||
if node_config.get("type", "") != "custom-note" and node_config.get("data", {}).get("type", "") != "group"
|
||||
]
|
||||
|
||||
# Parse node configurations
|
||||
node_configs_map = cls._parse_node_configs(node_configs)
|
||||
|
||||
@@ -93,8 +93,8 @@ class EventHandler:
|
||||
Args:
|
||||
event: The event to handle
|
||||
"""
|
||||
# Events in loops or iterations are always collected
|
||||
if event.in_loop_id or event.in_iteration_id:
|
||||
# Events in loops, iterations, or extractor groups are always collected
|
||||
if event.in_loop_id or event.in_iteration_id or event.in_mention_parent_id:
|
||||
self._event_collector.collect(event)
|
||||
return
|
||||
return self._dispatch(event)
|
||||
@@ -125,6 +125,11 @@ class EventHandler:
|
||||
Args:
|
||||
event: The node started event
|
||||
"""
|
||||
# Check if this is an extractor node (has parent_node_id)
|
||||
if self._is_extractor_node(event.node_id):
|
||||
self._handle_extractor_node_started(event)
|
||||
return
|
||||
|
||||
# Track execution in domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
is_initial_attempt = node_execution.retry_count == 0
|
||||
@@ -164,6 +169,11 @@ class EventHandler:
|
||||
Args:
|
||||
event: The node succeeded event
|
||||
"""
|
||||
# Check if this is an extractor node (has parent_node_id)
|
||||
if self._is_extractor_node(event.node_id):
|
||||
self._handle_extractor_node_success(event)
|
||||
return
|
||||
|
||||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_taken()
|
||||
@@ -226,6 +236,11 @@ class EventHandler:
|
||||
Args:
|
||||
event: The node failed event
|
||||
"""
|
||||
# Check if this is an extractor node (has parent_node_id)
|
||||
if self._is_extractor_node(event.node_id):
|
||||
self._handle_extractor_node_failed(event)
|
||||
return
|
||||
|
||||
# Update domain model
|
||||
node_execution = self._graph_execution.get_or_create_node_execution(event.node_id)
|
||||
node_execution.mark_failed(event.error)
|
||||
@@ -345,3 +360,57 @@ class EventHandler:
|
||||
self._graph_runtime_state.set_output("answer", value)
|
||||
else:
|
||||
self._graph_runtime_state.set_output(key, value)
|
||||
|
||||
def _is_extractor_node(self, node_id: str) -> bool:
|
||||
"""
|
||||
Check if node_id represents an extractor node (has parent_node_id).
|
||||
|
||||
Extractor nodes extract values from list[PromptMessage] for their parent node.
|
||||
They have a parent_node_id field pointing to their parent node.
|
||||
"""
|
||||
node = self._graph.nodes.get(node_id)
|
||||
if node is None:
|
||||
return False
|
||||
return node.node_data.is_extractor_node
|
||||
|
||||
def _handle_extractor_node_started(self, event: NodeRunStartedEvent) -> None:
|
||||
"""
|
||||
Handle extractor node started event.
|
||||
|
||||
Extractor nodes don't need full execution tracking, just collect the event.
|
||||
"""
|
||||
# Track in response coordinator for stream ordering
|
||||
self._response_coordinator.track_node_execution(event.node_id, event.id)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
||||
def _handle_extractor_node_success(self, event: NodeRunSucceededEvent) -> None:
|
||||
"""
|
||||
Handle extractor node success event.
|
||||
|
||||
Extractor nodes need special handling:
|
||||
- Store outputs in variable pool (for reference by other nodes)
|
||||
- Accumulate token usage
|
||||
- Collect the event for logging
|
||||
- Do NOT process edges or enqueue next nodes (parent node handles that)
|
||||
"""
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
# Store outputs in variable pool
|
||||
self._store_node_outputs(event.node_id, event.node_run_result.outputs)
|
||||
|
||||
# Collect the event
|
||||
self._event_collector.collect(event)
|
||||
|
||||
def _handle_extractor_node_failed(self, event: NodeRunFailedEvent) -> None:
|
||||
"""
|
||||
Handle extractor node failed event.
|
||||
|
||||
Extractor node failures are collected for logging,
|
||||
but the parent node is responsible for handling the error.
|
||||
"""
|
||||
self._accumulate_node_usage(event.node_run_result.llm_usage)
|
||||
|
||||
# Collect the event for logging
|
||||
self._event_collector.collect(event)
|
||||
|
||||
@@ -68,6 +68,7 @@ class _NodeRuntimeSnapshot:
|
||||
predecessor_node_id: str | None
|
||||
iteration_id: str | None
|
||||
loop_id: str | None
|
||||
mention_parent_id: str | None
|
||||
created_at: datetime
|
||||
|
||||
|
||||
@@ -230,6 +231,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
metadata = {
|
||||
WorkflowNodeExecutionMetadataKey.ITERATION_ID: event.in_iteration_id,
|
||||
WorkflowNodeExecutionMetadataKey.LOOP_ID: event.in_loop_id,
|
||||
WorkflowNodeExecutionMetadataKey.MENTION_PARENT_ID: event.in_mention_parent_id,
|
||||
}
|
||||
|
||||
domain_execution = WorkflowNodeExecution(
|
||||
@@ -256,6 +258,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
predecessor_node_id=event.predecessor_node_id,
|
||||
iteration_id=event.in_iteration_id,
|
||||
loop_id=event.in_loop_id,
|
||||
mention_parent_id=event.in_mention_parent_id,
|
||||
created_at=event.start_at,
|
||||
)
|
||||
self._node_snapshots[event.id] = snapshot
|
||||
|
||||
@@ -21,6 +21,12 @@ class GraphNodeEventBase(GraphEngineEvent):
|
||||
"""iteration id if node is in iteration"""
|
||||
in_loop_id: str | None = None
|
||||
"""loop id if node is in loop"""
|
||||
in_mention_parent_id: str | None = None
|
||||
"""Parent node id if this is an extractor node event.
|
||||
|
||||
When set, indicates this event belongs to an extractor node that
|
||||
is extracting values for the specified parent node.
|
||||
"""
|
||||
|
||||
# The version of the node, or "1" if not specified.
|
||||
node_version: str = "1"
|
||||
|
||||
@@ -12,11 +12,14 @@ from sqlalchemy.orm import Session
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.node_token_buffer_memory import NodeTokenBufferMemory
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from core.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryMode
|
||||
from core.provider_manager import ProviderManager
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
@@ -136,6 +139,9 @@ class AgentNode(Node[AgentNodeData]):
|
||||
)
|
||||
return
|
||||
|
||||
# Fetch memory for node memory saving
|
||||
memory = self._fetch_memory_for_save()
|
||||
|
||||
try:
|
||||
yield from self._transform_message(
|
||||
messages=message_stream,
|
||||
@@ -149,6 +155,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
memory=memory,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
@@ -395,8 +402,20 @@ class AgentNode(Node[AgentNodeData]):
|
||||
icon = None
|
||||
return icon
|
||||
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> TokenBufferMemory | None:
|
||||
# get conversation id
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> BaseMemory | None:
|
||||
"""
|
||||
Fetch memory based on configuration mode.
|
||||
|
||||
Returns TokenBufferMemory for conversation mode (default),
|
||||
or NodeTokenBufferMemory for node mode (Chatflow only).
|
||||
"""
|
||||
node_data = self.node_data
|
||||
memory_config = node_data.memory
|
||||
|
||||
if not memory_config:
|
||||
return None
|
||||
|
||||
# get conversation id (required for both modes in Chatflow)
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
@@ -404,16 +423,26 @@ class AgentNode(Node[AgentNodeData]):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == self.app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
|
||||
if not conversation:
|
||||
return None
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
return memory
|
||||
# Return appropriate memory type based on mode
|
||||
if memory_config.mode == MemoryMode.NODE:
|
||||
# Node-level memory (Chatflow only)
|
||||
return NodeTokenBufferMemory(
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=self._node_id,
|
||||
tenant_id=self.tenant_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
else:
|
||||
# Conversation-level memory (default)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == self.app_id, Conversation.id == conversation_id
|
||||
)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
provider_manager = ProviderManager()
|
||||
@@ -457,6 +486,47 @@ class AgentNode(Node[AgentNodeData]):
|
||||
else:
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||
|
||||
def _fetch_memory_for_save(self) -> BaseMemory | None:
|
||||
"""
|
||||
Fetch memory instance for saving node memory.
|
||||
This is a simplified version that doesn't require model_instance.
|
||||
"""
|
||||
from core.model_manager import ModelManager
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
node_data = self.node_data
|
||||
if not node_data.memory:
|
||||
return None
|
||||
|
||||
# Get conversation_id
|
||||
conversation_id_var = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
if not isinstance(conversation_id_var, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_var.value
|
||||
|
||||
# Return appropriate memory type based on mode
|
||||
if node_data.memory.mode == MemoryMode.NODE:
|
||||
# For node memory, we need a model_instance for token counting
|
||||
# Use a simple default model for this purpose
|
||||
try:
|
||||
model_instance = ModelManager().get_default_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return NodeTokenBufferMemory(
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=self._node_id,
|
||||
tenant_id=self.tenant_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
else:
|
||||
# Conversation-level memory doesn't need saving here
|
||||
return None
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
@@ -467,6 +537,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
memory: BaseMemory | None = None,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
@@ -711,6 +782,21 @@ class AgentNode(Node[AgentNodeData]):
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Save to node memory if in node memory mode
|
||||
from core.workflow.nodes.llm import llm_utils
|
||||
|
||||
# Get user query from sys.query
|
||||
user_query_var = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.QUERY])
|
||||
user_query = user_query_var.text if user_query_var else ""
|
||||
|
||||
llm_utils.save_node_memory(
|
||||
memory=memory,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
user_query=user_query,
|
||||
assistant_response=text,
|
||||
assistant_files=files,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
|
||||
@@ -1,4 +1,10 @@
|
||||
from .entities import BaseIterationNodeData, BaseIterationState, BaseLoopNodeData, BaseLoopState, BaseNodeData
|
||||
from .entities import (
|
||||
BaseIterationNodeData,
|
||||
BaseIterationState,
|
||||
BaseLoopNodeData,
|
||||
BaseLoopState,
|
||||
BaseNodeData,
|
||||
)
|
||||
from .usage_tracking_mixin import LLMUsageTrackingMixin
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -175,6 +175,16 @@ class BaseNodeData(ABC, BaseModel):
|
||||
default_value: list[DefaultValue] | None = None
|
||||
retry_config: RetryConfig = RetryConfig()
|
||||
|
||||
# Parent node ID when this node is used as an extractor.
|
||||
# If set, this node is an "attached" extractor node that extracts values
|
||||
# from list[PromptMessage] for the parent node's parameters.
|
||||
parent_node_id: str | None = None
|
||||
|
||||
@property
|
||||
def is_extractor_node(self) -> bool:
|
||||
"""Check if this node is an extractor node (has parent_node_id)."""
|
||||
return self.parent_node_id is not None
|
||||
|
||||
@property
|
||||
def default_value_dict(self) -> dict[str, Any]:
|
||||
if self.default_value:
|
||||
|
||||
@@ -270,10 +270,87 @@ class Node(Generic[NodeDataT]):
|
||||
"""Check if execution should be stopped."""
|
||||
return self.graph_runtime_state.stop_event.is_set()
|
||||
|
||||
def _find_extractor_node_configs(self) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Find all extractor node configurations that have parent_node_id == self._node_id.
|
||||
|
||||
Returns:
|
||||
List of node configuration dicts for extractor nodes
|
||||
"""
|
||||
nodes = self.graph_config.get("nodes", [])
|
||||
extractor_configs = []
|
||||
for node_config in nodes:
|
||||
node_data = node_config.get("data", {})
|
||||
if node_data.get("parent_node_id") == self._node_id:
|
||||
extractor_configs.append(node_config)
|
||||
return extractor_configs
|
||||
|
||||
def _execute_extractor_nodes(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
"""
|
||||
Execute all extractor nodes associated with this node.
|
||||
|
||||
Extractor nodes are nodes with parent_node_id == self._node_id.
|
||||
They are executed before the main node to extract values from list[PromptMessage].
|
||||
"""
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
|
||||
extractor_configs = self._find_extractor_node_configs()
|
||||
logger.debug("[Extractor] Found %d extractor nodes for parent '%s'", len(extractor_configs), self._node_id)
|
||||
if not extractor_configs:
|
||||
return
|
||||
|
||||
for config in extractor_configs:
|
||||
node_id = config.get("id")
|
||||
node_data = config.get("data", {})
|
||||
node_type_str = node_data.get("type")
|
||||
|
||||
if not node_id or not node_type_str:
|
||||
continue
|
||||
|
||||
# Get node class
|
||||
try:
|
||||
node_type = NodeType(node_type_str)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
node_mapping = NODE_TYPE_CLASSES_MAPPING.get(node_type)
|
||||
if not node_mapping:
|
||||
continue
|
||||
|
||||
node_version = str(node_data.get("version", "1"))
|
||||
node_cls = node_mapping.get(node_version) or node_mapping.get(LATEST_VERSION)
|
||||
if not node_cls:
|
||||
continue
|
||||
|
||||
# Instantiate and execute the extractor node
|
||||
extractor_node = node_cls(
|
||||
id=node_id,
|
||||
config=config,
|
||||
graph_init_params=self._graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
)
|
||||
|
||||
# Execute and process extractor node events
|
||||
for event in extractor_node.run():
|
||||
# Tag event with parent node id for stream ordering and history tracking
|
||||
if isinstance(event, GraphNodeEventBase):
|
||||
event.in_mention_parent_id = self._node_id
|
||||
|
||||
if isinstance(event, NodeRunSucceededEvent):
|
||||
# Store extractor node outputs in variable pool
|
||||
outputs: Mapping[str, Any] = event.node_run_result.outputs
|
||||
for variable_name, variable_value in outputs.items():
|
||||
self.graph_runtime_state.variable_pool.add((node_id, variable_name), variable_value)
|
||||
if not isinstance(event, NodeRunStreamChunkEvent):
|
||||
yield event
|
||||
|
||||
def run(self) -> Generator[GraphNodeEventBase, None, None]:
|
||||
execution_id = self.ensure_execution_id()
|
||||
self._start_at = naive_utc_now()
|
||||
|
||||
# Step 1: Execute associated extractor nodes before main node execution
|
||||
yield from self._execute_extractor_nodes()
|
||||
|
||||
# Create and push start event with required fields
|
||||
start_event = NodeRunStartedEvent(
|
||||
id=execution_id,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.model_runtime.entities import ImagePromptMessageContent, LLMMode
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate, MemoryConfig
|
||||
@@ -58,9 +58,28 @@ class LLMNodeCompletionModelPromptTemplate(CompletionModelPromptTemplate):
|
||||
jinja2_text: str | None = None
|
||||
|
||||
|
||||
class PromptMessageContext(BaseModel):
|
||||
"""Context variable reference in prompt template.
|
||||
|
||||
YAML/JSON format: { "$context": ["node_id", "variable_name"] }
|
||||
This will be expanded to list[PromptMessage] at runtime.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(populate_by_name=True)
|
||||
|
||||
value_selector: Sequence[str] = Field(alias="$context")
|
||||
|
||||
|
||||
# Union type for prompt template items (static message or context variable reference)
|
||||
PromptTemplateItem: TypeAlias = Annotated[
|
||||
LLMNodeChatModelMessage | PromptMessageContext,
|
||||
Field(discriminator=None),
|
||||
]
|
||||
|
||||
|
||||
class LLMNodeData(BaseNodeData):
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_template: Sequence[PromptTemplateItem] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
|
||||
memory: MemoryConfig | None = None
|
||||
context: ContextConfig
|
||||
|
||||
@@ -8,12 +8,13 @@ from configs import dify_config
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.entities.provider_entities import ProviderQuotaType, QuotaUnit
|
||||
from core.file.models import File
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.memory import NodeTokenBufferMemory, TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig, MemoryMode
|
||||
from core.variables.segments import ArrayAnySegment, ArrayFileSegment, FileSegment, NoneSegment, StringSegment
|
||||
from core.workflow.enums import SystemVariableKey
|
||||
from core.workflow.nodes.llm.entities import ModelConfig
|
||||
@@ -86,25 +87,100 @@ def fetch_files(variable_pool: VariablePool, selector: Sequence[str]) -> Sequenc
|
||||
|
||||
|
||||
def fetch_memory(
|
||||
variable_pool: VariablePool, app_id: str, node_data_memory: MemoryConfig | None, model_instance: ModelInstance
|
||||
) -> TokenBufferMemory | None:
|
||||
variable_pool: VariablePool,
|
||||
app_id: str,
|
||||
tenant_id: str,
|
||||
node_data_memory: MemoryConfig | None,
|
||||
model_instance: ModelInstance,
|
||||
node_id: str = "",
|
||||
) -> BaseMemory | None:
|
||||
"""
|
||||
Fetch memory based on configuration mode.
|
||||
|
||||
Returns TokenBufferMemory for conversation mode (default),
|
||||
or NodeTokenBufferMemory for node mode (Chatflow only).
|
||||
|
||||
:param variable_pool: Variable pool containing system variables
|
||||
:param app_id: Application ID
|
||||
:param tenant_id: Tenant ID
|
||||
:param node_data_memory: Memory configuration
|
||||
:param model_instance: Model instance for token counting
|
||||
:param node_id: Node ID in the workflow (required for node mode)
|
||||
:return: Memory instance or None if not applicable
|
||||
"""
|
||||
if not node_data_memory:
|
||||
return None
|
||||
|
||||
# get conversation id
|
||||
# Get conversation_id from variable pool (required for both modes in Chatflow)
|
||||
conversation_id_variable = variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
# Return appropriate memory type based on mode
|
||||
if node_data_memory.mode == MemoryMode.NODE:
|
||||
# Node-level memory (Chatflow only)
|
||||
if not node_id:
|
||||
return None
|
||||
return NodeTokenBufferMemory(
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=node_id,
|
||||
tenant_id=tenant_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
else:
|
||||
# Conversation-level memory (default)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(Conversation.app_id == app_id, Conversation.id == conversation_id)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
return memory
|
||||
|
||||
def save_node_memory(
|
||||
memory: BaseMemory | None,
|
||||
variable_pool: VariablePool,
|
||||
user_query: str,
|
||||
assistant_response: str,
|
||||
user_files: Sequence["File"] | None = None,
|
||||
assistant_files: Sequence["File"] | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Save dialogue turn to node memory if applicable.
|
||||
|
||||
This function handles the storage logic for NodeTokenBufferMemory.
|
||||
For TokenBufferMemory (conversation-level), no action is taken as it uses
|
||||
the Message table which is managed elsewhere.
|
||||
|
||||
:param memory: Memory instance (NodeTokenBufferMemory or TokenBufferMemory)
|
||||
:param variable_pool: Variable pool containing system variables
|
||||
:param user_query: User's input text
|
||||
:param assistant_response: Assistant's response text
|
||||
:param user_files: Files attached by user (optional)
|
||||
:param assistant_files: Files generated by assistant (optional)
|
||||
"""
|
||||
if not isinstance(memory, NodeTokenBufferMemory):
|
||||
return
|
||||
|
||||
# Get workflow_run_id as the key for this execution
|
||||
workflow_run_id_var = variable_pool.get(["sys", SystemVariableKey.WORKFLOW_EXECUTION_ID])
|
||||
if not isinstance(workflow_run_id_var, StringSegment):
|
||||
return
|
||||
|
||||
workflow_run_id = workflow_run_id_var.value
|
||||
if not workflow_run_id:
|
||||
return
|
||||
|
||||
memory.add_messages(
|
||||
workflow_run_id=workflow_run_id,
|
||||
user_content=user_query,
|
||||
user_files=list(user_files) if user_files else None,
|
||||
assistant_content=assistant_response,
|
||||
assistant_files=list(assistant_files) if assistant_files else None,
|
||||
)
|
||||
memory.flush()
|
||||
|
||||
|
||||
def deduct_llm_quota(tenant_id: str, model_instance: ModelInstance, usage: LLMUsage):
|
||||
|
||||
@@ -7,7 +7,7 @@ import logging
|
||||
import re
|
||||
import time
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
@@ -16,10 +16,11 @@ from core.file import File, FileTransferMethod, FileType, file_manager
|
||||
from core.helper.code_executor import CodeExecutor, CodeLanguage
|
||||
from core.llm_generator.output_parser.errors import OutputParserError
|
||||
from core.llm_generator.output_parser.structured_output import invoke_llm_with_structured_output
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.model_runtime.entities import (
|
||||
ImagePromptMessageContent,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
TextPromptMessageContent,
|
||||
@@ -51,6 +52,7 @@ from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.variables import (
|
||||
ArrayFileSegment,
|
||||
ArrayPromptMessageSegment,
|
||||
ArraySegment,
|
||||
FileSegment,
|
||||
NoneSegment,
|
||||
@@ -87,6 +89,7 @@ from .entities import (
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
LLMNodeData,
|
||||
ModelConfig,
|
||||
PromptMessageContext,
|
||||
)
|
||||
from .exc import (
|
||||
InvalidContextStructureError,
|
||||
@@ -159,8 +162,9 @@ class LLMNode(Node[LLMNodeData]):
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
try:
|
||||
# init messages template
|
||||
self.node_data.prompt_template = self._transform_chat_messages(self.node_data.prompt_template)
|
||||
# Parse prompt template to separate static messages and context references
|
||||
prompt_template = self.node_data.prompt_template
|
||||
static_messages, context_refs, template_order = self._parse_prompt_template()
|
||||
|
||||
# fetch variables and fetch values from variable pool
|
||||
inputs = self._fetch_inputs(node_data=self.node_data)
|
||||
@@ -208,8 +212,10 @@ class LLMNode(Node[LLMNodeData]):
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_data_memory=self.node_data.memory,
|
||||
model_instance=model_instance,
|
||||
node_id=self._node_id,
|
||||
)
|
||||
|
||||
query: str | None = None
|
||||
@@ -220,21 +226,40 @@ class LLMNode(Node[LLMNodeData]):
|
||||
):
|
||||
query = query_variable.text
|
||||
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=self.node_data.prompt_template,
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
context_files=context_files,
|
||||
)
|
||||
# Get prompt messages
|
||||
prompt_messages: Sequence[PromptMessage]
|
||||
stop: Sequence[str] | None
|
||||
if isinstance(prompt_template, list) and context_refs:
|
||||
prompt_messages, stop = self._build_prompt_messages_with_context(
|
||||
context_refs=context_refs,
|
||||
template_order=template_order,
|
||||
static_messages=static_messages,
|
||||
query=query,
|
||||
files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
context_files=context_files,
|
||||
)
|
||||
else:
|
||||
prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=cast(
|
||||
Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
self.node_data.prompt_template,
|
||||
),
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
# handle invoke result
|
||||
generator = LLMNode.invoke_llm(
|
||||
@@ -250,6 +275,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
reasoning_format=self.node_data.reasoning_format,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
|
||||
structured_output: LLMStructuredOutput | None = None
|
||||
@@ -301,12 +327,25 @@ class LLMNode(Node[LLMNodeData]):
|
||||
"reasoning_content": reasoning_content,
|
||||
"usage": jsonable_encoder(usage),
|
||||
"finish_reason": finish_reason,
|
||||
"context": self._build_context(prompt_messages, clean_text),
|
||||
}
|
||||
if structured_output:
|
||||
outputs["structured_output"] = structured_output.structured_output
|
||||
if self._file_outputs:
|
||||
outputs["files"] = ArrayFileSegment(value=self._file_outputs)
|
||||
|
||||
# Write to Node Memory if in node memory mode
|
||||
# Resolve the query template to get actual user content
|
||||
actual_query = variable_pool.convert_template(query or "").text
|
||||
llm_utils.save_node_memory(
|
||||
memory=memory,
|
||||
variable_pool=variable_pool,
|
||||
user_query=actual_query,
|
||||
assistant_response=clean_text,
|
||||
user_files=files,
|
||||
assistant_files=self._file_outputs,
|
||||
)
|
||||
|
||||
# Send final chunk event to indicate streaming is complete
|
||||
yield StreamChunkEvent(
|
||||
selector=[self._node_id, "text"],
|
||||
@@ -367,6 +406,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
node_id: str,
|
||||
node_type: NodeType,
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged",
|
||||
tenant_id: str | None = None,
|
||||
) -> Generator[NodeEventBase | LLMStructuredOutput, None, None]:
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
node_data_model.name, model_instance.credentials
|
||||
@@ -390,6 +430,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
stop=list(stop or []),
|
||||
stream=True,
|
||||
user=user_id,
|
||||
tenant_id=tenant_id,
|
||||
)
|
||||
else:
|
||||
request_start_time = time.perf_counter()
|
||||
@@ -566,6 +607,48 @@ class LLMNode(Node[LLMNodeData]):
|
||||
# Separated mode: always return clean text and reasoning_content
|
||||
return clean_text, reasoning_content or ""
|
||||
|
||||
@staticmethod
|
||||
def _build_context(
|
||||
prompt_messages: Sequence[PromptMessage],
|
||||
assistant_response: str,
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Build context from prompt messages and assistant response.
|
||||
Excludes system messages and includes the current LLM response.
|
||||
Returns list[PromptMessage] for use with ArrayPromptMessageSegment.
|
||||
|
||||
Note: Multi-modal content base64 data is truncated to avoid storing large data in context.
|
||||
"""
|
||||
context_messages: list[PromptMessage] = [
|
||||
LLMNode._truncate_multimodal_content(m) for m in prompt_messages if m.role != PromptMessageRole.SYSTEM
|
||||
]
|
||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
return context_messages
|
||||
|
||||
@staticmethod
|
||||
def _truncate_multimodal_content(message: PromptMessage) -> PromptMessage:
|
||||
"""
|
||||
Truncate multi-modal content base64 data in a message to avoid storing large data.
|
||||
Preserves the PromptMessage structure for ArrayPromptMessageSegment compatibility.
|
||||
"""
|
||||
content = message.content
|
||||
if content is None or isinstance(content, str):
|
||||
return message
|
||||
|
||||
# Process list content, truncating multi-modal base64 data
|
||||
new_content: list[PromptMessageContentUnionTypes] = []
|
||||
for item in content:
|
||||
if isinstance(item, MultiModalPromptMessageContent):
|
||||
# Truncate base64_data similar to prompt_messages_to_prompt_for_saving
|
||||
truncated_base64 = ""
|
||||
if item.base64_data:
|
||||
truncated_base64 = item.base64_data[:10] + "...[TRUNCATED]..." + item.base64_data[-10:]
|
||||
new_content.append(item.model_copy(update={"base64_data": truncated_base64}))
|
||||
else:
|
||||
new_content.append(item)
|
||||
|
||||
return message.model_copy(update={"content": new_content})
|
||||
|
||||
def _transform_chat_messages(
|
||||
self, messages: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate, /
|
||||
) -> Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate:
|
||||
@@ -581,6 +664,106 @@ class LLMNode(Node[LLMNodeData]):
|
||||
|
||||
return messages
|
||||
|
||||
def _parse_prompt_template(
|
||||
self,
|
||||
) -> tuple[list[LLMNodeChatModelMessage], list[PromptMessageContext], list[tuple[int, str]]]:
|
||||
"""
|
||||
Parse prompt_template to separate static messages and context references.
|
||||
|
||||
Returns:
|
||||
Tuple of (static_messages, context_refs, template_order)
|
||||
- static_messages: list of LLMNodeChatModelMessage
|
||||
- context_refs: list of PromptMessageContext
|
||||
- template_order: list of (index, type) tuples preserving original order
|
||||
"""
|
||||
prompt_template = self.node_data.prompt_template
|
||||
static_messages: list[LLMNodeChatModelMessage] = []
|
||||
context_refs: list[PromptMessageContext] = []
|
||||
template_order: list[tuple[int, str]] = []
|
||||
|
||||
if isinstance(prompt_template, list):
|
||||
for idx, item in enumerate(prompt_template):
|
||||
if isinstance(item, PromptMessageContext):
|
||||
context_refs.append(item)
|
||||
template_order.append((idx, "context"))
|
||||
else:
|
||||
static_messages.append(item)
|
||||
template_order.append((idx, "static"))
|
||||
# Transform static messages for jinja2
|
||||
if static_messages:
|
||||
self.node_data.prompt_template = self._transform_chat_messages(static_messages)
|
||||
|
||||
return static_messages, context_refs, template_order
|
||||
|
||||
def _build_prompt_messages_with_context(
|
||||
self,
|
||||
*,
|
||||
context_refs: list[PromptMessageContext],
|
||||
template_order: list[tuple[int, str]],
|
||||
static_messages: list[LLMNodeChatModelMessage],
|
||||
query: str | None,
|
||||
files: Sequence[File],
|
||||
context: str | None,
|
||||
memory: BaseMemory | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
context_files: list[File],
|
||||
) -> tuple[list[PromptMessage], Sequence[str] | None]:
|
||||
"""
|
||||
Build prompt messages by combining static messages and context references in DSL order.
|
||||
|
||||
Returns:
|
||||
Tuple of (prompt_messages, stop_sequences)
|
||||
"""
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
# Build a map from context index to its messages
|
||||
context_messages_map: dict[int, list[PromptMessage]] = {}
|
||||
context_idx = 0
|
||||
for idx, type_ in template_order:
|
||||
if type_ == "context":
|
||||
ctx_ref = context_refs[context_idx]
|
||||
ctx_var = variable_pool.get(ctx_ref.value_selector)
|
||||
if ctx_var is None:
|
||||
raise VariableNotFoundError(f"Variable {'.'.join(ctx_ref.value_selector)} not found")
|
||||
if not isinstance(ctx_var, ArrayPromptMessageSegment):
|
||||
raise InvalidVariableTypeError(f"Variable {'.'.join(ctx_ref.value_selector)} is not array[message]")
|
||||
context_messages_map[idx] = list(ctx_var.value)
|
||||
context_idx += 1
|
||||
|
||||
# Process static messages
|
||||
static_prompt_messages: Sequence[PromptMessage] = []
|
||||
stop: Sequence[str] | None = None
|
||||
if static_messages:
|
||||
static_prompt_messages, stop = LLMNode.fetch_prompt_messages(
|
||||
sys_query=query,
|
||||
sys_files=files,
|
||||
context=context,
|
||||
memory=memory,
|
||||
model_config=model_config,
|
||||
prompt_template=cast(Sequence[LLMNodeChatModelMessage], self.node_data.prompt_template),
|
||||
memory_config=self.node_data.memory,
|
||||
vision_enabled=self.node_data.vision.enabled,
|
||||
vision_detail=self.node_data.vision.configs.detail,
|
||||
variable_pool=variable_pool,
|
||||
jinja2_variables=self.node_data.prompt_config.jinja2_variables,
|
||||
tenant_id=self.tenant_id,
|
||||
context_files=context_files,
|
||||
)
|
||||
|
||||
# Combine messages according to original DSL order
|
||||
combined_messages: list[PromptMessage] = []
|
||||
static_msg_iter = iter(static_prompt_messages)
|
||||
for idx, type_ in template_order:
|
||||
if type_ == "context":
|
||||
combined_messages.extend(context_messages_map[idx])
|
||||
else:
|
||||
if msg := next(static_msg_iter, None):
|
||||
combined_messages.append(msg)
|
||||
# Append any remaining static messages (e.g., memory messages)
|
||||
combined_messages.extend(static_msg_iter)
|
||||
|
||||
return combined_messages, stop
|
||||
|
||||
def _fetch_jinja_inputs(self, node_data: LLMNodeData) -> dict[str, str]:
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
@@ -778,7 +961,7 @@ class LLMNode(Node[LLMNodeData]):
|
||||
sys_query: str | None = None,
|
||||
sys_files: Sequence[File],
|
||||
context: str | None = None,
|
||||
memory: TokenBufferMemory | None = None,
|
||||
memory: BaseMemory | None = None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate,
|
||||
memory_config: MemoryConfig | None = None,
|
||||
@@ -1337,7 +1520,7 @@ def _calculate_rest_token(
|
||||
|
||||
def _handle_memory_chat_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> Sequence[PromptMessage]:
|
||||
@@ -1354,7 +1537,7 @@ def _handle_memory_chat_mode(
|
||||
|
||||
def _handle_memory_completion_mode(
|
||||
*,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
memory_config: MemoryConfig | None,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
) -> str:
|
||||
|
||||
@@ -7,7 +7,7 @@ from typing import Any, cast
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.file import File
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import ImagePromptMessageContent
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
@@ -145,8 +145,10 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
node_id=self._node_id,
|
||||
)
|
||||
|
||||
if (
|
||||
@@ -244,6 +246,14 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
# transform result into standard format
|
||||
result = self._transform_result(data=node_data, result=result or {})
|
||||
|
||||
# Save to node memory if in node memory mode
|
||||
llm_utils.save_node_memory(
|
||||
memory=memory,
|
||||
variable_pool=variable_pool,
|
||||
user_query=query,
|
||||
assistant_response=json.dumps(result, ensure_ascii=False),
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=inputs,
|
||||
@@ -299,7 +309,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> tuple[list[PromptMessage], list[PromptMessageTool]]:
|
||||
@@ -381,7 +391,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@@ -419,7 +429,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@@ -453,7 +463,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
model_config: ModelConfigWithCredentialsEntity,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
files: Sequence[File],
|
||||
vision_detail: ImagePromptMessageContent.DETAIL | None = None,
|
||||
) -> list[PromptMessage]:
|
||||
@@ -681,7 +691,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
) -> list[ChatModelMessage]:
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
@@ -708,7 +718,7 @@ class ParameterExtractorNode(Node[ParameterExtractorNodeData]):
|
||||
node_data: ParameterExtractorNodeData,
|
||||
query: str,
|
||||
variable_pool: VariablePool,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
|
||||
@@ -4,7 +4,7 @@ from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import ModelConfigWithCredentialsEntity
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.model_runtime.entities import LLMUsage, ModelPropertyKey, PromptMessageRole
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
@@ -96,8 +96,10 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
memory = llm_utils.fetch_memory(
|
||||
variable_pool=variable_pool,
|
||||
app_id=self.app_id,
|
||||
tenant_id=self.tenant_id,
|
||||
node_data_memory=node_data.memory,
|
||||
model_instance=model_instance,
|
||||
node_id=self._node_id,
|
||||
)
|
||||
# fetch instruction
|
||||
node_data.instruction = node_data.instruction or ""
|
||||
@@ -203,6 +205,14 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
"usage": jsonable_encoder(usage),
|
||||
}
|
||||
|
||||
# Save to node memory if in node memory mode
|
||||
llm_utils.save_node_memory(
|
||||
memory=memory,
|
||||
variable_pool=variable_pool,
|
||||
user_query=query or "",
|
||||
assistant_response=f"class_name: {category_name}, class_id: {category_id}",
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs=variables,
|
||||
@@ -312,7 +322,7 @@ class QuestionClassifierNode(Node[QuestionClassifierNodeData]):
|
||||
self,
|
||||
node_data: QuestionClassifierNodeData,
|
||||
query: str,
|
||||
memory: TokenBufferMemory | None,
|
||||
memory: BaseMemory | None,
|
||||
max_token_limit: int = 2000,
|
||||
):
|
||||
model_mode = ModelMode(node_data.model.mode)
|
||||
|
||||
@@ -1,11 +1,63 @@
|
||||
from typing import Any, Literal, Union
|
||||
import re
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Literal, Self, Union
|
||||
|
||||
from pydantic import BaseModel, field_validator
|
||||
from pydantic import BaseModel, field_validator, model_validator
|
||||
from pydantic_core.core_schema import ValidationInfo
|
||||
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.workflow.nodes.base.entities import BaseNodeData
|
||||
|
||||
# Pattern to match mention value format: {{@node.context@}}instruction
|
||||
# The placeholder {{@node.context@}} must appear at the beginning
|
||||
# Format: {{@agent_node_id.context@}} where agent_node_id is dynamic, context is fixed
|
||||
MENTION_VALUE_PATTERN = re.compile(r"^\{\{@([a-zA-Z0-9_]+)\.context@\}\}(.*)$", re.DOTALL)
|
||||
|
||||
|
||||
def parse_mention_value(value: str) -> tuple[str, str]:
|
||||
"""Parse mention value into (node_id, instruction).
|
||||
|
||||
Args:
|
||||
value: The mention value string like "{{@llm.context@}}extract keywords"
|
||||
|
||||
Returns:
|
||||
Tuple of (node_id, instruction)
|
||||
|
||||
Raises:
|
||||
ValueError: If value format is invalid
|
||||
"""
|
||||
match = MENTION_VALUE_PATTERN.match(value)
|
||||
if not match:
|
||||
raise ValueError(
|
||||
"For mention type, value must start with {{@node.context@}} placeholder, "
|
||||
"e.g., '{{@llm.context@}}extract keywords'"
|
||||
)
|
||||
return match.group(1), match.group(2)
|
||||
|
||||
|
||||
class MentionConfig(BaseModel):
|
||||
"""Configuration for extracting value from context variable.
|
||||
|
||||
Used when a tool parameter needs to be extracted from list[PromptMessage]
|
||||
context using an extractor LLM node.
|
||||
|
||||
Note: instruction is embedded in the value field as "{{@node.context@}}instruction"
|
||||
"""
|
||||
|
||||
# ID of the extractor LLM node
|
||||
extractor_node_id: str
|
||||
|
||||
# Output variable selector from extractor node
|
||||
# e.g., ["text"], ["structured_output", "query"]
|
||||
output_selector: Sequence[str]
|
||||
|
||||
# Strategy when output is None
|
||||
null_strategy: Literal["raise_error", "use_default"] = "raise_error"
|
||||
|
||||
# Default value when null_strategy is "use_default"
|
||||
# Type should match the parameter's expected type
|
||||
default_value: Any = None
|
||||
|
||||
|
||||
class ToolEntity(BaseModel):
|
||||
provider_id: str
|
||||
@@ -35,7 +87,9 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
class ToolInput(BaseModel):
|
||||
# TODO: check this type
|
||||
value: Union[Any, list[str]]
|
||||
type: Literal["mixed", "variable", "constant"]
|
||||
type: Literal["mixed", "variable", "constant", "mention"]
|
||||
# Required config for mention type, extracting value from context variable
|
||||
mention_config: MentionConfig | None = None
|
||||
|
||||
@field_validator("type", mode="before")
|
||||
@classmethod
|
||||
@@ -48,6 +102,9 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
|
||||
if typ == "mixed" and not isinstance(value, str):
|
||||
raise ValueError("value must be a string")
|
||||
elif typ == "mention":
|
||||
# Skip here, will be validated in model_validator
|
||||
pass
|
||||
elif typ == "variable":
|
||||
if not isinstance(value, list):
|
||||
raise ValueError("value must be a list")
|
||||
@@ -58,6 +115,26 @@ class ToolNodeData(BaseNodeData, ToolEntity):
|
||||
raise ValueError("value must be a string, int, float, bool or dict")
|
||||
return typ
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_mention_type(self) -> Self:
|
||||
"""Validate mention type with mention_config."""
|
||||
if self.type != "mention":
|
||||
return self
|
||||
|
||||
value = self.value
|
||||
if value is None:
|
||||
return self
|
||||
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("value must be a string for mention type")
|
||||
# For mention type, value must match format: {{@node.context@}}instruction
|
||||
# This will raise ValueError if format is invalid
|
||||
parse_mention_value(value)
|
||||
# mention_config is required for mention type
|
||||
if self.mention_config is None:
|
||||
raise ValueError("mention_config is required for mention type")
|
||||
return self
|
||||
|
||||
tool_parameters: dict[str, ToolInput]
|
||||
# The version of the tool parameter.
|
||||
# If this value is None, it indicates this is a previous version
|
||||
|
||||
@@ -1,7 +1,10 @@
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
@@ -184,6 +187,7 @@ class ToolNode(Node[ToolNodeData]):
|
||||
tool_parameters (Sequence[ToolParameter]): The list of tool parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (ToolNodeData): The data associated with the tool node.
|
||||
for_log (bool): Whether to generate parameters for logging.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
@@ -199,14 +203,37 @@ class ToolNode(Node[ToolNodeData]):
|
||||
continue
|
||||
tool_input = node_data.tool_parameters[parameter_name]
|
||||
if tool_input.type == "variable":
|
||||
variable = variable_pool.get(tool_input.value)
|
||||
if not isinstance(tool_input.value, list):
|
||||
raise ToolParameterError(f"Invalid variable selector for parameter '{parameter_name}'")
|
||||
selector = tool_input.value
|
||||
variable = variable_pool.get(selector)
|
||||
if variable is None:
|
||||
if parameter.required:
|
||||
raise ToolParameterError(f"Variable {tool_input.value} does not exist")
|
||||
raise ToolParameterError(f"Variable {selector} does not exist")
|
||||
continue
|
||||
parameter_value = variable.value
|
||||
elif tool_input.type == "mention":
|
||||
# Mention type: get value from extractor node's output
|
||||
if tool_input.mention_config is None:
|
||||
raise ToolParameterError(
|
||||
f"mention_config is required for mention type parameter '{parameter_name}'"
|
||||
)
|
||||
mention_config = tool_input.mention_config.model_dump()
|
||||
try:
|
||||
parameter_value, found = variable_pool.resolve_mention(
|
||||
mention_config, parameter_name=parameter_name
|
||||
)
|
||||
if not found and parameter.required:
|
||||
raise ToolParameterError(
|
||||
f"Extractor output not found for required parameter '{parameter_name}'"
|
||||
)
|
||||
if not found:
|
||||
continue
|
||||
except ValueError as e:
|
||||
raise ToolParameterError(str(e)) from e
|
||||
elif tool_input.type in {"mixed", "constant"}:
|
||||
segment_group = variable_pool.convert_template(str(tool_input.value))
|
||||
template = str(tool_input.value)
|
||||
segment_group = variable_pool.convert_template(template)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
else:
|
||||
raise ToolParameterError(f"Unknown tool input type '{tool_input.type}'")
|
||||
@@ -488,8 +515,12 @@ class ToolNode(Node[ToolNodeData]):
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
elif input.type == "variable":
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
if isinstance(input.value, list):
|
||||
selector_key = ".".join(input.value)
|
||||
result[f"#{selector_key}#"] = input.value
|
||||
elif input.type == "mention":
|
||||
# Mention type: value is handled by extractor node, no direct variable reference
|
||||
pass
|
||||
elif input.type == "constant":
|
||||
pass
|
||||
|
||||
|
||||
@@ -268,6 +268,58 @@ class VariablePool(BaseModel):
|
||||
continue
|
||||
self.add(selector, value)
|
||||
|
||||
def resolve_mention(
|
||||
self,
|
||||
mention_config: Mapping[str, Any],
|
||||
/,
|
||||
*,
|
||||
parameter_name: str = "",
|
||||
) -> tuple[Any, bool]:
|
||||
"""
|
||||
Resolve a mention parameter value from an extractor node's output.
|
||||
|
||||
Mention parameters reference values extracted by an extractor LLM node
|
||||
from list[PromptMessage] context.
|
||||
|
||||
Args:
|
||||
mention_config: A dict containing:
|
||||
- extractor_node_id: ID of the extractor LLM node
|
||||
- output_selector: Selector path for the output variable (e.g., ["text"])
|
||||
- null_strategy: "raise_error" or "use_default"
|
||||
- default_value: Value to use when null_strategy is "use_default"
|
||||
parameter_name: Name of the parameter being resolved (for error messages)
|
||||
|
||||
Returns:
|
||||
Tuple of (resolved_value, found):
|
||||
- resolved_value: The extracted value, or default_value if not found
|
||||
- found: True if value was found, False if using default
|
||||
|
||||
Raises:
|
||||
ValueError: If extractor_node_id is missing, or if null_strategy is
|
||||
"raise_error" and the value is not found
|
||||
"""
|
||||
extractor_node_id = mention_config.get("extractor_node_id")
|
||||
if not extractor_node_id:
|
||||
raise ValueError(f"Missing extractor_node_id for mention parameter '{parameter_name}'")
|
||||
|
||||
output_selector = list(mention_config.get("output_selector", []))
|
||||
null_strategy = mention_config.get("null_strategy", "raise_error")
|
||||
default_value = mention_config.get("default_value")
|
||||
|
||||
# Build full selector: [extractor_node_id, ...output_selector]
|
||||
full_selector = [extractor_node_id] + output_selector
|
||||
variable = self.get(full_selector)
|
||||
|
||||
if variable is None:
|
||||
if null_strategy == "use_default":
|
||||
return default_value, False
|
||||
raise ValueError(
|
||||
f"Extractor node '{extractor_node_id}' output '{'.'.join(output_selector)}' "
|
||||
f"not found for parameter '{parameter_name}'"
|
||||
)
|
||||
|
||||
return variable.value, True
|
||||
|
||||
@classmethod
|
||||
def empty(cls) -> VariablePool:
|
||||
"""Create an empty variable pool."""
|
||||
|
||||
@@ -4,6 +4,7 @@ from uuid import uuid4
|
||||
|
||||
from configs import dify_config
|
||||
from core.file import File
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
from core.variables.exc import VariableError
|
||||
from core.variables.segments import (
|
||||
ArrayAnySegment,
|
||||
@@ -11,6 +12,7 @@ from core.variables.segments import (
|
||||
ArrayFileSegment,
|
||||
ArrayNumberSegment,
|
||||
ArrayObjectSegment,
|
||||
ArrayPromptMessageSegment,
|
||||
ArraySegment,
|
||||
ArrayStringSegment,
|
||||
BooleanSegment,
|
||||
@@ -29,6 +31,7 @@ from core.variables.variables import (
|
||||
ArrayFileVariable,
|
||||
ArrayNumberVariable,
|
||||
ArrayObjectVariable,
|
||||
ArrayPromptMessageVariable,
|
||||
ArrayStringVariable,
|
||||
BooleanVariable,
|
||||
FileVariable,
|
||||
@@ -61,6 +64,7 @@ SEGMENT_TO_VARIABLE_MAP = {
|
||||
ArrayFileSegment: ArrayFileVariable,
|
||||
ArrayNumberSegment: ArrayNumberVariable,
|
||||
ArrayObjectSegment: ArrayObjectVariable,
|
||||
ArrayPromptMessageSegment: ArrayPromptMessageVariable,
|
||||
ArrayStringSegment: ArrayStringVariable,
|
||||
BooleanSegment: BooleanVariable,
|
||||
FileSegment: FileVariable,
|
||||
@@ -156,7 +160,13 @@ def build_segment(value: Any, /) -> Segment:
|
||||
return ObjectSegment(value=value)
|
||||
if isinstance(value, File):
|
||||
return FileSegment(value=value)
|
||||
if isinstance(value, PromptMessage):
|
||||
# Single PromptMessage should be wrapped in a list
|
||||
return ArrayPromptMessageSegment(value=[value])
|
||||
if isinstance(value, list):
|
||||
# Check if all items are PromptMessage
|
||||
if value and all(isinstance(item, PromptMessage) for item in value):
|
||||
return ArrayPromptMessageSegment(value=value)
|
||||
items = [build_segment(item) for item in value]
|
||||
types = {item.value_type for item in items}
|
||||
if all(isinstance(item, ArraySegment) for item in items):
|
||||
@@ -200,6 +210,7 @@ _segment_factory: Mapping[SegmentType, type[Segment]] = {
|
||||
SegmentType.ARRAY_OBJECT: ArrayObjectSegment,
|
||||
SegmentType.ARRAY_FILE: ArrayFileSegment,
|
||||
SegmentType.ARRAY_BOOLEAN: ArrayBooleanSegment,
|
||||
SegmentType.ARRAY_PROMPT_MESSAGE: ArrayPromptMessageSegment,
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -1289,7 +1289,7 @@ class WorkflowDraftVariable(Base):
|
||||
# which may differ from the original value's type. Typically, they are the same,
|
||||
# but in cases where the structurally truncated value still exceeds the size limit,
|
||||
# text slicing is applied, and the `value_type` is converted to `STRING`.
|
||||
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=20))
|
||||
value_type: Mapped[SegmentType] = mapped_column(EnumText(SegmentType, length=21))
|
||||
|
||||
# The variable's value serialized as a JSON string
|
||||
#
|
||||
@@ -1663,7 +1663,7 @@ class WorkflowDraftVariableFile(Base):
|
||||
|
||||
# The `value_type` field records the type of the original value.
|
||||
value_type: Mapped[SegmentType] = mapped_column(
|
||||
EnumText(SegmentType, length=20),
|
||||
EnumText(SegmentType, length=21),
|
||||
nullable=False,
|
||||
)
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Any, Generic, TypeAlias, TypeVar, overload
|
||||
|
||||
from configs import dify_config
|
||||
from core.file.models import File
|
||||
from core.model_runtime.entities import PromptMessage
|
||||
from core.variables.segments import (
|
||||
ArrayFileSegment,
|
||||
ArraySegment,
|
||||
@@ -287,6 +288,10 @@ class VariableTruncator(BaseTruncator):
|
||||
if isinstance(item, File):
|
||||
truncated_value.append(item)
|
||||
continue
|
||||
# Handle PromptMessage types - convert to dict for truncation
|
||||
if isinstance(item, PromptMessage):
|
||||
truncated_value.append(item)
|
||||
continue
|
||||
if i >= target_length:
|
||||
return _PartResult(truncated_value, used_size, True)
|
||||
if i > 0:
|
||||
|
||||
181
api/tests/fixtures/file output schema.yml
vendored
Normal file
181
api/tests/fixtures/file output schema.yml
vendored
Normal file
@@ -0,0 +1,181 @@
|
||||
app:
|
||||
description: ''
|
||||
icon: 🤖
|
||||
icon_background: '#FFEAD5'
|
||||
mode: advanced-chat
|
||||
name: file output schema
|
||||
use_icon_as_answer_icon: false
|
||||
dependencies:
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: langgenius/openai:0.2.3@5a7f82fa86e28332ad51941d0b491c1e8a38ead539656442f7bf4c6129cd15fa
|
||||
version: null
|
||||
kind: app
|
||||
version: 0.5.0
|
||||
workflow:
|
||||
conversation_variables: []
|
||||
environment_variables: []
|
||||
features:
|
||||
file_upload:
|
||||
allowed_file_extensions:
|
||||
- .JPG
|
||||
- .JPEG
|
||||
- .PNG
|
||||
- .GIF
|
||||
- .WEBP
|
||||
- .SVG
|
||||
allowed_file_types:
|
||||
- image
|
||||
allowed_file_upload_methods:
|
||||
- remote_url
|
||||
- local_file
|
||||
enabled: true
|
||||
fileUploadConfig:
|
||||
attachment_image_file_size_limit: 2
|
||||
audio_file_size_limit: 50
|
||||
batch_count_limit: 5
|
||||
file_size_limit: 15
|
||||
file_upload_limit: 10
|
||||
image_file_batch_limit: 10
|
||||
image_file_size_limit: 10
|
||||
single_chunk_attachment_limit: 10
|
||||
video_file_size_limit: 100
|
||||
workflow_file_upload_limit: 10
|
||||
number_limits: 3
|
||||
opening_statement: ''
|
||||
retriever_resource:
|
||||
enabled: true
|
||||
sensitive_word_avoidance:
|
||||
enabled: false
|
||||
speech_to_text:
|
||||
enabled: false
|
||||
suggested_questions: []
|
||||
suggested_questions_after_answer:
|
||||
enabled: false
|
||||
text_to_speech:
|
||||
enabled: false
|
||||
language: ''
|
||||
voice: ''
|
||||
graph:
|
||||
edges:
|
||||
- data:
|
||||
sourceType: start
|
||||
targetType: llm
|
||||
id: 1768292241666-llm
|
||||
source: '1768292241666'
|
||||
sourceHandle: source
|
||||
target: llm
|
||||
targetHandle: target
|
||||
type: custom
|
||||
- data:
|
||||
sourceType: llm
|
||||
targetType: answer
|
||||
id: llm-answer
|
||||
source: llm
|
||||
sourceHandle: source
|
||||
target: answer
|
||||
targetHandle: target
|
||||
type: custom
|
||||
nodes:
|
||||
- data:
|
||||
selected: false
|
||||
title: User Input
|
||||
type: start
|
||||
variables: []
|
||||
height: 73
|
||||
id: '1768292241666'
|
||||
position:
|
||||
x: 80
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 80
|
||||
y: 282
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
context:
|
||||
enabled: false
|
||||
variable_selector: []
|
||||
memory:
|
||||
query_prompt_template: '{{#sys.query#}}
|
||||
|
||||
|
||||
{{#sys.files#}}'
|
||||
role_prefix:
|
||||
assistant: ''
|
||||
user: ''
|
||||
window:
|
||||
enabled: false
|
||||
size: 10
|
||||
model:
|
||||
completion_params:
|
||||
temperature: 0.7
|
||||
mode: chat
|
||||
name: gpt-4o-mini
|
||||
provider: langgenius/openai/openai
|
||||
prompt_template:
|
||||
- id: e30d75d7-7d85-49ec-be3c-3baf7f6d3c5a
|
||||
role: system
|
||||
text: ''
|
||||
selected: false
|
||||
structured_output:
|
||||
schema:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
image:
|
||||
description: File ID (UUID) of the selected image
|
||||
format: dify-file-ref
|
||||
type: string
|
||||
required:
|
||||
- image
|
||||
type: object
|
||||
structured_output_enabled: true
|
||||
title: LLM
|
||||
type: llm
|
||||
vision:
|
||||
configs:
|
||||
detail: high
|
||||
variable_selector:
|
||||
- sys
|
||||
- files
|
||||
enabled: true
|
||||
height: 88
|
||||
id: llm
|
||||
position:
|
||||
x: 380
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 380
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
answer: '{{#llm.structured_output.image#}}'
|
||||
selected: false
|
||||
title: Answer
|
||||
type: answer
|
||||
variables: []
|
||||
height: 103
|
||||
id: answer
|
||||
position:
|
||||
x: 680
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 680
|
||||
y: 282
|
||||
selected: true
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
viewport:
|
||||
x: -149
|
||||
y: 97.5
|
||||
zoom: 1
|
||||
rag_pipeline_variables: []
|
||||
307
api/tests/fixtures/pav-test-extraction.yml
vendored
Normal file
307
api/tests/fixtures/pav-test-extraction.yml
vendored
Normal file
@@ -0,0 +1,307 @@
|
||||
app:
|
||||
description: Test for variable extraction feature
|
||||
icon: 🤖
|
||||
icon_background: '#FFEAD5'
|
||||
mode: advanced-chat
|
||||
name: pav-test-extraction
|
||||
use_icon_as_answer_icon: false
|
||||
dependencies:
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
|
||||
version: null
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: langgenius/openai:0.2.3@5a7f82fa86e28332ad51941d0b491c1e8a38ead539656442f7bf4c6129cd15fa
|
||||
version: null
|
||||
- current_identifier: null
|
||||
type: marketplace
|
||||
value:
|
||||
marketplace_plugin_unique_identifier: langgenius/tongyi:0.1.16@d8bffbe45418f0c117fb3393e5e40e61faee98f9a2183f062e5a280e74b15d21
|
||||
version: null
|
||||
kind: app
|
||||
version: 0.5.0
|
||||
workflow:
|
||||
conversation_variables: []
|
||||
environment_variables: []
|
||||
features:
|
||||
file_upload:
|
||||
allowed_file_extensions:
|
||||
- .JPG
|
||||
- .JPEG
|
||||
- .PNG
|
||||
- .GIF
|
||||
- .WEBP
|
||||
- .SVG
|
||||
allowed_file_types:
|
||||
- image
|
||||
allowed_file_upload_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
enabled: false
|
||||
image:
|
||||
enabled: false
|
||||
number_limits: 3
|
||||
transfer_methods:
|
||||
- local_file
|
||||
- remote_url
|
||||
number_limits: 3
|
||||
opening_statement: 你好!我是一个搜索助手,请告诉我你想搜索什么内容。
|
||||
retriever_resource:
|
||||
enabled: true
|
||||
sensitive_word_avoidance:
|
||||
enabled: false
|
||||
speech_to_text:
|
||||
enabled: false
|
||||
suggested_questions: []
|
||||
suggested_questions_after_answer:
|
||||
enabled: false
|
||||
text_to_speech:
|
||||
enabled: false
|
||||
language: ''
|
||||
voice: ''
|
||||
graph:
|
||||
edges:
|
||||
- data:
|
||||
sourceType: start
|
||||
targetType: llm
|
||||
id: 1767773675796-llm
|
||||
source: '1767773675796'
|
||||
sourceHandle: source
|
||||
target: llm
|
||||
targetHandle: target
|
||||
type: custom
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: llm
|
||||
targetType: tool
|
||||
id: llm-source-1767773709491-target
|
||||
source: llm
|
||||
sourceHandle: source
|
||||
target: '1767773709491'
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
- data:
|
||||
isInIteration: false
|
||||
isInLoop: false
|
||||
sourceType: tool
|
||||
targetType: answer
|
||||
id: tool-source-answer-target
|
||||
source: '1767773709491'
|
||||
sourceHandle: source
|
||||
target: answer
|
||||
targetHandle: target
|
||||
type: custom
|
||||
zIndex: 0
|
||||
nodes:
|
||||
- data:
|
||||
selected: false
|
||||
title: User Input
|
||||
type: start
|
||||
variables: []
|
||||
height: 73
|
||||
id: '1767773675796'
|
||||
position:
|
||||
x: 80
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 80
|
||||
y: 282
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
context:
|
||||
enabled: false
|
||||
variable_selector: []
|
||||
memory:
|
||||
mode: node
|
||||
query_prompt_template: '{{#sys.query#}}'
|
||||
role_prefix:
|
||||
assistant: ''
|
||||
user: ''
|
||||
window:
|
||||
enabled: true
|
||||
size: 10
|
||||
model:
|
||||
completion_params:
|
||||
temperature: 0.7
|
||||
mode: chat
|
||||
name: qwen-max
|
||||
provider: langgenius/tongyi/tongyi
|
||||
prompt_template:
|
||||
- id: 11d06d15-914a-4915-a5b1-0e35ab4fba51
|
||||
role: system
|
||||
text: '你是一个智能搜索助手。用户会告诉你他们想搜索的内容。
|
||||
|
||||
请与用户进行对话,了解他们的搜索需求。
|
||||
|
||||
当用户明确表达了想要搜索的内容后,你可以回复"好的,我来帮你搜索"。
|
||||
|
||||
'
|
||||
selected: false
|
||||
title: LLM
|
||||
type: llm
|
||||
vision:
|
||||
enabled: false
|
||||
height: 88
|
||||
id: llm
|
||||
position:
|
||||
x: 380
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 380
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
is_team_authorization: true
|
||||
paramSchemas:
|
||||
- auto_generate: null
|
||||
default: null
|
||||
form: llm
|
||||
human_description:
|
||||
en_US: used for searching
|
||||
ja_JP: used for searching
|
||||
pt_BR: used for searching
|
||||
zh_Hans: 用于搜索网页内容
|
||||
label:
|
||||
en_US: Query string
|
||||
ja_JP: Query string
|
||||
pt_BR: Query string
|
||||
zh_Hans: 查询语句
|
||||
llm_description: key words for searching
|
||||
max: null
|
||||
min: null
|
||||
name: query
|
||||
options: []
|
||||
placeholder: null
|
||||
precision: null
|
||||
required: true
|
||||
scope: null
|
||||
template: null
|
||||
type: string
|
||||
params:
|
||||
query: ''
|
||||
plugin_id: langgenius/google
|
||||
plugin_unique_identifier: langgenius/google:0.0.8@3efcf55ffeef9d0f77715e0afb23534952ae0cb385c051d0637e86d71199d1a6
|
||||
provider_icon: http://localhost:5001/console/api/workspaces/current/plugin/icon?tenant_id=7217e801-f6f5-49ec-8103-d7de97a4b98f&filename=1c5871163478957bac64c3fe33d72d003f767497d921c74b742aad27a8344a74.svg
|
||||
provider_id: langgenius/google/google
|
||||
provider_name: langgenius/google/google
|
||||
provider_type: builtin
|
||||
selected: false
|
||||
title: GoogleSearch
|
||||
tool_configurations: {}
|
||||
tool_description: A tool for performing a Google SERP search and extracting
|
||||
snippets and webpages.Input should be a search query.
|
||||
tool_label: GoogleSearch
|
||||
tool_name: google_search
|
||||
tool_node_version: '2'
|
||||
tool_parameters:
|
||||
query:
|
||||
type: mention
|
||||
value: '{{@llm.context@}}请从对话历史中提取用户想要搜索的关键词,只返回关键词本身'
|
||||
mention_config:
|
||||
extractor_node_id: 1767773709491_ext_query
|
||||
output_selector:
|
||||
- structured_output
|
||||
- query
|
||||
null_strategy: use_default
|
||||
default_value: ''
|
||||
type: tool
|
||||
height: 52
|
||||
id: '1767773709491'
|
||||
position:
|
||||
x: 682
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 682
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
context:
|
||||
enabled: false
|
||||
variable_selector: []
|
||||
model:
|
||||
completion_params:
|
||||
temperature: 0.7
|
||||
mode: chat
|
||||
name: gpt-4o-mini
|
||||
provider: langgenius/openai/openai
|
||||
parent_node_id: '1767773709491'
|
||||
prompt_template:
|
||||
- $context:
|
||||
- llm
|
||||
- context
|
||||
id: 75d58e22-dc59-40c8-ba6f-aeb28f4f305a
|
||||
- id: 18ba6710-77f5-47f4-b144-9191833bb547
|
||||
role: user
|
||||
text: 请从对话历史中提取用户想要搜索的关键词,只返回关键词本身,不要返回其他内容
|
||||
selected: false
|
||||
structured_output:
|
||||
schema:
|
||||
additionalProperties: false
|
||||
properties:
|
||||
query:
|
||||
description: 搜索的关键词
|
||||
type: string
|
||||
required:
|
||||
- query
|
||||
type: object
|
||||
structured_output_enabled: true
|
||||
title: 提取搜索关键词
|
||||
type: llm
|
||||
vision:
|
||||
enabled: false
|
||||
height: 88
|
||||
id: 1767773709491_ext_query
|
||||
position:
|
||||
x: 531
|
||||
y: 382
|
||||
positionAbsolute:
|
||||
x: 531
|
||||
y: 382
|
||||
selected: true
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
- data:
|
||||
answer: '搜索结果:
|
||||
|
||||
{{#1767773709491.text#}}
|
||||
|
||||
'
|
||||
selected: false
|
||||
title: Answer
|
||||
type: answer
|
||||
height: 103
|
||||
id: answer
|
||||
position:
|
||||
x: 984
|
||||
y: 282
|
||||
positionAbsolute:
|
||||
x: 984
|
||||
y: 282
|
||||
selected: false
|
||||
sourcePosition: right
|
||||
targetPosition: left
|
||||
type: custom
|
||||
width: 242
|
||||
viewport:
|
||||
x: -151
|
||||
y: 123
|
||||
zoom: 1
|
||||
rag_pipeline_variables: []
|
||||
@@ -0,0 +1,269 @@
|
||||
"""
|
||||
Unit tests for file reference detection and conversion.
|
||||
"""
|
||||
|
||||
import uuid
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
from core.llm_generator.output_parser.file_ref import (
|
||||
FILE_REF_FORMAT,
|
||||
convert_file_refs_in_output,
|
||||
detect_file_ref_fields,
|
||||
is_file_ref_property,
|
||||
)
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment
|
||||
|
||||
|
||||
class TestIsFileRefProperty:
|
||||
"""Tests for is_file_ref_property function."""
|
||||
|
||||
def test_valid_file_ref(self):
|
||||
schema = {"type": "string", "format": FILE_REF_FORMAT}
|
||||
assert is_file_ref_property(schema) is True
|
||||
|
||||
def test_invalid_type(self):
|
||||
schema = {"type": "number", "format": FILE_REF_FORMAT}
|
||||
assert is_file_ref_property(schema) is False
|
||||
|
||||
def test_missing_format(self):
|
||||
schema = {"type": "string"}
|
||||
assert is_file_ref_property(schema) is False
|
||||
|
||||
def test_wrong_format(self):
|
||||
schema = {"type": "string", "format": "uuid"}
|
||||
assert is_file_ref_property(schema) is False
|
||||
|
||||
|
||||
class TestDetectFileRefFields:
|
||||
"""Tests for detect_file_ref_fields function."""
|
||||
|
||||
def test_simple_file_ref(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["image"]
|
||||
|
||||
def test_multiple_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"document": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"name": {"type": "string"},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert set(paths) == {"image", "document"}
|
||||
|
||||
def test_array_of_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["files[*]"]
|
||||
|
||||
def test_nested_file_ref(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"data": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == ["data.image"]
|
||||
|
||||
def test_no_file_refs(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == []
|
||||
|
||||
def test_empty_schema(self):
|
||||
schema = {}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert paths == []
|
||||
|
||||
def test_mixed_schema(self):
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"documents": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
paths = detect_file_ref_fields(schema)
|
||||
assert set(paths) == {"image", "documents[*]"}
|
||||
|
||||
|
||||
class TestConvertFileRefsInOutput:
|
||||
"""Tests for convert_file_refs_in_output function."""
|
||||
|
||||
@pytest.fixture
|
||||
def mock_file(self):
|
||||
"""Create a mock File object with all required attributes."""
|
||||
file = MagicMock(spec=File)
|
||||
file.type = FileType.IMAGE
|
||||
file.transfer_method = FileTransferMethod.TOOL_FILE
|
||||
file.related_id = "test-related-id"
|
||||
file.remote_url = None
|
||||
file.tenant_id = "tenant_123"
|
||||
file.id = None
|
||||
file.filename = "test.png"
|
||||
file.extension = ".png"
|
||||
file.mime_type = "image/png"
|
||||
file.size = 1024
|
||||
file.dify_model_identity = "__dify__file__"
|
||||
return file
|
||||
|
||||
@pytest.fixture
|
||||
def mock_build_from_mapping(self, mock_file):
|
||||
"""Mock the build_from_mapping function."""
|
||||
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
|
||||
mock.return_value = mock_file
|
||||
yield mock
|
||||
|
||||
def test_convert_simple_file_ref(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"image": file_id}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
# Result should be wrapped in FileSegment
|
||||
assert isinstance(result["image"], FileSegment)
|
||||
assert result["image"].value == mock_file
|
||||
mock_build_from_mapping.assert_called_once_with(
|
||||
mapping={"transfer_method": "tool_file", "tool_file_id": file_id},
|
||||
tenant_id="tenant_123",
|
||||
)
|
||||
|
||||
def test_convert_array_of_file_refs(self, mock_build_from_mapping, mock_file):
|
||||
file_id1 = str(uuid.uuid4())
|
||||
file_id2 = str(uuid.uuid4())
|
||||
output = {"files": [file_id1, file_id2]}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"files": {
|
||||
"type": "array",
|
||||
"items": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
# Result should be wrapped in ArrayFileSegment
|
||||
assert isinstance(result["files"], ArrayFileSegment)
|
||||
assert list(result["files"].value) == [mock_file, mock_file]
|
||||
assert mock_build_from_mapping.call_count == 2
|
||||
|
||||
def test_no_conversion_without_file_refs(self):
|
||||
output = {"name": "test", "count": 5}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result == {"name": "test", "count": 5}
|
||||
|
||||
def test_invalid_uuid_returns_none(self):
|
||||
output = {"image": "not-a-valid-uuid"}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["image"] is None
|
||||
|
||||
def test_file_not_found_returns_none(self):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"image": file_id}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
with patch("core.llm_generator.output_parser.file_ref.build_from_mapping") as mock:
|
||||
mock.side_effect = ValueError("File not found")
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["image"] is None
|
||||
|
||||
def test_preserves_non_file_fields(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
output = {"query": "search term", "image": file_id, "count": 10}
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {"type": "string"},
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
"count": {"type": "number"},
|
||||
},
|
||||
}
|
||||
|
||||
result = convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
assert result["query"] == "search term"
|
||||
assert isinstance(result["image"], FileSegment)
|
||||
assert result["image"].value == mock_file
|
||||
assert result["count"] == 10
|
||||
|
||||
def test_does_not_modify_original_output(self, mock_build_from_mapping, mock_file):
|
||||
file_id = str(uuid.uuid4())
|
||||
original = {"image": file_id}
|
||||
output = dict(original)
|
||||
schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"image": {"type": "string", "format": FILE_REF_FORMAT},
|
||||
},
|
||||
}
|
||||
|
||||
convert_file_refs_in_output(output, schema, "tenant_123")
|
||||
|
||||
# Original should still contain the string ID
|
||||
assert original["image"] == file_id
|
||||
@@ -25,6 +25,12 @@ class _StubErrorHandler:
|
||||
"""Minimal error handler stub for tests."""
|
||||
|
||||
|
||||
class _StubNodeData:
|
||||
"""Simple node data stub with is_extractor_node property."""
|
||||
|
||||
is_extractor_node = False
|
||||
|
||||
|
||||
class _StubNode:
|
||||
"""Simple node stub exposing the attributes needed by the state manager."""
|
||||
|
||||
@@ -36,6 +42,7 @@ class _StubNode:
|
||||
self.error_strategy = None
|
||||
self.retry_config = RetryConfig()
|
||||
self.retry = False
|
||||
self.node_data = _StubNodeData()
|
||||
|
||||
|
||||
def _build_event_handler(node_id: str) -> tuple[EventHandler, EventManager, GraphExecution]:
|
||||
|
||||
@@ -14,7 +14,6 @@ import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { useAppContext } from '@/context/app-context'
|
||||
import { useDocLink } from '@/context/i18n'
|
||||
import {
|
||||
|
||||
useAppTriggers,
|
||||
useInvalidateAppTriggers,
|
||||
useUpdateTriggerStatus,
|
||||
|
||||
@@ -38,13 +38,16 @@ export const getInputVars = (text: string): ValueSelector[] => {
|
||||
if (!text || typeof text !== 'string')
|
||||
return []
|
||||
|
||||
const allVars = text.match(/\{\{#([^#]*)#\}\}/g)
|
||||
const allVars = text.match(/\{\{[@#]([^@#]*)[@#]\}\}/g)
|
||||
if (allVars && allVars?.length > 0) {
|
||||
// {{#context#}}, {{#query#}} is not input vars
|
||||
const inputVars = allVars
|
||||
.filter(item => item.includes('.'))
|
||||
.map((item) => {
|
||||
const valueSelector = item.replace('{{#', '').replace('#}}', '').split('.')
|
||||
const valueSelector = item
|
||||
.replace(/^\{\{[@#]/, '')
|
||||
.replace(/[@#]\}\}$/, '')
|
||||
.split('.')
|
||||
if (valueSelector[1] === 'sys' && /^\d+$/.test(valueSelector[0]))
|
||||
return valueSelector.slice(1)
|
||||
|
||||
|
||||
@@ -5,6 +5,7 @@ import type {
|
||||
} from 'lexical'
|
||||
import type { FC } from 'react'
|
||||
import type {
|
||||
AgentBlockType,
|
||||
ContextBlockType,
|
||||
CurrentBlockType,
|
||||
ErrorMessageBlockType,
|
||||
@@ -103,6 +104,7 @@ export type PromptEditorProps = {
|
||||
currentBlock?: CurrentBlockType
|
||||
errorMessageBlock?: ErrorMessageBlockType
|
||||
lastRunBlock?: LastRunBlockType
|
||||
agentBlock?: AgentBlockType
|
||||
isSupportFileVar?: boolean
|
||||
}
|
||||
|
||||
@@ -128,6 +130,7 @@ const PromptEditor: FC<PromptEditorProps> = ({
|
||||
currentBlock,
|
||||
errorMessageBlock,
|
||||
lastRunBlock,
|
||||
agentBlock,
|
||||
isSupportFileVar,
|
||||
}) => {
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
@@ -139,6 +142,7 @@ const PromptEditor: FC<PromptEditorProps> = ({
|
||||
{
|
||||
replace: TextNode,
|
||||
with: (node: TextNode) => new CustomTextNode(node.__text),
|
||||
withKlass: CustomTextNode,
|
||||
},
|
||||
ContextBlockNode,
|
||||
HistoryBlockNode,
|
||||
@@ -212,6 +216,22 @@ const PromptEditor: FC<PromptEditorProps> = ({
|
||||
lastRunBlock={lastRunBlock}
|
||||
isSupportFileVar={isSupportFileVar}
|
||||
/>
|
||||
{(!agentBlock || agentBlock.show) && (
|
||||
<ComponentPickerBlock
|
||||
triggerString="@"
|
||||
contextBlock={contextBlock}
|
||||
historyBlock={historyBlock}
|
||||
queryBlock={queryBlock}
|
||||
variableBlock={variableBlock}
|
||||
externalToolBlock={externalToolBlock}
|
||||
workflowVariableBlock={workflowVariableBlock}
|
||||
currentBlock={currentBlock}
|
||||
errorMessageBlock={errorMessageBlock}
|
||||
lastRunBlock={lastRunBlock}
|
||||
agentBlock={agentBlock}
|
||||
isSupportFileVar={isSupportFileVar}
|
||||
/>
|
||||
)}
|
||||
<ComponentPickerBlock
|
||||
triggerString="{"
|
||||
contextBlock={contextBlock}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { MenuRenderFn } from '@lexical/react/LexicalTypeaheadMenuPlugin'
|
||||
import type { TextNode } from 'lexical'
|
||||
import type {
|
||||
AgentBlockType,
|
||||
ContextBlockType,
|
||||
CurrentBlockType,
|
||||
ErrorMessageBlockType,
|
||||
@@ -20,7 +21,11 @@ import {
|
||||
} from '@floating-ui/react'
|
||||
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
|
||||
import { LexicalTypeaheadMenuPlugin } from '@lexical/react/LexicalTypeaheadMenuPlugin'
|
||||
import { KEY_ESCAPE_COMMAND } from 'lexical'
|
||||
import {
|
||||
$getRoot,
|
||||
$insertNodes,
|
||||
KEY_ESCAPE_COMMAND,
|
||||
} from 'lexical'
|
||||
import {
|
||||
Fragment,
|
||||
memo,
|
||||
@@ -29,7 +34,9 @@ import {
|
||||
} from 'react'
|
||||
import ReactDOM from 'react-dom'
|
||||
import { GeneratorType } from '@/app/components/app/configuration/config/automatic/types'
|
||||
import AgentNodeList from '@/app/components/workflow/nodes/_base/components/agent-node-list'
|
||||
import VarReferenceVars from '@/app/components/workflow/nodes/_base/components/variable/var-reference-vars'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { useEventEmitterContextContext } from '@/context/event-emitter'
|
||||
import { useBasicTypeaheadTriggerMatch } from '../../hooks'
|
||||
import { $splitNodeContainingQuery } from '../../utils'
|
||||
@@ -38,6 +45,7 @@ import { INSERT_ERROR_MESSAGE_BLOCK_COMMAND } from '../error-message-block'
|
||||
import { INSERT_LAST_RUN_BLOCK_COMMAND } from '../last-run-block'
|
||||
import { INSERT_VARIABLE_VALUE_BLOCK_COMMAND } from '../variable-block'
|
||||
import { INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND } from '../workflow-variable-block'
|
||||
import { $createWorkflowVariableBlockNode } from '../workflow-variable-block/node'
|
||||
import { useOptions } from './hooks'
|
||||
|
||||
type ComponentPickerProps = {
|
||||
@@ -51,6 +59,7 @@ type ComponentPickerProps = {
|
||||
currentBlock?: CurrentBlockType
|
||||
errorMessageBlock?: ErrorMessageBlockType
|
||||
lastRunBlock?: LastRunBlockType
|
||||
agentBlock?: AgentBlockType
|
||||
isSupportFileVar?: boolean
|
||||
}
|
||||
const ComponentPicker = ({
|
||||
@@ -64,6 +73,7 @@ const ComponentPicker = ({
|
||||
currentBlock,
|
||||
errorMessageBlock,
|
||||
lastRunBlock,
|
||||
agentBlock,
|
||||
isSupportFileVar,
|
||||
}: ComponentPickerProps) => {
|
||||
const { eventEmitter } = useEventEmitterContextContext()
|
||||
@@ -151,12 +161,41 @@ const ComponentPicker = ({
|
||||
editor.dispatchCommand(KEY_ESCAPE_COMMAND, escapeEvent)
|
||||
}, [editor])
|
||||
|
||||
const handleSelectAgent = useCallback((agent: { id: string, title: string }) => {
|
||||
editor.update(() => {
|
||||
const needRemove = $splitNodeContainingQuery(checkForTriggerMatch(triggerString, editor)!)
|
||||
if (needRemove)
|
||||
needRemove.remove()
|
||||
|
||||
const root = $getRoot()
|
||||
const firstChild = root.getFirstChild()
|
||||
if (firstChild) {
|
||||
const selection = firstChild.selectStart()
|
||||
if (selection) {
|
||||
const workflowVariableBlockNode = $createWorkflowVariableBlockNode([agent.id, 'text'], {}, undefined)
|
||||
$insertNodes([workflowVariableBlockNode])
|
||||
}
|
||||
}
|
||||
})
|
||||
agentBlock?.onSelect?.(agent)
|
||||
handleClose()
|
||||
}, [editor, checkForTriggerMatch, triggerString, agentBlock, handleClose])
|
||||
|
||||
const isAgentTrigger = triggerString === '@' && agentBlock?.show
|
||||
const agentNodes = agentBlock?.agentNodes || []
|
||||
|
||||
const renderMenu = useCallback<MenuRenderFn<PickerBlockMenuOption>>((
|
||||
anchorElementRef,
|
||||
{ options, selectedIndex, selectOptionAndCleanUp, setHighlightedIndex },
|
||||
) => {
|
||||
if (!(anchorElementRef.current && (allFlattenOptions.length || workflowVariableBlock?.show)))
|
||||
return null
|
||||
if (isAgentTrigger) {
|
||||
if (!(anchorElementRef.current && agentNodes.length))
|
||||
return null
|
||||
}
|
||||
else {
|
||||
if (!(anchorElementRef.current && (allFlattenOptions.length || workflowVariableBlock?.show)))
|
||||
return null
|
||||
}
|
||||
|
||||
setTimeout(() => {
|
||||
if (anchorElementRef.current)
|
||||
@@ -167,9 +206,6 @@ const ComponentPicker = ({
|
||||
<>
|
||||
{
|
||||
ReactDOM.createPortal(
|
||||
// The `LexicalMenu` will try to calculate the position of the floating menu based on the first child.
|
||||
// Since we use floating ui, we need to wrap it with a div to prevent the position calculation being affected.
|
||||
// See https://github.com/facebook/lexical/blob/ac97dfa9e14a73ea2d6934ff566282d7f758e8bb/packages/lexical-react/src/shared/LexicalMenu.ts#L493
|
||||
<div className="h-0 w-0">
|
||||
<div
|
||||
className="w-[260px] rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg"
|
||||
@@ -179,56 +215,73 @@ const ComponentPicker = ({
|
||||
}}
|
||||
ref={refs.setFloating}
|
||||
>
|
||||
{
|
||||
workflowVariableBlock?.show && (
|
||||
<div className="p-1">
|
||||
<VarReferenceVars
|
||||
searchBoxClassName="mt-1"
|
||||
vars={workflowVariableOptions}
|
||||
onChange={(variables: string[]) => {
|
||||
handleSelectWorkflowVariable(variables)
|
||||
}}
|
||||
maxHeightClass="max-h-[34vh]"
|
||||
isSupportFileVar={isSupportFileVar}
|
||||
{isAgentTrigger
|
||||
? (
|
||||
<AgentNodeList
|
||||
nodes={agentNodes.map(node => ({
|
||||
...node,
|
||||
type: BlockEnum.Agent,
|
||||
}))}
|
||||
onSelect={handleSelectAgent}
|
||||
onClose={handleClose}
|
||||
onBlur={handleClose}
|
||||
showManageInputField={workflowVariableBlock.showManageInputField}
|
||||
onManageInputField={workflowVariableBlock.onManageInputField}
|
||||
maxHeightClass="max-h-[34vh]"
|
||||
autoFocus={false}
|
||||
isInCodeGeneratorInstructionEditor={currentBlock?.generatorType === GeneratorType.code}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
workflowVariableBlock?.show && !!options.length && (
|
||||
<div className="my-1 h-px w-full -translate-x-1 bg-divider-subtle"></div>
|
||||
)
|
||||
}
|
||||
<div>
|
||||
{
|
||||
options.map((option, index) => (
|
||||
<Fragment key={option.key}>
|
||||
)
|
||||
: (
|
||||
<>
|
||||
{
|
||||
// Divider
|
||||
index !== 0 && options.at(index - 1)?.group !== option.group && (
|
||||
workflowVariableBlock?.show && (
|
||||
<div className="p-1">
|
||||
<VarReferenceVars
|
||||
searchBoxClassName="mt-1"
|
||||
vars={workflowVariableOptions}
|
||||
onChange={(variables: string[]) => {
|
||||
handleSelectWorkflowVariable(variables)
|
||||
}}
|
||||
maxHeightClass="max-h-[34vh]"
|
||||
isSupportFileVar={isSupportFileVar}
|
||||
onClose={handleClose}
|
||||
onBlur={handleClose}
|
||||
showManageInputField={workflowVariableBlock.showManageInputField}
|
||||
onManageInputField={workflowVariableBlock.onManageInputField}
|
||||
autoFocus={false}
|
||||
isInCodeGeneratorInstructionEditor={currentBlock?.generatorType === GeneratorType.code}
|
||||
/>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{
|
||||
workflowVariableBlock?.show && !!options.length && (
|
||||
<div className="my-1 h-px w-full -translate-x-1 bg-divider-subtle"></div>
|
||||
)
|
||||
}
|
||||
{option.renderMenuOption({
|
||||
queryString,
|
||||
isSelected: selectedIndex === index,
|
||||
onSelect: () => {
|
||||
selectOptionAndCleanUp(option)
|
||||
},
|
||||
onSetHighlight: () => {
|
||||
setHighlightedIndex(index)
|
||||
},
|
||||
})}
|
||||
</Fragment>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
<div>
|
||||
{
|
||||
options.map((option, index) => (
|
||||
<Fragment key={option.key}>
|
||||
{
|
||||
index !== 0 && options.at(index - 1)?.group !== option.group && (
|
||||
<div className="my-1 h-px w-full -translate-x-1 bg-divider-subtle"></div>
|
||||
)
|
||||
}
|
||||
{option.renderMenuOption({
|
||||
queryString,
|
||||
isSelected: selectedIndex === index,
|
||||
onSelect: () => {
|
||||
selectOptionAndCleanUp(option)
|
||||
},
|
||||
onSetHighlight: () => {
|
||||
setHighlightedIndex(index)
|
||||
},
|
||||
})}
|
||||
</Fragment>
|
||||
))
|
||||
}
|
||||
</div>
|
||||
</>
|
||||
)}
|
||||
</div>
|
||||
</div>,
|
||||
anchorElementRef.current,
|
||||
@@ -236,7 +289,7 @@ const ComponentPicker = ({
|
||||
}
|
||||
</>
|
||||
)
|
||||
}, [allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, workflowVariableOptions, isSupportFileVar, handleClose, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField])
|
||||
}, [isAgentTrigger, agentNodes, allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, handleSelectAgent, handleClose, workflowVariableOptions, isSupportFileVar, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField])
|
||||
|
||||
return (
|
||||
<LexicalTypeaheadMenuPlugin
|
||||
|
||||
@@ -21,6 +21,7 @@ import {
|
||||
VariableLabelInEditor,
|
||||
} from '@/app/components/workflow/nodes/_base/components/variable/variable-label'
|
||||
import { Type } from '@/app/components/workflow/nodes/llm/types'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { isExceptionVariable } from '@/app/components/workflow/utils'
|
||||
import { useSelectOrDelete } from '../../hooks'
|
||||
import {
|
||||
@@ -66,6 +67,7 @@ const WorkflowVariableBlockComponent = ({
|
||||
)()
|
||||
const [localWorkflowNodesMap, setLocalWorkflowNodesMap] = useState<WorkflowNodesMap>(workflowNodesMap)
|
||||
const node = localWorkflowNodesMap![variables[isRagVar ? 1 : 0]]
|
||||
const isAgentContextVariable = node?.type === BlockEnum.Agent && variables[variablesLength - 1] === 'context'
|
||||
|
||||
const isException = isExceptionVariable(varName, node?.type)
|
||||
const variableValid = useMemo(() => {
|
||||
@@ -134,6 +136,9 @@ const WorkflowVariableBlockComponent = ({
|
||||
})
|
||||
}, [node, reactflow, store])
|
||||
|
||||
if (isAgentContextVariable)
|
||||
return <span className="hidden" ref={ref} />
|
||||
|
||||
const Item = (
|
||||
<VariableLabelInEditor
|
||||
nodeType={node?.type}
|
||||
|
||||
@@ -2,6 +2,7 @@ import type { LexicalNode, NodeKey, SerializedLexicalNode } from 'lexical'
|
||||
import type { GetVarType, WorkflowVariableBlockType } from '../../types'
|
||||
import type { Var } from '@/app/components/workflow/types'
|
||||
import { DecoratorNode } from 'lexical'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import WorkflowVariableBlockComponent from './component'
|
||||
|
||||
export type WorkflowNodesMap = WorkflowVariableBlockType['workflowNodesMap']
|
||||
@@ -120,7 +121,11 @@ export class WorkflowVariableBlockNode extends DecoratorNode<React.JSX.Element>
|
||||
}
|
||||
|
||||
getTextContent(): string {
|
||||
return `{{#${this.getVariables().join('.')}#}}`
|
||||
const variables = this.getVariables()
|
||||
const node = this.getWorkflowNodesMap()?.[variables[0]]
|
||||
const isAgentContextVariable = node?.type === BlockEnum.Agent && variables[variables.length - 1] === 'context'
|
||||
const marker = isAgentContextVariable ? '@' : '#'
|
||||
return `{{${marker}${variables.join('.')}${marker}}}`
|
||||
}
|
||||
}
|
||||
export function $createWorkflowVariableBlockNode(variables: string[], workflowNodesMap: WorkflowNodesMap, getVarType?: GetVarType, environmentVariables?: Var[], conversationVariables?: Var[], ragVariables?: Var[]): WorkflowVariableBlockNode {
|
||||
|
||||
@@ -73,6 +73,17 @@ export type WorkflowVariableBlockType = {
|
||||
onManageInputField?: () => void
|
||||
}
|
||||
|
||||
export type AgentNode = {
|
||||
id: string
|
||||
title: string
|
||||
}
|
||||
|
||||
export type AgentBlockType = {
|
||||
show?: boolean
|
||||
agentNodes?: AgentNode[]
|
||||
onSelect?: (agent: AgentNode) => void
|
||||
}
|
||||
|
||||
export type MenuTextMatch = {
|
||||
leadOffset: number
|
||||
matchingString: string
|
||||
|
||||
194
web/app/components/sub-graph/components/config-panel.tsx
Normal file
194
web/app/components/sub-graph/components/config-panel.tsx
Normal file
@@ -0,0 +1,194 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import type { Item } from '@/app/components/base/select'
|
||||
import type { MentionConfig } from '@/app/components/workflow/nodes/_base/types'
|
||||
import type { Node, NodeOutPutVar, ValueSelector } from '@/app/components/workflow/types'
|
||||
import { RiCheckLine } from '@remixicon/react'
|
||||
import { memo, useCallback, useMemo, useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { SimpleSelect } from '@/app/components/base/select'
|
||||
import CodeEditor from '@/app/components/workflow/nodes/_base/components/editor/code-editor'
|
||||
import Field from '@/app/components/workflow/nodes/_base/components/field'
|
||||
import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker'
|
||||
import Tab, { TabType } from '@/app/components/workflow/nodes/_base/components/workflow-panel/tab'
|
||||
import { CodeLanguage } from '@/app/components/workflow/nodes/code/types'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type ConfigPanelProps = {
|
||||
agentName: string
|
||||
extractorNodeId: string
|
||||
mentionConfig: MentionConfig
|
||||
availableNodes: Node[]
|
||||
availableVars: NodeOutPutVar[]
|
||||
onMentionConfigChange: (config: MentionConfig) => void
|
||||
}
|
||||
|
||||
const ConfigPanel: FC<ConfigPanelProps> = ({
|
||||
agentName,
|
||||
extractorNodeId,
|
||||
mentionConfig,
|
||||
availableNodes,
|
||||
availableVars,
|
||||
onMentionConfigChange,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [tabType, setTabType] = useState<TabType>(TabType.settings)
|
||||
|
||||
const resolvedExtractorId = mentionConfig.extractor_node_id || extractorNodeId
|
||||
|
||||
const selectedOutput = useMemo<ValueSelector>(() => {
|
||||
if (!resolvedExtractorId || !mentionConfig.output_selector?.length)
|
||||
return []
|
||||
|
||||
return [resolvedExtractorId, ...(mentionConfig.output_selector || [])]
|
||||
}, [mentionConfig.output_selector, resolvedExtractorId])
|
||||
|
||||
const handleOutputVarChange = useCallback((value: ValueSelector | string) => {
|
||||
const selector = Array.isArray(value) ? value : []
|
||||
const nextExtractorId = selector[0] || resolvedExtractorId
|
||||
const nextOutputSelector = selector.length > 1 ? selector.slice(1) : []
|
||||
|
||||
onMentionConfigChange({
|
||||
...mentionConfig,
|
||||
extractor_node_id: nextExtractorId,
|
||||
output_selector: nextOutputSelector,
|
||||
})
|
||||
}, [mentionConfig, onMentionConfigChange, resolvedExtractorId])
|
||||
|
||||
const whenOutputNoneOptions = useMemo(() => ([
|
||||
{
|
||||
value: 'raise_error',
|
||||
name: t('subGraphModal.whenOutputNone.error', { ns: 'workflow' }),
|
||||
description: t('subGraphModal.whenOutputNone.errorDesc', { ns: 'workflow' }),
|
||||
},
|
||||
{
|
||||
value: 'use_default',
|
||||
name: t('subGraphModal.whenOutputNone.default', { ns: 'workflow' }),
|
||||
description: t('subGraphModal.whenOutputNone.defaultDesc', { ns: 'workflow' }),
|
||||
},
|
||||
]), [t])
|
||||
const selectedWhenOutputNoneOption = useMemo(() => (
|
||||
whenOutputNoneOptions.find(item => item.value === mentionConfig.null_strategy) ?? whenOutputNoneOptions[0]
|
||||
), [mentionConfig.null_strategy, whenOutputNoneOptions])
|
||||
|
||||
const handleNullStrategyChange = useCallback((item: Item) => {
|
||||
if (typeof item.value !== 'string')
|
||||
return
|
||||
onMentionConfigChange({
|
||||
...mentionConfig,
|
||||
null_strategy: item.value as MentionConfig['null_strategy'],
|
||||
})
|
||||
}, [mentionConfig, onMentionConfigChange])
|
||||
|
||||
const handleDefaultValueChange = useCallback((value: string) => {
|
||||
const trimmed = value.trim()
|
||||
let nextValue: unknown = value
|
||||
if ((trimmed.startsWith('{') && trimmed.endsWith('}')) || (trimmed.startsWith('[') && trimmed.endsWith(']'))) {
|
||||
try {
|
||||
nextValue = JSON.parse(trimmed)
|
||||
}
|
||||
catch {
|
||||
nextValue = value
|
||||
}
|
||||
}
|
||||
|
||||
onMentionConfigChange({
|
||||
...mentionConfig,
|
||||
default_value: nextValue,
|
||||
})
|
||||
}, [mentionConfig, onMentionConfigChange])
|
||||
const defaultValue = mentionConfig.default_value ?? ''
|
||||
const shouldFormatDefaultValue = typeof defaultValue !== 'string'
|
||||
|
||||
return (
|
||||
<div className="flex h-full flex-col">
|
||||
<div className="px-4 pb-2 pt-4">
|
||||
<div className="system-lg-semibold text-text-primary">
|
||||
{t('subGraphModal.internalStructure', { ns: 'workflow' })}
|
||||
</div>
|
||||
<div className="system-sm-regular text-text-tertiary">
|
||||
{t('subGraphModal.internalStructureDesc', { ns: 'workflow', name: agentName })}
|
||||
</div>
|
||||
</div>
|
||||
<div className="px-4 pb-2">
|
||||
<Tab value={tabType} onChange={setTabType} />
|
||||
</div>
|
||||
{tabType === TabType.lastRun && (
|
||||
<div className="flex flex-1 items-center justify-center p-4">
|
||||
<p className="system-sm-regular text-text-tertiary">
|
||||
{t('subGraphModal.noRunHistory', { ns: 'workflow' })}
|
||||
</p>
|
||||
</div>
|
||||
)}
|
||||
{tabType === TabType.settings && (
|
||||
<div className="flex-1 overflow-y-auto">
|
||||
<div className="space-y-4 px-4 py-4">
|
||||
<Field title={t('subGraphModal.outputVariables', { ns: 'workflow' })}>
|
||||
<VarReferencePicker
|
||||
nodeId={extractorNodeId}
|
||||
readonly={false}
|
||||
isShowNodeName
|
||||
value={selectedOutput}
|
||||
onChange={handleOutputVarChange}
|
||||
availableNodes={availableNodes}
|
||||
availableVars={availableVars}
|
||||
/>
|
||||
</Field>
|
||||
</div>
|
||||
<div className="space-y-4 px-4 py-4">
|
||||
<Field
|
||||
title={t('subGraphModal.whenOutputIsNone', { ns: 'workflow' })}
|
||||
operations={(
|
||||
<div className="flex items-center">
|
||||
<SimpleSelect
|
||||
items={whenOutputNoneOptions}
|
||||
defaultValue={mentionConfig.null_strategy}
|
||||
allowSearch={false}
|
||||
notClearable
|
||||
wrapperClassName="min-w-[160px]"
|
||||
onSelect={handleNullStrategyChange}
|
||||
renderOption={({ item, selected }) => (
|
||||
<div className="flex items-start gap-2">
|
||||
<div className="mt-0.5 flex h-4 w-4 shrink-0 items-center justify-center">
|
||||
{selected && (
|
||||
<RiCheckLine className="h-4 w-4 text-[14px] text-text-accent" />
|
||||
)}
|
||||
</div>
|
||||
<div className="min-w-0">
|
||||
<div className="system-sm-medium text-text-secondary">{item.name}</div>
|
||||
<div className="system-xs-regular mt-0.5 text-text-tertiary">{item.description}</div>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
>
|
||||
<div className="space-y-2">
|
||||
{selectedWhenOutputNoneOption?.description && (
|
||||
<div className="system-xs-regular text-text-tertiary">
|
||||
{selectedWhenOutputNoneOption.description}
|
||||
</div>
|
||||
)}
|
||||
{mentionConfig.null_strategy === 'use_default' && (
|
||||
<div className={cn('overflow-hidden rounded-lg border border-components-input-border-active bg-components-input-bg-normal p-1')}>
|
||||
<CodeEditor
|
||||
noWrapper
|
||||
language={CodeLanguage.json}
|
||||
value={defaultValue}
|
||||
onChange={handleDefaultValueChange}
|
||||
isJSONStringifyBeauty={shouldFormatDefaultValue}
|
||||
className="min-h-[160px]"
|
||||
/>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</Field>
|
||||
</div>
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(ConfigPanel)
|
||||
@@ -0,0 +1,87 @@
|
||||
import type { FC } from 'react'
|
||||
import type { MentionConfig } from '@/app/components/workflow/nodes/_base/types'
|
||||
import type { NodeOutPutVar } from '@/app/components/workflow/types'
|
||||
import { memo, useMemo } from 'react'
|
||||
import { useStore as useReactFlowStore } from 'reactflow'
|
||||
import { useShallow } from 'zustand/react/shallow'
|
||||
import { useIsChatMode, useWorkflowVariables } from '@/app/components/workflow/hooks'
|
||||
import Panel from '@/app/components/workflow/panel'
|
||||
import { useStore } from '@/app/components/workflow/store'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import ConfigPanel from './config-panel'
|
||||
|
||||
type SubGraphChildrenProps = {
|
||||
agentName: string
|
||||
extractorNodeId: string
|
||||
mentionConfig: MentionConfig
|
||||
onMentionConfigChange: (config: MentionConfig) => void
|
||||
}
|
||||
|
||||
const SubGraphChildren: FC<SubGraphChildrenProps> = ({
|
||||
agentName,
|
||||
extractorNodeId,
|
||||
mentionConfig,
|
||||
onMentionConfigChange,
|
||||
}) => {
|
||||
const { getNodeAvailableVars } = useWorkflowVariables()
|
||||
const isChatMode = useIsChatMode()
|
||||
const nodePanelWidth = useStore(s => s.nodePanelWidth)
|
||||
|
||||
const selectedNode = useReactFlowStore(useShallow((s) => {
|
||||
return s.getNodes().find(node => node.data.selected)
|
||||
}))
|
||||
|
||||
const extractorNode = useReactFlowStore(useShallow((s) => {
|
||||
return s.getNodes().find(node => node.data.type === BlockEnum.LLM)
|
||||
}))
|
||||
|
||||
const availableNodes = useMemo(() => {
|
||||
return extractorNode ? [extractorNode] : []
|
||||
}, [extractorNode])
|
||||
|
||||
const availableVars = useMemo<NodeOutPutVar[]>(() => {
|
||||
if (!extractorNode)
|
||||
return []
|
||||
|
||||
const vars = getNodeAvailableVars({
|
||||
beforeNodes: [extractorNode],
|
||||
isChatMode,
|
||||
filterVar: () => true,
|
||||
})
|
||||
return vars.filter(item => item.nodeId === extractorNode.id)
|
||||
}, [extractorNode, getNodeAvailableVars, isChatMode])
|
||||
|
||||
const panelRight = useMemo(() => {
|
||||
if (selectedNode)
|
||||
return null
|
||||
|
||||
return (
|
||||
<div className="relative mr-1 h-full">
|
||||
<div
|
||||
className="flex h-full flex-col rounded-2xl border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-lg"
|
||||
style={{ width: `${nodePanelWidth}px` }}
|
||||
>
|
||||
<ConfigPanel
|
||||
agentName={agentName}
|
||||
extractorNodeId={extractorNodeId}
|
||||
mentionConfig={mentionConfig}
|
||||
availableNodes={availableNodes}
|
||||
availableVars={availableVars}
|
||||
onMentionConfigChange={onMentionConfigChange}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}, [agentName, availableNodes, availableVars, extractorNodeId, mentionConfig, nodePanelWidth, onMentionConfigChange, selectedNode])
|
||||
|
||||
return (
|
||||
<Panel
|
||||
withHeader={false}
|
||||
components={{
|
||||
right: panelRight,
|
||||
}}
|
||||
/>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(SubGraphChildren)
|
||||
109
web/app/components/sub-graph/components/sub-graph-main.tsx
Normal file
109
web/app/components/sub-graph/components/sub-graph-main.tsx
Normal file
@@ -0,0 +1,109 @@
|
||||
import type { FC } from 'react'
|
||||
import type { Viewport } from 'reactflow'
|
||||
import type { SyncWorkflowDraft, SyncWorkflowDraftCallback } from '../types'
|
||||
import type { Shape as HooksStoreShape } from '@/app/components/workflow/hooks-store'
|
||||
import type { MentionConfig } from '@/app/components/workflow/nodes/_base/types'
|
||||
import type { Edge, Node } from '@/app/components/workflow/types'
|
||||
import { useCallback, useMemo } from 'react'
|
||||
import { useStoreApi } from 'reactflow'
|
||||
import { WorkflowWithInnerContext } from '@/app/components/workflow'
|
||||
import { useSetWorkflowVarsWithValue } from '@/app/components/workflow/hooks/use-fetch-workflow-inspect-vars'
|
||||
import { useInspectVarsCrudCommon } from '@/app/components/workflow/hooks/use-inspect-vars-crud-common'
|
||||
import { FlowType } from '@/types/common'
|
||||
import { useAvailableNodesMetaData } from '../hooks'
|
||||
import SubGraphChildren from './sub-graph-children'
|
||||
|
||||
type SubGraphMainProps = {
|
||||
nodes: Node[]
|
||||
edges: Edge[]
|
||||
viewport: Viewport
|
||||
agentName: string
|
||||
extractorNodeId: string
|
||||
configsMap?: HooksStoreShape['configsMap']
|
||||
mentionConfig: MentionConfig
|
||||
onMentionConfigChange: (config: MentionConfig) => void
|
||||
onSave?: (nodes: Node[], edges: Edge[]) => void
|
||||
onSyncWorkflowDraft?: SyncWorkflowDraft
|
||||
}
|
||||
|
||||
const SubGraphMain: FC<SubGraphMainProps> = ({
|
||||
nodes,
|
||||
edges,
|
||||
viewport,
|
||||
agentName,
|
||||
extractorNodeId,
|
||||
configsMap,
|
||||
mentionConfig,
|
||||
onMentionConfigChange,
|
||||
onSave,
|
||||
onSyncWorkflowDraft,
|
||||
}) => {
|
||||
const reactFlowStore = useStoreApi()
|
||||
const availableNodesMetaData = useAvailableNodesMetaData()
|
||||
const flowType = configsMap?.flowType ?? FlowType.appFlow
|
||||
const flowId = configsMap?.flowId ?? ''
|
||||
const { fetchInspectVars } = useSetWorkflowVarsWithValue({
|
||||
flowType,
|
||||
flowId,
|
||||
})
|
||||
const inspectVarsCrud = useInspectVarsCrudCommon({
|
||||
flowType,
|
||||
flowId,
|
||||
})
|
||||
|
||||
const handleSyncSubGraphDraft = useCallback(async () => {
|
||||
const { getNodes, edges } = reactFlowStore.getState()
|
||||
await onSave?.(getNodes() as Node[], edges as Edge[])
|
||||
}, [onSave, reactFlowStore])
|
||||
|
||||
const handleSyncWorkflowDraft = useCallback(async (
|
||||
notRefreshWhenSyncError?: boolean,
|
||||
callback?: SyncWorkflowDraftCallback,
|
||||
) => {
|
||||
try {
|
||||
await handleSyncSubGraphDraft()
|
||||
if (onSyncWorkflowDraft) {
|
||||
await onSyncWorkflowDraft(notRefreshWhenSyncError, callback)
|
||||
return
|
||||
}
|
||||
callback?.onSuccess?.()
|
||||
}
|
||||
catch {
|
||||
callback?.onError?.()
|
||||
}
|
||||
finally {
|
||||
callback?.onSettled?.()
|
||||
}
|
||||
}, [handleSyncSubGraphDraft, onSyncWorkflowDraft])
|
||||
|
||||
const hooksStore = useMemo(() => ({
|
||||
interactionMode: 'subgraph',
|
||||
availableNodesMetaData,
|
||||
configsMap,
|
||||
fetchInspectVars,
|
||||
...inspectVarsCrud,
|
||||
doSyncWorkflowDraft: handleSyncWorkflowDraft,
|
||||
syncWorkflowDraftWhenPageClose: handleSyncSubGraphDraft,
|
||||
}), [availableNodesMetaData, configsMap, fetchInspectVars, handleSyncSubGraphDraft, handleSyncWorkflowDraft, inspectVarsCrud])
|
||||
|
||||
return (
|
||||
<WorkflowWithInnerContext
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
viewport={viewport}
|
||||
hooksStore={hooksStore as any}
|
||||
allowSelectionWhenReadOnly
|
||||
canvasReadOnly
|
||||
interactionMode="subgraph"
|
||||
>
|
||||
<SubGraphChildren
|
||||
agentName={agentName}
|
||||
extractorNodeId={extractorNodeId}
|
||||
mentionConfig={mentionConfig}
|
||||
onMentionConfigChange={onMentionConfigChange}
|
||||
/>
|
||||
</WorkflowWithInnerContext>
|
||||
)
|
||||
}
|
||||
|
||||
export default SubGraphMain
|
||||
2
web/app/components/sub-graph/hooks/index.ts
Normal file
2
web/app/components/sub-graph/hooks/index.ts
Normal file
@@ -0,0 +1,2 @@
|
||||
export { useAvailableNodesMetaData } from './use-available-nodes-meta-data'
|
||||
export { useSubGraphNodes } from './use-sub-graph-nodes'
|
||||
@@ -0,0 +1,43 @@
|
||||
import type { AvailableNodesMetaData } from '@/app/components/workflow/hooks-store/store'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { WORKFLOW_COMMON_NODES } from '@/app/components/workflow/constants/node'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
|
||||
export const useAvailableNodesMetaData = () => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const availableNodesMetaData = useMemo(() => WORKFLOW_COMMON_NODES.map((node) => {
|
||||
const { metaData } = node
|
||||
const title = t(`blocks.${metaData.type}`, { ns: 'workflow' })
|
||||
const description = t(`blocksAbout.${metaData.type}`, { ns: 'workflow' })
|
||||
return {
|
||||
...node,
|
||||
metaData: {
|
||||
...metaData,
|
||||
title,
|
||||
description,
|
||||
},
|
||||
defaultValue: {
|
||||
...node.defaultValue,
|
||||
type: metaData.type,
|
||||
title,
|
||||
},
|
||||
}
|
||||
}), [t])
|
||||
|
||||
const availableNodesMetaDataMap = useMemo(() => availableNodesMetaData.reduce((acc, node) => {
|
||||
acc![node.metaData.type] = node
|
||||
return acc
|
||||
}, {} as AvailableNodesMetaData['nodesMap']), [availableNodesMetaData])
|
||||
|
||||
return useMemo(() => {
|
||||
return {
|
||||
nodes: availableNodesMetaData,
|
||||
nodesMap: {
|
||||
...availableNodesMetaDataMap,
|
||||
[BlockEnum.VariableAssigner]: availableNodesMetaDataMap?.[BlockEnum.VariableAggregator],
|
||||
},
|
||||
}
|
||||
}, [availableNodesMetaData, availableNodesMetaDataMap])
|
||||
}
|
||||
20
web/app/components/sub-graph/hooks/use-sub-graph-nodes.ts
Normal file
20
web/app/components/sub-graph/hooks/use-sub-graph-nodes.ts
Normal file
@@ -0,0 +1,20 @@
|
||||
import type { Edge, Node } from '@/app/components/workflow/types'
|
||||
import { useMemo } from 'react'
|
||||
import { initialEdges, initialNodes } from '@/app/components/workflow/utils'
|
||||
|
||||
export const useSubGraphNodes = (nodes: Node[], edges: Edge[]) => {
|
||||
const processedNodes = useMemo(
|
||||
() => initialNodes(nodes, edges),
|
||||
[nodes, edges],
|
||||
)
|
||||
|
||||
const processedEdges = useMemo(
|
||||
() => initialEdges(edges, nodes),
|
||||
[edges, nodes],
|
||||
)
|
||||
|
||||
return {
|
||||
nodes: processedNodes,
|
||||
edges: processedEdges,
|
||||
}
|
||||
}
|
||||
212
web/app/components/sub-graph/index.tsx
Normal file
212
web/app/components/sub-graph/index.tsx
Normal file
@@ -0,0 +1,212 @@
|
||||
import type { FC } from 'react'
|
||||
import type { Viewport } from 'reactflow'
|
||||
import type { SubGraphProps } from './types'
|
||||
import type { InjectWorkflowStoreSliceFn } from '@/app/components/workflow/store'
|
||||
import type { PromptItem, PromptTemplateItem } from '@/app/components/workflow/types'
|
||||
import { memo, useEffect, useMemo } from 'react'
|
||||
import WorkflowWithDefaultContext from '@/app/components/workflow'
|
||||
import { NODE_WIDTH_X_OFFSET, START_INITIAL_POSITION } from '@/app/components/workflow/constants'
|
||||
import { WorkflowContextProvider } from '@/app/components/workflow/context'
|
||||
import { useStore } from '@/app/components/workflow/store'
|
||||
import { BlockEnum, EditionType, isPromptMessageContext, PromptRole } from '@/app/components/workflow/types'
|
||||
import SubGraphMain from './components/sub-graph-main'
|
||||
import { useSubGraphNodes } from './hooks'
|
||||
import { createSubGraphSlice } from './store'
|
||||
|
||||
const SUB_GRAPH_EDGE_GAP = 160
|
||||
const SUB_GRAPH_ENTRY_POSITION = {
|
||||
x: START_INITIAL_POSITION.x,
|
||||
y: 150,
|
||||
}
|
||||
const SUB_GRAPH_LLM_POSITION = {
|
||||
x: SUB_GRAPH_ENTRY_POSITION.x + NODE_WIDTH_X_OFFSET - SUB_GRAPH_EDGE_GAP,
|
||||
y: SUB_GRAPH_ENTRY_POSITION.y,
|
||||
}
|
||||
|
||||
const defaultViewport: Viewport = {
|
||||
x: SUB_GRAPH_EDGE_GAP,
|
||||
y: 50,
|
||||
zoom: 1.3,
|
||||
}
|
||||
|
||||
const SubGraphContent: FC<SubGraphProps> = (props) => {
|
||||
const {
|
||||
toolNodeId,
|
||||
paramKey,
|
||||
agentName,
|
||||
agentNodeId,
|
||||
mentionConfig,
|
||||
onMentionConfigChange,
|
||||
extractorNode,
|
||||
toolParamValue,
|
||||
parentAvailableNodes,
|
||||
parentAvailableVars,
|
||||
configsMap,
|
||||
onSave,
|
||||
onSyncWorkflowDraft,
|
||||
} = props
|
||||
|
||||
const setParentAvailableVars = useStore(state => state.setParentAvailableVars)
|
||||
const setParentAvailableNodes = useStore(state => state.setParentAvailableNodes)
|
||||
|
||||
useEffect(() => {
|
||||
setParentAvailableVars?.(parentAvailableVars || [])
|
||||
setParentAvailableNodes?.(parentAvailableNodes || [])
|
||||
}, [parentAvailableNodes, parentAvailableVars, setParentAvailableNodes, setParentAvailableVars])
|
||||
|
||||
const promptText = useMemo(() => {
|
||||
if (!toolParamValue)
|
||||
return ''
|
||||
// Reason: escape agent id before building a regex pattern.
|
||||
const escapedAgentId = agentNodeId.replace(/[.*+?^${}()|[\]\\]/g, '\\$&')
|
||||
const leadingPattern = new RegExp(`^\\{\\{[@#]${escapedAgentId}\\.context[@#]\\}\\}`)
|
||||
return toolParamValue.replace(leadingPattern, '')
|
||||
}, [agentNodeId, toolParamValue])
|
||||
|
||||
const startNode = useMemo(() => {
|
||||
return {
|
||||
id: 'subgraph-source',
|
||||
type: 'custom',
|
||||
position: SUB_GRAPH_ENTRY_POSITION,
|
||||
data: {
|
||||
type: BlockEnum.Start,
|
||||
title: agentName,
|
||||
desc: '',
|
||||
_connectedSourceHandleIds: ['source'],
|
||||
_connectedTargetHandleIds: [],
|
||||
_subGraphEntry: true,
|
||||
_iconTypeOverride: BlockEnum.Agent,
|
||||
selected: false,
|
||||
variables: [],
|
||||
},
|
||||
selected: false,
|
||||
selectable: false,
|
||||
draggable: false,
|
||||
connectable: false,
|
||||
focusable: false,
|
||||
deletable: false,
|
||||
}
|
||||
}, [agentName])
|
||||
|
||||
const extractorDisplayNode = useMemo(() => {
|
||||
if (!extractorNode)
|
||||
return null
|
||||
|
||||
const applyPromptText = (item: PromptItem) => {
|
||||
if (item.edition_type === EditionType.jinja2) {
|
||||
return {
|
||||
...item,
|
||||
text: promptText,
|
||||
jinja2_text: promptText,
|
||||
}
|
||||
}
|
||||
return { ...item, text: promptText }
|
||||
}
|
||||
|
||||
const nextPromptTemplate = (() => {
|
||||
const template = extractorNode.data.prompt_template
|
||||
if (!Array.isArray(template))
|
||||
return applyPromptText(template as PromptItem)
|
||||
|
||||
const userIndex = template.findIndex(
|
||||
item => !isPromptMessageContext(item) && (item as PromptItem).role === PromptRole.user,
|
||||
)
|
||||
if (userIndex >= 0) {
|
||||
return template.map((item, index) => {
|
||||
if (index !== userIndex)
|
||||
return item
|
||||
return applyPromptText(item as PromptItem)
|
||||
}) as PromptTemplateItem[]
|
||||
}
|
||||
|
||||
const useJinja = template.some(
|
||||
item => !isPromptMessageContext(item) && (item as PromptItem).edition_type === EditionType.jinja2,
|
||||
)
|
||||
const defaultUserPrompt: PromptItem = useJinja
|
||||
? {
|
||||
role: PromptRole.user,
|
||||
text: promptText,
|
||||
jinja2_text: promptText,
|
||||
edition_type: EditionType.jinja2,
|
||||
}
|
||||
: { role: PromptRole.user, text: promptText }
|
||||
return [...template, defaultUserPrompt] as PromptTemplateItem[]
|
||||
})()
|
||||
|
||||
return {
|
||||
...extractorNode,
|
||||
hidden: false,
|
||||
selected: false,
|
||||
position: SUB_GRAPH_LLM_POSITION,
|
||||
data: {
|
||||
...extractorNode.data,
|
||||
selected: false,
|
||||
prompt_template: nextPromptTemplate,
|
||||
},
|
||||
}
|
||||
}, [extractorNode, promptText])
|
||||
|
||||
const nodesSource = useMemo(() => {
|
||||
if (!extractorDisplayNode)
|
||||
return [startNode]
|
||||
|
||||
return [startNode, extractorDisplayNode]
|
||||
}, [extractorDisplayNode, startNode])
|
||||
|
||||
const edgesSource = useMemo(() => {
|
||||
if (!extractorDisplayNode)
|
||||
return []
|
||||
|
||||
return [
|
||||
{
|
||||
id: `${startNode.id}-${extractorDisplayNode.id}`,
|
||||
source: startNode.id,
|
||||
sourceHandle: 'source',
|
||||
target: extractorDisplayNode.id,
|
||||
targetHandle: 'target',
|
||||
type: 'custom',
|
||||
selectable: false,
|
||||
data: {
|
||||
sourceType: BlockEnum.Start,
|
||||
targetType: BlockEnum.LLM,
|
||||
_isTemp: true,
|
||||
_isSubGraphTemp: true,
|
||||
},
|
||||
},
|
||||
]
|
||||
}, [extractorDisplayNode, startNode])
|
||||
|
||||
const { nodes, edges } = useSubGraphNodes(nodesSource, edgesSource)
|
||||
|
||||
return (
|
||||
<WorkflowWithDefaultContext
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
>
|
||||
<SubGraphMain
|
||||
nodes={nodes}
|
||||
edges={edges}
|
||||
viewport={defaultViewport}
|
||||
agentName={agentName}
|
||||
extractorNodeId={`${toolNodeId}_ext_${paramKey}`}
|
||||
configsMap={configsMap}
|
||||
mentionConfig={mentionConfig}
|
||||
onMentionConfigChange={onMentionConfigChange}
|
||||
onSave={onSave}
|
||||
onSyncWorkflowDraft={onSyncWorkflowDraft}
|
||||
/>
|
||||
</WorkflowWithDefaultContext>
|
||||
)
|
||||
}
|
||||
|
||||
const SubGraph: FC<SubGraphProps> = (props) => {
|
||||
return (
|
||||
<WorkflowContextProvider
|
||||
injectWorkflowStoreSliceFn={createSubGraphSlice as InjectWorkflowStoreSliceFn}
|
||||
>
|
||||
<SubGraphContent {...props} />
|
||||
</WorkflowContextProvider>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(SubGraph)
|
||||
12
web/app/components/sub-graph/store/index.ts
Normal file
12
web/app/components/sub-graph/store/index.ts
Normal file
@@ -0,0 +1,12 @@
|
||||
import type { CreateSubGraphSlice, SubGraphSliceShape } from '../types'
|
||||
|
||||
const initialState: Omit<SubGraphSliceShape, 'setParentAvailableVars' | 'setParentAvailableNodes'> = {
|
||||
parentAvailableVars: [],
|
||||
parentAvailableNodes: [],
|
||||
}
|
||||
|
||||
export const createSubGraphSlice: CreateSubGraphSlice = set => ({
|
||||
...initialState,
|
||||
setParentAvailableVars: vars => set(() => ({ parentAvailableVars: vars })),
|
||||
setParentAvailableNodes: nodes => set(() => ({ parentAvailableNodes: nodes })),
|
||||
})
|
||||
42
web/app/components/sub-graph/types.ts
Normal file
42
web/app/components/sub-graph/types.ts
Normal file
@@ -0,0 +1,42 @@
|
||||
import type { StateCreator } from 'zustand'
|
||||
import type { Shape as HooksStoreShape } from '@/app/components/workflow/hooks-store'
|
||||
import type { MentionConfig } from '@/app/components/workflow/nodes/_base/types'
|
||||
import type { LLMNodeType } from '@/app/components/workflow/nodes/llm/types'
|
||||
import type { Edge, Node, NodeOutPutVar, ValueSelector } from '@/app/components/workflow/types'
|
||||
|
||||
export type SyncWorkflowDraftCallback = {
|
||||
onSuccess?: () => void
|
||||
onError?: () => void
|
||||
onSettled?: () => void
|
||||
}
|
||||
|
||||
export type SyncWorkflowDraft = (
|
||||
notRefreshWhenSyncError?: boolean,
|
||||
callback?: SyncWorkflowDraftCallback,
|
||||
) => Promise<void>
|
||||
|
||||
export type SubGraphProps = {
|
||||
toolNodeId: string
|
||||
paramKey: string
|
||||
sourceVariable: ValueSelector
|
||||
agentNodeId: string
|
||||
agentName: string
|
||||
configsMap?: HooksStoreShape['configsMap']
|
||||
mentionConfig: MentionConfig
|
||||
onMentionConfigChange: (config: MentionConfig) => void
|
||||
extractorNode?: Node<LLMNodeType>
|
||||
toolParamValue?: string
|
||||
parentAvailableNodes?: Node[]
|
||||
parentAvailableVars?: NodeOutPutVar[]
|
||||
onSave?: (nodes: Node[], edges: Edge[]) => void
|
||||
onSyncWorkflowDraft?: SyncWorkflowDraft
|
||||
}
|
||||
|
||||
export type SubGraphSliceShape = {
|
||||
parentAvailableVars: NodeOutPutVar[]
|
||||
parentAvailableNodes: Node[]
|
||||
setParentAvailableVars: (vars: NodeOutPutVar[]) => void
|
||||
setParentAvailableNodes: (nodes: Node[]) => void
|
||||
}
|
||||
|
||||
export type CreateSubGraphSlice = StateCreator<SubGraphSliceShape>
|
||||
@@ -1,6 +1,7 @@
|
||||
import type { FC } from 'react'
|
||||
import { memo } from 'react'
|
||||
import AppIcon from '@/app/components/base/app-icon'
|
||||
import { Folder as FolderLine } from '@/app/components/base/icons/src/vender/line/files'
|
||||
import {
|
||||
Agent,
|
||||
Answer,
|
||||
@@ -54,6 +55,7 @@ const DEFAULT_ICON_MAP: Record<BlockEnum, React.ComponentType<{ className: strin
|
||||
[BlockEnum.TemplateTransform]: TemplatingTransform,
|
||||
[BlockEnum.VariableAssigner]: VariableX,
|
||||
[BlockEnum.VariableAggregator]: VariableX,
|
||||
[BlockEnum.Group]: FolderLine,
|
||||
[BlockEnum.Assigner]: Assigner,
|
||||
[BlockEnum.Tool]: VariableX,
|
||||
[BlockEnum.IterationStart]: VariableX,
|
||||
@@ -97,6 +99,7 @@ const ICON_CONTAINER_BG_COLOR_MAP: Record<string, string> = {
|
||||
[BlockEnum.VariableAssigner]: 'bg-util-colors-blue-blue-500',
|
||||
[BlockEnum.VariableAggregator]: 'bg-util-colors-blue-blue-500',
|
||||
[BlockEnum.Tool]: 'bg-util-colors-blue-blue-500',
|
||||
[BlockEnum.Group]: 'bg-util-colors-blue-blue-500',
|
||||
[BlockEnum.Assigner]: 'bg-util-colors-blue-blue-500',
|
||||
[BlockEnum.ParameterExtractor]: 'bg-util-colors-blue-blue-500',
|
||||
[BlockEnum.DocExtractor]: 'bg-util-colors-green-green-500',
|
||||
|
||||
@@ -25,7 +25,8 @@ import {
|
||||
useAvailableBlocks,
|
||||
useNodesInteractions,
|
||||
} from './hooks'
|
||||
import { NodeRunningStatus } from './types'
|
||||
import { useHooksStore } from './hooks-store'
|
||||
import { BlockEnum, NodeRunningStatus } from './types'
|
||||
import { getEdgeColor } from './utils'
|
||||
|
||||
const CustomEdge = ({
|
||||
@@ -56,6 +57,8 @@ const CustomEdge = ({
|
||||
})
|
||||
const [open, setOpen] = useState(false)
|
||||
const { handleNodeAdd } = useNodesInteractions()
|
||||
const interactionMode = useHooksStore(s => s.interactionMode)
|
||||
const allowGraphActions = interactionMode !== 'subgraph'
|
||||
const { availablePrevBlocks } = useAvailableBlocks((data as Edge['data'])!.targetType, (data as Edge['data'])?.isInIteration || (data as Edge['data'])?.isInLoop)
|
||||
const { availableNextBlocks } = useAvailableBlocks((data as Edge['data'])!.sourceType, (data as Edge['data'])?.isInIteration || (data as Edge['data'])?.isInLoop)
|
||||
const {
|
||||
@@ -136,35 +139,37 @@ const CustomEdge = ({
|
||||
stroke,
|
||||
strokeWidth: 2,
|
||||
opacity: data._dimmed ? 0.3 : (data._waitingRun ? 0.7 : 1),
|
||||
strokeDasharray: data._isTemp ? '8 8' : undefined,
|
||||
strokeDasharray: (data._isTemp && !data._isSubGraphTemp && data.sourceType !== BlockEnum.Group && data.targetType !== BlockEnum.Group) ? '8 8' : undefined,
|
||||
}}
|
||||
/>
|
||||
<EdgeLabelRenderer>
|
||||
<div
|
||||
className={cn(
|
||||
'nopan nodrag hover:scale-125',
|
||||
data?._hovering ? 'block' : 'hidden',
|
||||
open && '!block',
|
||||
data.isInIteration && `z-[${ITERATION_CHILDREN_Z_INDEX}]`,
|
||||
data.isInLoop && `z-[${LOOP_CHILDREN_Z_INDEX}]`,
|
||||
)}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
transform: `translate(-50%, -50%) translate(${labelX}px, ${labelY}px)`,
|
||||
pointerEvents: 'all',
|
||||
opacity: data._waitingRun ? 0.7 : 1,
|
||||
}}
|
||||
>
|
||||
<BlockSelector
|
||||
open={open}
|
||||
onOpenChange={handleOpenChange}
|
||||
asChild
|
||||
onSelect={handleInsert}
|
||||
availableBlocksTypes={intersection(availablePrevBlocks, availableNextBlocks)}
|
||||
triggerClassName={() => 'hover:scale-150 transition-all'}
|
||||
/>
|
||||
</div>
|
||||
</EdgeLabelRenderer>
|
||||
{allowGraphActions && (
|
||||
<EdgeLabelRenderer>
|
||||
<div
|
||||
className={cn(
|
||||
'nopan nodrag hover:scale-125',
|
||||
data?._hovering ? 'block' : 'hidden',
|
||||
open && '!block',
|
||||
data.isInIteration && `z-[${ITERATION_CHILDREN_Z_INDEX}]`,
|
||||
data.isInLoop && `z-[${LOOP_CHILDREN_Z_INDEX}]`,
|
||||
)}
|
||||
style={{
|
||||
position: 'absolute',
|
||||
transform: `translate(-50%, -50%) translate(${labelX}px, ${labelY}px)`,
|
||||
pointerEvents: 'all',
|
||||
opacity: data._waitingRun ? 0.7 : 1,
|
||||
}}
|
||||
>
|
||||
<BlockSelector
|
||||
open={open}
|
||||
onOpenChange={handleOpenChange}
|
||||
asChild
|
||||
onSelect={handleInsert}
|
||||
availableBlocksTypes={intersection(availablePrevBlocks, availableNextBlocks)}
|
||||
triggerClassName={() => 'hover:scale-150 transition-all'}
|
||||
/>
|
||||
</div>
|
||||
</EdgeLabelRenderer>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
11
web/app/components/workflow/custom-group-node/constants.ts
Normal file
11
web/app/components/workflow/custom-group-node/constants.ts
Normal file
@@ -0,0 +1,11 @@
|
||||
export const CUSTOM_GROUP_NODE = 'custom-group'
|
||||
export const CUSTOM_GROUP_INPUT_NODE = 'custom-group-input'
|
||||
export const CUSTOM_GROUP_EXIT_PORT_NODE = 'custom-group-exit-port'
|
||||
|
||||
export const GROUP_CHILDREN_Z_INDEX = 1002
|
||||
|
||||
export const UI_ONLY_GROUP_NODE_TYPES = new Set([
|
||||
CUSTOM_GROUP_NODE,
|
||||
CUSTOM_GROUP_INPUT_NODE,
|
||||
CUSTOM_GROUP_EXIT_PORT_NODE,
|
||||
])
|
||||
@@ -0,0 +1,54 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { CustomGroupExitPortNodeData } from './types'
|
||||
import { memo } from 'react'
|
||||
import { Handle, Position } from 'reactflow'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type CustomGroupExitPortNodeProps = {
|
||||
id: string
|
||||
data: CustomGroupExitPortNodeData
|
||||
}
|
||||
|
||||
const CustomGroupExitPortNode: FC<CustomGroupExitPortNodeProps> = ({ id: _id, data }) => {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'flex items-center justify-center',
|
||||
'h-8 w-8 rounded-full',
|
||||
'bg-util-colors-green-green-500 shadow-md',
|
||||
data.selected && 'ring-2 ring-primary-400',
|
||||
)}
|
||||
>
|
||||
{/* Target handle - receives internal connections from leaf nodes */}
|
||||
<Handle
|
||||
id="target"
|
||||
type="target"
|
||||
position={Position.Left}
|
||||
className="!h-2 !w-2 !border-0 !bg-white"
|
||||
/>
|
||||
|
||||
{/* Source handle - connects to external nodes */}
|
||||
<Handle
|
||||
id="source"
|
||||
type="source"
|
||||
position={Position.Right}
|
||||
className="!h-2 !w-2 !border-0 !bg-white"
|
||||
/>
|
||||
|
||||
{/* Icon */}
|
||||
<svg
|
||||
className="h-4 w-4 text-white"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth={2}
|
||||
>
|
||||
<path d="M5 12h14M12 5l7 7-7 7" />
|
||||
</svg>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(CustomGroupExitPortNode)
|
||||
@@ -0,0 +1,55 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { CustomGroupInputNodeData } from './types'
|
||||
import { memo } from 'react'
|
||||
import { Handle, Position } from 'reactflow'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type CustomGroupInputNodeProps = {
|
||||
id: string
|
||||
data: CustomGroupInputNodeData
|
||||
}
|
||||
|
||||
const CustomGroupInputNode: FC<CustomGroupInputNodeProps> = ({ id: _id, data }) => {
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'flex items-center justify-center',
|
||||
'h-8 w-8 rounded-full',
|
||||
'bg-util-colors-blue-blue-500 shadow-md',
|
||||
data.selected && 'ring-2 ring-primary-400',
|
||||
)}
|
||||
>
|
||||
{/* Target handle - receives external connections */}
|
||||
<Handle
|
||||
id="target"
|
||||
type="target"
|
||||
position={Position.Left}
|
||||
className="!h-2 !w-2 !border-0 !bg-white"
|
||||
/>
|
||||
|
||||
{/* Source handle - connects to entry nodes */}
|
||||
<Handle
|
||||
id="source"
|
||||
type="source"
|
||||
position={Position.Right}
|
||||
className="!h-2 !w-2 !border-0 !bg-white"
|
||||
/>
|
||||
|
||||
{/* Icon */}
|
||||
<svg
|
||||
className="h-4 w-4 text-white"
|
||||
viewBox="0 0 24 24"
|
||||
fill="none"
|
||||
stroke="currentColor"
|
||||
strokeWidth={2}
|
||||
>
|
||||
<path d="M9 12l2 2 4-4" />
|
||||
<circle cx="12" cy="12" r="10" />
|
||||
</svg>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(CustomGroupInputNode)
|
||||
@@ -0,0 +1,94 @@
|
||||
'use client'
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type { CustomGroupNodeData } from './types'
|
||||
import { memo } from 'react'
|
||||
import { Handle, Position } from 'reactflow'
|
||||
import { Plus02 } from '@/app/components/base/icons/src/vender/line/general'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
type CustomGroupNodeProps = {
|
||||
id: string
|
||||
data: CustomGroupNodeData
|
||||
}
|
||||
|
||||
const CustomGroupNode: FC<CustomGroupNodeProps> = ({ data }) => {
|
||||
const { group } = data
|
||||
const exitPorts = group.exitPorts ?? []
|
||||
const connectedSourceHandleIds = data._connectedSourceHandleIds ?? []
|
||||
|
||||
return (
|
||||
<div
|
||||
className={cn(
|
||||
'bg-workflow-block-parma-bg/50 group relative rounded-2xl border-2 border-dashed border-components-panel-border',
|
||||
data.selected && 'border-primary-400',
|
||||
)}
|
||||
style={{
|
||||
width: data.width || 280,
|
||||
height: data.height || 200,
|
||||
}}
|
||||
>
|
||||
{/* Group Header */}
|
||||
<div className="absolute -top-7 left-0 flex items-center gap-1 px-2">
|
||||
<span className="text-xs font-medium text-text-tertiary">
|
||||
{group.title}
|
||||
</span>
|
||||
</div>
|
||||
|
||||
{/* Target handle for incoming connections */}
|
||||
<Handle
|
||||
id="target"
|
||||
type="target"
|
||||
position={Position.Left}
|
||||
className={cn(
|
||||
'!h-4 !w-4 !rounded-none !border-none !bg-transparent !outline-none',
|
||||
'after:absolute after:left-1.5 after:top-1 after:h-2 after:w-0.5 after:bg-workflow-link-line-handle',
|
||||
'transition-all hover:scale-125',
|
||||
)}
|
||||
style={{ top: '50%' }}
|
||||
/>
|
||||
|
||||
<div className="px-3 pt-3">
|
||||
{exitPorts.map((port, index) => {
|
||||
const connected = connectedSourceHandleIds.includes(port.portNodeId)
|
||||
|
||||
return (
|
||||
<div key={port.portNodeId} className="relative flex h-6 items-center px-1">
|
||||
<div className="w-full text-right text-xs font-semibold text-text-secondary">
|
||||
{port.name}
|
||||
</div>
|
||||
|
||||
<Handle
|
||||
id={port.portNodeId}
|
||||
type="source"
|
||||
position={Position.Right}
|
||||
className={cn(
|
||||
'group/handle z-[1] !h-4 !w-4 !rounded-none !border-none !bg-transparent !outline-none',
|
||||
'after:absolute after:right-1.5 after:top-1 after:h-2 after:w-0.5 after:bg-workflow-link-line-handle',
|
||||
'transition-all hover:scale-125',
|
||||
!connected && 'after:opacity-0',
|
||||
'!-right-[21px] !top-1/2 !-translate-y-1/2',
|
||||
)}
|
||||
isConnectable
|
||||
/>
|
||||
|
||||
{/* Visual "+" indicator (styling aligned with existing branch handles) */}
|
||||
<div
|
||||
className={cn(
|
||||
'pointer-events-none absolute z-10 hidden h-4 w-4 items-center justify-center rounded-full bg-components-button-primary-bg text-text-primary-on-surface',
|
||||
'-right-[21px] top-1/2 -translate-y-1/2',
|
||||
'group-hover:flex',
|
||||
data.selected && '!flex',
|
||||
)}
|
||||
>
|
||||
<Plus02 className="h-2.5 w-2.5" />
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
})}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
export default memo(CustomGroupNode)
|
||||
19
web/app/components/workflow/custom-group-node/index.ts
Normal file
19
web/app/components/workflow/custom-group-node/index.ts
Normal file
@@ -0,0 +1,19 @@
|
||||
export {
|
||||
CUSTOM_GROUP_EXIT_PORT_NODE,
|
||||
CUSTOM_GROUP_INPUT_NODE,
|
||||
CUSTOM_GROUP_NODE,
|
||||
GROUP_CHILDREN_Z_INDEX,
|
||||
UI_ONLY_GROUP_NODE_TYPES,
|
||||
} from './constants'
|
||||
|
||||
export { default as CustomGroupExitPortNode } from './custom-group-exit-port-node'
|
||||
|
||||
export { default as CustomGroupInputNode } from './custom-group-input-node'
|
||||
export { default as CustomGroupNode } from './custom-group-node'
|
||||
export type {
|
||||
CustomGroupExitPortNodeData,
|
||||
CustomGroupInputNodeData,
|
||||
CustomGroupNodeData,
|
||||
ExitPortInfo,
|
||||
GroupMember,
|
||||
} from './types'
|
||||
82
web/app/components/workflow/custom-group-node/types.ts
Normal file
82
web/app/components/workflow/custom-group-node/types.ts
Normal file
@@ -0,0 +1,82 @@
|
||||
import type { BlockEnum } from '../types'
|
||||
|
||||
/**
|
||||
* Exit port info stored in Group node
|
||||
*/
|
||||
export type ExitPortInfo = {
|
||||
portNodeId: string
|
||||
leafNodeId: string
|
||||
sourceHandle: string
|
||||
name: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Group node data structure
|
||||
* node.type = 'custom-group'
|
||||
* node.data.type = '' (empty string to bypass backend NodeType validation)
|
||||
*/
|
||||
export type CustomGroupNodeData = {
|
||||
type: '' // Empty string bypasses backend NodeType validation
|
||||
title: string
|
||||
desc?: string
|
||||
_connectedSourceHandleIds?: string[]
|
||||
_connectedTargetHandleIds?: string[]
|
||||
group: {
|
||||
groupId: string
|
||||
title: string
|
||||
memberNodeIds: string[]
|
||||
entryNodeIds: string[]
|
||||
inputNodeId: string
|
||||
exitPorts: ExitPortInfo[]
|
||||
collapsed: boolean
|
||||
}
|
||||
width?: number
|
||||
height?: number
|
||||
selected?: boolean
|
||||
_isTempNode?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Group Input node data structure
|
||||
* node.type = 'custom-group-input'
|
||||
* node.data.type = ''
|
||||
*/
|
||||
export type CustomGroupInputNodeData = {
|
||||
type: ''
|
||||
title: string
|
||||
desc?: string
|
||||
groupInput: {
|
||||
groupId: string
|
||||
title: string
|
||||
}
|
||||
selected?: boolean
|
||||
_isTempNode?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Exit Port node data structure
|
||||
* node.type = 'custom-group-exit-port'
|
||||
* node.data.type = ''
|
||||
*/
|
||||
export type CustomGroupExitPortNodeData = {
|
||||
type: ''
|
||||
title: string
|
||||
desc?: string
|
||||
exitPort: {
|
||||
groupId: string
|
||||
leafNodeId: string
|
||||
sourceHandle: string
|
||||
name: string
|
||||
}
|
||||
selected?: boolean
|
||||
_isTempNode?: boolean
|
||||
}
|
||||
|
||||
/**
|
||||
* Member node info for display
|
||||
*/
|
||||
export type GroupMember = {
|
||||
id: string
|
||||
type: BlockEnum
|
||||
label?: string
|
||||
}
|
||||
@@ -23,6 +23,7 @@ export type AvailableNodesMetaData = {
|
||||
nodesMap?: Record<BlockEnum, NodeDefault<any>>
|
||||
}
|
||||
export type CommonHooksFnMap = {
|
||||
interactionMode?: 'default' | 'subgraph'
|
||||
doSyncWorkflowDraft: (
|
||||
notRefreshWhenSyncError?: boolean,
|
||||
callback?: {
|
||||
@@ -76,6 +77,7 @@ export type Shape = {
|
||||
} & CommonHooksFnMap
|
||||
|
||||
export const createHooksStore = ({
|
||||
interactionMode = 'default',
|
||||
doSyncWorkflowDraft = async () => noop(),
|
||||
syncWorkflowDraftWhenPageClose = noop,
|
||||
handleRefreshWorkflowDraft = noop,
|
||||
@@ -118,6 +120,7 @@ export const createHooksStore = ({
|
||||
}: Partial<Shape>) => {
|
||||
return createStore<Shape>(set => ({
|
||||
refreshAll: props => set(state => ({ ...state, ...props })),
|
||||
interactionMode,
|
||||
doSyncWorkflowDraft,
|
||||
syncWorkflowDraftWhenPageClose,
|
||||
handleRefreshWorkflowDraft,
|
||||
|
||||
@@ -197,7 +197,8 @@ export const useChecklist = (nodes: Node[], edges: Edge[]) => {
|
||||
// Start nodes and Trigger nodes should not show unConnected error if they have validation errors
|
||||
// or if they are valid start nodes (even without incoming connections)
|
||||
const isStartNodeMeta = nodesExtraData?.[node.data.type as BlockEnum]?.metaData.isStart ?? false
|
||||
const canSkipConnectionCheck = shouldCheckStartNode ? isStartNodeMeta : true
|
||||
const isSubGraphNode = Boolean((node.data as { parent_node_id?: string }).parent_node_id)
|
||||
const canSkipConnectionCheck = isSubGraphNode || (shouldCheckStartNode ? isStartNodeMeta : true)
|
||||
|
||||
const isUnconnected = !validNodes.find(n => n.id === node.id)
|
||||
const shouldShowError = errorMessage || (isUnconnected && !canSkipConnectionCheck)
|
||||
@@ -390,7 +391,8 @@ export const useChecklistBeforePublish = () => {
|
||||
}
|
||||
|
||||
const isStartNodeMeta = nodesExtraData?.[node.data.type as BlockEnum]?.metaData.isStart ?? false
|
||||
const canSkipConnectionCheck = shouldCheckStartNode ? isStartNodeMeta : true
|
||||
const isSubGraphNode = Boolean((node.data as { parent_node_id?: string }).parent_node_id)
|
||||
const canSkipConnectionCheck = isSubGraphNode || (shouldCheckStartNode ? isStartNodeMeta : true)
|
||||
const isUnconnected = !validNodes.find(n => n.id === node.id)
|
||||
|
||||
if (isUnconnected && !canSkipConnectionCheck) {
|
||||
|
||||
@@ -10,6 +10,7 @@ import { useCallback } from 'react'
|
||||
import {
|
||||
useStoreApi,
|
||||
} from 'reactflow'
|
||||
import { BlockEnum } from '../types'
|
||||
import { getNodesConnectedSourceOrTargetHandleIdsMap } from '../utils'
|
||||
import { useNodesSyncDraft } from './use-nodes-sync-draft'
|
||||
import { useNodesReadOnly } from './use-workflow'
|
||||
@@ -108,6 +109,50 @@ export const useEdgesInteractions = () => {
|
||||
return
|
||||
const currentEdge = edges[currentEdgeIndex]
|
||||
const nodes = getNodes()
|
||||
|
||||
// collect edges to delete (including corresponding real edges for temp edges)
|
||||
const edgesToDelete: Set<string> = new Set([currentEdge.id])
|
||||
|
||||
// if deleting a temp edge connected to a group, also delete the corresponding real hidden edge
|
||||
if (currentEdge.data?._isTemp) {
|
||||
const groupNode = nodes.find(n =>
|
||||
n.data.type === BlockEnum.Group
|
||||
&& (n.id === currentEdge.source || n.id === currentEdge.target),
|
||||
)
|
||||
|
||||
if (groupNode) {
|
||||
const memberIds = new Set((groupNode.data.members || []).map((m: { id: string }) => m.id))
|
||||
|
||||
if (currentEdge.target === groupNode.id) {
|
||||
// inbound temp edge: find real edge with same source, target is a head node
|
||||
edges.forEach((edge) => {
|
||||
if (edge.source === currentEdge.source
|
||||
&& memberIds.has(edge.target)
|
||||
&& edge.sourceHandle === currentEdge.sourceHandle) {
|
||||
edgesToDelete.add(edge.id)
|
||||
}
|
||||
})
|
||||
}
|
||||
else if (currentEdge.source === groupNode.id) {
|
||||
// outbound temp edge: sourceHandle format is "leafNodeId-originalHandle"
|
||||
const sourceHandle = currentEdge.sourceHandle || ''
|
||||
const lastDashIndex = sourceHandle.lastIndexOf('-')
|
||||
if (lastDashIndex > 0) {
|
||||
const leafNodeId = sourceHandle.substring(0, lastDashIndex)
|
||||
const originalHandle = sourceHandle.substring(lastDashIndex + 1)
|
||||
|
||||
edges.forEach((edge) => {
|
||||
if (edge.source === leafNodeId
|
||||
&& edge.target === currentEdge.target
|
||||
&& (edge.sourceHandle || 'source') === originalHandle) {
|
||||
edgesToDelete.add(edge.id)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
const nodesConnectedSourceOrTargetHandleIdsMap = getNodesConnectedSourceOrTargetHandleIdsMap(
|
||||
[
|
||||
{ type: 'remove', edge: currentEdge },
|
||||
@@ -126,7 +171,10 @@ export const useEdgesInteractions = () => {
|
||||
})
|
||||
setNodes(newNodes)
|
||||
const newEdges = produce(edges, (draft) => {
|
||||
draft.splice(currentEdgeIndex, 1)
|
||||
for (let i = draft.length - 1; i >= 0; i--) {
|
||||
if (edgesToDelete.has(draft[i].id))
|
||||
draft.splice(i, 1)
|
||||
}
|
||||
})
|
||||
setEdges(newEdges)
|
||||
handleSyncWorkflowDraft()
|
||||
|
||||
138
web/app/components/workflow/hooks/use-make-group.ts
Normal file
138
web/app/components/workflow/hooks/use-make-group.ts
Normal file
@@ -0,0 +1,138 @@
|
||||
import type { PredecessorHandle } from '../utils'
|
||||
import { useMemo } from 'react'
|
||||
import { useStore as useReactFlowStore } from 'reactflow'
|
||||
import { shallow } from 'zustand/shallow'
|
||||
import { BlockEnum } from '../types'
|
||||
import { getCommonPredecessorHandles } from '../utils'
|
||||
|
||||
export type MakeGroupAvailability = {
|
||||
canMakeGroup: boolean
|
||||
branchEntryNodeIds: string[]
|
||||
commonPredecessorHandle?: PredecessorHandle
|
||||
}
|
||||
|
||||
type MinimalEdge = {
|
||||
id: string
|
||||
source: string
|
||||
sourceHandle: string
|
||||
target: string
|
||||
}
|
||||
|
||||
/**
|
||||
* Pure function to check if the selected nodes can be grouped.
|
||||
* Can be called both from React hooks and imperatively.
|
||||
*/
|
||||
export const checkMakeGroupAvailability = (
|
||||
selectedNodeIds: string[],
|
||||
edges: MinimalEdge[],
|
||||
hasGroupNode = false,
|
||||
): MakeGroupAvailability => {
|
||||
if (selectedNodeIds.length <= 1 || hasGroupNode) {
|
||||
return {
|
||||
canMakeGroup: false,
|
||||
branchEntryNodeIds: [],
|
||||
commonPredecessorHandle: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
const selectedNodeIdSet = new Set(selectedNodeIds)
|
||||
const inboundFromOutsideTargets = new Set<string>()
|
||||
const incomingEdgeCounts = new Map<string, number>()
|
||||
const incomingFromSelectedTargets = new Set<string>()
|
||||
|
||||
edges.forEach((edge) => {
|
||||
// Only consider edges whose target is inside the selected subgraph.
|
||||
if (!selectedNodeIdSet.has(edge.target))
|
||||
return
|
||||
|
||||
incomingEdgeCounts.set(edge.target, (incomingEdgeCounts.get(edge.target) ?? 0) + 1)
|
||||
|
||||
if (selectedNodeIdSet.has(edge.source))
|
||||
incomingFromSelectedTargets.add(edge.target)
|
||||
else
|
||||
inboundFromOutsideTargets.add(edge.target)
|
||||
})
|
||||
|
||||
// Branch head (entry) definition:
|
||||
// - has at least one incoming edge
|
||||
// - and all its incoming edges come from outside the selected subgraph
|
||||
const branchEntryNodeIds = selectedNodeIds.filter((nodeId) => {
|
||||
const incomingEdgeCount = incomingEdgeCounts.get(nodeId) ?? 0
|
||||
if (incomingEdgeCount === 0)
|
||||
return false
|
||||
|
||||
return !incomingFromSelectedTargets.has(nodeId)
|
||||
})
|
||||
|
||||
// No branch head means we cannot tell how many branches are represented by this selection.
|
||||
if (branchEntryNodeIds.length === 0) {
|
||||
return {
|
||||
canMakeGroup: false,
|
||||
branchEntryNodeIds,
|
||||
commonPredecessorHandle: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
// Guardrail: disallow side entrances into the selected subgraph.
|
||||
// If an outside node connects to a non-entry node inside the selection, the grouping boundary is ambiguous.
|
||||
const branchEntryNodeIdSet = new Set(branchEntryNodeIds)
|
||||
const hasInboundToNonEntryNode = Array.from(inboundFromOutsideTargets).some(nodeId => !branchEntryNodeIdSet.has(nodeId))
|
||||
|
||||
if (hasInboundToNonEntryNode) {
|
||||
return {
|
||||
canMakeGroup: false,
|
||||
branchEntryNodeIds,
|
||||
commonPredecessorHandle: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
// Compare the branch heads by their common predecessor "handler" (source node + sourceHandle).
|
||||
// This is required for multi-handle nodes like If-Else / Classifier where different branches use different handles.
|
||||
const commonPredecessorHandles = getCommonPredecessorHandles(
|
||||
branchEntryNodeIds,
|
||||
// Only look at edges coming from outside the selected subgraph when determining the "pre" handler.
|
||||
edges.filter(edge => !selectedNodeIdSet.has(edge.source)),
|
||||
)
|
||||
|
||||
if (commonPredecessorHandles.length !== 1) {
|
||||
return {
|
||||
canMakeGroup: false,
|
||||
branchEntryNodeIds,
|
||||
commonPredecessorHandle: undefined,
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
canMakeGroup: true,
|
||||
branchEntryNodeIds,
|
||||
commonPredecessorHandle: commonPredecessorHandles[0],
|
||||
}
|
||||
}
|
||||
|
||||
export const useMakeGroupAvailability = (selectedNodeIds: string[]): MakeGroupAvailability => {
|
||||
const edgeKeys = useReactFlowStore((state) => {
|
||||
const delimiter = '\u0000'
|
||||
const keys = state.edges.map(edge => `${edge.source}${delimiter}${edge.sourceHandle || 'source'}${delimiter}${edge.target}`)
|
||||
keys.sort()
|
||||
return keys
|
||||
}, shallow)
|
||||
|
||||
const hasGroupNode = useReactFlowStore((state) => {
|
||||
return state.getNodes().some(node => node.selected && node.data.type === BlockEnum.Group)
|
||||
})
|
||||
|
||||
return useMemo(() => {
|
||||
const delimiter = '\u0000'
|
||||
const edges = edgeKeys.map((key) => {
|
||||
const [source, handleId, target] = key.split(delimiter)
|
||||
return {
|
||||
id: key,
|
||||
source,
|
||||
sourceHandle: handleId || 'source',
|
||||
target,
|
||||
}
|
||||
})
|
||||
|
||||
return checkMakeGroupAvailability(selectedNodeIds, edges, hasGroupNode)
|
||||
}, [edgeKeys, selectedNodeIds, hasGroupNode])
|
||||
}
|
||||
@@ -8,6 +8,7 @@ import type {
|
||||
ResizeParamsWithDirection,
|
||||
} from 'reactflow'
|
||||
import type { PluginDefaultValue } from '../block-selector/types'
|
||||
import type { GroupHandler, GroupMember, GroupNodeData } from '../nodes/group/types'
|
||||
import type { IterationNodeType } from '../nodes/iteration/types'
|
||||
import type { LoopNodeType } from '../nodes/loop/types'
|
||||
import type { VariableAssignerNodeType } from '../nodes/variable-assigner/types'
|
||||
@@ -52,6 +53,7 @@ import { useWorkflowHistoryStore } from '../workflow-history-store'
|
||||
import { useAutoGenerateWebhookUrl } from './use-auto-generate-webhook-url'
|
||||
import { useHelpline } from './use-helpline'
|
||||
import useInspectVarsCrud from './use-inspect-vars-crud'
|
||||
import { checkMakeGroupAvailability } from './use-make-group'
|
||||
import { useNodesMetaData } from './use-nodes-meta-data'
|
||||
import { useNodesSyncDraft } from './use-nodes-sync-draft'
|
||||
import {
|
||||
@@ -73,6 +75,151 @@ const ENTRY_NODE_WRAPPER_OFFSET = {
|
||||
y: 21, // Adjusted based on visual testing feedback
|
||||
} as const
|
||||
|
||||
/**
|
||||
* Parse group handler id to get original node id and sourceHandle
|
||||
* Handler id format: `${nodeId}-${sourceHandle}`
|
||||
*/
|
||||
function parseGroupHandlerId(handlerId: string): { originalNodeId: string, originalSourceHandle: string } {
|
||||
const lastDashIndex = handlerId.lastIndexOf('-')
|
||||
return {
|
||||
originalNodeId: handlerId.substring(0, lastDashIndex),
|
||||
originalSourceHandle: handlerId.substring(lastDashIndex + 1),
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Create a pair of edges for group node connections:
|
||||
* - realEdge: hidden edge from original node to target (persisted to backend)
|
||||
* - uiEdge: visible temp edge from group to target (UI-only, not persisted)
|
||||
*/
|
||||
function createGroupEdgePair(params: {
|
||||
groupNodeId: string
|
||||
handlerId: string
|
||||
targetNodeId: string
|
||||
targetHandle: string
|
||||
nodes: Node[]
|
||||
baseEdgeData?: Partial<Edge['data']>
|
||||
zIndex?: number
|
||||
}): { realEdge: Edge, uiEdge: Edge } | null {
|
||||
const { groupNodeId, handlerId, targetNodeId, targetHandle, nodes, baseEdgeData = {}, zIndex = 0 } = params
|
||||
|
||||
const groupNode = nodes.find(node => node.id === groupNodeId)
|
||||
const groupData = groupNode?.data as GroupNodeData | undefined
|
||||
const handler = groupData?.handlers?.find(h => h.id === handlerId)
|
||||
|
||||
let originalNodeId: string
|
||||
let originalSourceHandle: string
|
||||
|
||||
if (handler?.nodeId && handler?.sourceHandle) {
|
||||
originalNodeId = handler.nodeId
|
||||
originalSourceHandle = handler.sourceHandle
|
||||
}
|
||||
else {
|
||||
const parsed = parseGroupHandlerId(handlerId)
|
||||
originalNodeId = parsed.originalNodeId
|
||||
originalSourceHandle = parsed.originalSourceHandle
|
||||
}
|
||||
|
||||
const originalNode = nodes.find(node => node.id === originalNodeId)
|
||||
const targetNode = nodes.find(node => node.id === targetNodeId)
|
||||
|
||||
if (!originalNode || !targetNode)
|
||||
return null
|
||||
|
||||
// Create the real edge (from original node to target) - hidden because original node is in group
|
||||
const realEdge: Edge = {
|
||||
id: `${originalNodeId}-${originalSourceHandle}-${targetNodeId}-${targetHandle}`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: originalNodeId,
|
||||
sourceHandle: originalSourceHandle,
|
||||
target: targetNodeId,
|
||||
targetHandle,
|
||||
hidden: true,
|
||||
data: {
|
||||
...baseEdgeData,
|
||||
sourceType: originalNode.data.type,
|
||||
targetType: targetNode.data.type,
|
||||
_hiddenInGroupId: groupNodeId,
|
||||
},
|
||||
zIndex,
|
||||
}
|
||||
|
||||
// Create the UI edge (from group to target) - temporary, not persisted to backend
|
||||
const uiEdge: Edge = {
|
||||
id: `${groupNodeId}-${handlerId}-${targetNodeId}-${targetHandle}`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: groupNodeId,
|
||||
sourceHandle: handlerId,
|
||||
target: targetNodeId,
|
||||
targetHandle,
|
||||
data: {
|
||||
...baseEdgeData,
|
||||
sourceType: BlockEnum.Group,
|
||||
targetType: targetNode.data.type,
|
||||
_isTemp: true,
|
||||
},
|
||||
zIndex,
|
||||
}
|
||||
|
||||
return { realEdge, uiEdge }
|
||||
}
|
||||
|
||||
function createGroupInboundEdges(params: {
|
||||
sourceNodeId: string
|
||||
sourceHandle: string
|
||||
groupNodeId: string
|
||||
groupData: GroupNodeData
|
||||
nodes: Node[]
|
||||
baseEdgeData?: Partial<Edge['data']>
|
||||
zIndex?: number
|
||||
}): { realEdges: Edge[], uiEdge: Edge } | null {
|
||||
const { sourceNodeId, sourceHandle, groupNodeId, groupData, nodes, baseEdgeData = {}, zIndex = 0 } = params
|
||||
|
||||
const sourceNode = nodes.find(node => node.id === sourceNodeId)
|
||||
const headNodeIds = groupData.headNodeIds || []
|
||||
|
||||
if (!sourceNode || headNodeIds.length === 0)
|
||||
return null
|
||||
|
||||
const realEdges: Edge[] = headNodeIds.map((headNodeId) => {
|
||||
const headNode = nodes.find(node => node.id === headNodeId)
|
||||
return {
|
||||
id: `${sourceNodeId}-${sourceHandle}-${headNodeId}-target`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: sourceNodeId,
|
||||
sourceHandle,
|
||||
target: headNodeId,
|
||||
targetHandle: 'target',
|
||||
hidden: true,
|
||||
data: {
|
||||
...baseEdgeData,
|
||||
sourceType: sourceNode.data.type,
|
||||
targetType: headNode?.data.type,
|
||||
_hiddenInGroupId: groupNodeId,
|
||||
},
|
||||
zIndex,
|
||||
} as Edge
|
||||
})
|
||||
|
||||
const uiEdge: Edge = {
|
||||
id: `${sourceNodeId}-${sourceHandle}-${groupNodeId}-target`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: sourceNodeId,
|
||||
sourceHandle,
|
||||
target: groupNodeId,
|
||||
targetHandle: 'target',
|
||||
data: {
|
||||
...baseEdgeData,
|
||||
sourceType: sourceNode.data.type,
|
||||
targetType: BlockEnum.Group,
|
||||
_isTemp: true,
|
||||
},
|
||||
zIndex,
|
||||
}
|
||||
|
||||
return { realEdges, uiEdge }
|
||||
}
|
||||
|
||||
export const useNodesInteractions = () => {
|
||||
const { t } = useTranslation()
|
||||
const store = useStoreApi()
|
||||
@@ -448,6 +595,146 @@ export const useNodesInteractions = () => {
|
||||
return
|
||||
}
|
||||
|
||||
// Check if source is a group node - need special handling
|
||||
const isSourceGroup = sourceNode?.data.type === BlockEnum.Group
|
||||
|
||||
if (isSourceGroup && sourceHandle && target && targetHandle) {
|
||||
const { originalNodeId, originalSourceHandle } = parseGroupHandlerId(sourceHandle)
|
||||
|
||||
// Check if real edge already exists
|
||||
if (edges.find(edge =>
|
||||
edge.source === originalNodeId
|
||||
&& edge.sourceHandle === originalSourceHandle
|
||||
&& edge.target === target
|
||||
&& edge.targetHandle === targetHandle,
|
||||
)) {
|
||||
return
|
||||
}
|
||||
|
||||
const parentNode = nodes.find(node => node.id === targetNode?.parentId)
|
||||
const isInIteration = parentNode && parentNode.data.type === BlockEnum.Iteration
|
||||
const isInLoop = !!parentNode && parentNode.data.type === BlockEnum.Loop
|
||||
|
||||
const edgePair = createGroupEdgePair({
|
||||
groupNodeId: source!,
|
||||
handlerId: sourceHandle,
|
||||
targetNodeId: target,
|
||||
targetHandle,
|
||||
nodes,
|
||||
baseEdgeData: {
|
||||
isInIteration,
|
||||
iteration_id: isInIteration ? targetNode?.parentId : undefined,
|
||||
isInLoop,
|
||||
loop_id: isInLoop ? targetNode?.parentId : undefined,
|
||||
},
|
||||
})
|
||||
|
||||
if (!edgePair)
|
||||
return
|
||||
|
||||
const { realEdge, uiEdge } = edgePair
|
||||
|
||||
// Update connected handle ids for the original node
|
||||
const nodesConnectedSourceOrTargetHandleIdsMap
|
||||
= getNodesConnectedSourceOrTargetHandleIdsMap(
|
||||
[{ type: 'add', edge: realEdge }],
|
||||
nodes,
|
||||
)
|
||||
const newNodes = produce(nodes, (draft: Node[]) => {
|
||||
draft.forEach((node) => {
|
||||
if (nodesConnectedSourceOrTargetHandleIdsMap[node.id]) {
|
||||
node.data = {
|
||||
...node.data,
|
||||
...nodesConnectedSourceOrTargetHandleIdsMap[node.id],
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
const newEdges = produce(edges, (draft) => {
|
||||
draft.push(realEdge)
|
||||
draft.push(uiEdge)
|
||||
})
|
||||
|
||||
setNodes(newNodes)
|
||||
setEdges(newEdges)
|
||||
|
||||
handleSyncWorkflowDraft()
|
||||
saveStateToHistory(WorkflowHistoryEvent.NodeConnect, {
|
||||
nodeId: targetNode?.id,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
const isTargetGroup = targetNode?.data.type === BlockEnum.Group
|
||||
|
||||
if (isTargetGroup && source && sourceHandle) {
|
||||
const groupData = targetNode.data as GroupNodeData
|
||||
const headNodeIds = groupData.headNodeIds || []
|
||||
|
||||
if (edges.find(edge =>
|
||||
edge.source === source
|
||||
&& edge.sourceHandle === sourceHandle
|
||||
&& edge.target === target
|
||||
&& edge.targetHandle === targetHandle,
|
||||
)) {
|
||||
return
|
||||
}
|
||||
|
||||
const parentNode = nodes.find(node => node.id === sourceNode?.parentId)
|
||||
const isInIteration = parentNode && parentNode.data.type === BlockEnum.Iteration
|
||||
const isInLoop = !!parentNode && parentNode.data.type === BlockEnum.Loop
|
||||
|
||||
const inboundResult = createGroupInboundEdges({
|
||||
sourceNodeId: source,
|
||||
sourceHandle,
|
||||
groupNodeId: target!,
|
||||
groupData,
|
||||
nodes,
|
||||
baseEdgeData: {
|
||||
isInIteration,
|
||||
iteration_id: isInIteration ? sourceNode?.parentId : undefined,
|
||||
isInLoop,
|
||||
loop_id: isInLoop ? sourceNode?.parentId : undefined,
|
||||
},
|
||||
})
|
||||
|
||||
if (!inboundResult)
|
||||
return
|
||||
|
||||
const { realEdges, uiEdge } = inboundResult
|
||||
|
||||
const edgeChanges = realEdges.map(edge => ({ type: 'add' as const, edge }))
|
||||
const nodesConnectedSourceOrTargetHandleIdsMap
|
||||
= getNodesConnectedSourceOrTargetHandleIdsMap(edgeChanges, nodes)
|
||||
|
||||
const newNodes = produce(nodes, (draft: Node[]) => {
|
||||
draft.forEach((node) => {
|
||||
if (nodesConnectedSourceOrTargetHandleIdsMap[node.id]) {
|
||||
node.data = {
|
||||
...node.data,
|
||||
...nodesConnectedSourceOrTargetHandleIdsMap[node.id],
|
||||
}
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
const newEdges = produce(edges, (draft) => {
|
||||
realEdges.forEach((edge) => {
|
||||
draft.push(edge)
|
||||
})
|
||||
draft.push(uiEdge)
|
||||
})
|
||||
|
||||
setNodes(newNodes)
|
||||
setEdges(newEdges)
|
||||
|
||||
handleSyncWorkflowDraft()
|
||||
saveStateToHistory(WorkflowHistoryEvent.NodeConnect, {
|
||||
nodeId: headNodeIds[0],
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
if (
|
||||
edges.find(
|
||||
edge =>
|
||||
@@ -909,8 +1196,34 @@ export const useNodesInteractions = () => {
|
||||
}
|
||||
}
|
||||
|
||||
let newEdge = null
|
||||
if (nodeType !== BlockEnum.DataSource) {
|
||||
// Check if prevNode is a group node - need special handling
|
||||
const isPrevNodeGroup = prevNode.data.type === BlockEnum.Group
|
||||
let newEdge: Edge | null = null
|
||||
let newUiEdge: Edge | null = null
|
||||
|
||||
if (isPrevNodeGroup && prevNodeSourceHandle && nodeType !== BlockEnum.DataSource) {
|
||||
const edgePair = createGroupEdgePair({
|
||||
groupNodeId: prevNodeId,
|
||||
handlerId: prevNodeSourceHandle,
|
||||
targetNodeId: newNode.id,
|
||||
targetHandle,
|
||||
nodes: [...nodes, newNode],
|
||||
baseEdgeData: {
|
||||
isInIteration,
|
||||
isInLoop,
|
||||
iteration_id: isInIteration ? prevNode.parentId : undefined,
|
||||
loop_id: isInLoop ? prevNode.parentId : undefined,
|
||||
_connectedNodeIsSelected: true,
|
||||
},
|
||||
})
|
||||
|
||||
if (edgePair) {
|
||||
newEdge = edgePair.realEdge
|
||||
newUiEdge = edgePair.uiEdge
|
||||
}
|
||||
}
|
||||
else if (nodeType !== BlockEnum.DataSource) {
|
||||
// Normal case: prevNode is not a group
|
||||
newEdge = {
|
||||
id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`,
|
||||
type: CUSTOM_EDGE,
|
||||
@@ -935,9 +1248,10 @@ export const useNodesInteractions = () => {
|
||||
}
|
||||
}
|
||||
|
||||
const edgesToAdd = [newEdge, newUiEdge].filter(Boolean).map(edge => ({ type: 'add' as const, edge: edge! }))
|
||||
const nodesConnectedSourceOrTargetHandleIdsMap
|
||||
= getNodesConnectedSourceOrTargetHandleIdsMap(
|
||||
(newEdge ? [{ type: 'add', edge: newEdge }] : []),
|
||||
edgesToAdd,
|
||||
nodes,
|
||||
)
|
||||
const newNodes = produce(nodes, (draft: Node[]) => {
|
||||
@@ -1006,6 +1320,8 @@ export const useNodesInteractions = () => {
|
||||
})
|
||||
if (newEdge)
|
||||
draft.push(newEdge)
|
||||
if (newUiEdge)
|
||||
draft.push(newUiEdge)
|
||||
})
|
||||
|
||||
setNodes(newNodes)
|
||||
@@ -1090,7 +1406,7 @@ export const useNodesInteractions = () => {
|
||||
|
||||
const afterNodesInSameBranch = getAfterNodesInSameBranch(nextNodeId!)
|
||||
const afterNodesInSameBranchIds = afterNodesInSameBranch.map(
|
||||
node => node.id,
|
||||
(node: Node) => node.id,
|
||||
)
|
||||
const newNodes = produce(nodes, (draft) => {
|
||||
draft.forEach((node) => {
|
||||
@@ -1200,37 +1516,113 @@ export const useNodesInteractions = () => {
|
||||
}
|
||||
}
|
||||
|
||||
const currentEdgeIndex = edges.findIndex(
|
||||
edge => edge.source === prevNodeId && edge.target === nextNodeId,
|
||||
)
|
||||
let newPrevEdge = null
|
||||
// Check if prevNode is a group node - need special handling
|
||||
const isPrevNodeGroup = prevNode.data.type === BlockEnum.Group
|
||||
let newPrevEdge: Edge | null = null
|
||||
let newPrevUiEdge: Edge | null = null
|
||||
const edgesToRemove: string[] = []
|
||||
|
||||
if (nodeType !== BlockEnum.DataSource) {
|
||||
newPrevEdge = {
|
||||
id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: prevNodeId,
|
||||
sourceHandle: prevNodeSourceHandle,
|
||||
target: newNode.id,
|
||||
if (isPrevNodeGroup && prevNodeSourceHandle && nodeType !== BlockEnum.DataSource) {
|
||||
const { originalNodeId, originalSourceHandle } = parseGroupHandlerId(prevNodeSourceHandle)
|
||||
|
||||
// Find edges to remove: both hidden real edge and UI temp edge from group to nextNode
|
||||
const hiddenEdge = edges.find(
|
||||
edge => edge.source === originalNodeId
|
||||
&& edge.sourceHandle === originalSourceHandle
|
||||
&& edge.target === nextNodeId,
|
||||
)
|
||||
const uiTempEdge = edges.find(
|
||||
edge => edge.source === prevNodeId
|
||||
&& edge.sourceHandle === prevNodeSourceHandle
|
||||
&& edge.target === nextNodeId,
|
||||
)
|
||||
if (hiddenEdge)
|
||||
edgesToRemove.push(hiddenEdge.id)
|
||||
if (uiTempEdge)
|
||||
edgesToRemove.push(uiTempEdge.id)
|
||||
|
||||
const edgePair = createGroupEdgePair({
|
||||
groupNodeId: prevNodeId,
|
||||
handlerId: prevNodeSourceHandle,
|
||||
targetNodeId: newNode.id,
|
||||
targetHandle,
|
||||
data: {
|
||||
sourceType: prevNode.data.type,
|
||||
targetType: newNode.data.type,
|
||||
nodes: [...nodes, newNode],
|
||||
baseEdgeData: {
|
||||
isInIteration,
|
||||
isInLoop,
|
||||
iteration_id: isInIteration ? prevNode.parentId : undefined,
|
||||
loop_id: isInLoop ? prevNode.parentId : undefined,
|
||||
_connectedNodeIsSelected: true,
|
||||
},
|
||||
zIndex: prevNode.parentId
|
||||
? isInIteration
|
||||
? ITERATION_CHILDREN_Z_INDEX
|
||||
: LOOP_CHILDREN_Z_INDEX
|
||||
: 0,
|
||||
})
|
||||
|
||||
if (edgePair) {
|
||||
newPrevEdge = edgePair.realEdge
|
||||
newPrevUiEdge = edgePair.uiEdge
|
||||
}
|
||||
}
|
||||
else {
|
||||
const isNextNodeGroupForRemoval = nextNode.data.type === BlockEnum.Group
|
||||
|
||||
if (isNextNodeGroupForRemoval) {
|
||||
const groupData = nextNode.data as GroupNodeData
|
||||
const headNodeIds = groupData.headNodeIds || []
|
||||
|
||||
headNodeIds.forEach((headNodeId) => {
|
||||
const realEdge = edges.find(
|
||||
edge => edge.source === prevNodeId
|
||||
&& edge.sourceHandle === prevNodeSourceHandle
|
||||
&& edge.target === headNodeId,
|
||||
)
|
||||
if (realEdge)
|
||||
edgesToRemove.push(realEdge.id)
|
||||
})
|
||||
|
||||
const uiEdge = edges.find(
|
||||
edge => edge.source === prevNodeId
|
||||
&& edge.sourceHandle === prevNodeSourceHandle
|
||||
&& edge.target === nextNodeId,
|
||||
)
|
||||
if (uiEdge)
|
||||
edgesToRemove.push(uiEdge.id)
|
||||
}
|
||||
else {
|
||||
const currentEdge = edges.find(
|
||||
edge => edge.source === prevNodeId && edge.target === nextNodeId,
|
||||
)
|
||||
if (currentEdge)
|
||||
edgesToRemove.push(currentEdge.id)
|
||||
}
|
||||
|
||||
if (nodeType !== BlockEnum.DataSource) {
|
||||
newPrevEdge = {
|
||||
id: `${prevNodeId}-${prevNodeSourceHandle}-${newNode.id}-${targetHandle}`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: prevNodeId,
|
||||
sourceHandle: prevNodeSourceHandle,
|
||||
target: newNode.id,
|
||||
targetHandle,
|
||||
data: {
|
||||
sourceType: prevNode.data.type,
|
||||
targetType: newNode.data.type,
|
||||
isInIteration,
|
||||
isInLoop,
|
||||
iteration_id: isInIteration ? prevNode.parentId : undefined,
|
||||
loop_id: isInLoop ? prevNode.parentId : undefined,
|
||||
_connectedNodeIsSelected: true,
|
||||
},
|
||||
zIndex: prevNode.parentId
|
||||
? isInIteration
|
||||
? ITERATION_CHILDREN_Z_INDEX
|
||||
: LOOP_CHILDREN_Z_INDEX
|
||||
: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let newNextEdge: Edge | null = null
|
||||
let newNextUiEdge: Edge | null = null
|
||||
const newNextRealEdges: Edge[] = []
|
||||
|
||||
const nextNodeParentNode
|
||||
= nodes.find(node => node.id === nextNode.parentId) || null
|
||||
@@ -1241,49 +1633,113 @@ export const useNodesInteractions = () => {
|
||||
= !!nextNodeParentNode
|
||||
&& nextNodeParentNode.data.type === BlockEnum.Loop
|
||||
|
||||
const isNextNodeGroup = nextNode.data.type === BlockEnum.Group
|
||||
|
||||
if (
|
||||
nodeType !== BlockEnum.IfElse
|
||||
&& nodeType !== BlockEnum.QuestionClassifier
|
||||
&& nodeType !== BlockEnum.LoopEnd
|
||||
) {
|
||||
newNextEdge = {
|
||||
id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: newNode.id,
|
||||
sourceHandle,
|
||||
target: nextNodeId,
|
||||
targetHandle: nextNodeTargetHandle,
|
||||
data: {
|
||||
sourceType: newNode.data.type,
|
||||
targetType: nextNode.data.type,
|
||||
isInIteration: isNextNodeInIteration,
|
||||
isInLoop: isNextNodeInLoop,
|
||||
iteration_id: isNextNodeInIteration
|
||||
? nextNode.parentId
|
||||
: undefined,
|
||||
loop_id: isNextNodeInLoop ? nextNode.parentId : undefined,
|
||||
_connectedNodeIsSelected: true,
|
||||
},
|
||||
zIndex: nextNode.parentId
|
||||
? isNextNodeInIteration
|
||||
? ITERATION_CHILDREN_Z_INDEX
|
||||
: LOOP_CHILDREN_Z_INDEX
|
||||
: 0,
|
||||
if (isNextNodeGroup) {
|
||||
const groupData = nextNode.data as GroupNodeData
|
||||
const headNodeIds = groupData.headNodeIds || []
|
||||
|
||||
headNodeIds.forEach((headNodeId) => {
|
||||
const headNode = nodes.find(node => node.id === headNodeId)
|
||||
newNextRealEdges.push({
|
||||
id: `${newNode.id}-${sourceHandle}-${headNodeId}-target`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: newNode.id,
|
||||
sourceHandle,
|
||||
target: headNodeId,
|
||||
targetHandle: 'target',
|
||||
hidden: true,
|
||||
data: {
|
||||
sourceType: newNode.data.type,
|
||||
targetType: headNode?.data.type,
|
||||
isInIteration: isNextNodeInIteration,
|
||||
isInLoop: isNextNodeInLoop,
|
||||
iteration_id: isNextNodeInIteration ? nextNode.parentId : undefined,
|
||||
loop_id: isNextNodeInLoop ? nextNode.parentId : undefined,
|
||||
_hiddenInGroupId: nextNodeId,
|
||||
_connectedNodeIsSelected: true,
|
||||
},
|
||||
zIndex: nextNode.parentId
|
||||
? isNextNodeInIteration
|
||||
? ITERATION_CHILDREN_Z_INDEX
|
||||
: LOOP_CHILDREN_Z_INDEX
|
||||
: 0,
|
||||
} as Edge)
|
||||
})
|
||||
|
||||
newNextUiEdge = {
|
||||
id: `${newNode.id}-${sourceHandle}-${nextNodeId}-target`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: newNode.id,
|
||||
sourceHandle,
|
||||
target: nextNodeId,
|
||||
targetHandle: 'target',
|
||||
data: {
|
||||
sourceType: newNode.data.type,
|
||||
targetType: BlockEnum.Group,
|
||||
isInIteration: isNextNodeInIteration,
|
||||
isInLoop: isNextNodeInLoop,
|
||||
iteration_id: isNextNodeInIteration ? nextNode.parentId : undefined,
|
||||
loop_id: isNextNodeInLoop ? nextNode.parentId : undefined,
|
||||
_isTemp: true,
|
||||
_connectedNodeIsSelected: true,
|
||||
},
|
||||
zIndex: nextNode.parentId
|
||||
? isNextNodeInIteration
|
||||
? ITERATION_CHILDREN_Z_INDEX
|
||||
: LOOP_CHILDREN_Z_INDEX
|
||||
: 0,
|
||||
}
|
||||
}
|
||||
else {
|
||||
newNextEdge = {
|
||||
id: `${newNode.id}-${sourceHandle}-${nextNodeId}-${nextNodeTargetHandle}`,
|
||||
type: CUSTOM_EDGE,
|
||||
source: newNode.id,
|
||||
sourceHandle,
|
||||
target: nextNodeId,
|
||||
targetHandle: nextNodeTargetHandle,
|
||||
data: {
|
||||
sourceType: newNode.data.type,
|
||||
targetType: nextNode.data.type,
|
||||
isInIteration: isNextNodeInIteration,
|
||||
isInLoop: isNextNodeInLoop,
|
||||
iteration_id: isNextNodeInIteration
|
||||
? nextNode.parentId
|
||||
: undefined,
|
||||
loop_id: isNextNodeInLoop ? nextNode.parentId : undefined,
|
||||
_connectedNodeIsSelected: true,
|
||||
},
|
||||
zIndex: nextNode.parentId
|
||||
? isNextNodeInIteration
|
||||
? ITERATION_CHILDREN_Z_INDEX
|
||||
: LOOP_CHILDREN_Z_INDEX
|
||||
: 0,
|
||||
}
|
||||
}
|
||||
}
|
||||
const edgeChanges = [
|
||||
...edgesToRemove.map(id => ({ type: 'remove' as const, edge: edges.find(e => e.id === id)! })).filter(c => c.edge),
|
||||
...(newPrevEdge ? [{ type: 'add' as const, edge: newPrevEdge }] : []),
|
||||
...(newPrevUiEdge ? [{ type: 'add' as const, edge: newPrevUiEdge }] : []),
|
||||
...(newNextEdge ? [{ type: 'add' as const, edge: newNextEdge }] : []),
|
||||
...newNextRealEdges.map(edge => ({ type: 'add' as const, edge })),
|
||||
...(newNextUiEdge ? [{ type: 'add' as const, edge: newNextUiEdge }] : []),
|
||||
]
|
||||
const nodesConnectedSourceOrTargetHandleIdsMap
|
||||
= getNodesConnectedSourceOrTargetHandleIdsMap(
|
||||
[
|
||||
{ type: 'remove', edge: edges[currentEdgeIndex] },
|
||||
...(newPrevEdge ? [{ type: 'add', edge: newPrevEdge }] : []),
|
||||
...(newNextEdge ? [{ type: 'add', edge: newNextEdge }] : []),
|
||||
],
|
||||
edgeChanges,
|
||||
[...nodes, newNode],
|
||||
)
|
||||
|
||||
const afterNodesInSameBranch = getAfterNodesInSameBranch(nextNodeId!)
|
||||
const afterNodesInSameBranchIds = afterNodesInSameBranch.map(
|
||||
node => node.id,
|
||||
(node: Node) => node.id,
|
||||
)
|
||||
const newNodes = produce(nodes, (draft) => {
|
||||
draft.forEach((node) => {
|
||||
@@ -1342,7 +1798,10 @@ export const useNodesInteractions = () => {
|
||||
})
|
||||
}
|
||||
const newEdges = produce(edges, (draft) => {
|
||||
draft.splice(currentEdgeIndex, 1)
|
||||
const filteredDraft = draft.filter(edge => !edgesToRemove.includes(edge.id))
|
||||
draft.length = 0
|
||||
draft.push(...filteredDraft)
|
||||
|
||||
draft.forEach((item) => {
|
||||
item.data = {
|
||||
...item.data,
|
||||
@@ -1351,9 +1810,15 @@ export const useNodesInteractions = () => {
|
||||
})
|
||||
if (newPrevEdge)
|
||||
draft.push(newPrevEdge)
|
||||
|
||||
if (newPrevUiEdge)
|
||||
draft.push(newPrevUiEdge)
|
||||
if (newNextEdge)
|
||||
draft.push(newNextEdge)
|
||||
newNextRealEdges.forEach((edge) => {
|
||||
draft.push(edge)
|
||||
})
|
||||
if (newNextUiEdge)
|
||||
draft.push(newNextUiEdge)
|
||||
})
|
||||
setEdges(newEdges)
|
||||
}
|
||||
@@ -2087,6 +2552,302 @@ export const useNodesInteractions = () => {
|
||||
setEdges(newEdges)
|
||||
}, [store])
|
||||
|
||||
// Check if there are any nodes selected via box selection
|
||||
const hasBundledNodes = useCallback(() => {
|
||||
const { getNodes } = store.getState()
|
||||
const nodes = getNodes()
|
||||
return nodes.some(node => node.data._isBundled)
|
||||
}, [store])
|
||||
|
||||
const getCanMakeGroup = useCallback(() => {
|
||||
const { getNodes, edges } = store.getState()
|
||||
const nodes = getNodes()
|
||||
const bundledNodes = nodes.filter(node => node.data._isBundled)
|
||||
|
||||
if (bundledNodes.length <= 1)
|
||||
return false
|
||||
|
||||
const bundledNodeIds = bundledNodes.map(node => node.id)
|
||||
const minimalEdges = edges.map(edge => ({
|
||||
id: edge.id,
|
||||
source: edge.source,
|
||||
sourceHandle: edge.sourceHandle || 'source',
|
||||
target: edge.target,
|
||||
}))
|
||||
const hasGroupNode = bundledNodes.some(node => node.data.type === BlockEnum.Group)
|
||||
|
||||
const { canMakeGroup } = checkMakeGroupAvailability(bundledNodeIds, minimalEdges, hasGroupNode)
|
||||
return canMakeGroup
|
||||
}, [store])
|
||||
|
||||
const handleMakeGroup = useCallback(() => {
|
||||
const { getNodes, setNodes, edges, setEdges } = store.getState()
|
||||
const nodes = getNodes()
|
||||
const bundledNodes = nodes.filter(node => node.data._isBundled)
|
||||
|
||||
if (bundledNodes.length <= 1)
|
||||
return
|
||||
|
||||
const bundledNodeIds = bundledNodes.map(node => node.id)
|
||||
const minimalEdges = edges.map(edge => ({
|
||||
id: edge.id,
|
||||
source: edge.source,
|
||||
sourceHandle: edge.sourceHandle || 'source',
|
||||
target: edge.target,
|
||||
}))
|
||||
const hasGroupNode = bundledNodes.some(node => node.data.type === BlockEnum.Group)
|
||||
|
||||
const { canMakeGroup } = checkMakeGroupAvailability(bundledNodeIds, minimalEdges, hasGroupNode)
|
||||
if (!canMakeGroup)
|
||||
return
|
||||
|
||||
const bundledNodeIdSet = new Set(bundledNodeIds)
|
||||
const bundledNodeIdIsLeaf = new Set<string>()
|
||||
const inboundEdges = edges.filter(edge => !bundledNodeIdSet.has(edge.source) && bundledNodeIdSet.has(edge.target))
|
||||
const outboundEdges = edges.filter(edge => bundledNodeIdSet.has(edge.source) && !bundledNodeIdSet.has(edge.target))
|
||||
|
||||
// leaf node: no outbound edges to other nodes in the selection
|
||||
const handlers: GroupHandler[] = []
|
||||
const leafNodeIdSet = new Set<string>()
|
||||
|
||||
bundledNodes.forEach((node: Node) => {
|
||||
const targetBranches = node.data._targetBranches || [{ id: 'source', name: node.data.title }]
|
||||
targetBranches.forEach((branch) => {
|
||||
// A branch should be a handler if it's either:
|
||||
// 1. Connected to a node OUTSIDE the group
|
||||
// 2. NOT connected to any node INSIDE the group
|
||||
const isConnectedInside = edges.some(edge =>
|
||||
edge.source === node.id
|
||||
&& (edge.sourceHandle === branch.id || (!edge.sourceHandle && branch.id === 'source'))
|
||||
&& bundledNodeIdSet.has(edge.target),
|
||||
)
|
||||
const isConnectedOutside = edges.some(edge =>
|
||||
edge.source === node.id
|
||||
&& (edge.sourceHandle === branch.id || (!edge.sourceHandle && branch.id === 'source'))
|
||||
&& !bundledNodeIdSet.has(edge.target),
|
||||
)
|
||||
|
||||
if (isConnectedOutside || !isConnectedInside) {
|
||||
const handlerId = `${node.id}-${branch.id}`
|
||||
handlers.push({
|
||||
id: handlerId,
|
||||
label: branch.name || node.data.title || node.id,
|
||||
nodeId: node.id,
|
||||
sourceHandle: branch.id,
|
||||
})
|
||||
leafNodeIdSet.add(node.id)
|
||||
}
|
||||
})
|
||||
})
|
||||
|
||||
const leafNodeIds = Array.from(leafNodeIdSet)
|
||||
leafNodeIds.forEach(id => bundledNodeIdIsLeaf.add(id))
|
||||
|
||||
const members: GroupMember[] = bundledNodes.map((node) => {
|
||||
return {
|
||||
id: node.id,
|
||||
type: node.data.type,
|
||||
label: node.data.title,
|
||||
}
|
||||
})
|
||||
|
||||
// head nodes: nodes that receive input from outside the group
|
||||
const headNodeIds = [...new Set(inboundEdges.map(edge => edge.target))]
|
||||
|
||||
// put the group node at the top-left corner of the selection, slightly offset
|
||||
const { x: minX, y: minY } = getTopLeftNodePosition(bundledNodes)
|
||||
|
||||
const groupNodeData: GroupNodeData = {
|
||||
title: t('operator.makeGroup', { ns: 'workflow' }),
|
||||
desc: '',
|
||||
type: BlockEnum.Group,
|
||||
members,
|
||||
handlers,
|
||||
headNodeIds,
|
||||
leafNodeIds,
|
||||
selected: true,
|
||||
_targetBranches: handlers.map(handler => ({
|
||||
id: handler.id,
|
||||
name: handler.label || handler.id,
|
||||
})),
|
||||
}
|
||||
|
||||
const { newNode: groupNode } = generateNewNode({
|
||||
data: groupNodeData,
|
||||
position: {
|
||||
x: minX - 20,
|
||||
y: minY - 20,
|
||||
},
|
||||
})
|
||||
|
||||
const nodeTypeMap = new Map(nodes.map(node => [node.id, node.data.type]))
|
||||
|
||||
const newNodes = produce(nodes, (draft) => {
|
||||
draft.forEach((node) => {
|
||||
if (bundledNodeIdSet.has(node.id)) {
|
||||
node.data._isBundled = false
|
||||
node.selected = false
|
||||
node.hidden = true
|
||||
node.data._hiddenInGroupId = groupNode.id
|
||||
}
|
||||
else {
|
||||
node.data._isBundled = false
|
||||
}
|
||||
})
|
||||
draft.push(groupNode)
|
||||
})
|
||||
|
||||
const newEdges = produce(edges, (draft) => {
|
||||
draft.forEach((edge) => {
|
||||
if (bundledNodeIdSet.has(edge.source) || bundledNodeIdSet.has(edge.target)) {
|
||||
edge.hidden = true
|
||||
edge.data = {
|
||||
...edge.data,
|
||||
_hiddenInGroupId: groupNode.id,
|
||||
_isBundled: false,
|
||||
}
|
||||
}
|
||||
else if (edge.data?._isBundled) {
|
||||
edge.data._isBundled = false
|
||||
}
|
||||
})
|
||||
|
||||
// re-add the external inbound edges to the group node as UI-only edges (not persisted to backend)
|
||||
inboundEdges.forEach((edge) => {
|
||||
draft.push({
|
||||
id: `${edge.id}__to-${groupNode.id}`,
|
||||
type: edge.type || CUSTOM_EDGE,
|
||||
source: edge.source,
|
||||
target: groupNode.id,
|
||||
sourceHandle: edge.sourceHandle,
|
||||
targetHandle: 'target',
|
||||
data: {
|
||||
...edge.data,
|
||||
sourceType: nodeTypeMap.get(edge.source)!,
|
||||
targetType: BlockEnum.Group,
|
||||
_hiddenInGroupId: undefined,
|
||||
_isBundled: false,
|
||||
_isTemp: true, // UI-only edge, not persisted to backend
|
||||
},
|
||||
zIndex: edge.zIndex,
|
||||
})
|
||||
})
|
||||
|
||||
// outbound edges of the group node as UI-only edges (not persisted to backend)
|
||||
outboundEdges.forEach((edge) => {
|
||||
if (!bundledNodeIdIsLeaf.has(edge.source))
|
||||
return
|
||||
|
||||
// Use the same handler id format: nodeId-sourceHandle
|
||||
const originalSourceHandle = edge.sourceHandle || 'source'
|
||||
const handlerId = `${edge.source}-${originalSourceHandle}`
|
||||
|
||||
draft.push({
|
||||
id: `${groupNode.id}-${edge.target}-${edge.targetHandle || 'target'}-${handlerId}`,
|
||||
type: edge.type || CUSTOM_EDGE,
|
||||
source: groupNode.id,
|
||||
target: edge.target,
|
||||
sourceHandle: handlerId,
|
||||
targetHandle: edge.targetHandle,
|
||||
data: {
|
||||
...edge.data,
|
||||
sourceType: BlockEnum.Group,
|
||||
targetType: nodeTypeMap.get(edge.target)!,
|
||||
_hiddenInGroupId: undefined,
|
||||
_isBundled: false,
|
||||
_isTemp: true,
|
||||
},
|
||||
zIndex: edge.zIndex,
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
setNodes(newNodes)
|
||||
setEdges(newEdges)
|
||||
workflowStore.setState({
|
||||
selectionMenu: undefined,
|
||||
})
|
||||
handleSyncWorkflowDraft()
|
||||
saveStateToHistory(WorkflowHistoryEvent.NodeAdd, {
|
||||
nodeId: groupNode.id,
|
||||
})
|
||||
}, [handleSyncWorkflowDraft, saveStateToHistory, store, t, workflowStore])
|
||||
|
||||
// check if the current selection can be ungrouped (single selected Group node)
|
||||
const getCanUngroup = useCallback(() => {
|
||||
const { getNodes } = store.getState()
|
||||
const nodes = getNodes()
|
||||
const selectedNodes = nodes.filter(node => node.selected)
|
||||
|
||||
if (selectedNodes.length !== 1)
|
||||
return false
|
||||
|
||||
return selectedNodes[0].data.type === BlockEnum.Group
|
||||
}, [store])
|
||||
|
||||
// get the selected group node id for ungroup operation
|
||||
const getSelectedGroupId = useCallback(() => {
|
||||
const { getNodes } = store.getState()
|
||||
const nodes = getNodes()
|
||||
const selectedNodes = nodes.filter(node => node.selected)
|
||||
|
||||
if (selectedNodes.length === 1 && selectedNodes[0].data.type === BlockEnum.Group)
|
||||
return selectedNodes[0].id
|
||||
|
||||
return undefined
|
||||
}, [store])
|
||||
|
||||
const handleUngroup = useCallback((groupId: string) => {
|
||||
const { getNodes, setNodes, edges, setEdges } = store.getState()
|
||||
const nodes = getNodes()
|
||||
const groupNode = nodes.find(n => n.id === groupId)
|
||||
|
||||
if (!groupNode || groupNode.data.type !== BlockEnum.Group)
|
||||
return
|
||||
|
||||
const memberIds = new Set((groupNode.data.members || []).map((m: { id: string }) => m.id))
|
||||
|
||||
// restore hidden member nodes
|
||||
const newNodes = produce(nodes, (draft) => {
|
||||
draft.forEach((node) => {
|
||||
if (memberIds.has(node.id)) {
|
||||
node.hidden = false
|
||||
delete node.data._hiddenInGroupId
|
||||
}
|
||||
})
|
||||
// remove group node
|
||||
const groupIndex = draft.findIndex(n => n.id === groupId)
|
||||
if (groupIndex !== -1)
|
||||
draft.splice(groupIndex, 1)
|
||||
})
|
||||
|
||||
// restore hidden edges and remove temp edges in single pass O(E)
|
||||
const newEdges = produce(edges, (draft) => {
|
||||
const indicesToRemove: number[] = []
|
||||
|
||||
for (let i = 0; i < draft.length; i++) {
|
||||
const edge = draft[i]
|
||||
// restore hidden edges that involve member nodes
|
||||
if (edge.hidden && (memberIds.has(edge.source) || memberIds.has(edge.target)))
|
||||
edge.hidden = false
|
||||
// collect temp edges connected to group for removal
|
||||
if (edge.data?._isTemp && (edge.source === groupId || edge.target === groupId))
|
||||
indicesToRemove.push(i)
|
||||
}
|
||||
|
||||
// remove collected indices in reverse order to avoid index shift
|
||||
for (let i = indicesToRemove.length - 1; i >= 0; i--)
|
||||
draft.splice(indicesToRemove[i], 1)
|
||||
})
|
||||
|
||||
setNodes(newNodes)
|
||||
setEdges(newEdges)
|
||||
handleSyncWorkflowDraft()
|
||||
saveStateToHistory(WorkflowHistoryEvent.NodeDelete, {
|
||||
nodeId: groupId,
|
||||
})
|
||||
}, [handleSyncWorkflowDraft, saveStateToHistory, store])
|
||||
|
||||
return {
|
||||
handleNodeDragStart,
|
||||
handleNodeDrag,
|
||||
@@ -2107,11 +2868,17 @@ export const useNodesInteractions = () => {
|
||||
handleNodesPaste,
|
||||
handleNodesDuplicate,
|
||||
handleNodesDelete,
|
||||
handleMakeGroup,
|
||||
handleUngroup,
|
||||
handleNodeResize,
|
||||
handleNodeDisconnect,
|
||||
handleHistoryBack,
|
||||
handleHistoryForward,
|
||||
dimOtherNodes,
|
||||
undimAllNodes,
|
||||
hasBundledNodes,
|
||||
getCanMakeGroup,
|
||||
getCanUngroup,
|
||||
getSelectedGroupId,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import type { AvailableNodesMetaData } from '@/app/components/workflow/hooks-store'
|
||||
import type { Node } from '@/app/components/workflow/types'
|
||||
import { useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { CollectionType } from '@/app/components/tools/types'
|
||||
import { useHooksStore } from '@/app/components/workflow/hooks-store'
|
||||
import GroupDefault from '@/app/components/workflow/nodes/group/default'
|
||||
import { useStore } from '@/app/components/workflow/store'
|
||||
import { BlockEnum } from '@/app/components/workflow/types'
|
||||
import { useGetLanguage } from '@/context/i18n'
|
||||
@@ -25,6 +27,7 @@ export const useNodesMetaData = () => {
|
||||
}
|
||||
|
||||
export const useNodeMetaData = (node: Node) => {
|
||||
const { t } = useTranslation()
|
||||
const language = useGetLanguage()
|
||||
const { data: buildInTools } = useAllBuiltInTools()
|
||||
const { data: customTools } = useAllCustomTools()
|
||||
@@ -34,6 +37,9 @@ export const useNodeMetaData = (node: Node) => {
|
||||
const { data } = node
|
||||
const nodeMetaData = availableNodesMetaData.nodesMap?.[data.type]
|
||||
const author = useMemo(() => {
|
||||
if (data.type === BlockEnum.Group)
|
||||
return GroupDefault.metaData.author
|
||||
|
||||
if (data.type === BlockEnum.DataSource)
|
||||
return dataSourceList?.find(dataSource => dataSource.plugin_id === data.plugin_id)?.author
|
||||
|
||||
@@ -48,6 +54,9 @@ export const useNodeMetaData = (node: Node) => {
|
||||
}, [data, buildInTools, customTools, workflowTools, nodeMetaData, dataSourceList])
|
||||
|
||||
const description = useMemo(() => {
|
||||
if (data.type === BlockEnum.Group)
|
||||
return t('blocksAbout.group', { ns: 'workflow' })
|
||||
|
||||
if (data.type === BlockEnum.DataSource)
|
||||
return dataSourceList?.find(dataSource => dataSource.plugin_id === data.plugin_id)?.description[language]
|
||||
if (data.type === BlockEnum.Tool) {
|
||||
@@ -58,7 +67,7 @@ export const useNodeMetaData = (node: Node) => {
|
||||
return customTools?.find(toolWithProvider => toolWithProvider.id === data.provider_id)?.description[language]
|
||||
}
|
||||
return nodeMetaData?.metaData.description
|
||||
}, [data, buildInTools, customTools, workflowTools, nodeMetaData, dataSourceList, language])
|
||||
}, [data, buildInTools, customTools, workflowTools, nodeMetaData, dataSourceList, language, t])
|
||||
|
||||
return useMemo(() => {
|
||||
return {
|
||||
|
||||
@@ -17,7 +17,7 @@ import {
|
||||
} from '../utils'
|
||||
import { useWorkflowHistoryStore } from '../workflow-history-store'
|
||||
|
||||
export const useShortcuts = (): void => {
|
||||
export const useShortcuts = (enabled = true): void => {
|
||||
const {
|
||||
handleNodesCopy,
|
||||
handleNodesPaste,
|
||||
@@ -27,6 +27,12 @@ export const useShortcuts = (): void => {
|
||||
handleHistoryForward,
|
||||
dimOtherNodes,
|
||||
undimAllNodes,
|
||||
hasBundledNodes,
|
||||
getCanMakeGroup,
|
||||
handleMakeGroup,
|
||||
getCanUngroup,
|
||||
getSelectedGroupId,
|
||||
handleUngroup,
|
||||
} = useNodesInteractions()
|
||||
const { shortcutsEnabled: workflowHistoryShortcutsEnabled } = useWorkflowHistoryStore()
|
||||
const { handleSyncWorkflowDraft } = useNodesSyncDraft()
|
||||
@@ -60,13 +66,17 @@ export const useShortcuts = (): void => {
|
||||
}
|
||||
|
||||
const shouldHandleShortcut = useCallback((e: KeyboardEvent) => {
|
||||
if (!enabled)
|
||||
return false
|
||||
return !isEventTargetInputArea(e.target as HTMLElement)
|
||||
}, [])
|
||||
}, [enabled])
|
||||
|
||||
const shouldHandleCopy = useCallback(() => {
|
||||
if (!enabled)
|
||||
return false
|
||||
const selection = document.getSelection()
|
||||
return !selection || selection.isCollapsed
|
||||
}, [])
|
||||
}, [enabled])
|
||||
|
||||
useKeyPress(['delete', 'backspace'], (e) => {
|
||||
if (shouldHandleShortcut(e)) {
|
||||
@@ -78,7 +88,8 @@ export const useShortcuts = (): void => {
|
||||
|
||||
useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.c`, (e) => {
|
||||
const { showDebugAndPreviewPanel } = workflowStore.getState()
|
||||
if (shouldHandleShortcut(e) && shouldHandleCopy() && !showDebugAndPreviewPanel) {
|
||||
// Only intercept when nodes are selected via box selection
|
||||
if (shouldHandleShortcut(e) && shouldHandleCopy() && !showDebugAndPreviewPanel && hasBundledNodes()) {
|
||||
e.preventDefault()
|
||||
handleNodesCopy()
|
||||
}
|
||||
@@ -99,6 +110,26 @@ export const useShortcuts = (): void => {
|
||||
}
|
||||
}, { exactMatch: true, useCapture: true })
|
||||
|
||||
useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.g`, (e) => {
|
||||
// Only intercept when the selection can be grouped
|
||||
if (shouldHandleShortcut(e) && getCanMakeGroup()) {
|
||||
e.preventDefault()
|
||||
// Close selection context menu if open
|
||||
workflowStore.setState({ selectionMenu: undefined })
|
||||
handleMakeGroup()
|
||||
}
|
||||
}, { exactMatch: true, useCapture: true })
|
||||
|
||||
useKeyPress(`${getKeyboardKeyCodeBySystem('ctrl')}.shift.g`, (e) => {
|
||||
// Only intercept when the selection can be ungrouped
|
||||
if (shouldHandleShortcut(e) && getCanUngroup()) {
|
||||
e.preventDefault()
|
||||
const groupId = getSelectedGroupId()
|
||||
if (groupId)
|
||||
handleUngroup(groupId)
|
||||
}
|
||||
}, { exactMatch: true, useCapture: true })
|
||||
|
||||
useKeyPress(`${getKeyboardKeyCodeBySystem('alt')}.r`, (e) => {
|
||||
if (shouldHandleShortcut(e)) {
|
||||
e.preventDefault()
|
||||
@@ -255,6 +286,8 @@ export const useShortcuts = (): void => {
|
||||
|
||||
// Listen for zen toggle event from /zen command
|
||||
useEffect(() => {
|
||||
if (!enabled)
|
||||
return
|
||||
const handleZenToggle = () => {
|
||||
handleToggleMaximizeCanvas()
|
||||
}
|
||||
@@ -263,5 +296,5 @@ export const useShortcuts = (): void => {
|
||||
return () => {
|
||||
window.removeEventListener(ZEN_TOGGLE_EVENT, handleZenToggle)
|
||||
}
|
||||
}, [handleToggleMaximizeCanvas])
|
||||
}, [enabled, handleToggleMaximizeCanvas])
|
||||
}
|
||||
|
||||
@@ -37,7 +37,10 @@ export const useWorkflowNodeFinished = () => {
|
||||
}))
|
||||
|
||||
const newNodes = produce(nodes, (draft) => {
|
||||
const currentNode = draft.find(node => node.id === data.node_id)!
|
||||
const currentNode = draft.find(node => node.id === data.node_id)
|
||||
// Skip if node not found (e.g., virtual extraction nodes)
|
||||
if (!currentNode)
|
||||
return
|
||||
currentNode.data._runningStatus = data.status
|
||||
if (data.status === NodeRunningStatus.Exception) {
|
||||
if (data.execution_metadata?.error_strategy === ErrorHandleTypeEnum.failBranch)
|
||||
|
||||
@@ -45,6 +45,11 @@ export const useWorkflowNodeStarted = () => {
|
||||
} = reactflow
|
||||
const currentNodeIndex = nodes.findIndex(node => node.id === data.node_id)
|
||||
const currentNode = nodes[currentNodeIndex]
|
||||
|
||||
// Skip if node not found (e.g., virtual extraction nodes)
|
||||
if (!currentNode)
|
||||
return
|
||||
|
||||
const position = currentNode.position
|
||||
const zoom = transform[2]
|
||||
|
||||
|
||||
@@ -1,10 +1,10 @@
|
||||
import type {
|
||||
Connection,
|
||||
} from 'reactflow'
|
||||
import type { GroupNodeData } from '../nodes/group/types'
|
||||
import type { IterationNodeType } from '../nodes/iteration/types'
|
||||
import type { LoopNodeType } from '../nodes/loop/types'
|
||||
import type {
|
||||
BlockEnum,
|
||||
Edge,
|
||||
Node,
|
||||
ValueSelector,
|
||||
@@ -28,14 +28,12 @@ import {
|
||||
} from '../constants'
|
||||
import { findUsedVarNodes, getNodeOutputVars, updateNodeVars } from '../nodes/_base/components/variable/utils'
|
||||
import { CUSTOM_NOTE_NODE } from '../note-node/constants'
|
||||
|
||||
import {
|
||||
useStore,
|
||||
useWorkflowStore,
|
||||
} from '../store'
|
||||
import {
|
||||
WorkflowRunningStatus,
|
||||
} from '../types'
|
||||
|
||||
import { BlockEnum, WorkflowRunningStatus } from '../types'
|
||||
import {
|
||||
getWorkflowEntryNode,
|
||||
isWorkflowEntryNode,
|
||||
@@ -381,7 +379,7 @@ export const useWorkflow = () => {
|
||||
return startNodes
|
||||
}, [nodesMap, getRootNodesById])
|
||||
|
||||
const isValidConnection = useCallback(({ source, sourceHandle: _sourceHandle, target }: Connection) => {
|
||||
const isValidConnection = useCallback(({ source, sourceHandle, target }: Connection) => {
|
||||
const {
|
||||
edges,
|
||||
getNodes,
|
||||
@@ -396,15 +394,42 @@ export const useWorkflow = () => {
|
||||
if (sourceNode.parentId !== targetNode.parentId)
|
||||
return false
|
||||
|
||||
// For Group nodes, use the leaf node's type for validation
|
||||
// sourceHandle format: "${leafNodeId}-${originalSourceHandle}"
|
||||
let actualSourceType = sourceNode.data.type
|
||||
if (sourceNode.data.type === BlockEnum.Group && sourceHandle) {
|
||||
const lastDashIndex = sourceHandle.lastIndexOf('-')
|
||||
if (lastDashIndex > 0) {
|
||||
const leafNodeId = sourceHandle.substring(0, lastDashIndex)
|
||||
const leafNode = nodes.find(node => node.id === leafNodeId)
|
||||
if (leafNode)
|
||||
actualSourceType = leafNode.data.type
|
||||
}
|
||||
}
|
||||
|
||||
if (sourceNode && targetNode) {
|
||||
const sourceNodeAvailableNextNodes = getAvailableBlocks(sourceNode.data.type, !!sourceNode.parentId).availableNextBlocks
|
||||
const sourceNodeAvailableNextNodes = getAvailableBlocks(actualSourceType, !!sourceNode.parentId).availableNextBlocks
|
||||
const targetNodeAvailablePrevNodes = getAvailableBlocks(targetNode.data.type, !!targetNode.parentId).availablePrevBlocks
|
||||
|
||||
if (!sourceNodeAvailableNextNodes.includes(targetNode.data.type))
|
||||
return false
|
||||
if (targetNode.data.type === BlockEnum.Group) {
|
||||
const groupData = targetNode.data as GroupNodeData
|
||||
const headNodeIds = groupData.headNodeIds || []
|
||||
if (headNodeIds.length > 0) {
|
||||
const headNode = nodes.find(node => node.id === headNodeIds[0])
|
||||
if (headNode) {
|
||||
const headNodeAvailablePrevNodes = getAvailableBlocks(headNode.data.type, !!targetNode.parentId).availablePrevBlocks
|
||||
if (!headNodeAvailablePrevNodes.includes(actualSourceType))
|
||||
return false
|
||||
}
|
||||
}
|
||||
}
|
||||
else {
|
||||
if (!sourceNodeAvailableNextNodes.includes(targetNode.data.type))
|
||||
return false
|
||||
|
||||
if (!targetNodeAvailablePrevNodes.includes(sourceNode.data.type))
|
||||
return false
|
||||
if (!targetNodeAvailablePrevNodes.includes(actualSourceType))
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
const hasCycle = (node: Node, visited = new Set()) => {
|
||||
@@ -473,13 +498,9 @@ export const useNodesReadOnly = () => {
|
||||
const isRestoring = useStore(s => s.isRestoring)
|
||||
|
||||
const getNodesReadOnly = useCallback((): boolean => {
|
||||
const {
|
||||
workflowRunningData,
|
||||
historyWorkflowData,
|
||||
isRestoring,
|
||||
} = workflowStore.getState()
|
||||
const state = workflowStore.getState()
|
||||
|
||||
return !!(workflowRunningData?.result.status === WorkflowRunningStatus.Running || historyWorkflowData || isRestoring)
|
||||
return !!(state.workflowRunningData?.result.status === WorkflowRunningStatus.Running || state.historyWorkflowData || state.isRestoring)
|
||||
}, [workflowStore])
|
||||
|
||||
return {
|
||||
@@ -525,6 +546,7 @@ export const useIsNodeInLoop = (loopId: string) => {
|
||||
return false
|
||||
|
||||
if (node.parentId === loopId)
|
||||
|
||||
return true
|
||||
|
||||
return false
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import type { FC } from 'react'
|
||||
import type {
|
||||
NodeMouseHandler,
|
||||
Viewport,
|
||||
} from 'reactflow'
|
||||
import type { Shape as HooksStoreShape } from './hooks-store'
|
||||
@@ -54,6 +55,14 @@ import {
|
||||
} from './constants'
|
||||
import CustomConnectionLine from './custom-connection-line'
|
||||
import CustomEdge from './custom-edge'
|
||||
import {
|
||||
CUSTOM_GROUP_EXIT_PORT_NODE,
|
||||
CUSTOM_GROUP_INPUT_NODE,
|
||||
CUSTOM_GROUP_NODE,
|
||||
CustomGroupExitPortNode,
|
||||
CustomGroupInputNode,
|
||||
CustomGroupNode,
|
||||
} from './custom-group-node'
|
||||
import DatasetsDetailProvider from './datasets-detail-store/provider'
|
||||
import HelpLine from './help-line'
|
||||
import {
|
||||
@@ -94,6 +103,7 @@ import {
|
||||
} from './store'
|
||||
import SyncingDataModal from './syncing-data-modal'
|
||||
import {
|
||||
BlockEnum,
|
||||
ControlMode,
|
||||
} from './types'
|
||||
import { setupScrollToNodeListener } from './utils/node-navigation'
|
||||
@@ -112,6 +122,9 @@ const nodeTypes = {
|
||||
[CUSTOM_ITERATION_START_NODE]: CustomIterationStartNode,
|
||||
[CUSTOM_LOOP_START_NODE]: CustomLoopStartNode,
|
||||
[CUSTOM_DATA_SOURCE_EMPTY_NODE]: CustomDataSourceEmptyNode,
|
||||
[CUSTOM_GROUP_NODE]: CustomGroupNode,
|
||||
[CUSTOM_GROUP_INPUT_NODE]: CustomGroupInputNode,
|
||||
[CUSTOM_GROUP_EXIT_PORT_NODE]: CustomGroupExitPortNode,
|
||||
}
|
||||
const edgeTypes = {
|
||||
[CUSTOM_EDGE]: CustomEdge,
|
||||
@@ -123,6 +136,9 @@ export type WorkflowProps = {
|
||||
viewport?: Viewport
|
||||
children?: React.ReactNode
|
||||
onWorkflowDataUpdate?: (v: any) => void
|
||||
allowSelectionWhenReadOnly?: boolean
|
||||
canvasReadOnly?: boolean
|
||||
interactionMode?: 'default' | 'subgraph'
|
||||
}
|
||||
export const Workflow: FC<WorkflowProps> = memo(({
|
||||
nodes: originalNodes,
|
||||
@@ -130,6 +146,9 @@ export const Workflow: FC<WorkflowProps> = memo(({
|
||||
viewport,
|
||||
children,
|
||||
onWorkflowDataUpdate,
|
||||
allowSelectionWhenReadOnly = false,
|
||||
canvasReadOnly = false,
|
||||
interactionMode = 'default',
|
||||
}) => {
|
||||
const workflowContainerRef = useRef<HTMLDivElement>(null)
|
||||
const workflowStore = useWorkflowStore()
|
||||
@@ -182,9 +201,10 @@ export const Workflow: FC<WorkflowProps> = memo(({
|
||||
id: node.id,
|
||||
data: node.data,
|
||||
}))
|
||||
if (!isEqual(oldData, nodesData))
|
||||
if (!isEqual(oldData, nodesData)) {
|
||||
setNodesInStore(nodes)
|
||||
}, [setNodesInStore, workflowStore])
|
||||
}
|
||||
}, [setNodesInStore])
|
||||
useEffect(() => {
|
||||
setNodesOnlyChangeWithData(currentNodes as Node[])
|
||||
}, [currentNodes, setNodesOnlyChangeWithData])
|
||||
@@ -316,7 +336,8 @@ export const Workflow: FC<WorkflowProps> = memo(({
|
||||
},
|
||||
})
|
||||
|
||||
useShortcuts()
|
||||
const isSubGraph = interactionMode === 'subgraph'
|
||||
useShortcuts(!isSubGraph)
|
||||
// Initialize workflow node search functionality
|
||||
useWorkflowSearch()
|
||||
|
||||
@@ -370,6 +391,16 @@ export const Workflow: FC<WorkflowProps> = memo(({
|
||||
}
|
||||
}
|
||||
|
||||
const handleNodeClickInMode = useCallback<NodeMouseHandler>(
|
||||
(event, node) => {
|
||||
if (isSubGraph && node.data.type !== BlockEnum.LLM)
|
||||
return
|
||||
|
||||
handleNodeClick(event, node)
|
||||
},
|
||||
[handleNodeClick, isSubGraph],
|
||||
)
|
||||
|
||||
return (
|
||||
<div
|
||||
id="workflow-container"
|
||||
@@ -381,18 +412,18 @@ export const Workflow: FC<WorkflowProps> = memo(({
|
||||
ref={workflowContainerRef}
|
||||
>
|
||||
<SyncingDataModal />
|
||||
<CandidateNode />
|
||||
{!isSubGraph && <CandidateNode />}
|
||||
<div
|
||||
className="pointer-events-none absolute left-0 top-0 z-10 flex w-12 items-center justify-center p-1 pl-2"
|
||||
style={{ height: controlHeight }}
|
||||
>
|
||||
<Control />
|
||||
{!isSubGraph && <Control />}
|
||||
</div>
|
||||
<Operator handleRedo={handleHistoryForward} handleUndo={handleHistoryBack} />
|
||||
<PanelContextmenu />
|
||||
<NodeContextmenu />
|
||||
<SelectionContextmenu />
|
||||
<HelpLine />
|
||||
{!isSubGraph && <PanelContextmenu />}
|
||||
{!isSubGraph && <NodeContextmenu />}
|
||||
{!isSubGraph && <SelectionContextmenu />}
|
||||
{!isSubGraph && <HelpLine />}
|
||||
{
|
||||
!!showConfirm && (
|
||||
<Confirm
|
||||
@@ -415,38 +446,38 @@ export const Workflow: FC<WorkflowProps> = memo(({
|
||||
onNodeDragStop={handleNodeDragStop}
|
||||
onNodeMouseEnter={handleNodeEnter}
|
||||
onNodeMouseLeave={handleNodeLeave}
|
||||
onNodeClick={handleNodeClick}
|
||||
onNodeContextMenu={handleNodeContextMenu}
|
||||
onConnect={handleNodeConnect}
|
||||
onConnectStart={handleNodeConnectStart}
|
||||
onConnectEnd={handleNodeConnectEnd}
|
||||
onNodeClick={handleNodeClickInMode}
|
||||
onNodeContextMenu={isSubGraph ? undefined : handleNodeContextMenu}
|
||||
onConnect={isSubGraph ? undefined : handleNodeConnect}
|
||||
onConnectStart={isSubGraph ? undefined : handleNodeConnectStart}
|
||||
onConnectEnd={isSubGraph ? undefined : handleNodeConnectEnd}
|
||||
onEdgeMouseEnter={handleEdgeEnter}
|
||||
onEdgeMouseLeave={handleEdgeLeave}
|
||||
onEdgesChange={handleEdgesChange}
|
||||
onSelectionStart={handleSelectionStart}
|
||||
onSelectionChange={handleSelectionChange}
|
||||
onSelectionDrag={handleSelectionDrag}
|
||||
onPaneContextMenu={handlePaneContextMenu}
|
||||
onSelectionContextMenu={handleSelectionContextMenu}
|
||||
onSelectionStart={isSubGraph ? undefined : handleSelectionStart}
|
||||
onSelectionChange={isSubGraph ? undefined : handleSelectionChange}
|
||||
onSelectionDrag={isSubGraph ? undefined : handleSelectionDrag}
|
||||
onPaneContextMenu={isSubGraph ? undefined : handlePaneContextMenu}
|
||||
onSelectionContextMenu={isSubGraph ? undefined : handleSelectionContextMenu}
|
||||
connectionLineComponent={CustomConnectionLine}
|
||||
// NOTE: For LOOP node, how to distinguish between ITERATION and LOOP here? Maybe both are the same?
|
||||
connectionLineContainerStyle={{ zIndex: ITERATION_CHILDREN_Z_INDEX }}
|
||||
defaultViewport={viewport}
|
||||
multiSelectionKeyCode={null}
|
||||
deleteKeyCode={null}
|
||||
nodesDraggable={!nodesReadOnly}
|
||||
nodesConnectable={!nodesReadOnly}
|
||||
nodesFocusable={!nodesReadOnly}
|
||||
edgesFocusable={!nodesReadOnly}
|
||||
panOnScroll={controlMode === ControlMode.Pointer && !workflowReadOnly}
|
||||
panOnDrag={controlMode === ControlMode.Hand || [1]}
|
||||
zoomOnPinch={true}
|
||||
zoomOnScroll={true}
|
||||
zoomOnDoubleClick={true}
|
||||
nodesDraggable={!(nodesReadOnly || canvasReadOnly || isSubGraph)}
|
||||
nodesConnectable={!(nodesReadOnly || canvasReadOnly || isSubGraph)}
|
||||
nodesFocusable={allowSelectionWhenReadOnly ? true : !nodesReadOnly}
|
||||
edgesFocusable={isSubGraph ? false : (allowSelectionWhenReadOnly ? true : !nodesReadOnly)}
|
||||
panOnScroll={!isSubGraph && controlMode === ControlMode.Pointer && !workflowReadOnly}
|
||||
panOnDrag={!isSubGraph && (controlMode === ControlMode.Hand || [1])}
|
||||
selectionOnDrag={!isSubGraph && controlMode === ControlMode.Pointer && !workflowReadOnly && !canvasReadOnly}
|
||||
zoomOnPinch={!isSubGraph}
|
||||
zoomOnScroll={!isSubGraph}
|
||||
zoomOnDoubleClick={!isSubGraph}
|
||||
isValidConnection={isValidConnection}
|
||||
selectionKeyCode={null}
|
||||
selectionMode={SelectionMode.Partial}
|
||||
selectionOnDrag={controlMode === ControlMode.Pointer && !workflowReadOnly}
|
||||
minZoom={0.25}
|
||||
>
|
||||
<Background
|
||||
|
||||
@@ -0,0 +1,135 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import type { BlockEnum } from '@/app/components/workflow/types'
|
||||
import * as React from 'react'
|
||||
import { useState } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Input from '@/app/components/base/input'
|
||||
import BlockIcon from '@/app/components/workflow/block-icon'
|
||||
import { cn } from '@/utils/classnames'
|
||||
|
||||
export type AgentNode = {
|
||||
id: string
|
||||
title: string
|
||||
type: BlockEnum
|
||||
}
|
||||
|
||||
type ItemProps = {
|
||||
node: AgentNode
|
||||
onSelect: (node: AgentNode) => void
|
||||
}
|
||||
|
||||
const Item: FC<ItemProps> = ({ node, onSelect }) => {
|
||||
const [isHovering, setIsHovering] = useState(false)
|
||||
|
||||
return (
|
||||
<button
|
||||
type="button"
|
||||
className={cn(
|
||||
'relative flex h-6 w-full cursor-pointer items-center rounded-md border-none bg-transparent px-3 text-left',
|
||||
isHovering && 'bg-state-base-hover',
|
||||
)}
|
||||
onMouseEnter={() => setIsHovering(true)}
|
||||
onMouseLeave={() => setIsHovering(false)}
|
||||
onClick={() => onSelect(node)}
|
||||
onMouseDown={e => e.preventDefault()}
|
||||
>
|
||||
<BlockIcon
|
||||
className="mr-1 shrink-0"
|
||||
type={node.type}
|
||||
size="xs"
|
||||
/>
|
||||
<span
|
||||
className="system-sm-medium truncate text-text-secondary"
|
||||
title={node.title}
|
||||
>
|
||||
{node.title}
|
||||
</span>
|
||||
</button>
|
||||
)
|
||||
}
|
||||
|
||||
type Props = {
|
||||
nodes: AgentNode[]
|
||||
onSelect: (node: AgentNode) => void
|
||||
onClose?: () => void
|
||||
onBlur?: () => void
|
||||
hideSearch?: boolean
|
||||
searchBoxClassName?: string
|
||||
maxHeightClass?: string
|
||||
autoFocus?: boolean
|
||||
}
|
||||
|
||||
const AgentNodeList: FC<Props> = ({
|
||||
nodes,
|
||||
onSelect,
|
||||
onClose,
|
||||
onBlur,
|
||||
hideSearch,
|
||||
searchBoxClassName,
|
||||
maxHeightClass,
|
||||
autoFocus = true,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const [searchText, setSearchText] = useState('')
|
||||
|
||||
const handleKeyDown = (e: React.KeyboardEvent) => {
|
||||
if (e.key === 'Escape') {
|
||||
e.preventDefault()
|
||||
onClose?.()
|
||||
}
|
||||
}
|
||||
|
||||
const filteredNodes = nodes.filter((node) => {
|
||||
if (!searchText)
|
||||
return true
|
||||
return node.title.toLowerCase().includes(searchText.toLowerCase())
|
||||
})
|
||||
|
||||
return (
|
||||
<>
|
||||
{!hideSearch && (
|
||||
<>
|
||||
<div className={cn('mx-2 mb-2 mt-2', searchBoxClassName)}>
|
||||
<Input
|
||||
showLeftIcon
|
||||
showClearIcon
|
||||
value={searchText}
|
||||
placeholder={t('common.searchAgent', { ns: 'workflow' })}
|
||||
onChange={e => setSearchText(e.target.value)}
|
||||
onClick={e => e.stopPropagation()}
|
||||
onKeyDown={handleKeyDown}
|
||||
onClear={() => setSearchText('')}
|
||||
onBlur={onBlur}
|
||||
autoFocus={autoFocus}
|
||||
/>
|
||||
</div>
|
||||
<div
|
||||
className="relative left-[-4px] h-[0.5px] bg-black/5"
|
||||
style={{ width: 'calc(100% + 8px)' }}
|
||||
/>
|
||||
</>
|
||||
)}
|
||||
|
||||
{filteredNodes.length > 0
|
||||
? (
|
||||
<div className={cn('max-h-[85vh] overflow-y-auto py-1', maxHeightClass)}>
|
||||
{filteredNodes.map(node => (
|
||||
<Item
|
||||
key={node.id}
|
||||
node={node}
|
||||
onSelect={onSelect}
|
||||
/>
|
||||
))}
|
||||
</div>
|
||||
)
|
||||
: (
|
||||
<div className="py-2 pl-3 text-xs font-medium text-text-tertiary">
|
||||
{t('common.noAgentNodes', { ns: 'workflow' })}
|
||||
</div>
|
||||
)}
|
||||
</>
|
||||
)
|
||||
}
|
||||
|
||||
export default React.memo(AgentNodeList)
|
||||
@@ -1,6 +1,6 @@
|
||||
'use client'
|
||||
import type { FC } from 'react'
|
||||
import type { ResourceVarInputs } from '../types'
|
||||
import type { MentionConfig, ResourceVarInputs } from '../types'
|
||||
import type { CredentialFormSchema, FormOption } from '@/app/components/header/account-setting/model-provider-page/declarations'
|
||||
import type { Event, Tool } from '@/app/components/tools/types'
|
||||
import type { TriggerWithProvider } from '@/app/components/workflow/block-selector/types'
|
||||
@@ -233,13 +233,25 @@ const FormInputItem: FC<Props> = ({
|
||||
}
|
||||
}
|
||||
|
||||
const handleValueChange = (newValue: any) => {
|
||||
const handleValueChange = (newValue: any, newType?: VarKindType, mentionConfig?: MentionConfig | null) => {
|
||||
const normalizedValue = isNumber ? Number.parseFloat(newValue) : newValue
|
||||
const resolvedType = newType ?? (varInput?.type === VarKindType.mention ? VarKindType.mention : getVarKindType())
|
||||
const resolvedMentionConfig = resolvedType === VarKindType.mention
|
||||
? (mentionConfig ?? varInput?.mention_config ?? {
|
||||
extractor_node_id: '',
|
||||
output_selector: [],
|
||||
null_strategy: 'use_default',
|
||||
default_value: '',
|
||||
})
|
||||
: undefined
|
||||
|
||||
onChange({
|
||||
...value,
|
||||
[variable]: {
|
||||
...varInput,
|
||||
type: getVarKindType(),
|
||||
value: isNumber ? Number.parseFloat(newValue) : newValue,
|
||||
type: resolvedType,
|
||||
value: normalizedValue,
|
||||
mention_config: resolvedMentionConfig,
|
||||
},
|
||||
})
|
||||
}
|
||||
@@ -337,6 +349,8 @@ const FormInputItem: FC<Props> = ({
|
||||
showManageInputField={showManageInputField}
|
||||
onManageInputField={onManageInputField}
|
||||
disableVariableInsertion={disableVariableInsertion}
|
||||
toolNodeId={nodeId}
|
||||
paramKey={variable}
|
||||
/>
|
||||
)}
|
||||
{isNumber && isConstant && (
|
||||
|
||||
@@ -13,6 +13,7 @@ import {
|
||||
Stop,
|
||||
} from '@/app/components/base/icons/src/vender/line/mediaAndDevices'
|
||||
import Tooltip from '@/app/components/base/tooltip'
|
||||
import { useHooksStore } from '@/app/components/workflow/hooks-store'
|
||||
import { useWorkflowStore } from '@/app/components/workflow/store'
|
||||
import {
|
||||
useNodesInteractions,
|
||||
@@ -30,12 +31,18 @@ const NodeControl: FC<NodeControlProps> = ({
|
||||
const [open, setOpen] = useState(false)
|
||||
const { handleNodeSelect } = useNodesInteractions()
|
||||
const workflowStore = useWorkflowStore()
|
||||
const interactionMode = useHooksStore(s => s.interactionMode)
|
||||
const isSingleRunning = data._singleRunningStatus === NodeRunningStatus.Running
|
||||
const handleOpenChange = useCallback((newOpen: boolean) => {
|
||||
setOpen(newOpen)
|
||||
}, [])
|
||||
|
||||
const isChildNode = !!(data.isInIteration || data.isInLoop)
|
||||
const allowNodeMenu = interactionMode !== 'subgraph'
|
||||
const canSingleRun = canRunBySingle(data.type, isChildNode)
|
||||
|
||||
if (!allowNodeMenu && !canSingleRun)
|
||||
return null
|
||||
return (
|
||||
<div
|
||||
className={`
|
||||
@@ -50,7 +57,7 @@ const NodeControl: FC<NodeControlProps> = ({
|
||||
onClick={e => e.stopPropagation()}
|
||||
>
|
||||
{
|
||||
canRunBySingle(data.type, isChildNode) && (
|
||||
canSingleRun && (
|
||||
<div
|
||||
className={`flex h-5 w-5 items-center justify-center rounded-md ${isSingleRunning && 'cursor-pointer hover:bg-state-base-hover'}`}
|
||||
onClick={() => {
|
||||
@@ -80,13 +87,15 @@ const NodeControl: FC<NodeControlProps> = ({
|
||||
</div>
|
||||
)
|
||||
}
|
||||
<PanelOperator
|
||||
id={id}
|
||||
data={data}
|
||||
offset={0}
|
||||
onOpenChange={handleOpenChange}
|
||||
triggerClassName="!w-5 !h-5"
|
||||
/>
|
||||
{allowNodeMenu && (
|
||||
<PanelOperator
|
||||
id={id}
|
||||
data={data}
|
||||
offset={0}
|
||||
onOpenChange={handleOpenChange}
|
||||
triggerClassName="!w-5 !h-5"
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
|
||||
@@ -12,6 +12,7 @@ import {
|
||||
Handle,
|
||||
Position,
|
||||
} from 'reactflow'
|
||||
import { useHooksStore } from '@/app/components/workflow/hooks-store'
|
||||
import { cn } from '@/utils/classnames'
|
||||
import BlockSelector from '../../../block-selector'
|
||||
import {
|
||||
@@ -46,6 +47,8 @@ export const NodeTargetHandle = memo(({
|
||||
const [open, setOpen] = useState(false)
|
||||
const { handleNodeAdd } = useNodesInteractions()
|
||||
const { getNodesReadOnly } = useNodesReadOnly()
|
||||
const interactionMode = useHooksStore(s => s.interactionMode)
|
||||
const allowGraphActions = interactionMode !== 'subgraph'
|
||||
const connected = data._connectedTargetHandleIds?.includes(handleId)
|
||||
const { availablePrevBlocks } = useAvailableBlocks(data.type, data.isInIteration || data.isInLoop)
|
||||
const isConnectable = !!availablePrevBlocks.length
|
||||
@@ -55,9 +58,9 @@ export const NodeTargetHandle = memo(({
|
||||
}, [])
|
||||
const handleHandleClick = useCallback((e: MouseEvent) => {
|
||||
e.stopPropagation()
|
||||
if (!connected)
|
||||
if (!connected && allowGraphActions)
|
||||
setOpen(v => !v)
|
||||
}, [connected])
|
||||
}, [allowGraphActions, connected])
|
||||
const handleSelect = useCallback((type: BlockEnum, pluginDefaultValue?: PluginDefaultValue) => {
|
||||
handleNodeAdd(
|
||||
{
|
||||
@@ -91,11 +94,11 @@ export const NodeTargetHandle = memo(({
|
||||
|| data.type === BlockEnum.TriggerPlugin) && 'opacity-0',
|
||||
handleClassName,
|
||||
)}
|
||||
isConnectable={isConnectable}
|
||||
onClick={handleHandleClick}
|
||||
isConnectable={allowGraphActions && isConnectable}
|
||||
onClick={allowGraphActions ? handleHandleClick : undefined}
|
||||
>
|
||||
{
|
||||
!connected && isConnectable && !getNodesReadOnly() && (
|
||||
allowGraphActions && !connected && isConnectable && !getNodesReadOnly() && (
|
||||
<BlockSelector
|
||||
open={open}
|
||||
onOpenChange={handleOpenChange}
|
||||
@@ -135,6 +138,8 @@ export const NodeSourceHandle = memo(({
|
||||
const [open, setOpen] = useState(false)
|
||||
const { handleNodeAdd } = useNodesInteractions()
|
||||
const { getNodesReadOnly } = useNodesReadOnly()
|
||||
const interactionMode = useHooksStore(s => s.interactionMode)
|
||||
const allowGraphActions = interactionMode !== 'subgraph'
|
||||
const { availableNextBlocks } = useAvailableBlocks(data.type, data.isInIteration || data.isInLoop)
|
||||
const isConnectable = !!availableNextBlocks.length
|
||||
const isChatMode = useIsChatMode()
|
||||
@@ -145,8 +150,9 @@ export const NodeSourceHandle = memo(({
|
||||
}, [])
|
||||
const handleHandleClick = useCallback((e: MouseEvent) => {
|
||||
e.stopPropagation()
|
||||
setOpen(v => !v)
|
||||
}, [])
|
||||
if (allowGraphActions)
|
||||
setOpen(v => !v)
|
||||
}, [allowGraphActions])
|
||||
const handleSelect = useCallback((type: BlockEnum, pluginDefaultValue?: PluginDefaultValue) => {
|
||||
handleNodeAdd(
|
||||
{
|
||||
@@ -161,7 +167,7 @@ export const NodeSourceHandle = memo(({
|
||||
}, [handleNodeAdd, id, handleId])
|
||||
|
||||
useEffect(() => {
|
||||
if (!shouldAutoOpenStartNodeSelector)
|
||||
if (!shouldAutoOpenStartNodeSelector || !allowGraphActions)
|
||||
return
|
||||
|
||||
if (isChatMode) {
|
||||
@@ -198,8 +204,8 @@ export const NodeSourceHandle = memo(({
|
||||
!connected && 'after:opacity-0',
|
||||
handleClassName,
|
||||
)}
|
||||
isConnectable={isConnectable}
|
||||
onClick={handleHandleClick}
|
||||
isConnectable={allowGraphActions && isConnectable}
|
||||
onClick={allowGraphActions ? handleHandleClick : undefined}
|
||||
>
|
||||
<div className="absolute -top-1 left-1/2 hidden -translate-x-1/2 -translate-y-full rounded-lg border-[0.5px] border-components-panel-border bg-components-tooltip-bg p-1.5 shadow-lg group-hover/handle:block">
|
||||
<div className="system-xs-regular text-text-tertiary">
|
||||
@@ -214,7 +220,7 @@ export const NodeSourceHandle = memo(({
|
||||
</div>
|
||||
</div>
|
||||
{
|
||||
isConnectable && !getNodesReadOnly() && (
|
||||
allowGraphActions && isConnectable && !getNodesReadOnly() && (
|
||||
<BlockSelector
|
||||
open={open}
|
||||
onOpenChange={handleOpenChange}
|
||||
|
||||
@@ -41,13 +41,14 @@ const PanelOperatorPopup = ({
|
||||
handleNodesDuplicate,
|
||||
handleNodeSelect,
|
||||
handleNodesCopy,
|
||||
handleUngroup,
|
||||
} = useNodesInteractions()
|
||||
const { handleNodeDataUpdate } = useNodeDataUpdate()
|
||||
const { handleSyncWorkflowDraft } = useNodesSyncDraft()
|
||||
const { nodesReadOnly } = useNodesReadOnly()
|
||||
const edge = edges.find(edge => edge.target === id)
|
||||
const nodeMetaData = useNodeMetaData({ id, data } as Node)
|
||||
const showChangeBlock = !nodeMetaData.isTypeFixed && !nodesReadOnly
|
||||
const showChangeBlock = !nodeMetaData.isTypeFixed && !nodesReadOnly && data.type !== BlockEnum.Group
|
||||
const isChildNode = !!(data.isInIteration || data.isInLoop)
|
||||
|
||||
const { data: workflowTools } = useAllWorkflowTools()
|
||||
@@ -61,6 +62,25 @@ const PanelOperatorPopup = ({
|
||||
|
||||
return (
|
||||
<div className="w-[240px] rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-bg shadow-xl">
|
||||
{
|
||||
!nodesReadOnly && data.type === BlockEnum.Group && (
|
||||
<>
|
||||
<div className="p-1">
|
||||
<div
|
||||
className="flex h-8 cursor-pointer items-center justify-between rounded-lg px-3 text-sm text-text-secondary hover:bg-state-base-hover"
|
||||
onClick={() => {
|
||||
onClosePopup()
|
||||
handleUngroup(id)
|
||||
}}
|
||||
>
|
||||
{t('panel.ungroup', { ns: 'workflow' })}
|
||||
<ShortcutsName keys={['ctrl', 'shift', 'g']} />
|
||||
</div>
|
||||
</div>
|
||||
<div className="h-px bg-divider-regular"></div>
|
||||
</>
|
||||
)
|
||||
}
|
||||
{
|
||||
(showChangeBlock || canRunBySingle(data.type, isChildNode)) && (
|
||||
<>
|
||||
|
||||
@@ -84,6 +84,7 @@ type Props = {
|
||||
currentTool?: Tool
|
||||
currentProvider?: ToolWithProvider | TriggerWithProvider
|
||||
preferSchemaType?: boolean
|
||||
hideSearch?: boolean
|
||||
}
|
||||
|
||||
const DEFAULT_VALUE_SELECTOR: Props['value'] = []
|
||||
@@ -117,6 +118,7 @@ const VarReferencePicker: FC<Props> = ({
|
||||
currentTool,
|
||||
currentProvider,
|
||||
preferSchemaType,
|
||||
hideSearch,
|
||||
}) => {
|
||||
const { t } = useTranslation()
|
||||
const store = useStoreApi()
|
||||
@@ -636,6 +638,7 @@ const VarReferencePicker: FC<Props> = ({
|
||||
isSupportFileVar={isSupportFileVar}
|
||||
zIndex={zIndex}
|
||||
preferSchemaType={preferSchemaType}
|
||||
hideSearch={hideSearch}
|
||||
/>
|
||||
)}
|
||||
</PortalToFollowElemContent>
|
||||
|
||||
@@ -15,6 +15,7 @@ type Props = {
|
||||
onChange: (value: ValueSelector, varDetail: Var) => void
|
||||
itemWidth?: number
|
||||
isSupportFileVar?: boolean
|
||||
hideSearch?: boolean
|
||||
zIndex?: number
|
||||
preferSchemaType?: boolean
|
||||
}
|
||||
@@ -24,6 +25,7 @@ const VarReferencePopup: FC<Props> = ({
|
||||
onChange,
|
||||
itemWidth,
|
||||
isSupportFileVar = true,
|
||||
hideSearch,
|
||||
zIndex,
|
||||
preferSchemaType,
|
||||
}) => {
|
||||
@@ -35,7 +37,7 @@ const VarReferencePopup: FC<Props> = ({
|
||||
// max-h-[300px] overflow-y-auto todo: use portal to handle long list
|
||||
return (
|
||||
<div
|
||||
className="space-y-1 rounded-lg border border-components-panel-border bg-components-panel-bg p-1 shadow-lg"
|
||||
className="space-y-1 rounded-lg border-[0.5px] border-components-panel-border bg-components-panel-bg-blur p-1 shadow-lg"
|
||||
style={{
|
||||
width: itemWidth || 228,
|
||||
}}
|
||||
@@ -84,6 +86,7 @@ const VarReferencePopup: FC<Props> = ({
|
||||
showManageInputField={showManageRagInputFields}
|
||||
onManageInputField={() => setShowInputFieldPanel?.(true)}
|
||||
preferSchemaType={preferSchemaType}
|
||||
hideSearch={hideSearch}
|
||||
/>
|
||||
)}
|
||||
</div>
|
||||
|
||||
@@ -231,6 +231,8 @@ const BasePanel: FC<BasePanelProps> = ({
|
||||
} = useNodesMetaData()
|
||||
|
||||
const configsMap = useHooksStore(s => s.configsMap)
|
||||
const interactionMode = useHooksStore(s => s.interactionMode)
|
||||
const allowGraphActions = interactionMode !== 'subgraph'
|
||||
const {
|
||||
isShowSingleRun,
|
||||
hideSingleRun,
|
||||
@@ -514,9 +516,9 @@ const BasePanel: FC<BasePanelProps> = ({
|
||||
</Tooltip>
|
||||
)
|
||||
}
|
||||
<HelpLink nodeType={data.type} />
|
||||
<PanelOperator id={id} data={data} showHelpLink={false} />
|
||||
<div className="mx-3 h-3.5 w-[1px] bg-divider-regular" />
|
||||
{allowGraphActions && <HelpLink nodeType={data.type} />}
|
||||
{allowGraphActions && <PanelOperator id={id} data={data} showHelpLink={false} />}
|
||||
{allowGraphActions && <div className="mx-3 h-3.5 w-[1px] bg-divider-regular" />}
|
||||
<div
|
||||
className="flex h-6 w-6 cursor-pointer items-center justify-center"
|
||||
onClick={() => handleNodeSelect(id, true)}
|
||||
@@ -594,7 +596,7 @@ const BasePanel: FC<BasePanelProps> = ({
|
||||
)
|
||||
}
|
||||
{
|
||||
!needsToolAuth && !currentDataSource && !currentTriggerPlugin && (
|
||||
!needsToolAuth && !currentDataSource && !currentTriggerPlugin && data.type !== BlockEnum.Group && (
|
||||
<div className="flex items-center justify-between pl-4 pr-3">
|
||||
<Tab
|
||||
value={tabType}
|
||||
@@ -603,9 +605,9 @@ const BasePanel: FC<BasePanelProps> = ({
|
||||
</div>
|
||||
)
|
||||
}
|
||||
<Split />
|
||||
{data.type !== BlockEnum.Group && <Split />}
|
||||
</div>
|
||||
{tabType === TabType.settings && (
|
||||
{(tabType === TabType.settings || data.type === BlockEnum.Group) && (
|
||||
<div className="flex flex-1 flex-col overflow-y-auto">
|
||||
<div>
|
||||
{cloneElement(children as any, {
|
||||
@@ -623,7 +625,7 @@ const BasePanel: FC<BasePanelProps> = ({
|
||||
</div>
|
||||
<Split />
|
||||
{
|
||||
hasRetryNode(data.type) && (
|
||||
allowGraphActions && hasRetryNode(data.type) && (
|
||||
<RetryOnPanel
|
||||
id={id}
|
||||
data={data}
|
||||
@@ -631,7 +633,7 @@ const BasePanel: FC<BasePanelProps> = ({
|
||||
)
|
||||
}
|
||||
{
|
||||
hasErrorHandleNode(data.type) && (
|
||||
allowGraphActions && hasErrorHandleNode(data.type) && (
|
||||
<ErrorHandleOnPanel
|
||||
id={id}
|
||||
data={data}
|
||||
@@ -639,7 +641,7 @@ const BasePanel: FC<BasePanelProps> = ({
|
||||
)
|
||||
}
|
||||
{
|
||||
!!availableNextBlocks.length && (
|
||||
allowGraphActions && !!availableNextBlocks.length && (
|
||||
<div className="border-t-[0.5px] border-divider-regular p-4">
|
||||
<div className="system-sm-semibold-uppercase mb-1 flex items-center text-text-secondary">
|
||||
{t('panel.nextStep', { ns: 'workflow' }).toLocaleUpperCase()}
|
||||
@@ -651,7 +653,7 @@ const BasePanel: FC<BasePanelProps> = ({
|
||||
</div>
|
||||
)
|
||||
}
|
||||
{readmeEntranceComponent}
|
||||
{allowGraphActions ? readmeEntranceComponent : null}
|
||||
</div>
|
||||
)}
|
||||
|
||||
|
||||
@@ -56,6 +56,7 @@ const singleRunFormParamsHooks: Record<BlockEnum, any> = {
|
||||
[BlockEnum.VariableAggregator]: useVariableAggregatorSingleRunFormParams,
|
||||
[BlockEnum.Assigner]: useVariableAssignerSingleRunFormParams,
|
||||
[BlockEnum.KnowledgeBase]: useKnowledgeBaseSingleRunFormParams,
|
||||
[BlockEnum.Group]: undefined,
|
||||
[BlockEnum.VariableAssigner]: undefined,
|
||||
[BlockEnum.End]: undefined,
|
||||
[BlockEnum.Answer]: undefined,
|
||||
@@ -103,6 +104,7 @@ const getDataForCheckMoreHooks: Record<BlockEnum, any> = {
|
||||
[BlockEnum.DataSource]: undefined,
|
||||
[BlockEnum.DataSourceEmpty]: undefined,
|
||||
[BlockEnum.KnowledgeBase]: undefined,
|
||||
[BlockEnum.Group]: undefined,
|
||||
[BlockEnum.TriggerWebhook]: undefined,
|
||||
[BlockEnum.TriggerSchedule]: undefined,
|
||||
[BlockEnum.TriggerPlugin]: useTriggerPluginGetDataForCheckMore,
|
||||
|
||||
@@ -63,6 +63,11 @@ const BaseNode: FC<BaseNodeProps> = ({
|
||||
const { t } = useTranslation()
|
||||
const nodeRef = useRef<HTMLDivElement>(null)
|
||||
const { nodesReadOnly } = useNodesReadOnly()
|
||||
const { _subGraphEntry, _iconTypeOverride } = data as {
|
||||
_subGraphEntry?: boolean
|
||||
_iconTypeOverride?: BlockEnum
|
||||
}
|
||||
const iconType = _iconTypeOverride ?? data.type
|
||||
|
||||
const { handleNodeIterationChildSizeChange } = useNodeIterationInteractions()
|
||||
const { handleNodeLoopChildSizeChange } = useNodeLoopInteractions()
|
||||
@@ -138,6 +143,48 @@ const BaseNode: FC<BaseNodeProps> = ({
|
||||
return null
|
||||
}, [data._loopIndex, data._runningStatus, t])
|
||||
|
||||
if (_subGraphEntry) {
|
||||
return (
|
||||
<div
|
||||
className="relative"
|
||||
ref={nodeRef}
|
||||
>
|
||||
<NodeSourceHandle
|
||||
id={id}
|
||||
data={data}
|
||||
handleClassName="!top-1/2 !-right-[9px] !-translate-y-1/2 opacity-0 pointer-events-none after:opacity-0"
|
||||
handleId="source"
|
||||
/>
|
||||
<div
|
||||
className={cn(
|
||||
'flex rounded-2xl border p-0.5',
|
||||
showSelectedBorder ? 'border-components-option-card-option-selected-border' : 'border-workflow-block-border',
|
||||
data._waitingRun && 'opacity-70',
|
||||
showRunningBorder && '!border-state-accent-solid',
|
||||
showSuccessBorder && '!border-state-success-solid',
|
||||
showFailedBorder && '!border-state-destructive-solid',
|
||||
showExceptionBorder && '!border-state-warning-solid',
|
||||
)}
|
||||
>
|
||||
<div className="flex items-center gap-2 rounded-[15px] bg-workflow-block-bg px-3 py-2 shadow-xs">
|
||||
<BlockIcon
|
||||
className="shrink-0"
|
||||
type={iconType}
|
||||
size="md"
|
||||
toolIcon={toolIcon}
|
||||
/>
|
||||
<div
|
||||
title={data.title}
|
||||
className="system-sm-semibold-uppercase text-text-primary"
|
||||
>
|
||||
{data.title}
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
const nodeContent = (
|
||||
<div
|
||||
className={cn(
|
||||
@@ -221,7 +268,7 @@ const BaseNode: FC<BaseNodeProps> = ({
|
||||
)
|
||||
}
|
||||
{
|
||||
data.type !== BlockEnum.IfElse && data.type !== BlockEnum.QuestionClassifier && !data._isCandidate && (
|
||||
data.type !== BlockEnum.IfElse && data.type !== BlockEnum.QuestionClassifier && data.type !== BlockEnum.Group && !data._isCandidate && (
|
||||
<NodeSourceHandle
|
||||
id={id}
|
||||
data={data}
|
||||
@@ -245,7 +292,7 @@ const BaseNode: FC<BaseNodeProps> = ({
|
||||
>
|
||||
<BlockIcon
|
||||
className="mr-2 shrink-0"
|
||||
type={data.type}
|
||||
type={iconType}
|
||||
size="md"
|
||||
toolIcon={toolIcon}
|
||||
/>
|
||||
@@ -344,8 +391,9 @@ const BaseNode: FC<BaseNodeProps> = ({
|
||||
|
||||
const isStartNode = data.type === BlockEnum.Start
|
||||
const isEntryNode = isTriggerNode(data.type as any) || isStartNode
|
||||
const shouldWrapEntryNode = isEntryNode && !(isStartNode && _subGraphEntry)
|
||||
|
||||
return isEntryNode
|
||||
return shouldWrapEntryNode
|
||||
? (
|
||||
<EntryNodeContainer
|
||||
nodeType={isStartNode ? StartNodeTypeEnum.Start : StartNodeTypeEnum.Trigger}
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user