mirror of
https://github.com/langgenius/dify.git
synced 2026-04-14 20:42:39 +00:00
Compare commits
40 Commits
feat/tidb-
...
feat/new-a
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e9ee897973 | ||
|
|
971828615e | ||
|
|
b804c7ed47 | ||
|
|
c7a7c73034 | ||
|
|
94b3087b98 | ||
|
|
3e0578a1c6 | ||
|
|
5f87239abc | ||
|
|
c03b25a940 | ||
|
|
90cce7693f | ||
|
|
77c182f738 | ||
|
|
e04f00d29b | ||
|
|
bbed99a4cb | ||
|
|
df6c1064c6 | ||
|
|
f4e04fc872 | ||
|
|
59b9221501 | ||
|
|
218c10ba4f | ||
|
|
4c878da9e6 | ||
|
|
698af54c4f | ||
|
|
10bb276e97 | ||
|
|
73fd439541 | ||
|
|
5cdae671d5 | ||
|
|
e50c36526e | ||
|
|
2de2a8fd3a | ||
|
|
e2e16772a1 | ||
|
|
b21a443d56 | ||
|
|
4f010cd4f5 | ||
|
|
3d4be88d97 | ||
|
|
482a004efe | ||
|
|
7052257c8d | ||
|
|
edfcab6455 | ||
|
|
66212e3575 | ||
|
|
96374d7f6a | ||
|
|
44491e427c | ||
|
|
d3d9f21cdf | ||
|
|
0c7e7e0c4e | ||
|
|
d9d1e9b63a | ||
|
|
bebafaa346 | ||
|
|
1835a1dc5d | ||
|
|
8f3a3ea03e | ||
|
|
96641a93f6 |
@@ -287,6 +287,27 @@ class MarketplaceConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class CreatorsPlatformConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for creators platform
|
||||
"""
|
||||
|
||||
CREATORS_PLATFORM_FEATURES_ENABLED: bool = Field(
|
||||
description="Enable or disable creators platform features",
|
||||
default=True,
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_API_URL: HttpUrl = Field(
|
||||
description="Creators Platform API URL",
|
||||
default=HttpUrl("https://creators.dify.ai"),
|
||||
)
|
||||
|
||||
CREATORS_PLATFORM_OAUTH_CLIENT_ID: str = Field(
|
||||
description="OAuth client_id for the Creators Platform app registered in Dify",
|
||||
default="",
|
||||
)
|
||||
|
||||
|
||||
class EndpointConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for various application endpoints and URLs
|
||||
@@ -341,6 +362,15 @@ class FileAccessConfig(BaseSettings):
|
||||
default="",
|
||||
)
|
||||
|
||||
FILES_API_URL: str = Field(
|
||||
description="Base URL for storage file ticket API endpoints."
|
||||
" Used by sandbox containers (internal or external like e2b) that need"
|
||||
" an absolute, routable address to upload/download files via the API."
|
||||
" For all-in-one Docker deployments, set to http://localhost."
|
||||
" For public sandbox environments, set to a public domain or IP.",
|
||||
default="",
|
||||
)
|
||||
|
||||
FILES_ACCESS_TIMEOUT: int = Field(
|
||||
description="Expiration time in seconds for file access URLs",
|
||||
default=300,
|
||||
@@ -1274,6 +1304,52 @@ class PositionConfig(BaseSettings):
|
||||
return {item.strip() for item in self.POSITION_TOOL_EXCLUDES.split(",") if item.strip() != ""}
|
||||
|
||||
|
||||
class CollaborationConfig(BaseSettings):
|
||||
ENABLE_COLLABORATION_MODE: bool = Field(
|
||||
description="Whether to enable collaboration mode features across the workspace",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class SandboxExpiredRecordsCleanConfig(BaseSettings):
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
|
||||
description="Graceful period in days for sandbox records clean after subscription expiration",
|
||||
default=21,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Maximum number of records to process in each batch",
|
||||
default=1000,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL: PositiveInt = Field(
|
||||
description="Maximum interval in milliseconds between batches",
|
||||
default=200,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Retention days for sandbox expired workflow_run records and message records",
|
||||
default=30,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: PositiveInt = Field(
|
||||
description="Lock TTL for sandbox expired records clean task in seconds",
|
||||
default=90000,
|
||||
)
|
||||
|
||||
|
||||
class AgentV2UpgradeConfig(BaseSettings):
|
||||
"""Feature flags for transparent Agent V2 upgrade."""
|
||||
|
||||
AGENT_V2_TRANSPARENT_UPGRADE: bool = Field(
|
||||
description="Transparently run old apps (chat/completion/agent-chat) through the Agent V2 workflow engine. "
|
||||
"When enabled, old apps synthesize a virtual workflow at runtime instead of using legacy runners.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
AGENT_V2_REPLACES_LLM: bool = Field(
|
||||
description="Transparently replace LLM nodes in workflows with Agent V2 nodes at runtime. "
|
||||
"LLMNodeData is remapped to AgentV2NodeData with tools=[] (identical behavior).",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class LoginConfig(BaseSettings):
|
||||
ENABLE_EMAIL_CODE_LOGIN: bool = Field(
|
||||
description="whether to enable email code login",
|
||||
@@ -1343,29 +1419,6 @@ class TenantIsolatedTaskQueueConfig(BaseSettings):
|
||||
)
|
||||
|
||||
|
||||
class SandboxExpiredRecordsCleanConfig(BaseSettings):
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD: NonNegativeInt = Field(
|
||||
description="Graceful period in days for sandbox records clean after subscription expiration",
|
||||
default=21,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE: PositiveInt = Field(
|
||||
description="Maximum number of records to process in each batch",
|
||||
default=1000,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL: PositiveInt = Field(
|
||||
description="Maximum interval in milliseconds between batches",
|
||||
default=200,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Retention days for sandbox expired workflow_run records and message records",
|
||||
default=30,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL: PositiveInt = Field(
|
||||
description="Lock TTL for sandbox expired records clean task in seconds",
|
||||
default=90000,
|
||||
)
|
||||
|
||||
|
||||
class FeatureConfig(
|
||||
# place the configs in alphabet order
|
||||
AppExecutionConfig,
|
||||
@@ -1376,6 +1429,7 @@ class FeatureConfig(
|
||||
AsyncWorkflowConfig,
|
||||
PluginConfig,
|
||||
MarketplaceConfig,
|
||||
CreatorsPlatformConfig,
|
||||
DataSetConfig,
|
||||
EndpointConfig,
|
||||
FileAccessConfig,
|
||||
@@ -1391,7 +1445,6 @@ class FeatureConfig(
|
||||
PositionConfig,
|
||||
RagEtlConfig,
|
||||
RepositoryConfig,
|
||||
SandboxExpiredRecordsCleanConfig,
|
||||
SecurityConfig,
|
||||
TenantIsolatedTaskQueueConfig,
|
||||
ToolConfig,
|
||||
@@ -1399,6 +1452,9 @@ class FeatureConfig(
|
||||
WorkflowConfig,
|
||||
WorkflowNodeExecutionConfig,
|
||||
WorkspaceConfig,
|
||||
CollaborationConfig,
|
||||
AgentV2UpgradeConfig,
|
||||
SandboxExpiredRecordsCleanConfig,
|
||||
LoginConfig,
|
||||
AccountConfig,
|
||||
SwaggerUIConfig,
|
||||
|
||||
@@ -81,4 +81,20 @@ default_app_templates: Mapping[AppMode, Mapping] = {
|
||||
},
|
||||
},
|
||||
},
|
||||
# agent default mode (new agent backed by single-node workflow)
|
||||
AppMode.AGENT: {
|
||||
"app": {
|
||||
"mode": AppMode.AGENT,
|
||||
"enable_site": True,
|
||||
"enable_api": True,
|
||||
},
|
||||
"model_config": {
|
||||
"model": {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
@@ -52,7 +52,7 @@ from services.entities.knowledge_entities.knowledge_entities import (
|
||||
)
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion"]
|
||||
ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "completion", "agent"]
|
||||
|
||||
register_enum_models(console_ns, IconType)
|
||||
|
||||
@@ -62,7 +62,7 @@ _logger = logging.getLogger(__name__)
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
limit: int = Field(default=20, ge=1, le=100, description="Page size (1-100)")
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "channel", "all"] = Field(
|
||||
mode: Literal["completion", "chat", "advanced-chat", "workflow", "agent-chat", "agent", "channel", "all"] = Field(
|
||||
default="all", description="App mode filter"
|
||||
)
|
||||
name: str | None = Field(default=None, description="Filter by app name")
|
||||
@@ -94,7 +94,9 @@ class AppListQuery(BaseModel):
|
||||
class CreateAppPayload(BaseModel):
|
||||
name: str = Field(..., min_length=1, description="App name")
|
||||
description: str | None = Field(default=None, description="App description (max 400 chars)", max_length=400)
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion"] = Field(..., description="App mode")
|
||||
mode: Literal["chat", "agent-chat", "advanced-chat", "workflow", "completion", "agent"] = Field(
|
||||
..., description="App mode"
|
||||
)
|
||||
icon_type: IconType | None = Field(default=None, description="Icon type")
|
||||
icon: str | None = Field(default=None, description="Icon")
|
||||
icon_background: str | None = Field(default=None, description="Icon background color")
|
||||
|
||||
@@ -161,7 +161,7 @@ class ChatMessageApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model):
|
||||
args_model = ChatMessagePayload.model_validate(console_ns.payload)
|
||||
@@ -215,7 +215,7 @@ class ChatMessageStopApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
def post(self, app_model, task_id):
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
@@ -237,7 +237,7 @@ class ChatMessageListApi(Resource):
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@setup_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@marshal_with(message_infinite_scroll_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model):
|
||||
@@ -393,7 +393,7 @@ class MessageSuggestedQuestionApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
def get(self, app_model, message_id):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
message_id = str(message_id)
|
||||
|
||||
@@ -206,7 +206,7 @@ class DraftWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
@@ -226,7 +226,7 @@ class DraftWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@console_ns.doc("sync_draft_workflow")
|
||||
@console_ns.doc(description="Sync draft workflow configuration")
|
||||
@console_ns.expect(console_ns.models[SyncDraftWorkflowPayload.__name__])
|
||||
@@ -310,7 +310,7 @@ class AdvancedChatDraftWorkflowRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
@@ -356,7 +356,7 @@ class AdvancedChatDraftRunIterationNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
@@ -432,7 +432,7 @@ class AdvancedChatDraftRunLoopNodeApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
@@ -534,7 +534,7 @@ class AdvancedChatDraftHumanInputFormPreviewApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
@@ -563,7 +563,7 @@ class AdvancedChatDraftHumanInputFormRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
@@ -718,7 +718,7 @@ class WorkflowTaskStopApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, task_id: str):
|
||||
"""
|
||||
@@ -746,7 +746,7 @@ class DraftWorkflowNodeRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
@@ -792,7 +792,7 @@ class PublishedWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
@@ -810,7 +810,7 @@ class PublishedWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App):
|
||||
"""
|
||||
@@ -854,7 +854,7 @@ class DefaultBlockConfigsApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@@ -876,7 +876,7 @@ class DefaultBlockConfigApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App, block_type: str):
|
||||
"""
|
||||
@@ -941,7 +941,7 @@ class PublishedAllWorkflowApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_pagination_model)
|
||||
@edit_permission_required
|
||||
def get(self, app_model: App):
|
||||
@@ -990,7 +990,7 @@ class DraftWorkflowRestoreApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, workflow_id: str):
|
||||
current_user, _ = current_account_with_tenant()
|
||||
@@ -1028,7 +1028,7 @@ class WorkflowByIdApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_model)
|
||||
@edit_permission_required
|
||||
def patch(self, app_model: App, workflow_id: str):
|
||||
@@ -1068,7 +1068,7 @@ class WorkflowByIdApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@edit_permission_required
|
||||
def delete(self, app_model: App, workflow_id: str):
|
||||
"""
|
||||
@@ -1103,7 +1103,7 @@ class DraftWorkflowNodeLastRunApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_run_node_execution_model)
|
||||
def get(self, app_model: App, node_id: str):
|
||||
srv = WorkflowService()
|
||||
|
||||
322
api/controllers/console/app/workflow_comment.py
Normal file
322
api/controllers/console/app/workflow_comment.py
Normal file
@@ -0,0 +1,322 @@
|
||||
import logging
|
||||
|
||||
from flask_restx import Resource, marshal_with
|
||||
from pydantic import BaseModel, Field, TypeAdapter
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from fields.member_fields import AccountWithRole
|
||||
from fields.workflow_comment_fields import (
|
||||
workflow_comment_basic_fields,
|
||||
workflow_comment_create_fields,
|
||||
workflow_comment_detail_fields,
|
||||
workflow_comment_reply_create_fields,
|
||||
workflow_comment_reply_update_fields,
|
||||
workflow_comment_resolve_fields,
|
||||
workflow_comment_update_fields,
|
||||
)
|
||||
from libs.login import current_user, login_required
|
||||
from models import App
|
||||
from services.account_service import TenantService
|
||||
from services.workflow_comment_service import WorkflowCommentService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowCommentCreatePayload(BaseModel):
|
||||
position_x: float = Field(..., description="Comment X position")
|
||||
position_y: float = Field(..., description="Comment Y position")
|
||||
content: str = Field(..., description="Comment content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Comment content")
|
||||
position_x: float | None = Field(default=None, description="Comment X position")
|
||||
position_y: float | None = Field(default=None, description="Comment Y position")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentReplyCreatePayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentReplyUpdatePayload(BaseModel):
|
||||
content: str = Field(..., description="Reply content")
|
||||
mentioned_user_ids: list[str] = Field(default_factory=list, description="Mentioned user IDs")
|
||||
|
||||
|
||||
class WorkflowCommentMentionUsersResponse(BaseModel):
|
||||
users: list[AccountWithRole] = Field(description="Mentionable users")
|
||||
|
||||
|
||||
for model in (
|
||||
WorkflowCommentCreatePayload,
|
||||
WorkflowCommentUpdatePayload,
|
||||
WorkflowCommentReplyCreatePayload,
|
||||
WorkflowCommentReplyUpdatePayload,
|
||||
):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
for model in (AccountWithRole, WorkflowCommentMentionUsersResponse):
|
||||
console_ns.schema_model(model.__name__, model.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0))
|
||||
|
||||
workflow_comment_basic_model = console_ns.model("WorkflowCommentBasic", workflow_comment_basic_fields)
|
||||
workflow_comment_detail_model = console_ns.model("WorkflowCommentDetail", workflow_comment_detail_fields)
|
||||
workflow_comment_create_model = console_ns.model("WorkflowCommentCreate", workflow_comment_create_fields)
|
||||
workflow_comment_update_model = console_ns.model("WorkflowCommentUpdate", workflow_comment_update_fields)
|
||||
workflow_comment_resolve_model = console_ns.model("WorkflowCommentResolve", workflow_comment_resolve_fields)
|
||||
workflow_comment_reply_create_model = console_ns.model(
|
||||
"WorkflowCommentReplyCreate", workflow_comment_reply_create_fields
|
||||
)
|
||||
workflow_comment_reply_update_model = console_ns.model(
|
||||
"WorkflowCommentReplyUpdate", workflow_comment_reply_update_fields
|
||||
)
|
||||
workflow_comment_mention_users_model = console_ns.models[WorkflowCommentMentionUsersResponse.__name__]
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments")
|
||||
class WorkflowCommentListApi(Resource):
|
||||
"""API for listing and creating workflow comments."""
|
||||
|
||||
@console_ns.doc("list_workflow_comments")
|
||||
@console_ns.doc(description="Get all comments for a workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Comments retrieved successfully", workflow_comment_basic_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_basic_model, envelope="data")
|
||||
def get(self, app_model: App):
|
||||
"""Get all comments for a workflow."""
|
||||
comments = WorkflowCommentService.get_comments(tenant_id=current_user.current_tenant_id, app_id=app_model.id)
|
||||
|
||||
return comments
|
||||
|
||||
@console_ns.doc("create_workflow_comment")
|
||||
@console_ns.doc(description="Create a new workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentCreatePayload.__name__])
|
||||
@console_ns.response(201, "Comment created successfully", workflow_comment_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_create_model)
|
||||
def post(self, app_model: App):
|
||||
"""Create a new workflow comment."""
|
||||
payload = WorkflowCommentCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
created_by=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>")
|
||||
class WorkflowCommentDetailApi(Resource):
|
||||
"""API for managing individual workflow comments."""
|
||||
|
||||
@console_ns.doc("get_workflow_comment")
|
||||
@console_ns.doc(description="Get a specific workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment retrieved successfully", workflow_comment_detail_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_detail_model)
|
||||
def get(self, app_model: App, comment_id: str):
|
||||
"""Get a specific workflow comment."""
|
||||
comment = WorkflowCommentService.get_comment(
|
||||
tenant_id=current_user.current_tenant_id, app_id=app_model.id, comment_id=comment_id
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
@console_ns.doc("update_workflow_comment")
|
||||
@console_ns.doc(description="Update a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Comment updated successfully", workflow_comment_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_update_model)
|
||||
def put(self, app_model: App, comment_id: str):
|
||||
"""Update a workflow comment."""
|
||||
payload = WorkflowCommentUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.update_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
position_x=payload.position_x,
|
||||
position_y=payload.position_y,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
@console_ns.doc("delete_workflow_comment")
|
||||
@console_ns.doc(description="Delete a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(204, "Comment deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def delete(self, app_model: App, comment_id: str):
|
||||
"""Delete a workflow comment."""
|
||||
WorkflowCommentService.delete_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/resolve")
|
||||
class WorkflowCommentResolveApi(Resource):
|
||||
"""API for resolving and reopening workflow comments."""
|
||||
|
||||
@console_ns.doc("resolve_workflow_comment")
|
||||
@console_ns.doc(description="Resolve a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.response(200, "Comment resolved successfully", workflow_comment_resolve_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_resolve_model)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Resolve a workflow comment."""
|
||||
comment = WorkflowCommentService.resolve_comment(
|
||||
tenant_id=current_user.current_tenant_id,
|
||||
app_id=app_model.id,
|
||||
comment_id=comment_id,
|
||||
user_id=current_user.id,
|
||||
)
|
||||
|
||||
return comment
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies")
|
||||
class WorkflowCommentReplyApi(Resource):
|
||||
"""API for managing comment replies."""
|
||||
|
||||
@console_ns.doc("create_workflow_comment_reply")
|
||||
@console_ns.doc(description="Add a reply to a workflow comment")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyCreatePayload.__name__])
|
||||
@console_ns.response(201, "Reply created successfully", workflow_comment_reply_create_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_create_model)
|
||||
def post(self, app_model: App, comment_id: str):
|
||||
"""Add a reply to a workflow comment."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyCreatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
result = WorkflowCommentService.create_reply(
|
||||
comment_id=comment_id,
|
||||
content=payload.content,
|
||||
created_by=current_user.id,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return result, 201
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/<string:comment_id>/replies/<string:reply_id>")
|
||||
class WorkflowCommentReplyDetailApi(Resource):
|
||||
"""API for managing individual comment replies."""
|
||||
|
||||
@console_ns.doc("update_workflow_comment_reply")
|
||||
@console_ns.doc(description="Update a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.expect(console_ns.models[WorkflowCommentReplyUpdatePayload.__name__])
|
||||
@console_ns.response(200, "Reply updated successfully", workflow_comment_reply_update_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
@marshal_with(workflow_comment_reply_update_model)
|
||||
def put(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Update a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
payload = WorkflowCommentReplyUpdatePayload.model_validate(console_ns.payload or {})
|
||||
|
||||
reply = WorkflowCommentService.update_reply(
|
||||
reply_id=reply_id,
|
||||
user_id=current_user.id,
|
||||
content=payload.content,
|
||||
mentioned_user_ids=payload.mentioned_user_ids,
|
||||
)
|
||||
|
||||
return reply
|
||||
|
||||
@console_ns.doc("delete_workflow_comment_reply")
|
||||
@console_ns.doc(description="Delete a comment reply")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "comment_id": "Comment ID", "reply_id": "Reply ID"})
|
||||
@console_ns.response(204, "Reply deleted successfully")
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def delete(self, app_model: App, comment_id: str, reply_id: str):
|
||||
"""Delete a comment reply."""
|
||||
# Validate comment access first
|
||||
WorkflowCommentService.validate_comment_access(
|
||||
comment_id=comment_id, tenant_id=current_user.current_tenant_id, app_id=app_model.id
|
||||
)
|
||||
|
||||
WorkflowCommentService.delete_reply(reply_id=reply_id, user_id=current_user.id)
|
||||
|
||||
return {"result": "success"}, 204
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflow/comments/mention-users")
|
||||
class WorkflowCommentMentionUsersApi(Resource):
|
||||
"""API for getting mentionable users for workflow comments."""
|
||||
|
||||
@console_ns.doc("workflow_comment_mention_users")
|
||||
@console_ns.doc(description="Get all users in current tenant for mentions")
|
||||
@console_ns.doc(params={"app_id": "Application ID"})
|
||||
@console_ns.response(200, "Mentionable users retrieved successfully", workflow_comment_mention_users_model)
|
||||
@login_required
|
||||
@setup_required
|
||||
@account_initialization_required
|
||||
@get_app_model()
|
||||
def get(self, app_model: App):
|
||||
"""Get all users in current tenant for mentions."""
|
||||
members = TenantService.get_tenant_members(current_user.current_tenant)
|
||||
member_models = TypeAdapter(list[AccountWithRole]).validate_python(members, from_attributes=True)
|
||||
response = WorkflowCommentMentionUsersResponse(users=member_models)
|
||||
return response.model_dump(mode="json"), 200
|
||||
@@ -216,7 +216,7 @@ def _api_prerequisite[**P, R](f: Callable[P, R]) -> Callable[P, R | Response]:
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@edit_permission_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@wraps(f)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R | Response:
|
||||
return f(*args, **kwargs)
|
||||
|
||||
@@ -207,7 +207,7 @@ class AdvancedChatAppWorkflowRunListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@marshal_with(advanced_chat_workflow_run_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@@ -305,7 +305,7 @@ class AdvancedChatAppWorkflowRunCountApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.AGENT])
|
||||
@marshal_with(workflow_run_count_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@@ -349,7 +349,7 @@ class WorkflowRunListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_run_pagination_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@@ -397,7 +397,7 @@ class WorkflowRunCountApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_run_count_model)
|
||||
def get(self, app_model: App):
|
||||
"""
|
||||
@@ -434,7 +434,7 @@ class WorkflowRunDetailApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_run_detail_model)
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
@@ -458,7 +458,7 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW])
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT])
|
||||
@marshal_with(workflow_run_node_execution_list_model)
|
||||
def get(self, app_model: App, run_id):
|
||||
"""
|
||||
|
||||
1
api/controllers/console/socketio/__init__.py
Normal file
1
api/controllers/console/socketio/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
|
||||
119
api/controllers/console/socketio/workflow.py
Normal file
119
api/controllers/console/socketio/workflow.py
Normal file
@@ -0,0 +1,119 @@
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import cast
|
||||
|
||||
from flask import Request as FlaskRequest
|
||||
|
||||
from extensions.ext_socketio import sio
|
||||
from libs.passport import PassportService
|
||||
from libs.token import extract_access_token
|
||||
from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository
|
||||
from services.account_service import AccountService
|
||||
from services.workflow_collaboration_service import WorkflowCollaborationService
|
||||
|
||||
repository = WorkflowCollaborationRepository()
|
||||
collaboration_service = WorkflowCollaborationService(repository, sio)
|
||||
|
||||
|
||||
def _sio_on(event: str) -> Callable[[Callable[..., object]], Callable[..., object]]:
|
||||
return cast(Callable[[Callable[..., object]], Callable[..., object]], sio.on(event))
|
||||
|
||||
|
||||
@_sio_on("connect")
|
||||
def socket_connect(sid, environ, auth):
|
||||
"""
|
||||
WebSocket connect event, do authentication here.
|
||||
"""
|
||||
try:
|
||||
request_environ = FlaskRequest(environ)
|
||||
token = extract_access_token(request_environ)
|
||||
except Exception:
|
||||
logging.exception("Failed to extract token")
|
||||
token = None
|
||||
|
||||
if not token:
|
||||
logging.warning("Socket connect rejected: missing token (sid=%s)", sid)
|
||||
return False
|
||||
|
||||
try:
|
||||
decoded = PassportService().verify(token)
|
||||
user_id = decoded.get("user_id")
|
||||
if not user_id:
|
||||
logging.warning("Socket connect rejected: missing user_id (sid=%s)", sid)
|
||||
return False
|
||||
|
||||
with sio.app.app_context():
|
||||
user = AccountService.load_logged_in_account(account_id=user_id)
|
||||
if not user:
|
||||
logging.warning("Socket connect rejected: user not found (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
if not user.has_edit_permission:
|
||||
logging.warning("Socket connect rejected: no edit permission (user_id=%s, sid=%s)", user_id, sid)
|
||||
return False
|
||||
|
||||
collaboration_service.save_session(sid, user)
|
||||
return True
|
||||
|
||||
except Exception:
|
||||
logging.exception("Socket authentication failed")
|
||||
return False
|
||||
|
||||
|
||||
@_sio_on("user_connect")
|
||||
def handle_user_connect(sid, data):
|
||||
"""
|
||||
Handle user connect event. Each session (tab) is treated as an independent collaborator.
|
||||
"""
|
||||
workflow_id = data.get("workflow_id")
|
||||
if not workflow_id:
|
||||
return {"msg": "workflow_id is required"}, 400
|
||||
|
||||
result = collaboration_service.register_session(workflow_id, sid)
|
||||
if not result:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
user_id, is_leader = result
|
||||
return {"msg": "connected", "user_id": user_id, "sid": sid, "isLeader": is_leader}
|
||||
|
||||
|
||||
@_sio_on("disconnect")
|
||||
def handle_disconnect(sid):
|
||||
"""
|
||||
Handle session disconnect event. Remove the specific session from online users.
|
||||
"""
|
||||
collaboration_service.disconnect_session(sid)
|
||||
|
||||
|
||||
@_sio_on("collaboration_event")
|
||||
def handle_collaboration_event(sid, data):
|
||||
"""
|
||||
Handle general collaboration events, include:
|
||||
1. mouse_move
|
||||
2. vars_and_features_update
|
||||
3. sync_request (ask leader to update graph)
|
||||
4. app_state_update
|
||||
5. mcp_server_update
|
||||
6. workflow_update
|
||||
7. comments_update
|
||||
8. node_panel_presence
|
||||
9. skill_file_active
|
||||
10. skill_sync_request
|
||||
11. skill_resync_request
|
||||
"""
|
||||
return collaboration_service.relay_collaboration_event(sid, data)
|
||||
|
||||
|
||||
@_sio_on("graph_event")
|
||||
def handle_graph_event(sid, data):
|
||||
"""
|
||||
Handle graph events - simple broadcast relay.
|
||||
"""
|
||||
return collaboration_service.relay_graph_event(sid, data)
|
||||
|
||||
|
||||
@_sio_on("skill_event")
|
||||
def handle_skill_event(sid, data):
|
||||
"""
|
||||
Handle skill events - simple broadcast relay.
|
||||
"""
|
||||
return collaboration_service.relay_skill_event(sid, data)
|
||||
67
api/controllers/console/workspace/dsl.py
Normal file
67
api/controllers/console/workspace/dsl.py
Normal file
@@ -0,0 +1,67 @@
|
||||
import json
|
||||
|
||||
import httpx
|
||||
import yaml
|
||||
from flask import request
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from core.plugin.impl.exc import PluginPermissionDeniedError
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import App
|
||||
from models.workflow import Workflow
|
||||
from services.app_dsl_service import AppDslService
|
||||
|
||||
|
||||
class DSLPredictRequest(BaseModel):
|
||||
app_id: str
|
||||
current_node_id: str
|
||||
|
||||
|
||||
@console_ns.route("/workspaces/current/dsl/predict")
|
||||
class DSLPredictApi(Resource):
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
user, _ = current_account_with_tenant()
|
||||
if not user.is_admin_or_owner:
|
||||
raise Forbidden()
|
||||
|
||||
args = DSLPredictRequest.model_validate(request.get_json())
|
||||
|
||||
app_id: str = args.app_id
|
||||
current_node_id: str = args.current_node_id
|
||||
|
||||
with Session(db.engine) as session:
|
||||
app = session.query(App).filter_by(id=app_id).first()
|
||||
workflow = session.query(Workflow).filter_by(app_id=app_id, version=Workflow.VERSION_DRAFT).first()
|
||||
|
||||
if not app:
|
||||
raise ValueError("App not found")
|
||||
if not workflow:
|
||||
raise ValueError("Workflow not found")
|
||||
|
||||
try:
|
||||
i = 0
|
||||
for node_id, _ in workflow.walk_nodes():
|
||||
if node_id == current_node_id:
|
||||
break
|
||||
i += 1
|
||||
|
||||
dsl = yaml.safe_load(AppDslService.export_dsl(app_model=app))
|
||||
|
||||
response = httpx.post(
|
||||
"http://spark-832c:8000/predict",
|
||||
json={"graph_data": dsl, "source_node_index": i},
|
||||
)
|
||||
return {
|
||||
"nodes": json.loads(response.json()),
|
||||
}
|
||||
except PluginPermissionDeniedError as e:
|
||||
raise ValueError(e.description) from e
|
||||
80
api/controllers/files/storage_files.py
Normal file
80
api/controllers/files/storage_files.py
Normal file
@@ -0,0 +1,80 @@
|
||||
"""Token-based file proxy controller for storage operations.
|
||||
|
||||
This controller handles file download and upload operations using opaque UUID tokens.
|
||||
The token maps to the real storage key in Redis, so the actual storage path is never
|
||||
exposed in the URL.
|
||||
|
||||
Routes:
|
||||
GET /files/storage-files/{token} - Download a file
|
||||
PUT /files/storage-files/{token} - Upload a file
|
||||
|
||||
The operation type (download/upload) is determined by the ticket stored in Redis,
|
||||
not by the HTTP method. This ensures a download ticket cannot be used for upload
|
||||
and vice versa.
|
||||
"""
|
||||
|
||||
from urllib.parse import quote
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource
|
||||
from werkzeug.exceptions import Forbidden, NotFound, RequestEntityTooLarge
|
||||
|
||||
from controllers.files import files_ns
|
||||
from extensions.ext_storage import storage
|
||||
from services.storage_ticket_service import StorageTicketService
|
||||
|
||||
|
||||
@files_ns.route("/storage-files/<string:token>")
|
||||
class StorageFilesApi(Resource):
|
||||
"""Handle file operations through token-based URLs."""
|
||||
|
||||
def get(self, token: str):
|
||||
"""Download a file using a token.
|
||||
|
||||
The ticket must have op="download", otherwise returns 403.
|
||||
"""
|
||||
ticket = StorageTicketService.get_ticket(token)
|
||||
if ticket is None:
|
||||
raise Forbidden("Invalid or expired token")
|
||||
|
||||
if ticket.op != "download":
|
||||
raise Forbidden("This token is not valid for download")
|
||||
|
||||
try:
|
||||
generator = storage.load_stream(ticket.storage_key)
|
||||
except FileNotFoundError:
|
||||
raise NotFound("File not found")
|
||||
|
||||
filename = ticket.filename or ticket.storage_key.rsplit("/", 1)[-1]
|
||||
encoded_filename = quote(filename)
|
||||
|
||||
return Response(
|
||||
generator,
|
||||
mimetype="application/octet-stream",
|
||||
direct_passthrough=True,
|
||||
headers={
|
||||
"Content-Disposition": f"attachment; filename*=UTF-8''{encoded_filename}",
|
||||
},
|
||||
)
|
||||
|
||||
def put(self, token: str):
|
||||
"""Upload a file using a token.
|
||||
|
||||
The ticket must have op="upload", otherwise returns 403.
|
||||
If the request body exceeds max_bytes, returns 413.
|
||||
"""
|
||||
ticket = StorageTicketService.get_ticket(token)
|
||||
if ticket is None:
|
||||
raise Forbidden("Invalid or expired token")
|
||||
|
||||
if ticket.op != "upload":
|
||||
raise Forbidden("This token is not valid for upload")
|
||||
|
||||
content = request.get_data()
|
||||
|
||||
if ticket.max_bytes is not None and len(content) > ticket.max_bytes:
|
||||
raise RequestEntityTooLarge(f"Upload exceeds maximum size of {ticket.max_bytes} bytes")
|
||||
|
||||
storage.save(ticket.storage_key, content)
|
||||
|
||||
return Response(status=204)
|
||||
@@ -194,7 +194,7 @@ class ChatApi(Resource):
|
||||
Supports conversation management and both blocking and streaming response modes.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
payload = ChatRequestPayload.model_validate(service_api_ns.payload or {})
|
||||
@@ -258,7 +258,7 @@ class ChatStopApi(Resource):
|
||||
def post(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
"""Stop a running chat message generation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
AppTaskService.stop_task(
|
||||
|
||||
@@ -98,7 +98,7 @@ class ConversationApi(Resource):
|
||||
Supports pagination using last_id and limit parameters.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
query_args = ConversationListQuery.model_validate(request.args.to_dict())
|
||||
@@ -142,7 +142,7 @@ class ConversationDetailApi(Resource):
|
||||
def delete(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Delete a specific conversation."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@@ -171,7 +171,7 @@ class ConversationRenameApi(Resource):
|
||||
def post(self, app_model: App, end_user: EndUser, c_id):
|
||||
"""Rename a conversation or auto-generate a name."""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@@ -213,7 +213,7 @@ class ConversationVariablesApi(Resource):
|
||||
"""
|
||||
# conversational variable only for chat app
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
@@ -252,7 +252,7 @@ class ConversationVariableDetailApi(Resource):
|
||||
The value must match the variable's expected type.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
conversation_id = str(c_id)
|
||||
|
||||
@@ -53,7 +53,7 @@ class MessageListApi(Resource):
|
||||
Retrieves messages with pagination support using first_id.
|
||||
"""
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
query_args = MessageListQuery.model_validate(request.args.to_dict())
|
||||
@@ -158,7 +158,7 @@ class MessageSuggestedApi(Resource):
|
||||
"""
|
||||
message_id = str(message_id)
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
try:
|
||||
|
||||
399
api/core/agent/agent_app_runner.py
Normal file
399
api/core/agent/agent_app_runner.py
Normal file
@@ -0,0 +1,399 @@
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from copy import deepcopy
|
||||
from typing import Any, cast
|
||||
|
||||
from core.agent.base_agent_runner import BaseAgentRunner
|
||||
from core.agent.entities import AgentEntity, AgentLog, AgentResult, ExecutionContext
|
||||
from core.agent.patterns.strategy_factory import StrategyFactory
|
||||
from core.app.apps.base_app_queue_manager import PublishFrom
|
||||
from core.app.entities.queue_entities import QueueAgentThoughtEvent, QueueMessageEndEvent, QueueMessageFileEvent
|
||||
from core.prompt.agent_history_prompt_transform import AgentHistoryPromptTransform
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from graphon.file import file_manager
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageContentType,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import ImagePromptMessageContent, PromptMessageContentUnionTypes
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from models.model import Message
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentAppRunner(BaseAgentRunner):
|
||||
|
||||
@property
|
||||
def model_features(self) -> list[ModelFeature]:
|
||||
llm_model = cast(LargeLanguageModel, self.model_instance.model_type_instance)
|
||||
model_schema = llm_model.get_model_schema(self.model_instance.model_name, self.model_instance.credentials)
|
||||
if not model_schema:
|
||||
return []
|
||||
return list(model_schema.features or [])
|
||||
|
||||
def build_execution_context(self) -> ExecutionContext:
|
||||
return ExecutionContext(
|
||||
user_id=self.user_id,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
conversation_id=self.conversation.id if self.conversation else None,
|
||||
message_id=self.message.id if self.message else None,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
def _create_tool_invoke_hook(self, message: Message):
|
||||
"""
|
||||
Create a tool invoke hook that uses ToolEngine.agent_invoke.
|
||||
This hook handles file creation and returns proper meta information.
|
||||
"""
|
||||
# Get trace manager from app generate entity
|
||||
trace_manager = self.application_generate_entity.trace_manager
|
||||
|
||||
def tool_invoke_hook(
|
||||
tool: Tool, tool_args: dict[str, Any], tool_name: str
|
||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||
"""Hook that uses agent_invoke for proper file and meta handling."""
|
||||
tool_invoke_response, message_files, tool_invoke_meta = ToolEngine.agent_invoke(
|
||||
tool=tool,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.user_id,
|
||||
tenant_id=self.tenant_id,
|
||||
message=message,
|
||||
invoke_from=self.application_generate_entity.invoke_from,
|
||||
agent_tool_callback=self.agent_callback,
|
||||
trace_manager=trace_manager,
|
||||
app_id=self.application_generate_entity.app_config.app_id,
|
||||
message_id=message.id,
|
||||
conversation_id=self.conversation.id,
|
||||
)
|
||||
|
||||
# Publish files and track IDs
|
||||
for message_file_id in message_files:
|
||||
self.queue_manager.publish(
|
||||
QueueMessageFileEvent(message_file_id=message_file_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
self._current_message_file_ids.append(message_file_id)
|
||||
|
||||
return tool_invoke_response, message_files, tool_invoke_meta
|
||||
|
||||
return tool_invoke_hook
|
||||
|
||||
def run(self, message: Message, query: str, **kwargs: Any) -> Generator[LLMResultChunk, None, None]:
|
||||
"""
|
||||
Run Agent application
|
||||
"""
|
||||
self.query = query
|
||||
app_generate_entity = self.application_generate_entity
|
||||
|
||||
app_config = self.app_config
|
||||
assert app_config is not None, "app_config is required"
|
||||
assert app_config.agent is not None, "app_config.agent is required"
|
||||
|
||||
# convert tools into ModelRuntime Tool format
|
||||
tool_instances, _ = self._init_prompt_tools()
|
||||
|
||||
assert app_config.agent
|
||||
|
||||
# Create tool invoke hook for agent_invoke
|
||||
tool_invoke_hook = self._create_tool_invoke_hook(message)
|
||||
|
||||
# Get instruction for ReAct strategy
|
||||
instruction = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
# Use factory to create appropriate strategy
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=self.model_features,
|
||||
model_instance=self.model_instance,
|
||||
tools=list(tool_instances.values()),
|
||||
files=list(self.files),
|
||||
max_iterations=app_config.agent.max_iteration,
|
||||
context=self.build_execution_context(),
|
||||
agent_strategy=self.config.strategy,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Initialize state variables
|
||||
current_agent_thought_id: str | None = None
|
||||
has_published_thought = False
|
||||
current_tool_name: str | None = None
|
||||
self._current_message_file_ids: list[str] = []
|
||||
|
||||
# organize prompt messages
|
||||
prompt_messages = self._organize_prompt_messages()
|
||||
|
||||
# Run strategy
|
||||
generator = strategy.run(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=app_generate_entity.model_conf.parameters,
|
||||
stop=app_generate_entity.model_conf.stop,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
# Consume generator and collect result
|
||||
result: AgentResult | None = None
|
||||
try:
|
||||
while True:
|
||||
try:
|
||||
output = next(generator)
|
||||
except StopIteration as e:
|
||||
# Generator finished, get the return value
|
||||
result = e.value
|
||||
break
|
||||
|
||||
if isinstance(output, LLMResultChunk):
|
||||
# Handle LLM chunk
|
||||
if current_agent_thought_id and not has_published_thought:
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
has_published_thought = True
|
||||
|
||||
yield output
|
||||
|
||||
elif isinstance(output, AgentLog):
|
||||
# Handle Agent Log using log_type for type-safe dispatch
|
||||
if output.status == AgentLog.LogStatus.START:
|
||||
if output.log_type == AgentLog.LogType.ROUND:
|
||||
# Start of a new round
|
||||
message_file_ids: list[str] = []
|
||||
current_agent_thought_id = self.create_agent_thought(
|
||||
message_id=message.id,
|
||||
message="",
|
||||
tool_name="",
|
||||
tool_input="",
|
||||
messages_ids=message_file_ids,
|
||||
)
|
||||
has_published_thought = False
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call start - extract data from structured fields
|
||||
current_tool_name = output.data.get("tool_name", "")
|
||||
tool_input = output.data.get("tool_args", {})
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=current_tool_name,
|
||||
tool_input=tool_input,
|
||||
thought=None,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.status == AgentLog.LogStatus.SUCCESS:
|
||||
if output.log_type == AgentLog.LogType.THOUGHT:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
thought_text = output.data.get("thought")
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=thought_text,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=None,
|
||||
messages_ids=[],
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.TOOL_CALL:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Tool call finished
|
||||
tool_output = output.data.get("output")
|
||||
# Get meta from strategy output (now properly populated)
|
||||
tool_meta = output.data.get("meta")
|
||||
|
||||
# Wrap tool_meta with tool_name as key (required by agent_service)
|
||||
if tool_meta and current_tool_name:
|
||||
tool_meta = {current_tool_name: tool_meta}
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=None,
|
||||
observation=tool_output,
|
||||
tool_invoke_meta=tool_meta,
|
||||
answer=None,
|
||||
messages_ids=self._current_message_file_ids,
|
||||
)
|
||||
# Clear message file ids after saving
|
||||
self._current_message_file_ids = []
|
||||
current_tool_name = None
|
||||
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
elif output.log_type == AgentLog.LogType.ROUND:
|
||||
if current_agent_thought_id is None:
|
||||
continue
|
||||
|
||||
# Round finished - save LLM usage and answer
|
||||
llm_usage = output.metadata.get(AgentLog.LogMetadata.LLM_USAGE)
|
||||
llm_result = output.data.get("llm_result")
|
||||
final_answer = output.data.get("final_answer")
|
||||
|
||||
self.save_agent_thought(
|
||||
agent_thought_id=current_agent_thought_id,
|
||||
tool_name=None,
|
||||
tool_input=None,
|
||||
thought=llm_result,
|
||||
observation=None,
|
||||
tool_invoke_meta=None,
|
||||
answer=final_answer,
|
||||
messages_ids=[],
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
self.queue_manager.publish(
|
||||
QueueAgentThoughtEvent(agent_thought_id=current_agent_thought_id),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
except Exception:
|
||||
# Re-raise any other exceptions
|
||||
raise
|
||||
|
||||
# Process final result
|
||||
if isinstance(result, AgentResult):
|
||||
final_answer = result.text
|
||||
usage = result.usage or LLMUsage.empty_usage()
|
||||
|
||||
# Publish end event
|
||||
self.queue_manager.publish(
|
||||
QueueMessageEndEvent(
|
||||
llm_result=LLMResult(
|
||||
model=self.model_instance.model_name,
|
||||
prompt_messages=prompt_messages,
|
||||
message=AssistantPromptMessage(content=final_answer),
|
||||
usage=usage,
|
||||
system_fingerprint="",
|
||||
)
|
||||
),
|
||||
PublishFrom.APPLICATION_MANAGER,
|
||||
)
|
||||
|
||||
def _init_system_message(self, prompt_template: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Initialize system message
|
||||
"""
|
||||
if not prompt_template:
|
||||
return prompt_messages or []
|
||||
|
||||
prompt_messages = prompt_messages or []
|
||||
|
||||
if prompt_messages and isinstance(prompt_messages[0], SystemPromptMessage):
|
||||
prompt_messages[0] = SystemPromptMessage(content=prompt_template)
|
||||
return prompt_messages
|
||||
|
||||
if not prompt_messages:
|
||||
return [SystemPromptMessage(content=prompt_template)]
|
||||
|
||||
prompt_messages.insert(0, SystemPromptMessage(content=prompt_template))
|
||||
return prompt_messages
|
||||
|
||||
def _organize_user_query(self, query: str, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
Organize user query
|
||||
"""
|
||||
if self.files:
|
||||
# get image detail config
|
||||
image_detail_config = (
|
||||
self.application_generate_entity.file_upload_config.image_config.detail
|
||||
if (
|
||||
self.application_generate_entity.file_upload_config
|
||||
and self.application_generate_entity.file_upload_config.image_config
|
||||
)
|
||||
else None
|
||||
)
|
||||
image_detail_config = image_detail_config or ImagePromptMessageContent.DETAIL.LOW
|
||||
|
||||
prompt_message_contents: list[PromptMessageContentUnionTypes] = []
|
||||
for file in self.files:
|
||||
prompt_message_contents.append(
|
||||
file_manager.to_prompt_message_content(
|
||||
file,
|
||||
image_detail_config=image_detail_config,
|
||||
)
|
||||
)
|
||||
prompt_message_contents.append(TextPromptMessageContent(data=query))
|
||||
|
||||
prompt_messages.append(UserPromptMessage(content=prompt_message_contents))
|
||||
else:
|
||||
prompt_messages.append(UserPromptMessage(content=query))
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _clear_user_prompt_image_messages(self, prompt_messages: list[PromptMessage]) -> list[PromptMessage]:
|
||||
"""
|
||||
As for now, gpt supports both fc and vision at the first iteration.
|
||||
We need to remove the image messages from the prompt messages at the first iteration.
|
||||
"""
|
||||
prompt_messages = deepcopy(prompt_messages)
|
||||
|
||||
for prompt_message in prompt_messages:
|
||||
if isinstance(prompt_message, UserPromptMessage):
|
||||
if isinstance(prompt_message.content, list):
|
||||
prompt_message.content = "\n".join(
|
||||
[
|
||||
content.data
|
||||
if content.type == PromptMessageContentType.TEXT
|
||||
else "[image]"
|
||||
if content.type == PromptMessageContentType.IMAGE
|
||||
else "[file]"
|
||||
for content in prompt_message.content
|
||||
]
|
||||
)
|
||||
|
||||
return prompt_messages
|
||||
|
||||
def _organize_prompt_messages(self):
|
||||
# For ReAct strategy, use the agent prompt template
|
||||
if self.config.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT and self.config.prompt:
|
||||
prompt_template = self.config.prompt.first_prompt
|
||||
else:
|
||||
prompt_template = self.app_config.prompt_template.simple_prompt_template or ""
|
||||
|
||||
self.history_prompt_messages = self._init_system_message(prompt_template, self.history_prompt_messages)
|
||||
query_prompt_messages = self._organize_user_query(self.query or "", [])
|
||||
|
||||
self.history_prompt_messages = AgentHistoryPromptTransform(
|
||||
model_config=self.model_config,
|
||||
prompt_messages=[*query_prompt_messages, *self._current_thoughts],
|
||||
history_messages=self.history_prompt_messages,
|
||||
memory=self.memory,
|
||||
).get_prompt()
|
||||
|
||||
prompt_messages = [*self.history_prompt_messages, *query_prompt_messages, *self._current_thoughts]
|
||||
if len(self._current_thoughts) != 0:
|
||||
# clear messages after the first iteration
|
||||
prompt_messages = self._clear_user_prompt_image_messages(prompt_messages)
|
||||
return prompt_messages
|
||||
@@ -1,3 +1,5 @@
|
||||
import uuid
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum
|
||||
from typing import Any, Union
|
||||
|
||||
@@ -92,3 +94,79 @@ class AgentInvokeMessage(ToolInvokeMessage):
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class ExecutionContext(BaseModel):
|
||||
"""Execution context containing trace and audit information.
|
||||
|
||||
Carries IDs and metadata needed for tracing, auditing, and correlation
|
||||
but not part of the core business logic.
|
||||
"""
|
||||
|
||||
user_id: str | None = None
|
||||
app_id: str | None = None
|
||||
conversation_id: str | None = None
|
||||
message_id: str | None = None
|
||||
tenant_id: str | None = None
|
||||
|
||||
@classmethod
|
||||
def create_minimal(cls, user_id: str | None = None) -> "ExecutionContext":
|
||||
return cls(user_id=user_id)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
return {
|
||||
"user_id": self.user_id,
|
||||
"app_id": self.app_id,
|
||||
"conversation_id": self.conversation_id,
|
||||
"message_id": self.message_id,
|
||||
"tenant_id": self.tenant_id,
|
||||
}
|
||||
|
||||
def with_updates(self, **kwargs) -> "ExecutionContext":
|
||||
data = self.to_dict()
|
||||
data.update(kwargs)
|
||||
return ExecutionContext(**{k: v for k, v in data.items() if k in ExecutionContext.model_fields})
|
||||
|
||||
|
||||
class AgentLog(BaseModel):
|
||||
"""Structured log entry for agent execution tracing."""
|
||||
|
||||
class LogType(StrEnum):
|
||||
ROUND = "round"
|
||||
THOUGHT = "thought"
|
||||
TOOL_CALL = "tool_call"
|
||||
|
||||
class LogMetadata(StrEnum):
|
||||
STARTED_AT = "started_at"
|
||||
FINISHED_AT = "finished_at"
|
||||
ELAPSED_TIME = "elapsed_time"
|
||||
TOTAL_PRICE = "total_price"
|
||||
TOTAL_TOKENS = "total_tokens"
|
||||
PROVIDER = "provider"
|
||||
CURRENCY = "currency"
|
||||
LLM_USAGE = "llm_usage"
|
||||
ICON = "icon"
|
||||
ICON_DARK = "icon_dark"
|
||||
|
||||
class LogStatus(StrEnum):
|
||||
START = "start"
|
||||
ERROR = "error"
|
||||
SUCCESS = "success"
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid.uuid4()))
|
||||
label: str = Field(...)
|
||||
log_type: LogType = Field(...)
|
||||
parent_id: str | None = Field(default=None)
|
||||
error: str | None = Field(default=None)
|
||||
status: LogStatus = Field(...)
|
||||
data: Mapping[str, Any] = Field(...)
|
||||
metadata: Mapping[LogMetadata, Any] = Field(default={})
|
||||
|
||||
|
||||
class AgentResult(BaseModel):
|
||||
"""Agent execution result."""
|
||||
|
||||
text: str = Field(default="")
|
||||
files: list[Any] = Field(default_factory=list)
|
||||
usage: Any | None = Field(default=None)
|
||||
finish_reason: str | None = Field(default=None)
|
||||
|
||||
19
api/core/agent/patterns/__init__.py
Normal file
19
api/core/agent/patterns/__init__.py
Normal file
@@ -0,0 +1,19 @@
|
||||
"""Agent patterns module.
|
||||
|
||||
This module provides different strategies for agent execution:
|
||||
- FunctionCallStrategy: Uses native function/tool calling
|
||||
- ReActStrategy: Uses ReAct (Reasoning + Acting) approach
|
||||
- StrategyFactory: Factory for creating strategies based on model features
|
||||
"""
|
||||
|
||||
from .base import AgentPattern
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
from .strategy_factory import StrategyFactory
|
||||
|
||||
__all__ = [
|
||||
"AgentPattern",
|
||||
"FunctionCallStrategy",
|
||||
"ReActStrategy",
|
||||
"StrategyFactory",
|
||||
]
|
||||
506
api/core/agent/patterns/base.py
Normal file
506
api/core/agent/patterns/base.py
Normal file
@@ -0,0 +1,506 @@
|
||||
"""Base class for agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Callable, Generator
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
|
||||
from core.model_manager import ModelInstance
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage, ToolInvokeMeta
|
||||
from graphon.file import File
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
)
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
# Type alias for tool invoke hook
|
||||
# Returns: (response_content, message_file_ids, tool_invoke_meta)
|
||||
ToolInvokeHook = Callable[["Tool", dict[str, Any], str], tuple[str, list[str], ToolInvokeMeta]]
|
||||
|
||||
|
||||
class AgentPattern(ABC):
|
||||
"""Base class for agent execution strategies."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
):
|
||||
"""Initialize the agent strategy."""
|
||||
self.model_instance = model_instance
|
||||
self.tools = tools
|
||||
self.context = context
|
||||
self.max_iterations = min(max_iterations, 99) # Cap at 99 iterations
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self.files: list[File] = files
|
||||
self.tool_invoke_hook = tool_invoke_hook
|
||||
|
||||
@abstractmethod
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the agent strategy."""
|
||||
pass
|
||||
|
||||
def _accumulate_usage(self, total_usage: dict[str, Any], delta_usage: LLMUsage) -> None:
|
||||
"""Accumulate LLM usage statistics."""
|
||||
if not total_usage.get("usage"):
|
||||
# Create a copy to avoid modifying the original
|
||||
total_usage["usage"] = LLMUsage(
|
||||
prompt_tokens=delta_usage.prompt_tokens,
|
||||
prompt_unit_price=delta_usage.prompt_unit_price,
|
||||
prompt_price_unit=delta_usage.prompt_price_unit,
|
||||
prompt_price=delta_usage.prompt_price,
|
||||
completion_tokens=delta_usage.completion_tokens,
|
||||
completion_unit_price=delta_usage.completion_unit_price,
|
||||
completion_price_unit=delta_usage.completion_price_unit,
|
||||
completion_price=delta_usage.completion_price,
|
||||
total_tokens=delta_usage.total_tokens,
|
||||
total_price=delta_usage.total_price,
|
||||
currency=delta_usage.currency,
|
||||
latency=delta_usage.latency,
|
||||
)
|
||||
else:
|
||||
current: LLMUsage = total_usage["usage"]
|
||||
current.prompt_tokens += delta_usage.prompt_tokens
|
||||
current.completion_tokens += delta_usage.completion_tokens
|
||||
current.total_tokens += delta_usage.total_tokens
|
||||
current.prompt_price += delta_usage.prompt_price
|
||||
current.completion_price += delta_usage.completion_price
|
||||
current.total_price += delta_usage.total_price
|
||||
|
||||
def _extract_content(self, content: Any) -> str:
|
||||
"""Extract text content from message content."""
|
||||
if isinstance(content, list):
|
||||
# Content items are PromptMessageContentUnionTypes
|
||||
text_parts = []
|
||||
for c in content:
|
||||
# Check if it's a TextPromptMessageContent (which has data attribute)
|
||||
if isinstance(c, TextPromptMessageContent):
|
||||
text_parts.append(c.data)
|
||||
return "".join(text_parts)
|
||||
return str(content)
|
||||
|
||||
def _has_tool_calls(self, chunk: LLMResultChunk) -> bool:
|
||||
"""Check if chunk contains tool calls."""
|
||||
# LLMResultChunk always has delta attribute
|
||||
return bool(chunk.delta.message and chunk.delta.message.tool_calls)
|
||||
|
||||
def _has_tool_calls_result(self, result: LLMResult) -> bool:
|
||||
"""Check if result contains tool calls (non-streaming)."""
|
||||
# LLMResult always has message attribute
|
||||
return bool(result.message and result.message.tool_calls)
|
||||
|
||||
def _extract_tool_calls(self, chunk: LLMResultChunk) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from streaming chunk."""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
if chunk.delta.message and chunk.delta.message.tool_calls:
|
||||
for tool_call in chunk.delta.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_tool_calls_result(self, result: LLMResult) -> list[tuple[str, str, dict[str, Any]]]:
|
||||
"""Extract tool calls from non-streaming result."""
|
||||
tool_calls = []
|
||||
if result.message and result.message.tool_calls:
|
||||
for tool_call in result.message.tool_calls:
|
||||
if tool_call.function:
|
||||
try:
|
||||
args = json.loads(tool_call.function.arguments) if tool_call.function.arguments else {}
|
||||
except json.JSONDecodeError:
|
||||
args = {}
|
||||
tool_calls.append((tool_call.id or "", tool_call.function.name, args))
|
||||
return tool_calls
|
||||
|
||||
def _extract_text_from_message(self, message: PromptMessage) -> str:
|
||||
"""Extract text content from a prompt message."""
|
||||
# PromptMessage always has content attribute
|
||||
content = message.content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
elif isinstance(content, list):
|
||||
# Extract text from content list
|
||||
text_parts = []
|
||||
for item in content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
text_parts.append(item.data)
|
||||
return " ".join(text_parts)
|
||||
return ""
|
||||
|
||||
def _get_tool_metadata(self, tool_instance: Tool) -> dict[AgentLog.LogMetadata, Any]:
|
||||
"""Get metadata for a tool including provider and icon info."""
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
metadata: dict[AgentLog.LogMetadata, Any] = {}
|
||||
if tool_instance.entity and tool_instance.entity.identity:
|
||||
identity = tool_instance.entity.identity
|
||||
if identity.provider:
|
||||
metadata[AgentLog.LogMetadata.PROVIDER] = identity.provider
|
||||
|
||||
# Get icon using ToolManager for proper URL generation
|
||||
tenant_id = self.context.tenant_id
|
||||
if tenant_id and identity.provider:
|
||||
try:
|
||||
provider_type = tool_instance.tool_provider_type()
|
||||
icon = ToolManager.get_tool_icon(tenant_id, provider_type, identity.provider)
|
||||
if isinstance(icon, str):
|
||||
metadata[AgentLog.LogMetadata.ICON] = icon
|
||||
elif isinstance(icon, dict):
|
||||
# Handle icon dict with background/content or light/dark variants
|
||||
metadata[AgentLog.LogMetadata.ICON] = icon
|
||||
except Exception:
|
||||
# Fallback to identity.icon if ToolManager fails
|
||||
if identity.icon:
|
||||
metadata[AgentLog.LogMetadata.ICON] = identity.icon
|
||||
elif identity.icon:
|
||||
metadata[AgentLog.LogMetadata.ICON] = identity.icon
|
||||
return metadata
|
||||
|
||||
def _create_log(
|
||||
self,
|
||||
label: str,
|
||||
log_type: AgentLog.LogType,
|
||||
status: AgentLog.LogStatus,
|
||||
data: dict[str, Any] | None = None,
|
||||
parent_id: str | None = None,
|
||||
extra_metadata: dict[AgentLog.LogMetadata, Any] | None = None,
|
||||
) -> AgentLog:
|
||||
"""Create a new AgentLog with standard metadata."""
|
||||
metadata: dict[AgentLog.LogMetadata, Any] = {
|
||||
AgentLog.LogMetadata.STARTED_AT: time.perf_counter(),
|
||||
}
|
||||
if extra_metadata:
|
||||
metadata.update(extra_metadata)
|
||||
|
||||
return AgentLog(
|
||||
label=label,
|
||||
log_type=log_type,
|
||||
status=status,
|
||||
data=data or {},
|
||||
parent_id=parent_id,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
def _finish_log(
|
||||
self,
|
||||
log: AgentLog,
|
||||
data: dict[str, Any] | None = None,
|
||||
usage: LLMUsage | None = None,
|
||||
) -> AgentLog:
|
||||
"""Finish an AgentLog by updating its status and metadata."""
|
||||
log.status = AgentLog.LogStatus.SUCCESS
|
||||
|
||||
if data is not None:
|
||||
log.data = data
|
||||
|
||||
# Calculate elapsed time
|
||||
started_at = log.metadata.get(AgentLog.LogMetadata.STARTED_AT, time.perf_counter())
|
||||
finished_at = time.perf_counter()
|
||||
|
||||
# Update metadata
|
||||
log.metadata = {
|
||||
**log.metadata,
|
||||
AgentLog.LogMetadata.FINISHED_AT: finished_at,
|
||||
# Calculate elapsed time in seconds
|
||||
AgentLog.LogMetadata.ELAPSED_TIME: round(finished_at - started_at, 4),
|
||||
}
|
||||
|
||||
# Add usage information if provided
|
||||
if usage:
|
||||
log.metadata.update(
|
||||
{
|
||||
AgentLog.LogMetadata.TOTAL_PRICE: usage.total_price,
|
||||
AgentLog.LogMetadata.CURRENCY: usage.currency,
|
||||
AgentLog.LogMetadata.TOTAL_TOKENS: usage.total_tokens,
|
||||
AgentLog.LogMetadata.LLM_USAGE: usage,
|
||||
}
|
||||
)
|
||||
|
||||
return log
|
||||
|
||||
def _replace_file_references(self, tool_args: dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Replace file references in tool arguments with actual File objects.
|
||||
|
||||
Args:
|
||||
tool_args: Dictionary of tool arguments
|
||||
|
||||
Returns:
|
||||
Updated tool arguments with file references replaced
|
||||
"""
|
||||
# Process each argument in the dictionary
|
||||
processed_args: dict[str, Any] = {}
|
||||
for key, value in tool_args.items():
|
||||
processed_args[key] = self._process_file_reference(value)
|
||||
return processed_args
|
||||
|
||||
def _process_file_reference(self, data: Any) -> Any:
|
||||
"""
|
||||
Recursively process data to replace file references.
|
||||
Supports both single file [File: file_id] and multiple files [Files: file_id1, file_id2, ...].
|
||||
|
||||
Args:
|
||||
data: The data to process (can be dict, list, str, or other types)
|
||||
|
||||
Returns:
|
||||
Processed data with file references replaced
|
||||
"""
|
||||
single_file_pattern = re.compile(r"^\[File:\s*([^\]]+)\]$")
|
||||
multiple_files_pattern = re.compile(r"^\[Files:\s*([^\]]+)\]$")
|
||||
|
||||
if isinstance(data, dict):
|
||||
# Process dictionary recursively
|
||||
return {key: self._process_file_reference(value) for key, value in data.items()}
|
||||
elif isinstance(data, list):
|
||||
# Process list recursively
|
||||
return [self._process_file_reference(item) for item in data]
|
||||
elif isinstance(data, str):
|
||||
# Check for single file pattern [File: file_id]
|
||||
single_match = single_file_pattern.match(data.strip())
|
||||
if single_match:
|
||||
file_id = single_match.group(1).strip()
|
||||
# Find the file in self.files
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
return file
|
||||
# If file not found, return original value
|
||||
return data
|
||||
|
||||
# Check for multiple files pattern [Files: file_id1, file_id2, ...]
|
||||
multiple_match = multiple_files_pattern.match(data.strip())
|
||||
if multiple_match:
|
||||
file_ids_str = multiple_match.group(1).strip()
|
||||
# Split by comma and strip whitespace
|
||||
file_ids = [fid.strip() for fid in file_ids_str.split(",")]
|
||||
|
||||
# Find all matching files
|
||||
matched_files: list[File] = []
|
||||
for file_id in file_ids:
|
||||
for file in self.files:
|
||||
if file.id and str(file.id) == file_id:
|
||||
matched_files.append(file)
|
||||
break
|
||||
|
||||
# Return list of files if any were found, otherwise return original
|
||||
return matched_files or data
|
||||
|
||||
return data
|
||||
else:
|
||||
# Return other types as-is
|
||||
return data
|
||||
|
||||
def _create_text_chunk(self, text: str, prompt_messages: list[PromptMessage]) -> LLMResultChunk:
|
||||
"""Create a text chunk for streaming."""
|
||||
return LLMResultChunk(
|
||||
model=self.model_instance.model_name,
|
||||
prompt_messages=prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=AssistantPromptMessage(content=text),
|
||||
usage=None,
|
||||
),
|
||||
system_fingerprint="",
|
||||
)
|
||||
|
||||
def _invoke_tool(
|
||||
self,
|
||||
tool_instance: Tool,
|
||||
tool_args: dict[str, Any],
|
||||
tool_name: str,
|
||||
) -> tuple[str, list[File], ToolInvokeMeta | None]:
|
||||
"""
|
||||
Invoke a tool and collect its response.
|
||||
|
||||
Args:
|
||||
tool_instance: The tool instance to invoke
|
||||
tool_args: Tool arguments
|
||||
tool_name: Name of the tool
|
||||
|
||||
Returns:
|
||||
Tuple of (response_content, tool_files, tool_invoke_meta)
|
||||
"""
|
||||
# Process tool_args to replace file references with actual File objects
|
||||
tool_args = self._replace_file_references(tool_args)
|
||||
|
||||
# If a tool invoke hook is set, use it instead of generic_invoke
|
||||
if self.tool_invoke_hook:
|
||||
response_content, _, tool_invoke_meta = self.tool_invoke_hook(tool_instance, tool_args, tool_name)
|
||||
# Note: message_file_ids are stored in DB, we don't convert them to File objects here
|
||||
# The caller (AgentAppRunner) handles file publishing
|
||||
return response_content, [], tool_invoke_meta
|
||||
|
||||
# Default: use generic_invoke for workflow scenarios
|
||||
# Import here to avoid circular import
|
||||
from core.tools.tool_engine import DifyWorkflowCallbackHandler, ToolEngine
|
||||
|
||||
tool_response = ToolEngine.generic_invoke(
|
||||
tool=tool_instance,
|
||||
tool_parameters=tool_args,
|
||||
user_id=self.context.user_id or "",
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
app_id=self.context.app_id,
|
||||
conversation_id=self.context.conversation_id,
|
||||
message_id=self.context.message_id,
|
||||
)
|
||||
|
||||
# Collect response and files
|
||||
response_content = ""
|
||||
tool_files: list[File] = []
|
||||
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
|
||||
response_content += response.message.text
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
# Handle link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Link: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE:
|
||||
# Handle image URL messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.IMAGE_LINK:
|
||||
# Handle image link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Image: {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BINARY_LINK:
|
||||
# Handle binary file link messages
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
filename = response.meta.get("filename", "file") if response.meta else "file"
|
||||
response_content += f"[File: {filename} - {response.message.text}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
# Handle JSON messages
|
||||
if isinstance(response.message, ToolInvokeMessage.JsonMessage):
|
||||
response_content += json.dumps(response.message.json_object, ensure_ascii=False, indent=2)
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# Handle blob messages - convert to text representation
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobMessage):
|
||||
mime_type = (
|
||||
response.meta.get("mime_type", "application/octet-stream")
|
||||
if response.meta
|
||||
else "application/octet-stream"
|
||||
)
|
||||
size = len(response.message.blob)
|
||||
response_content += f"[Binary data: {mime_type}, size: {size} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
# Handle variable messages
|
||||
if isinstance(response.message, ToolInvokeMessage.VariableMessage):
|
||||
var_name = response.message.variable_name
|
||||
var_value = response.message.variable_value
|
||||
if isinstance(var_value, str):
|
||||
response_content += var_value
|
||||
else:
|
||||
response_content += f"[Variable {var_name}: {json.dumps(var_value, ensure_ascii=False)}]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.BLOB_CHUNK:
|
||||
# Handle blob chunk messages - these are parts of a larger blob
|
||||
if isinstance(response.message, ToolInvokeMessage.BlobChunkMessage):
|
||||
response_content += f"[Blob chunk {response.message.sequence}: {len(response.message.blob)} bytes]"
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.RETRIEVER_RESOURCES:
|
||||
# Handle retriever resources messages
|
||||
if isinstance(response.message, ToolInvokeMessage.RetrieverResourceMessage):
|
||||
response_content += response.message.context
|
||||
|
||||
elif response.type == ToolInvokeMessage.MessageType.FILE:
|
||||
# Extract file from meta
|
||||
if response.meta and "file" in response.meta:
|
||||
file = response.meta["file"]
|
||||
if isinstance(file, File):
|
||||
# Check if file is for model or tool output
|
||||
if response.meta.get("target") == "self":
|
||||
# File is for model - add to files for next prompt
|
||||
self.files.append(file)
|
||||
response_content += f"File '{file.filename}' has been loaded into your context."
|
||||
else:
|
||||
# File is tool output
|
||||
tool_files.append(file)
|
||||
|
||||
return response_content, tool_files, None
|
||||
|
||||
def _validate_tool_args(self, tool_instance: Tool, tool_args: dict[str, Any]) -> str | None:
|
||||
"""Validate tool arguments against the tool's required parameters.
|
||||
|
||||
Checks that all required LLM-facing parameters are present and non-empty
|
||||
before actual execution, preventing wasted tool invocations when the model
|
||||
generates calls with missing arguments (e.g. empty ``{}``).
|
||||
|
||||
Returns:
|
||||
Error message if validation fails, None if all required parameters are satisfied.
|
||||
"""
|
||||
prompt_tool = tool_instance.to_prompt_message_tool()
|
||||
required_params: list[str] = prompt_tool.parameters.get("required", [])
|
||||
|
||||
if not required_params:
|
||||
return None
|
||||
|
||||
missing = [
|
||||
p
|
||||
for p in required_params
|
||||
if p not in tool_args
|
||||
or tool_args[p] is None
|
||||
or (isinstance(tool_args[p], str) and not tool_args[p].strip())
|
||||
]
|
||||
|
||||
if not missing:
|
||||
return None
|
||||
|
||||
return (
|
||||
f"Missing required parameter(s): {', '.join(missing)}. "
|
||||
f"Please provide all required parameters before calling this tool."
|
||||
)
|
||||
|
||||
def _find_tool_by_name(self, tool_name: str) -> Tool | None:
|
||||
"""Find a tool instance by its name."""
|
||||
for tool in self.tools:
|
||||
if tool.entity.identity.name == tool_name:
|
||||
return tool
|
||||
return None
|
||||
|
||||
def _convert_tools_to_prompt_format(self) -> list[PromptMessageTool]:
|
||||
"""Convert tools to prompt message format."""
|
||||
prompt_tools: list[PromptMessageTool] = []
|
||||
for tool in self.tools:
|
||||
prompt_tools.append(tool.to_prompt_message_tool())
|
||||
return prompt_tools
|
||||
|
||||
def _update_usage_with_empty(self, llm_usage: dict[str, Any]) -> None:
|
||||
"""Initialize usage tracking with empty usage if not set."""
|
||||
if "usage" not in llm_usage or llm_usage["usage"] is None:
|
||||
llm_usage["usage"] = LLMUsage.empty_usage()
|
||||
358
api/core/agent/patterns/function_call.py
Normal file
358
api/core/agent/patterns/function_call.py
Normal file
@@ -0,0 +1,358 @@
|
||||
"""Function Call strategy implementation.
|
||||
|
||||
Implements the Function Call agent pattern where the LLM uses native tool-calling
|
||||
capability to invoke tools. Includes pre-execution parameter validation that
|
||||
intercepts invalid calls (e.g. empty arguments) before they reach tool backends,
|
||||
and avoids counting purely-invalid rounds against the iteration budget.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from typing import Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from graphon.file import File
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
LLMUsage,
|
||||
PromptMessage,
|
||||
PromptMessageTool,
|
||||
ToolPromptMessage,
|
||||
)
|
||||
|
||||
from .base import AgentPattern
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class FunctionCallStrategy(AgentPattern):
|
||||
"""Function Call strategy using model's native tool calling capability."""
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the function call agent strategy."""
|
||||
# Convert tools to prompt format
|
||||
prompt_tools: list[PromptMessageTool] = self._convert_tools_to_prompt_format()
|
||||
|
||||
# Initialize tracking
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
function_call_state: bool = True
|
||||
total_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
messages: list[PromptMessage] = list(prompt_messages) # Create mutable copy
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
# Consecutive rounds where ALL tool calls failed parameter validation.
|
||||
# When this happens the round is "free" (iteration_step not incremented)
|
||||
# up to a safety cap to prevent infinite loops.
|
||||
consecutive_validation_failures: int = 0
|
||||
max_validation_retries: int = 3
|
||||
|
||||
while function_call_state and iteration_step <= max_iterations:
|
||||
function_call_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
# On last iteration, remove tools to force final answer
|
||||
current_tools: list[PromptMessageTool] = [] if iteration_step == max_iterations else prompt_tools
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model_name} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, LLMUsage | None] = {"usage": None}
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages,
|
||||
model_parameters=model_parameters,
|
||||
tools=current_tools,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
tool_calls, response_content, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log
|
||||
)
|
||||
messages.append(self._create_assistant_message(response_content, tool_calls))
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update final text if no tool calls (this is likely the final answer)
|
||||
if not tool_calls:
|
||||
final_text = response_content
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Process tool calls
|
||||
tool_outputs: dict[str, str] = {}
|
||||
all_validation_errors: bool = True
|
||||
if tool_calls:
|
||||
function_call_state = True
|
||||
# Execute tools (with pre-execution parameter validation)
|
||||
for tool_call_id, tool_name, tool_args in tool_calls:
|
||||
tool_response, tool_files, _, is_validation_error = yield from self._handle_tool_call(
|
||||
tool_name, tool_args, tool_call_id, messages, round_log
|
||||
)
|
||||
tool_outputs[tool_name] = tool_response
|
||||
output_files.extend(tool_files)
|
||||
if not is_validation_error:
|
||||
all_validation_errors = False
|
||||
else:
|
||||
all_validation_errors = False
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"llm_result": response_content,
|
||||
"tool_calls": [
|
||||
{"name": tc[1], "args": tc[2], "output": tool_outputs.get(tc[1], "")} for tc in tool_calls
|
||||
]
|
||||
if tool_calls
|
||||
else [],
|
||||
"final_answer": final_text if not function_call_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
|
||||
# Skip iteration counter when every tool call in this round failed validation,
|
||||
# giving the model a free retry — but cap retries to prevent infinite loops.
|
||||
if tool_calls and all_validation_errors:
|
||||
consecutive_validation_failures += 1
|
||||
if consecutive_validation_failures >= max_validation_retries:
|
||||
logger.warning(
|
||||
"Agent hit %d consecutive validation-only rounds, forcing iteration increment",
|
||||
consecutive_validation_failures,
|
||||
)
|
||||
iteration_step += 1
|
||||
consecutive_validation_failures = 0
|
||||
else:
|
||||
logger.info(
|
||||
"All tool calls failed validation (attempt %d/%d), not counting iteration",
|
||||
consecutive_validation_failures,
|
||||
max_validation_retries,
|
||||
)
|
||||
else:
|
||||
consecutive_validation_failures = 0
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text,
|
||||
files=output_files,
|
||||
usage=total_usage.get("usage") or LLMUsage.empty_usage(),
|
||||
finish_reason=finish_reason,
|
||||
)
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, LLMUsage | None],
|
||||
start_log: AgentLog,
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[list[tuple[str, str, dict[str, Any]]], str, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract tool calls and content.
|
||||
|
||||
Returns a tuple of (tool_calls, response_content, finish_reason).
|
||||
"""
|
||||
tool_calls: list[tuple[str, str, dict[str, Any]]] = []
|
||||
response_content: str = ""
|
||||
finish_reason: str | None = None
|
||||
if not isinstance(chunks, LLMResult):
|
||||
# Streaming response
|
||||
for chunk in chunks:
|
||||
# Extract tool calls
|
||||
if self._has_tool_calls(chunk):
|
||||
tool_calls.extend(self._extract_tool_calls(chunk))
|
||||
|
||||
# Extract content
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
response_content += self._extract_content(chunk.delta.message.content)
|
||||
|
||||
# Track usage
|
||||
if chunk.delta.usage:
|
||||
self._accumulate_usage(llm_usage, chunk.delta.usage)
|
||||
|
||||
# Capture finish reason
|
||||
if chunk.delta.finish_reason:
|
||||
finish_reason = chunk.delta.finish_reason
|
||||
|
||||
yield chunk
|
||||
else:
|
||||
# Non-streaming response
|
||||
result: LLMResult = chunks
|
||||
|
||||
if self._has_tool_calls_result(result):
|
||||
tool_calls.extend(self._extract_tool_calls_result(result))
|
||||
|
||||
if result.message and result.message.content:
|
||||
response_content += self._extract_content(result.message.content)
|
||||
|
||||
if result.usage:
|
||||
self._accumulate_usage(llm_usage, result.usage)
|
||||
|
||||
# Convert to streaming format
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(index=0, message=result.message, usage=result.usage),
|
||||
)
|
||||
yield self._finish_log(
|
||||
start_log,
|
||||
data={
|
||||
"result": response_content,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
return tool_calls, response_content, finish_reason
|
||||
|
||||
def _create_assistant_message(
|
||||
self, content: str, tool_calls: list[tuple[str, str, dict[str, Any]]] | None = None
|
||||
) -> AssistantPromptMessage:
|
||||
"""Create assistant message with tool calls."""
|
||||
if tool_calls is None:
|
||||
return AssistantPromptMessage(content=content)
|
||||
return AssistantPromptMessage(
|
||||
content=content or "",
|
||||
tool_calls=[
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tc[0],
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(name=tc[1], arguments=json.dumps(tc[2])),
|
||||
)
|
||||
for tc in tool_calls
|
||||
],
|
||||
)
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any],
|
||||
tool_call_id: str,
|
||||
messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File], ToolInvokeMeta | None, bool]]:
|
||||
"""Handle a single tool call and return response with files, meta, and validation status.
|
||||
|
||||
Validates required parameters before execution. When validation fails the tool
|
||||
is never invoked — a synthetic error is fed back to the model so it can self-correct
|
||||
without consuming a real iteration.
|
||||
|
||||
Returns:
|
||||
(response_content, tool_files, tool_invoke_meta, is_validation_error).
|
||||
``is_validation_error`` is True when the call was rejected due to missing
|
||||
required parameters, allowing the caller to skip the iteration counter.
|
||||
"""
|
||||
# Find tool
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
if not tool_instance:
|
||||
raise ValueError(f"Tool {tool_name} not found")
|
||||
|
||||
# Get tool metadata (provider, icon, etc.)
|
||||
tool_metadata = self._get_tool_metadata(tool_instance)
|
||||
|
||||
# Create tool call log
|
||||
tool_call_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_call_id": tool_call_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata=tool_metadata,
|
||||
)
|
||||
yield tool_call_log
|
||||
|
||||
# Validate required parameters before execution to avoid wasted invocations
|
||||
validation_error = self._validate_tool_args(tool_instance, tool_args)
|
||||
if validation_error:
|
||||
tool_call_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_call_log.error = validation_error
|
||||
tool_call_log.data = {**tool_call_log.data, "error": validation_error}
|
||||
yield tool_call_log
|
||||
|
||||
messages.append(ToolPromptMessage(content=validation_error, tool_call_id=tool_call_id, name=tool_name))
|
||||
return validation_error, [], None, True
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args, tool_name)
|
||||
|
||||
yield self._finish_log(
|
||||
tool_call_log,
|
||||
data={
|
||||
**tool_call_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
final_content = response_content or "Tool executed successfully"
|
||||
# Add tool response to messages
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=final_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return response_content, tool_files, tool_invoke_meta, False
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_call_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_call_log.error = error_message
|
||||
tool_call_log.data = {
|
||||
**tool_call_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_call_log
|
||||
|
||||
# Add error message to conversation
|
||||
error_content = f"Tool execution failed: {error_message}"
|
||||
messages.append(
|
||||
ToolPromptMessage(
|
||||
content=error_content,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
return error_content, [], None, False
|
||||
418
api/core/agent/patterns/react.py
Normal file
418
api/core/agent/patterns/react.py
Normal file
@@ -0,0 +1,418 @@
|
||||
"""ReAct strategy implementation."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, Any, Union
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult, AgentScratchpadUnit, ExecutionContext
|
||||
from core.agent.output_parser.cot_output_parser import CotAgentOutputParser
|
||||
from core.model_manager import ModelInstance
|
||||
from graphon.file import File
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResult,
|
||||
LLMResultChunk,
|
||||
LLMResultChunkDelta,
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
)
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class ReActStrategy(AgentPattern):
|
||||
"""ReAct strategy using reasoning and acting approach."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
tools: list[Tool],
|
||||
context: ExecutionContext,
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
files: list[File] = [],
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
):
|
||||
"""Initialize the ReAct strategy with instruction support."""
|
||||
super().__init__(
|
||||
model_instance=model_instance,
|
||||
tools=tools,
|
||||
context=context,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
files=files,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
self.instruction = instruction
|
||||
|
||||
def run(
|
||||
self,
|
||||
prompt_messages: list[PromptMessage],
|
||||
model_parameters: dict[str, Any],
|
||||
stop: list[str] = [],
|
||||
stream: bool = True,
|
||||
) -> Generator[LLMResultChunk | AgentLog, None, AgentResult]:
|
||||
"""Execute the ReAct agent strategy."""
|
||||
# Initialize tracking
|
||||
agent_scratchpad: list[AgentScratchpadUnit] = []
|
||||
iteration_step: int = 1
|
||||
max_iterations: int = self.max_iterations + 1
|
||||
react_state: bool = True
|
||||
total_usage: dict[str, Any] = {"usage": None}
|
||||
output_files: list[File] = [] # Track files produced by tools
|
||||
final_text: str = ""
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Add "Observation" to stop sequences
|
||||
if "Observation" not in stop:
|
||||
stop = stop.copy()
|
||||
stop.append("Observation")
|
||||
|
||||
while react_state and iteration_step <= max_iterations:
|
||||
react_state = False
|
||||
round_log = self._create_log(
|
||||
label=f"ROUND {iteration_step}",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
)
|
||||
yield round_log
|
||||
|
||||
# Build prompt with/without tools based on iteration
|
||||
include_tools = iteration_step < max_iterations
|
||||
current_messages = self._build_prompt_with_react_format(
|
||||
prompt_messages, agent_scratchpad, include_tools, self.instruction
|
||||
)
|
||||
|
||||
model_log = self._create_log(
|
||||
label=f"{self.model_instance.model_name} Thought",
|
||||
log_type=AgentLog.LogType.THOUGHT,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata={
|
||||
AgentLog.LogMetadata.PROVIDER: self.model_instance.provider,
|
||||
},
|
||||
)
|
||||
yield model_log
|
||||
|
||||
# Track usage for this round only
|
||||
round_usage: dict[str, Any] = {"usage": None}
|
||||
|
||||
# Use current messages directly (files are handled by base class if needed)
|
||||
messages_to_use = current_messages
|
||||
|
||||
# Invoke model
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult] = self.model_instance.invoke_llm(
|
||||
prompt_messages=messages_to_use,
|
||||
model_parameters=model_parameters,
|
||||
stop=stop,
|
||||
stream=stream,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
# Process response
|
||||
scratchpad, chunk_finish_reason = yield from self._handle_chunks(
|
||||
chunks, round_usage, model_log, current_messages
|
||||
)
|
||||
agent_scratchpad.append(scratchpad)
|
||||
|
||||
# Accumulate to total usage
|
||||
round_usage_value = round_usage.get("usage")
|
||||
if round_usage_value:
|
||||
self._accumulate_usage(total_usage, round_usage_value)
|
||||
|
||||
# Update finish reason
|
||||
if chunk_finish_reason:
|
||||
finish_reason = chunk_finish_reason
|
||||
|
||||
# Check if we have an action to execute
|
||||
if scratchpad.action and scratchpad.action.action_name.lower() != "final answer":
|
||||
react_state = True
|
||||
# Execute tool
|
||||
observation, tool_files = yield from self._handle_tool_call(
|
||||
scratchpad.action, current_messages, round_log
|
||||
)
|
||||
scratchpad.observation = observation
|
||||
# Track files produced by tools
|
||||
output_files.extend(tool_files)
|
||||
|
||||
# Add observation to scratchpad for display
|
||||
yield self._create_text_chunk(f"\nObservation: {observation}\n", current_messages)
|
||||
else:
|
||||
# Extract final answer
|
||||
if scratchpad.action and scratchpad.action.action_input:
|
||||
final_answer = scratchpad.action.action_input
|
||||
if isinstance(final_answer, dict):
|
||||
final_answer = json.dumps(final_answer, ensure_ascii=False)
|
||||
final_text = str(final_answer)
|
||||
elif scratchpad.thought:
|
||||
# If no action but we have thought, use thought as final answer
|
||||
final_text = scratchpad.thought
|
||||
|
||||
yield self._finish_log(
|
||||
round_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
"observation": scratchpad.observation or None,
|
||||
"final_answer": final_text if not react_state else None,
|
||||
},
|
||||
usage=round_usage.get("usage"),
|
||||
)
|
||||
iteration_step += 1
|
||||
|
||||
# Return final result
|
||||
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
return AgentResult(
|
||||
text=final_text, files=output_files, usage=total_usage.get("usage"), finish_reason=finish_reason
|
||||
)
|
||||
|
||||
def _build_prompt_with_react_format(
|
||||
self,
|
||||
original_messages: list[PromptMessage],
|
||||
agent_scratchpad: list[AgentScratchpadUnit],
|
||||
include_tools: bool = True,
|
||||
instruction: str = "",
|
||||
) -> list[PromptMessage]:
|
||||
"""Build prompt messages with ReAct format."""
|
||||
# Copy messages to avoid modifying original
|
||||
messages = list(original_messages)
|
||||
|
||||
# Find and update the system prompt that should already exist
|
||||
system_prompt_found = False
|
||||
for i, msg in enumerate(messages):
|
||||
if isinstance(msg, SystemPromptMessage):
|
||||
system_prompt_found = True
|
||||
# The system prompt from frontend already has the template, just replace placeholders
|
||||
|
||||
# Format tools
|
||||
tools_str = ""
|
||||
tool_names = []
|
||||
if include_tools and self.tools:
|
||||
# Convert tools to prompt message tools format
|
||||
prompt_tools = [tool.to_prompt_message_tool() for tool in self.tools]
|
||||
tool_names = [tool.name for tool in prompt_tools]
|
||||
|
||||
# Format tools as JSON for comprehensive information
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
|
||||
tools_str = json.dumps(jsonable_encoder(prompt_tools), indent=2)
|
||||
tool_names_str = ", ".join(f'"{name}"' for name in tool_names)
|
||||
else:
|
||||
tools_str = "No tools available"
|
||||
tool_names_str = ""
|
||||
|
||||
# Replace placeholders in the existing system prompt
|
||||
updated_content = msg.content
|
||||
assert isinstance(updated_content, str)
|
||||
updated_content = updated_content.replace("{{instruction}}", instruction)
|
||||
updated_content = updated_content.replace("{{tools}}", tools_str)
|
||||
updated_content = updated_content.replace("{{tool_names}}", tool_names_str)
|
||||
|
||||
# Create new SystemPromptMessage with updated content
|
||||
messages[i] = SystemPromptMessage(content=updated_content)
|
||||
break
|
||||
|
||||
# If no system prompt found, that's unexpected but add scratchpad anyway
|
||||
if not system_prompt_found:
|
||||
# This shouldn't happen if frontend is working correctly
|
||||
pass
|
||||
|
||||
# Format agent scratchpad
|
||||
scratchpad_str = ""
|
||||
if agent_scratchpad:
|
||||
scratchpad_parts: list[str] = []
|
||||
for unit in agent_scratchpad:
|
||||
if unit.thought:
|
||||
scratchpad_parts.append(f"Thought: {unit.thought}")
|
||||
if unit.action_str:
|
||||
scratchpad_parts.append(f"Action:\n```\n{unit.action_str}\n```")
|
||||
if unit.observation:
|
||||
scratchpad_parts.append(f"Observation: {unit.observation}")
|
||||
scratchpad_str = "\n".join(scratchpad_parts)
|
||||
|
||||
# If there's a scratchpad, append it to the last message
|
||||
if scratchpad_str:
|
||||
messages.append(AssistantPromptMessage(content=scratchpad_str))
|
||||
|
||||
return messages
|
||||
|
||||
def _handle_chunks(
|
||||
self,
|
||||
chunks: Union[Generator[LLMResultChunk, None, None], LLMResult],
|
||||
llm_usage: dict[str, Any],
|
||||
model_log: AgentLog,
|
||||
current_messages: list[PromptMessage],
|
||||
) -> Generator[
|
||||
LLMResultChunk | AgentLog,
|
||||
None,
|
||||
tuple[AgentScratchpadUnit, str | None],
|
||||
]:
|
||||
"""Handle LLM response chunks and extract action/thought.
|
||||
|
||||
Returns a tuple of (scratchpad_unit, finish_reason).
|
||||
"""
|
||||
usage_dict: dict[str, Any] = {}
|
||||
|
||||
# Convert non-streaming to streaming format if needed
|
||||
if isinstance(chunks, LLMResult):
|
||||
result = chunks
|
||||
|
||||
def result_to_chunks() -> Generator[LLMResultChunk, None, None]:
|
||||
yield LLMResultChunk(
|
||||
model=result.model,
|
||||
prompt_messages=result.prompt_messages,
|
||||
delta=LLMResultChunkDelta(
|
||||
index=0,
|
||||
message=result.message,
|
||||
usage=result.usage,
|
||||
finish_reason=None,
|
||||
),
|
||||
system_fingerprint=result.system_fingerprint or "",
|
||||
)
|
||||
|
||||
streaming_chunks = result_to_chunks()
|
||||
else:
|
||||
streaming_chunks = chunks
|
||||
|
||||
react_chunks = CotAgentOutputParser.handle_react_stream_output(streaming_chunks, usage_dict)
|
||||
|
||||
# Initialize scratchpad unit
|
||||
scratchpad = AgentScratchpadUnit(
|
||||
agent_response="",
|
||||
thought="",
|
||||
action_str="",
|
||||
observation="",
|
||||
action=None,
|
||||
)
|
||||
|
||||
finish_reason: str | None = None
|
||||
|
||||
# Process chunks
|
||||
for chunk in react_chunks:
|
||||
if isinstance(chunk, AgentScratchpadUnit.Action):
|
||||
# Action detected
|
||||
action_str = json.dumps(chunk.model_dump())
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + action_str
|
||||
scratchpad.action_str = action_str
|
||||
scratchpad.action = chunk
|
||||
|
||||
yield self._create_text_chunk(json.dumps(chunk.model_dump()), current_messages)
|
||||
else:
|
||||
# Text chunk
|
||||
chunk_text = str(chunk)
|
||||
scratchpad.agent_response = (scratchpad.agent_response or "") + chunk_text
|
||||
scratchpad.thought = (scratchpad.thought or "") + chunk_text
|
||||
|
||||
yield self._create_text_chunk(chunk_text, current_messages)
|
||||
|
||||
# Update usage
|
||||
if usage_dict.get("usage"):
|
||||
if llm_usage.get("usage"):
|
||||
self._accumulate_usage(llm_usage, usage_dict["usage"])
|
||||
else:
|
||||
llm_usage["usage"] = usage_dict["usage"]
|
||||
|
||||
# Clean up thought
|
||||
scratchpad.thought = (scratchpad.thought or "").strip() or "I am thinking about how to help you"
|
||||
|
||||
# Finish model log
|
||||
yield self._finish_log(
|
||||
model_log,
|
||||
data={
|
||||
"thought": scratchpad.thought,
|
||||
"action": scratchpad.action_str if scratchpad.action else None,
|
||||
},
|
||||
usage=llm_usage.get("usage"),
|
||||
)
|
||||
|
||||
return scratchpad, finish_reason
|
||||
|
||||
def _handle_tool_call(
|
||||
self,
|
||||
action: AgentScratchpadUnit.Action,
|
||||
prompt_messages: list[PromptMessage],
|
||||
round_log: AgentLog,
|
||||
) -> Generator[AgentLog, None, tuple[str, list[File]]]:
|
||||
"""Handle tool call and return observation with files."""
|
||||
tool_name = action.action_name
|
||||
tool_args: dict[str, Any] | str = action.action_input
|
||||
|
||||
# Find tool instance first to get metadata
|
||||
tool_instance = self._find_tool_by_name(tool_name)
|
||||
tool_metadata = self._get_tool_metadata(tool_instance) if tool_instance else {}
|
||||
|
||||
# Start tool log with tool metadata
|
||||
tool_log = self._create_log(
|
||||
label=f"CALL {tool_name}",
|
||||
log_type=AgentLog.LogType.TOOL_CALL,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={
|
||||
"tool_name": tool_name,
|
||||
"tool_args": tool_args,
|
||||
},
|
||||
parent_id=round_log.id,
|
||||
extra_metadata=tool_metadata,
|
||||
)
|
||||
yield tool_log
|
||||
|
||||
if not tool_instance:
|
||||
# Finish tool log with error
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"error": f"Tool {tool_name} not found",
|
||||
},
|
||||
)
|
||||
return f"Tool {tool_name} not found", []
|
||||
|
||||
# Ensure tool_args is a dict
|
||||
tool_args_dict: dict[str, Any]
|
||||
if isinstance(tool_args, str):
|
||||
try:
|
||||
tool_args_dict = json.loads(tool_args)
|
||||
except json.JSONDecodeError:
|
||||
tool_args_dict = {"input": tool_args}
|
||||
elif not isinstance(tool_args, dict):
|
||||
tool_args_dict = {"input": str(tool_args)}
|
||||
else:
|
||||
tool_args_dict = tool_args
|
||||
|
||||
# Invoke tool using base class method with error handling
|
||||
try:
|
||||
response_content, tool_files, tool_invoke_meta = self._invoke_tool(tool_instance, tool_args_dict, tool_name)
|
||||
|
||||
# Finish tool log
|
||||
yield self._finish_log(
|
||||
tool_log,
|
||||
data={
|
||||
**tool_log.data,
|
||||
"output": response_content,
|
||||
"files": len(tool_files),
|
||||
"meta": tool_invoke_meta.to_dict() if tool_invoke_meta else None,
|
||||
},
|
||||
)
|
||||
|
||||
return response_content or "Tool executed successfully", tool_files
|
||||
except Exception as e:
|
||||
# Tool invocation failed, yield error log
|
||||
error_message = str(e)
|
||||
tool_log.status = AgentLog.LogStatus.ERROR
|
||||
tool_log.error = error_message
|
||||
tool_log.data = {
|
||||
**tool_log.data,
|
||||
"error": error_message,
|
||||
}
|
||||
yield tool_log
|
||||
|
||||
return f"Tool execution failed: {error_message}", []
|
||||
108
api/core/agent/patterns/strategy_factory.py
Normal file
108
api/core/agent/patterns/strategy_factory.py
Normal file
@@ -0,0 +1,108 @@
|
||||
"""Strategy factory for creating agent strategies."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.agent.entities import AgentEntity, ExecutionContext
|
||||
from core.model_manager import ModelInstance
|
||||
from graphon.file.models import File
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
from .base import AgentPattern, ToolInvokeHook
|
||||
from .function_call import FunctionCallStrategy
|
||||
from .react import ReActStrategy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.tools.__base.tool import Tool
|
||||
|
||||
|
||||
class StrategyFactory:
|
||||
"""Factory for creating agent strategies based on model features."""
|
||||
|
||||
# Tool calling related features
|
||||
TOOL_CALL_FEATURES = {ModelFeature.TOOL_CALL, ModelFeature.MULTI_TOOL_CALL, ModelFeature.STREAM_TOOL_CALL}
|
||||
|
||||
@staticmethod
|
||||
def create_strategy(
|
||||
model_features: list[ModelFeature],
|
||||
model_instance: ModelInstance,
|
||||
context: ExecutionContext,
|
||||
tools: list[Tool],
|
||||
files: list[File],
|
||||
max_iterations: int = 10,
|
||||
workflow_call_depth: int = 0,
|
||||
agent_strategy: AgentEntity.Strategy | None = None,
|
||||
tool_invoke_hook: ToolInvokeHook | None = None,
|
||||
instruction: str = "",
|
||||
) -> AgentPattern:
|
||||
"""
|
||||
Create an appropriate strategy based on model features.
|
||||
|
||||
Args:
|
||||
model_features: List of model features/capabilities
|
||||
model_instance: Model instance to use
|
||||
context: Execution context containing trace/audit information
|
||||
tools: Available tools
|
||||
files: Available files
|
||||
max_iterations: Maximum iterations for the strategy
|
||||
workflow_call_depth: Depth of workflow calls
|
||||
agent_strategy: Optional explicit strategy override
|
||||
tool_invoke_hook: Optional hook for custom tool invocation (e.g., agent_invoke)
|
||||
instruction: Optional instruction for ReAct strategy
|
||||
|
||||
Returns:
|
||||
AgentStrategy instance
|
||||
"""
|
||||
|
||||
# If explicit strategy is provided and it's Function Calling, try to use it if supported
|
||||
if agent_strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
# Fallback to ReAct if FC is requested but not supported
|
||||
|
||||
# If explicit strategy is Chain of Thought (ReAct)
|
||||
if agent_strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
|
||||
# Default auto-selection logic
|
||||
if set(model_features) & StrategyFactory.TOOL_CALL_FEATURES:
|
||||
# Model supports native function calling
|
||||
return FunctionCallStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
)
|
||||
else:
|
||||
# Use ReAct strategy for models without function calling
|
||||
return ReActStrategy(
|
||||
model_instance=model_instance,
|
||||
context=context,
|
||||
tools=tools,
|
||||
files=files,
|
||||
max_iterations=max_iterations,
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
tool_invoke_hook=tool_invoke_hook,
|
||||
instruction=instruction,
|
||||
)
|
||||
@@ -177,6 +177,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# always enable retriever resource in debugger mode
|
||||
app_config.additional_features.show_retrieve_source = True # type: ignore
|
||||
|
||||
# Resolve parent_message_id for thread continuity
|
||||
if invoke_from == InvokeFrom.SERVICE_API:
|
||||
parent_message_id: str | None = UUID_NIL
|
||||
else:
|
||||
parent_message_id = args.get("parent_message_id")
|
||||
if not parent_message_id and conversation:
|
||||
parent_message_id = self._resolve_latest_message_id(conversation.id)
|
||||
|
||||
# init application generate entity
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
@@ -188,7 +196,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
),
|
||||
query=query,
|
||||
files=list(file_objs),
|
||||
parent_message_id=args.get("parent_message_id") if invoke_from != InvokeFrom.SERVICE_API else UUID_NIL,
|
||||
parent_message_id=parent_message_id,
|
||||
user_id=user.id,
|
||||
stream=streaming,
|
||||
invoke_from=invoke_from,
|
||||
@@ -689,3 +697,17 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
else:
|
||||
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
|
||||
raise e
|
||||
|
||||
@staticmethod
|
||||
def _resolve_latest_message_id(conversation_id: str) -> str | None:
|
||||
"""Auto-resolve parent_message_id to the latest message when client doesn't provide one."""
|
||||
from sqlalchemy import select
|
||||
|
||||
stmt = (
|
||||
select(Message.id)
|
||||
.where(Message.conversation_id == conversation_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(1)
|
||||
)
|
||||
latest_id = db.session.scalar(stmt)
|
||||
return str(latest_id) if latest_id else None
|
||||
|
||||
@@ -1,15 +1,12 @@
|
||||
import logging
|
||||
from typing import cast
|
||||
|
||||
from graphon.model_runtime.entities.llm_entities import LLMMode
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelPropertyKey
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
from graphon.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from sqlalchemy import select
|
||||
|
||||
from core.agent.cot_chat_agent_runner import CotChatAgentRunner
|
||||
from core.agent.cot_completion_agent_runner import CotCompletionAgentRunner
|
||||
from core.agent.agent_app_runner import AgentAppRunner
|
||||
from core.agent.entities import AgentEntity
|
||||
from core.agent.fc_agent_runner import FunctionCallAgentRunner
|
||||
from core.app.apps.agent_chat.app_config_manager import AgentChatAppConfig
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager, PublishFrom
|
||||
from core.app.apps.base_app_runner import AppRunner
|
||||
@@ -192,24 +189,8 @@ class AgentChatAppRunner(AppRunner):
|
||||
message_result = db.session.scalar(msg_stmt)
|
||||
if message_result is None:
|
||||
raise ValueError("Message not found")
|
||||
db.session.close()
|
||||
|
||||
runner_cls: type[FunctionCallAgentRunner] | type[CotChatAgentRunner] | type[CotCompletionAgentRunner]
|
||||
# start agent runner
|
||||
if agent_entity.strategy == AgentEntity.Strategy.CHAIN_OF_THOUGHT:
|
||||
# check LLM mode
|
||||
if model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.CHAT:
|
||||
runner_cls = CotChatAgentRunner
|
||||
elif model_schema.model_properties.get(ModelPropertyKey.MODE) == LLMMode.COMPLETION:
|
||||
runner_cls = CotCompletionAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid LLM mode: {model_schema.model_properties.get(ModelPropertyKey.MODE)}")
|
||||
elif agent_entity.strategy == AgentEntity.Strategy.FUNCTION_CALLING:
|
||||
runner_cls = FunctionCallAgentRunner
|
||||
else:
|
||||
raise ValueError(f"Invalid agent strategy: {agent_entity.strategy}")
|
||||
|
||||
runner = runner_cls(
|
||||
runner = AgentAppRunner(
|
||||
tenant_id=app_config.tenant_id,
|
||||
application_generate_entity=application_generate_entity,
|
||||
conversation=conversation_result,
|
||||
|
||||
53
api/core/app/apps/common/legacy_response_adapter.py
Normal file
53
api/core/app/apps/common/legacy_response_adapter.py
Normal file
@@ -0,0 +1,53 @@
|
||||
"""Legacy Response Adapter for transparent upgrade.
|
||||
|
||||
When old apps (chat/completion/agent-chat) run through the Agent V2
|
||||
workflow engine via transparent upgrade, the SSE events are in workflow
|
||||
format (workflow_started, node_started, etc.). This adapter filters out
|
||||
workflow-specific events and passes through only the events that old
|
||||
clients expect (message, message_end, etc.).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
WORKFLOW_ONLY_EVENTS = frozenset({
|
||||
"workflow_started",
|
||||
"workflow_finished",
|
||||
"node_started",
|
||||
"node_finished",
|
||||
"iteration_started",
|
||||
"iteration_next",
|
||||
"iteration_completed",
|
||||
})
|
||||
|
||||
|
||||
def adapt_workflow_stream_for_legacy(
|
||||
stream: Generator[str, None, None],
|
||||
) -> Generator[str, None, None]:
|
||||
"""Filter workflow-specific SSE events from a streaming response.
|
||||
|
||||
Passes through message, message_end, agent_log, error, ping events.
|
||||
Suppresses workflow_started, workflow_finished, node_started, node_finished.
|
||||
|
||||
This makes the SSE stream look more like what old easy-UI apps produce,
|
||||
while still carrying the actual LLM response content.
|
||||
"""
|
||||
for chunk in stream:
|
||||
if not chunk or not chunk.strip():
|
||||
yield chunk
|
||||
continue
|
||||
|
||||
try:
|
||||
if chunk.startswith("data: "):
|
||||
data = json.loads(chunk[6:])
|
||||
event = data.get("event", "")
|
||||
if event in WORKFLOW_ONLY_EVENTS:
|
||||
continue
|
||||
yield chunk
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
yield chunk
|
||||
@@ -146,8 +146,6 @@ class WorkflowBasedAppRunner:
|
||||
call_depth=0,
|
||||
)
|
||||
|
||||
# Use the provided graph_runtime_state for consistent state management
|
||||
|
||||
node_factory = DifyNodeFactory.from_graph_init_context(
|
||||
graph_init_context=graph_init_context,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
72
api/core/app/entities/llm_generation_entities.py
Normal file
72
api/core/app/entities/llm_generation_entities.py
Normal file
@@ -0,0 +1,72 @@
|
||||
"""
|
||||
LLM Generation Detail entities.
|
||||
|
||||
Defines the structure for storing and transmitting LLM generation details
|
||||
including reasoning content, tool calls, and their sequence.
|
||||
"""
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class ContentSegment(BaseModel):
|
||||
"""Represents a content segment in the generation sequence."""
|
||||
|
||||
type: Literal["content"] = "content"
|
||||
start: int = Field(..., description="Start position in the text")
|
||||
end: int = Field(..., description="End position in the text")
|
||||
|
||||
|
||||
class ReasoningSegment(BaseModel):
|
||||
"""Represents a reasoning segment in the generation sequence."""
|
||||
|
||||
type: Literal["reasoning"] = "reasoning"
|
||||
index: int = Field(..., description="Index into reasoning_content array")
|
||||
|
||||
|
||||
class ToolCallSegment(BaseModel):
|
||||
"""Represents a tool call segment in the generation sequence."""
|
||||
|
||||
type: Literal["tool_call"] = "tool_call"
|
||||
index: int = Field(..., description="Index into tool_calls array")
|
||||
|
||||
|
||||
SequenceSegment = ContentSegment | ReasoningSegment | ToolCallSegment
|
||||
|
||||
|
||||
class ToolCallDetail(BaseModel):
|
||||
"""Represents a tool call with its arguments and result."""
|
||||
|
||||
id: str = Field(default="", description="Unique identifier for the tool call")
|
||||
name: str = Field(..., description="Name of the tool")
|
||||
arguments: str = Field(default="", description="JSON string of tool arguments")
|
||||
result: str = Field(default="", description="Result from the tool execution")
|
||||
elapsed_time: float | None = Field(default=None, description="Elapsed time in seconds")
|
||||
icon: str | dict | None = Field(default=None, description="Icon of the tool")
|
||||
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
|
||||
|
||||
|
||||
class LLMGenerationDetailData(BaseModel):
|
||||
"""
|
||||
Domain model for LLM generation detail.
|
||||
|
||||
Contains the structured data for reasoning content, tool calls,
|
||||
and their display sequence.
|
||||
"""
|
||||
|
||||
reasoning_content: list[str] = Field(default_factory=list, description="List of reasoning segments")
|
||||
tool_calls: list[ToolCallDetail] = Field(default_factory=list, description="List of tool call details")
|
||||
sequence: list[SequenceSegment] = Field(default_factory=list, description="Display order of segments")
|
||||
|
||||
def is_empty(self) -> bool:
|
||||
"""Check if there's any meaningful generation detail."""
|
||||
return not self.reasoning_content and not self.tool_calls
|
||||
|
||||
def to_response_dict(self) -> dict:
|
||||
"""Convert to dictionary for API response."""
|
||||
return {
|
||||
"reasoning_content": self.reasoning_content,
|
||||
"tool_calls": [tc.model_dump() for tc in self.tool_calls],
|
||||
"sequence": [seg.model_dump() for seg in self.sequence],
|
||||
}
|
||||
75
api/core/helper/creators.py
Normal file
75
api/core/helper/creators.py
Normal file
@@ -0,0 +1,75 @@
|
||||
"""
|
||||
Helper module for Creators Platform integration.
|
||||
|
||||
Provides functionality to upload DSL files to the Creators Platform
|
||||
and generate redirect URLs with OAuth authorization codes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from urllib.parse import urlencode
|
||||
|
||||
import httpx
|
||||
from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
creators_platform_api_url = URL(str(dify_config.CREATORS_PLATFORM_API_URL))
|
||||
|
||||
|
||||
def upload_dsl(dsl_file_bytes: bytes, filename: str = "template.yaml") -> str:
|
||||
"""Upload a DSL file to the Creators Platform anonymous upload endpoint.
|
||||
|
||||
Args:
|
||||
dsl_file_bytes: Raw bytes of the DSL file (YAML or ZIP).
|
||||
filename: Original filename for the upload.
|
||||
|
||||
Returns:
|
||||
The claim_code string used to retrieve the DSL later.
|
||||
|
||||
Raises:
|
||||
httpx.HTTPStatusError: If the upload request fails.
|
||||
ValueError: If the response does not contain a valid claim_code.
|
||||
"""
|
||||
url = str(creators_platform_api_url / "api/v1/templates/anonymous-upload")
|
||||
response = httpx.post(url, files={"file": (filename, dsl_file_bytes)}, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
data = response.json()
|
||||
claim_code = data.get("data", {}).get("claim_code")
|
||||
if not claim_code:
|
||||
raise ValueError("Creators Platform did not return a valid claim_code")
|
||||
|
||||
return claim_code
|
||||
|
||||
|
||||
def get_redirect_url(user_account_id: str, claim_code: str) -> str:
|
||||
"""Generate the redirect URL to the Creators Platform frontend.
|
||||
|
||||
Redirects to the Creators Platform root page with the dsl_claim_code.
|
||||
If CREATORS_PLATFORM_OAUTH_CLIENT_ID is configured (Dify Cloud),
|
||||
also signs an OAuth authorization code so the frontend can
|
||||
automatically authenticate the user via the OAuth callback.
|
||||
|
||||
For self-hosted Dify without OAuth client_id configured, only the
|
||||
dsl_claim_code is passed and the user must log in manually.
|
||||
|
||||
Args:
|
||||
user_account_id: The Dify user account ID.
|
||||
claim_code: The claim_code obtained from upload_dsl().
|
||||
|
||||
Returns:
|
||||
The full redirect URL string.
|
||||
"""
|
||||
base_url = str(dify_config.CREATORS_PLATFORM_API_URL).rstrip("/")
|
||||
params: dict[str, str] = {"dsl_claim_code": claim_code}
|
||||
|
||||
client_id = str(dify_config.CREATORS_PLATFORM_OAUTH_CLIENT_ID or "")
|
||||
if client_id:
|
||||
from services.oauth_server import OAuthServerService
|
||||
|
||||
oauth_code = OAuthServerService.sign_oauth_authorization_code(client_id, user_account_id)
|
||||
params["oauth_code"] = oauth_code
|
||||
|
||||
return f"{base_url}?{urlencode(params)}"
|
||||
62
api/core/llm_generator/context_models.py
Normal file
62
api/core/llm_generator/context_models.py
Normal file
@@ -0,0 +1,62 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class VariableSelectorPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
variable: str = Field(..., description="Variable name used in generated code")
|
||||
value_selector: list[str] = Field(..., description="Path to upstream node output, format: [node_id, output_name]")
|
||||
|
||||
|
||||
class CodeOutputPayload(BaseModel):
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
type: str = Field(..., description="Output variable type")
|
||||
|
||||
|
||||
class CodeContextPayload(BaseModel):
|
||||
# From web/app/components/workflow/nodes/tool/components/context-generate-modal/index.tsx (code node snapshot).
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
code: str = Field(..., description="Existing code in the Code node")
|
||||
outputs: dict[str, CodeOutputPayload] | None = Field(
|
||||
default=None, description="Existing output definitions for the Code node"
|
||||
)
|
||||
variables: list[VariableSelectorPayload] | None = Field(
|
||||
default=None, description="Existing variable selectors used by the Code node"
|
||||
)
|
||||
|
||||
|
||||
class AvailableVarPayload(BaseModel):
|
||||
# From web/app/components/workflow/nodes/_base/hooks/use-available-var-list.ts (available variables).
|
||||
model_config = ConfigDict(extra="forbid", populate_by_name=True)
|
||||
|
||||
value_selector: list[str] = Field(..., description="Path to upstream node output")
|
||||
type: str = Field(..., description="Variable type, e.g. string, number, array[object]")
|
||||
description: str | None = Field(default=None, description="Optional variable description")
|
||||
node_id: str | None = Field(default=None, description="Source node ID")
|
||||
node_title: str | None = Field(default=None, description="Source node title")
|
||||
node_type: str | None = Field(default=None, description="Source node type")
|
||||
json_schema: dict[str, Any] | None = Field(
|
||||
default=None,
|
||||
alias="schema",
|
||||
description="Optional JSON schema for object variables",
|
||||
)
|
||||
|
||||
|
||||
class ParameterInfoPayload(BaseModel):
|
||||
# From web/app/components/workflow/nodes/tool/use-config.ts (ToolParameter metadata).
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str = Field(..., description="Target parameter name")
|
||||
type: str = Field(default="string", description="Target parameter type")
|
||||
description: str = Field(default="", description="Parameter description")
|
||||
required: bool | None = Field(default=None, description="Whether the parameter is required")
|
||||
options: list[str] | None = Field(default=None, description="Allowed option values")
|
||||
min: float | None = Field(default=None, description="Minimum numeric value")
|
||||
max: float | None = Field(default=None, description="Maximum numeric value")
|
||||
default: str | int | float | bool | None = Field(default=None, description="Default value")
|
||||
multiple: bool | None = Field(default=None, description="Whether the parameter accepts multiple values")
|
||||
label: str | None = Field(default=None, description="Optional display label")
|
||||
67
api/core/llm_generator/output_models.py
Normal file
67
api/core/llm_generator/output_models.py
Normal file
@@ -0,0 +1,67 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from graphon.variables.types import SegmentType
|
||||
|
||||
|
||||
class SuggestedQuestionsOutput(BaseModel):
|
||||
"""Output model for suggested questions generation."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
questions: list[str] = Field(
|
||||
min_length=3,
|
||||
max_length=3,
|
||||
description="Exactly 3 suggested follow-up questions for the user",
|
||||
)
|
||||
|
||||
|
||||
class VariableSelectorOutput(BaseModel):
|
||||
"""Variable selector mapping code variable to upstream node output.
|
||||
|
||||
Note: Separate from VariableSelector to ensure 'additionalProperties: false'
|
||||
in JSON schema for OpenAI/Azure strict mode.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
variable: str = Field(description="Variable name used in the generated code")
|
||||
value_selector: list[str] = Field(description="Path to upstream node output, format: [node_id, output_name]")
|
||||
|
||||
|
||||
class CodeNodeOutputItem(BaseModel):
|
||||
"""Single output variable definition.
|
||||
|
||||
Note: OpenAI/Azure strict mode requires 'additionalProperties: false' and
|
||||
does not support dynamic object keys, so outputs use array format.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
name: str = Field(description="Output variable name returned by the main function")
|
||||
type: SegmentType = Field(description="Data type of the output variable")
|
||||
|
||||
|
||||
class CodeNodeStructuredOutput(BaseModel):
|
||||
"""Structured output for code node generation."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
variables: list[VariableSelectorOutput] = Field(
|
||||
description="Input variables mapping code variables to upstream node outputs"
|
||||
)
|
||||
code: str = Field(description="Generated code with a main function that processes inputs and returns outputs")
|
||||
outputs: list[CodeNodeOutputItem] = Field(
|
||||
description="Output variable definitions specifying name and type for each return value"
|
||||
)
|
||||
message: str = Field(description="Brief explanation of what the generated code does")
|
||||
|
||||
|
||||
class InstructionModifyOutput(BaseModel):
|
||||
"""Output model for instruction-based prompt modification."""
|
||||
|
||||
model_config = ConfigDict(extra="forbid")
|
||||
|
||||
modified: str = Field(description="The modified prompt content after applying the instruction")
|
||||
message: str = Field(description="Brief explanation of what changes were made")
|
||||
203
api/core/llm_generator/output_parser/file_ref.py
Normal file
203
api/core/llm_generator/output_parser/file_ref.py
Normal file
@@ -0,0 +1,203 @@
|
||||
"""
|
||||
File path detection and conversion for structured output.
|
||||
|
||||
This module provides utilities to:
|
||||
1. Detect sandbox file path fields in JSON Schema (format: "file-path")
|
||||
2. Adapt schemas to add file-path descriptions before model invocation
|
||||
3. Convert sandbox file path strings into File objects via a resolver
|
||||
"""
|
||||
|
||||
from collections.abc import Callable, Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from graphon.file import File
|
||||
from graphon.variables.segments import ArrayFileSegment, FileSegment
|
||||
|
||||
FILE_PATH_FORMAT = "file-path"
|
||||
FILE_PATH_DESCRIPTION_SUFFIX = "this field contains a file path from the Dify sandbox"
|
||||
|
||||
|
||||
def is_file_path_property(schema: Mapping[str, Any]) -> bool:
|
||||
"""Check if a schema property represents a sandbox file path."""
|
||||
if schema.get("type") != "string":
|
||||
return False
|
||||
format_value = schema.get("format")
|
||||
if not isinstance(format_value, str):
|
||||
return False
|
||||
normalized_format = format_value.lower().replace("_", "-")
|
||||
return normalized_format == FILE_PATH_FORMAT
|
||||
|
||||
|
||||
def detect_file_path_fields(schema: Mapping[str, Any], path: str = "") -> list[str]:
|
||||
"""Recursively detect file path fields in a JSON schema."""
|
||||
file_path_fields: list[str] = []
|
||||
schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "object":
|
||||
properties = schema.get("properties")
|
||||
if isinstance(properties, Mapping):
|
||||
properties_mapping = cast(Mapping[str, Any], properties)
|
||||
for prop_name, prop_schema in properties_mapping.items():
|
||||
if not isinstance(prop_schema, Mapping):
|
||||
continue
|
||||
prop_schema_mapping = cast(Mapping[str, Any], prop_schema)
|
||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||
|
||||
if is_file_path_property(prop_schema_mapping):
|
||||
file_path_fields.append(current_path)
|
||||
else:
|
||||
file_path_fields.extend(detect_file_path_fields(prop_schema_mapping, current_path))
|
||||
|
||||
elif schema_type == "array":
|
||||
items_schema = schema.get("items")
|
||||
if not isinstance(items_schema, Mapping):
|
||||
return file_path_fields
|
||||
items_schema_mapping = cast(Mapping[str, Any], items_schema)
|
||||
array_path = f"{path}[*]" if path else "[*]"
|
||||
|
||||
if is_file_path_property(items_schema_mapping):
|
||||
file_path_fields.append(array_path)
|
||||
else:
|
||||
file_path_fields.extend(detect_file_path_fields(items_schema_mapping, array_path))
|
||||
|
||||
return file_path_fields
|
||||
|
||||
|
||||
def adapt_schema_for_sandbox_file_paths(schema: Mapping[str, Any]) -> tuple[dict[str, Any], list[str]]:
|
||||
"""Normalize sandbox file path fields and collect their JSON paths."""
|
||||
result = _deep_copy_value(schema)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError("structured_output_schema must be a JSON object")
|
||||
result_dict = cast(dict[str, Any], result)
|
||||
|
||||
file_path_fields: list[str] = []
|
||||
_adapt_schema_in_place(result_dict, path="", file_path_fields=file_path_fields)
|
||||
return result_dict, file_path_fields
|
||||
|
||||
|
||||
def convert_sandbox_file_paths_in_output(
|
||||
output: Mapping[str, Any],
|
||||
file_path_fields: Sequence[str],
|
||||
file_resolver: Callable[[str], File],
|
||||
) -> tuple[dict[str, Any], list[File]]:
|
||||
"""Convert sandbox file paths into File objects using the resolver."""
|
||||
if not file_path_fields:
|
||||
return dict(output), []
|
||||
|
||||
result = _deep_copy_value(output)
|
||||
if not isinstance(result, dict):
|
||||
raise ValueError("Structured output must be a JSON object")
|
||||
result_dict = cast(dict[str, Any], result)
|
||||
|
||||
files: list[File] = []
|
||||
for path in file_path_fields:
|
||||
_convert_path_in_place(result_dict, path.split("."), file_resolver, files)
|
||||
|
||||
return result_dict, files
|
||||
|
||||
|
||||
def _adapt_schema_in_place(schema: dict[str, Any], path: str, file_path_fields: list[str]) -> None:
|
||||
schema_type = schema.get("type")
|
||||
|
||||
if schema_type == "object":
|
||||
properties = schema.get("properties")
|
||||
if isinstance(properties, Mapping):
|
||||
properties_mapping = cast(Mapping[str, Any], properties)
|
||||
for prop_name, prop_schema in properties_mapping.items():
|
||||
if not isinstance(prop_schema, dict):
|
||||
continue
|
||||
prop_schema_dict = cast(dict[str, Any], prop_schema)
|
||||
current_path = f"{path}.{prop_name}" if path else prop_name
|
||||
|
||||
if is_file_path_property(prop_schema_dict):
|
||||
_normalize_file_path_schema(prop_schema_dict)
|
||||
file_path_fields.append(current_path)
|
||||
else:
|
||||
_adapt_schema_in_place(prop_schema_dict, current_path, file_path_fields)
|
||||
|
||||
elif schema_type == "array":
|
||||
items_schema = schema.get("items")
|
||||
if not isinstance(items_schema, dict):
|
||||
return
|
||||
items_schema_dict = cast(dict[str, Any], items_schema)
|
||||
array_path = f"{path}[*]" if path else "[*]"
|
||||
|
||||
if is_file_path_property(items_schema_dict):
|
||||
_normalize_file_path_schema(items_schema_dict)
|
||||
file_path_fields.append(array_path)
|
||||
else:
|
||||
_adapt_schema_in_place(items_schema_dict, array_path, file_path_fields)
|
||||
|
||||
|
||||
def _normalize_file_path_schema(schema: dict[str, Any]) -> None:
|
||||
schema["type"] = "string"
|
||||
schema["format"] = FILE_PATH_FORMAT
|
||||
description = schema.get("description", "")
|
||||
if description:
|
||||
if FILE_PATH_DESCRIPTION_SUFFIX not in description:
|
||||
schema["description"] = f"{description}\n{FILE_PATH_DESCRIPTION_SUFFIX}"
|
||||
else:
|
||||
schema["description"] = FILE_PATH_DESCRIPTION_SUFFIX
|
||||
|
||||
|
||||
def _deep_copy_value(value: Any) -> Any:
|
||||
if isinstance(value, Mapping):
|
||||
mapping = cast(Mapping[str, Any], value)
|
||||
return {key: _deep_copy_value(item) for key, item in mapping.items()}
|
||||
if isinstance(value, list):
|
||||
list_value = cast(list[Any], value)
|
||||
return [_deep_copy_value(item) for item in list_value]
|
||||
return value
|
||||
|
||||
|
||||
def _convert_path_in_place(
|
||||
obj: dict[str, Any],
|
||||
path_parts: list[str],
|
||||
file_resolver: Callable[[str], File],
|
||||
files: list[File],
|
||||
) -> None:
|
||||
if not path_parts:
|
||||
return
|
||||
|
||||
current = path_parts[0]
|
||||
remaining = path_parts[1:]
|
||||
|
||||
if current.endswith("[*]"):
|
||||
key = current[:-3] if current != "[*]" else ""
|
||||
target_value = obj.get(key) if key else obj
|
||||
|
||||
if isinstance(target_value, list):
|
||||
target_list = cast(list[Any], target_value)
|
||||
if remaining:
|
||||
for item in target_list:
|
||||
if isinstance(item, dict):
|
||||
item_dict = cast(dict[str, Any], item)
|
||||
_convert_path_in_place(item_dict, remaining, file_resolver, files)
|
||||
else:
|
||||
resolved_files: list[File] = []
|
||||
for item in target_list:
|
||||
if not isinstance(item, str):
|
||||
raise ValueError("File path must be a string")
|
||||
file = file_resolver(item)
|
||||
files.append(file)
|
||||
resolved_files.append(file)
|
||||
if key:
|
||||
obj[key] = ArrayFileSegment(value=resolved_files)
|
||||
return
|
||||
|
||||
if not remaining:
|
||||
if current not in obj:
|
||||
return
|
||||
value = obj[current]
|
||||
if value is None:
|
||||
obj[current] = None
|
||||
return
|
||||
if not isinstance(value, str):
|
||||
raise ValueError("File path must be a string")
|
||||
file = file_resolver(value)
|
||||
files.append(file)
|
||||
obj[current] = FileSegment(value=file)
|
||||
return
|
||||
|
||||
if current in obj and isinstance(obj[current], dict):
|
||||
_convert_path_in_place(obj[current], remaining, file_resolver, files)
|
||||
45
api/core/llm_generator/utils.py
Normal file
45
api/core/llm_generator/utils.py
Normal file
@@ -0,0 +1,45 @@
|
||||
"""Utility functions for LLM generator."""
|
||||
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
|
||||
|
||||
def deserialize_prompt_messages(messages: list[dict]) -> list[PromptMessage]:
|
||||
"""
|
||||
Deserialize list of dicts to list[PromptMessage].
|
||||
|
||||
Expected format:
|
||||
[
|
||||
{"role": "user", "content": "..."},
|
||||
{"role": "assistant", "content": "..."},
|
||||
]
|
||||
"""
|
||||
result: list[PromptMessage] = []
|
||||
for msg in messages:
|
||||
role = PromptMessageRole.value_of(msg["role"])
|
||||
content = msg.get("content", "")
|
||||
|
||||
match role:
|
||||
case PromptMessageRole.USER:
|
||||
result.append(UserPromptMessage(content=content))
|
||||
case PromptMessageRole.ASSISTANT:
|
||||
result.append(AssistantPromptMessage(content=content))
|
||||
case PromptMessageRole.SYSTEM:
|
||||
result.append(SystemPromptMessage(content=content))
|
||||
case PromptMessageRole.TOOL:
|
||||
result.append(ToolPromptMessage(content=content, tool_call_id=msg.get("tool_call_id", "")))
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def serialize_prompt_messages(messages: list[PromptMessage]) -> list[dict]:
|
||||
"""
|
||||
Serialize list[PromptMessage] to list of dicts.
|
||||
"""
|
||||
return [{"role": msg.role.value, "content": msg.content} for msg in messages]
|
||||
11
api/core/memory/__init__.py
Normal file
11
api/core/memory/__init__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.node_token_buffer_memory import (
|
||||
NodeTokenBufferMemory,
|
||||
)
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
|
||||
__all__ = [
|
||||
"BaseMemory",
|
||||
"NodeTokenBufferMemory",
|
||||
"TokenBufferMemory",
|
||||
]
|
||||
82
api/core/memory/base.py
Normal file
82
api/core/memory/base.py
Normal file
@@ -0,0 +1,82 @@
|
||||
"""
|
||||
Base memory interfaces and types.
|
||||
|
||||
This module defines the common protocol for memory implementations.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Sequence
|
||||
|
||||
from graphon.model_runtime.entities import ImagePromptMessageContent, PromptMessage
|
||||
|
||||
|
||||
class BaseMemory(ABC):
|
||||
"""
|
||||
Abstract base class for memory implementations.
|
||||
|
||||
Provides a common interface for both conversation-level and node-level memory.
|
||||
"""
|
||||
|
||||
@abstractmethod
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Get history prompt messages.
|
||||
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: Maximum number of messages
|
||||
:return: Sequence of PromptMessage for LLM context
|
||||
"""
|
||||
pass
|
||||
|
||||
def get_history_prompt_text(
|
||||
self,
|
||||
human_prefix: str = "Human",
|
||||
ai_prefix: str = "Assistant",
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> str:
|
||||
"""
|
||||
Get history prompt as formatted text.
|
||||
|
||||
:param human_prefix: Prefix for human messages
|
||||
:param ai_prefix: Prefix for assistant messages
|
||||
:param max_token_limit: Maximum tokens for history
|
||||
:param message_limit: Maximum number of messages
|
||||
:return: Formatted history text
|
||||
"""
|
||||
from graphon.model_runtime.entities import (
|
||||
PromptMessageRole,
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
|
||||
prompt_messages = self.get_history_prompt_messages(
|
||||
max_token_limit=max_token_limit,
|
||||
message_limit=message_limit,
|
||||
)
|
||||
|
||||
string_messages = []
|
||||
for m in prompt_messages:
|
||||
if m.role == PromptMessageRole.USER:
|
||||
role = human_prefix
|
||||
elif m.role == PromptMessageRole.ASSISTANT:
|
||||
role = ai_prefix
|
||||
else:
|
||||
continue
|
||||
|
||||
if isinstance(m.content, list):
|
||||
inner_msg = ""
|
||||
for content in m.content:
|
||||
if isinstance(content, TextPromptMessageContent):
|
||||
inner_msg += f"{content.data}\n"
|
||||
elif isinstance(content, ImagePromptMessageContent):
|
||||
inner_msg += "[image]\n"
|
||||
string_messages.append(f"{role}: {inner_msg.strip()}")
|
||||
else:
|
||||
message = f"{role}: {m.content}"
|
||||
string_messages.append(message)
|
||||
|
||||
return "\n".join(string_messages)
|
||||
196
api/core/memory/node_token_buffer_memory.py
Normal file
196
api/core/memory/node_token_buffer_memory.py
Normal file
@@ -0,0 +1,196 @@
|
||||
"""
|
||||
Node-level Token Buffer Memory for Chatflow.
|
||||
|
||||
This module provides node-scoped memory within a conversation.
|
||||
Each LLM node in a workflow can maintain its own independent conversation history.
|
||||
|
||||
Note: This is only available in Chatflow (advanced-chat mode) because it requires
|
||||
both conversation_id and node_id.
|
||||
|
||||
Design:
|
||||
- History is read directly from WorkflowNodeExecutionModel.outputs["context"]
|
||||
- No separate storage needed - the context is already saved during node execution
|
||||
- Thread tracking leverages Message table's parent_message_id structure
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
from typing import cast
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.memory.base import BaseMemory
|
||||
from core.model_manager import ModelInstance
|
||||
from core.prompt.utils.extract_thread_messages import extract_thread_messages
|
||||
from graphon.file import file_manager
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
MultiModalPromptMessageContent,
|
||||
PromptMessage,
|
||||
PromptMessageRole,
|
||||
SystemPromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.entities.message_entities import PromptMessageContentUnionTypes
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class NodeTokenBufferMemory(BaseMemory):
|
||||
"""
|
||||
Node-level Token Buffer Memory.
|
||||
|
||||
Provides node-scoped memory within a conversation. Each LLM node can maintain
|
||||
its own independent conversation history.
|
||||
|
||||
Key design: History is read directly from WorkflowNodeExecutionModel.outputs["context"],
|
||||
which is already saved during node execution. No separate storage needed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app_id: str,
|
||||
conversation_id: str,
|
||||
node_id: str,
|
||||
tenant_id: str,
|
||||
model_instance: ModelInstance,
|
||||
):
|
||||
self.app_id = app_id
|
||||
self.conversation_id = conversation_id
|
||||
self.node_id = node_id
|
||||
self.tenant_id = tenant_id
|
||||
self.model_instance = model_instance
|
||||
|
||||
def _get_thread_workflow_run_ids(self) -> list[str]:
|
||||
"""
|
||||
Get workflow_run_ids for the current thread by querying Message table.
|
||||
Returns workflow_run_ids in chronological order (oldest first).
|
||||
"""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = (
|
||||
select(Message)
|
||||
.where(Message.conversation_id == self.conversation_id)
|
||||
.order_by(Message.created_at.desc())
|
||||
.limit(500)
|
||||
)
|
||||
messages = list(session.scalars(stmt).all())
|
||||
|
||||
if not messages:
|
||||
return []
|
||||
|
||||
# Extract thread messages using existing logic
|
||||
thread_messages = extract_thread_messages(messages)
|
||||
|
||||
# For newly created message, its answer is temporarily empty, skip it
|
||||
if thread_messages and not thread_messages[0].answer and thread_messages[0].answer_tokens == 0:
|
||||
thread_messages.pop(0)
|
||||
|
||||
# Reverse to get chronological order, extract workflow_run_ids
|
||||
return [msg.workflow_run_id for msg in reversed(thread_messages) if msg.workflow_run_id]
|
||||
|
||||
def _deserialize_prompt_message(self, msg_dict: dict) -> PromptMessage:
|
||||
"""Deserialize a dict to PromptMessage based on role."""
|
||||
role = msg_dict.get("role")
|
||||
if role in (PromptMessageRole.USER, "user"):
|
||||
return UserPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.ASSISTANT, "assistant"):
|
||||
return AssistantPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.SYSTEM, "system"):
|
||||
return SystemPromptMessage.model_validate(msg_dict)
|
||||
elif role in (PromptMessageRole.TOOL, "tool"):
|
||||
return ToolPromptMessage.model_validate(msg_dict)
|
||||
else:
|
||||
return PromptMessage.model_validate(msg_dict)
|
||||
|
||||
def _deserialize_context(self, context_data: list[dict]) -> list[PromptMessage]:
|
||||
"""Deserialize context data from outputs to list of PromptMessage."""
|
||||
messages = []
|
||||
for msg_dict in context_data:
|
||||
try:
|
||||
msg = self._deserialize_prompt_message(msg_dict)
|
||||
msg = self._restore_multimodal_content(msg)
|
||||
messages.append(msg)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to deserialize prompt message: %s", e)
|
||||
return messages
|
||||
|
||||
def _restore_multimodal_content(self, message: PromptMessage) -> PromptMessage:
|
||||
"""
|
||||
Restore multimodal content (base64 or url) from file_ref.
|
||||
|
||||
When context is saved, base64_data is cleared to save storage space.
|
||||
This method restores the content by parsing file_ref (format: "method:id_or_url").
|
||||
"""
|
||||
content = message.content
|
||||
if content is None or isinstance(content, str):
|
||||
return message
|
||||
|
||||
# Process list content, restoring multimodal data from file references
|
||||
restored_content: list[PromptMessageContentUnionTypes] = []
|
||||
for item in content:
|
||||
if isinstance(item, MultiModalPromptMessageContent):
|
||||
# restore_multimodal_content preserves the concrete subclass type
|
||||
restored_item = file_manager.restore_multimodal_content(item)
|
||||
restored_content.append(cast(PromptMessageContentUnionTypes, restored_item))
|
||||
else:
|
||||
restored_content.append(item)
|
||||
|
||||
return message.model_copy(update={"content": restored_content})
|
||||
|
||||
def get_history_prompt_messages(
|
||||
self,
|
||||
max_token_limit: int = 2000,
|
||||
message_limit: int | None = None,
|
||||
) -> Sequence[PromptMessage]:
|
||||
"""
|
||||
Retrieve history as PromptMessage sequence.
|
||||
History is read directly from the last completed node execution's outputs["context"].
|
||||
"""
|
||||
_ = message_limit # unused, kept for interface compatibility
|
||||
|
||||
thread_workflow_run_ids = self._get_thread_workflow_run_ids()
|
||||
if not thread_workflow_run_ids:
|
||||
return []
|
||||
|
||||
# Get the last completed workflow_run_id (contains accumulated context)
|
||||
last_run_id = thread_workflow_run_ids[-1]
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.workflow_run_id == last_run_id,
|
||||
WorkflowNodeExecutionModel.node_id == self.node_id,
|
||||
WorkflowNodeExecutionModel.status == "succeeded",
|
||||
)
|
||||
execution = session.scalars(stmt).first()
|
||||
|
||||
if not execution:
|
||||
return []
|
||||
|
||||
outputs = execution.outputs_dict
|
||||
if not outputs:
|
||||
return []
|
||||
|
||||
context_data = outputs.get("context")
|
||||
|
||||
if not context_data or not isinstance(context_data, list):
|
||||
return []
|
||||
|
||||
prompt_messages = self._deserialize_context(context_data)
|
||||
if not prompt_messages:
|
||||
return []
|
||||
|
||||
# Truncate by token limit
|
||||
try:
|
||||
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
while current_tokens > max_token_limit and len(prompt_messages) > 1:
|
||||
prompt_messages.pop(0)
|
||||
current_tokens = self.model_instance.get_llm_num_tokens(prompt_messages)
|
||||
except Exception as e:
|
||||
logger.warning("Failed to count tokens for truncation: %s", e)
|
||||
|
||||
return prompt_messages
|
||||
@@ -64,7 +64,7 @@ class TokenBufferMemory:
|
||||
match self.conversation.mode:
|
||||
case AppMode.AGENT_CHAT | AppMode.COMPLETION | AppMode.CHAT:
|
||||
file_extra_config = FileUploadConfigManager.convert(self.conversation.model_config)
|
||||
case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW:
|
||||
case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW | AppMode.AGENT:
|
||||
app = self.conversation.app
|
||||
if not app:
|
||||
raise ValueError("App not found for conversation")
|
||||
|
||||
@@ -8,6 +8,8 @@ from typing import TYPE_CHECKING, Any
|
||||
if TYPE_CHECKING: # pragma: no cover
|
||||
from models.model import File
|
||||
|
||||
from graphon.model_runtime.entities import PromptMessageTool
|
||||
|
||||
from core.tools.__base.tool_runtime import ToolRuntime
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolEntity,
|
||||
@@ -154,6 +156,61 @@ class Tool(ABC):
|
||||
|
||||
return parameters
|
||||
|
||||
def to_prompt_message_tool(self) -> PromptMessageTool:
|
||||
"""Convert this tool to a PromptMessageTool for LLM consumption."""
|
||||
message_tool = PromptMessageTool(
|
||||
name=self.entity.identity.name,
|
||||
description=self.entity.description.llm if self.entity.description else "",
|
||||
parameters={
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
},
|
||||
)
|
||||
|
||||
parameters = self.get_merged_runtime_parameters()
|
||||
for parameter in parameters:
|
||||
if parameter.form != ToolParameter.ToolParameterForm.LLM:
|
||||
continue
|
||||
|
||||
parameter_type = parameter.type.as_normal_type()
|
||||
if parameter.type in {
|
||||
ToolParameter.ToolParameterType.SYSTEM_FILES,
|
||||
ToolParameter.ToolParameterType.FILE,
|
||||
ToolParameter.ToolParameterType.FILES,
|
||||
}:
|
||||
if parameter.type == ToolParameter.ToolParameterType.FILE:
|
||||
file_format_desc = " Input the file id with format: [File: file_id]."
|
||||
else:
|
||||
file_format_desc = "Input the file id with format: [Files: file_id1, file_id2, ...]. "
|
||||
|
||||
message_tool.parameters["properties"][parameter.name] = {
|
||||
"type": "string",
|
||||
"description": (parameter.llm_description or "") + file_format_desc,
|
||||
}
|
||||
continue
|
||||
|
||||
enum = []
|
||||
if parameter.type == ToolParameter.ToolParameterType.SELECT:
|
||||
enum = [option.value for option in parameter.options] if parameter.options else []
|
||||
|
||||
message_tool.parameters["properties"][parameter.name] = (
|
||||
{
|
||||
"type": parameter_type,
|
||||
"description": parameter.llm_description or "",
|
||||
}
|
||||
if parameter.input_schema is None
|
||||
else parameter.input_schema
|
||||
)
|
||||
|
||||
if len(enum) > 0:
|
||||
message_tool.parameters["properties"][parameter.name]["enum"] = enum
|
||||
|
||||
if parameter.required:
|
||||
message_tool.parameters["required"].append(parameter.name)
|
||||
|
||||
return message_tool
|
||||
|
||||
def create_image_message(
|
||||
self,
|
||||
image: str,
|
||||
|
||||
187
api/core/tools/utils/system_encryption.py
Normal file
187
api/core/tools/utils/system_encryption.py
Normal file
@@ -0,0 +1,187 @@
|
||||
import base64
|
||||
import hashlib
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from Crypto.Cipher import AES
|
||||
from Crypto.Random import get_random_bytes
|
||||
from Crypto.Util.Padding import pad, unpad
|
||||
from pydantic import TypeAdapter
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class EncryptionError(Exception):
|
||||
"""Encryption/decryption specific error"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class SystemEncrypter:
|
||||
"""
|
||||
A simple parameters encrypter using AES-CBC encryption.
|
||||
|
||||
This class provides methods to encrypt and decrypt parameters
|
||||
using AES-CBC mode with a key derived from the application's SECRET_KEY.
|
||||
"""
|
||||
|
||||
def __init__(self, secret_key: str | None = None):
|
||||
"""
|
||||
Initialize the encrypter.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Raises:
|
||||
ValueError: If SECRET_KEY is not configured or empty
|
||||
"""
|
||||
secret_key = secret_key or dify_config.SECRET_KEY or ""
|
||||
|
||||
# Generate a fixed 256-bit key using SHA-256
|
||||
self.key = hashlib.sha256(secret_key.encode()).digest()
|
||||
|
||||
def encrypt_params(self, params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt parameters.
|
||||
|
||||
Args:
|
||||
params: parameters dictionary, e.g., {"client_id": "xxx", "client_secret": "xxx"}
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
|
||||
Raises:
|
||||
EncryptionError: If encryption fails
|
||||
ValueError: If params is invalid
|
||||
"""
|
||||
|
||||
try:
|
||||
# Generate random IV (16 bytes)
|
||||
iv = get_random_bytes(16)
|
||||
|
||||
# Create AES cipher (CBC mode)
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Encrypt data
|
||||
padded_data = pad(TypeAdapter(dict).dump_json(dict(params)), AES.block_size)
|
||||
encrypted_data = cipher.encrypt(padded_data)
|
||||
|
||||
# Combine IV and encrypted data
|
||||
combined = iv + encrypted_data
|
||||
|
||||
# Return base64 encoded string
|
||||
return base64.b64encode(combined).decode()
|
||||
|
||||
except Exception as e:
|
||||
raise EncryptionError(f"Encryption failed: {str(e)}") from e
|
||||
|
||||
def decrypt_params(self, encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt parameters.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted parameters dictionary
|
||||
|
||||
Raises:
|
||||
EncryptionError: If decryption fails
|
||||
ValueError: If encrypted_data is invalid
|
||||
"""
|
||||
if not isinstance(encrypted_data, str):
|
||||
raise ValueError("encrypted_data must be a string")
|
||||
|
||||
if not encrypted_data:
|
||||
raise ValueError("encrypted_data cannot be empty")
|
||||
|
||||
try:
|
||||
# Base64 decode
|
||||
combined = base64.b64decode(encrypted_data)
|
||||
|
||||
# Check minimum length (IV + at least one AES block)
|
||||
if len(combined) < 32: # 16 bytes IV + 16 bytes minimum encrypted data
|
||||
raise ValueError("Invalid encrypted data format")
|
||||
|
||||
# Separate IV and encrypted data
|
||||
iv = combined[:16]
|
||||
encrypted_data_bytes = combined[16:]
|
||||
|
||||
# Create AES cipher
|
||||
cipher = AES.new(self.key, AES.MODE_CBC, iv)
|
||||
|
||||
# Decrypt data
|
||||
decrypted_data = cipher.decrypt(encrypted_data_bytes)
|
||||
unpadded_data = unpad(decrypted_data, AES.block_size)
|
||||
|
||||
# Parse JSON
|
||||
params: Mapping[str, Any] = TypeAdapter(Mapping[str, Any]).validate_json(unpadded_data)
|
||||
|
||||
if not isinstance(params, dict):
|
||||
raise ValueError("Decrypted data is not a valid dictionary")
|
||||
|
||||
return params
|
||||
|
||||
except Exception as e:
|
||||
raise EncryptionError(f"Decryption failed: {str(e)}") from e
|
||||
|
||||
|
||||
# Factory function for creating encrypter instances
|
||||
def create_system_encrypter(secret_key: str | None = None) -> SystemEncrypter:
|
||||
"""
|
||||
Create an encrypter instance.
|
||||
|
||||
Args:
|
||||
secret_key: Optional secret key. If not provided, uses dify_config.SECRET_KEY
|
||||
|
||||
Returns:
|
||||
SystemEncrypter instance
|
||||
"""
|
||||
return SystemEncrypter(secret_key=secret_key)
|
||||
|
||||
|
||||
# Global encrypter instance (for backward compatibility)
|
||||
_encrypter: SystemEncrypter | None = None
|
||||
|
||||
|
||||
def get_system_encrypter() -> SystemEncrypter:
|
||||
"""
|
||||
Get the global encrypter instance.
|
||||
|
||||
Returns:
|
||||
SystemEncrypter instance
|
||||
"""
|
||||
global _encrypter
|
||||
if _encrypter is None:
|
||||
_encrypter = SystemEncrypter()
|
||||
return _encrypter
|
||||
|
||||
|
||||
# Convenience functions for backward compatibility
|
||||
def encrypt_system_params(params: Mapping[str, Any]) -> str:
|
||||
"""
|
||||
Encrypt parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
params: parameters dictionary
|
||||
|
||||
Returns:
|
||||
Base64-encoded encrypted string
|
||||
"""
|
||||
return get_system_encrypter().encrypt_params(params)
|
||||
|
||||
|
||||
def decrypt_system_params(encrypted_data: str) -> Mapping[str, Any]:
|
||||
"""
|
||||
Decrypt parameters using the global encrypter.
|
||||
|
||||
Args:
|
||||
encrypted_data: Base64-encoded encrypted string
|
||||
|
||||
Returns:
|
||||
Decrypted parameters dictionary
|
||||
"""
|
||||
return get_system_encrypter().decrypt_params(encrypted_data)
|
||||
@@ -53,6 +53,9 @@ from core.workflow.nodes.agent.plugin_strategy_adapter import (
|
||||
PluginAgentStrategyResolver,
|
||||
)
|
||||
from core.workflow.nodes.agent.runtime_support import AgentRuntimeSupport
|
||||
from core.workflow.nodes.agent_v2.entities import AGENT_V2_NODE_TYPE
|
||||
from core.workflow.nodes.agent_v2.event_adapter import AgentV2EventAdapter
|
||||
from core.workflow.nodes.agent_v2.tool_manager import AgentV2ToolManager
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text, system_variable_selector
|
||||
from core.workflow.template_rendering import CodeExecutorJinja2TemplateRenderer
|
||||
from extensions.ext_database import db
|
||||
@@ -367,6 +370,11 @@ class DifyNodeFactory(NodeFactory):
|
||||
typed_node_config = NodeConfigDictAdapter.validate_python(normalize_node_config_for_graph(node_config))
|
||||
node_id = typed_node_config["id"]
|
||||
node_data = typed_node_config["data"]
|
||||
|
||||
if node_data.type == BuiltinNodeTypes.LLM and dify_config.AGENT_V2_REPLACES_LLM:
|
||||
node_data = self._remap_llm_to_agent_v2(node_data)
|
||||
typed_node_config["data"] = node_data
|
||||
|
||||
node_class = self._resolve_node_class(node_type=node_data.type, node_version=str(node_data.version))
|
||||
node_type = node_data.type
|
||||
node_init_kwargs_factories: Mapping[NodeType, Callable[[], dict[str, object]]] = {
|
||||
@@ -433,6 +441,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
"runtime_support": self._agent_runtime_support,
|
||||
"message_transformer": self._agent_message_transformer,
|
||||
},
|
||||
AGENT_V2_NODE_TYPE: lambda: self._build_agent_v2_kwargs(node_data),
|
||||
}
|
||||
node_init_kwargs = node_init_kwargs_factories.get(node_type, lambda: {})()
|
||||
return node_class(
|
||||
@@ -443,6 +452,71 @@ class DifyNodeFactory(NodeFactory):
|
||||
**node_init_kwargs,
|
||||
)
|
||||
|
||||
def _build_agent_v2_kwargs(self, node_data: BaseNodeData) -> dict[str, object]:
|
||||
"""Build initialization kwargs for Agent V2 node.
|
||||
|
||||
Injects memory (same mechanism as LLM Node) plus tool_manager
|
||||
and event_adapter.
|
||||
"""
|
||||
from core.workflow.nodes.agent_v2.entities import AgentV2NodeData
|
||||
|
||||
validated = AgentV2NodeData.model_validate(node_data.model_dump())
|
||||
|
||||
import logging as _logging
|
||||
_log = _logging.getLogger(__name__)
|
||||
|
||||
memory = None
|
||||
if validated.memory is not None:
|
||||
conversation_id = get_system_text(
|
||||
self.graph_runtime_state.variable_pool, SystemVariableKey.CONVERSATION_ID
|
||||
)
|
||||
_log.info("[AGENT_V2_MEMORY] memory_config=%s, conversation_id=%s", validated.memory, conversation_id)
|
||||
if conversation_id:
|
||||
from graphon.model_runtime.entities.model_entities import ModelType as _ModelType
|
||||
|
||||
from core.model_manager import ModelManager as _ModelManager
|
||||
|
||||
model_instance = _ModelManager.for_tenant(
|
||||
tenant_id=self._dify_context.tenant_id
|
||||
).get_model_instance(
|
||||
tenant_id=self._dify_context.tenant_id,
|
||||
provider=validated.model.provider,
|
||||
model_type=_ModelType.LLM,
|
||||
model=validated.model.name,
|
||||
)
|
||||
memory = fetch_memory(
|
||||
conversation_id=conversation_id,
|
||||
app_id=self._dify_context.app_id,
|
||||
node_data_memory=validated.memory,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
|
||||
return {
|
||||
"tool_manager": AgentV2ToolManager(
|
||||
tenant_id=self._dify_context.tenant_id,
|
||||
app_id=self._dify_context.app_id,
|
||||
),
|
||||
"event_adapter": AgentV2EventAdapter(),
|
||||
"memory": memory,
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _remap_llm_to_agent_v2(node_data: BaseNodeData) -> BaseNodeData:
|
||||
"""Transparently remap LLMNodeData to AgentV2NodeData.
|
||||
|
||||
Since AgentV2NodeData is a strict superset of LLMNodeData
|
||||
(same LLM fields + tools/iterations/strategy), the mapping is lossless.
|
||||
With tools=[], Agent V2 behaves identically to LLM Node.
|
||||
"""
|
||||
from core.workflow.nodes.agent_v2.entities import AGENT_V2_NODE_TYPE, AgentV2NodeData
|
||||
|
||||
data_dict = node_data.model_dump()
|
||||
data_dict["type"] = AGENT_V2_NODE_TYPE
|
||||
data_dict.setdefault("tools", [])
|
||||
data_dict.setdefault("max_iterations", 10)
|
||||
data_dict.setdefault("agent_strategy", "auto")
|
||||
return AgentV2NodeData.model_validate(data_dict)
|
||||
|
||||
@staticmethod
|
||||
def _validate_resolved_node_data(node_class: type[Node], node_data: BaseNodeData) -> BaseNodeData:
|
||||
"""
|
||||
|
||||
4
api/core/workflow/nodes/agent_v2/__init__.py
Normal file
4
api/core/workflow/nodes/agent_v2/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .entities import AgentV2NodeData
|
||||
from .node import AgentV2Node
|
||||
|
||||
__all__ = ["AgentV2Node", "AgentV2NodeData"]
|
||||
86
api/core/workflow/nodes/agent_v2/entities.py
Normal file
86
api/core/workflow/nodes/agent_v2/entities.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Agent V2 Node data model.
|
||||
|
||||
Merges LLM Node capabilities (prompt, memory, vision, context, structured output)
|
||||
with Agent capabilities (tool calling loop, strategy selection).
|
||||
When no tools are configured, behaves identically to an LLM Node.
|
||||
"""
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, Literal
|
||||
|
||||
from graphon.entities.base_node_data import BaseNodeData
|
||||
from graphon.model_runtime.entities import ImagePromptMessageContent
|
||||
from graphon.nodes.llm.entities import (
|
||||
LLMNodeChatModelMessage,
|
||||
LLMNodeCompletionModelPromptTemplate,
|
||||
ModelConfig,
|
||||
PromptConfig,
|
||||
)
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryConfig
|
||||
from core.tools.entities.tool_entities import ToolProviderType
|
||||
|
||||
AGENT_V2_NODE_TYPE = "agent-v2"
|
||||
|
||||
|
||||
class ContextConfig(BaseModel):
|
||||
enabled: bool
|
||||
variable_selector: list[str] | None = None
|
||||
|
||||
|
||||
class VisionConfigOptions(BaseModel):
|
||||
variable_selector: Sequence[str] = Field(default_factory=lambda: ["sys", "files"])
|
||||
detail: ImagePromptMessageContent.DETAIL = ImagePromptMessageContent.DETAIL.HIGH
|
||||
|
||||
|
||||
class VisionConfig(BaseModel):
|
||||
enabled: bool = False
|
||||
configs: VisionConfigOptions = Field(default_factory=VisionConfigOptions)
|
||||
|
||||
@field_validator("configs", mode="before")
|
||||
@classmethod
|
||||
def convert_none_configs(cls, v: Any):
|
||||
if v is None:
|
||||
return VisionConfigOptions()
|
||||
return v
|
||||
|
||||
|
||||
class ToolMetadata(BaseModel):
|
||||
"""Tool configuration for Agent V2 node."""
|
||||
|
||||
enabled: bool = True
|
||||
type: ToolProviderType = Field(..., description="Tool provider type: builtin, api, mcp, workflow")
|
||||
provider_name: str = Field(..., description="Tool provider name/identifier")
|
||||
tool_name: str = Field(..., description="Tool name")
|
||||
plugin_unique_identifier: str | None = Field(None)
|
||||
credential_id: str | None = Field(None)
|
||||
parameters: dict[str, Any] = Field(default_factory=dict)
|
||||
settings: dict[str, Any] = Field(default_factory=dict)
|
||||
extra: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
|
||||
class AgentV2NodeData(BaseNodeData):
|
||||
"""Agent V2 Node — LLM + Agent capabilities in a single workflow node."""
|
||||
|
||||
type: str = AGENT_V2_NODE_TYPE
|
||||
|
||||
# --- LLM capabilities (superset of LLMNodeData) ---
|
||||
model: ModelConfig
|
||||
prompt_template: Sequence[LLMNodeChatModelMessage] | LLMNodeCompletionModelPromptTemplate
|
||||
prompt_config: PromptConfig = Field(default_factory=PromptConfig)
|
||||
memory: MemoryConfig | None = None
|
||||
context: ContextConfig = Field(default_factory=lambda: ContextConfig(enabled=False))
|
||||
vision: VisionConfig = Field(default_factory=VisionConfig)
|
||||
structured_output: Mapping[str, Any] | None = None
|
||||
structured_output_switch_on: bool = Field(False, alias="structured_output_enabled")
|
||||
reasoning_format: Literal["separated", "tagged"] = "tagged"
|
||||
|
||||
# --- Agent capabilities ---
|
||||
tools: Sequence[ToolMetadata] = Field(default_factory=list)
|
||||
max_iterations: int = Field(default=10, ge=1, le=99)
|
||||
agent_strategy: Literal["auto", "function-calling", "chain-of-thought"] = "auto"
|
||||
|
||||
@property
|
||||
def tool_call_enabled(self) -> bool:
|
||||
return bool(self.tools) and any(t.enabled for t in self.tools)
|
||||
86
api/core/workflow/nodes/agent_v2/event_adapter.py
Normal file
86
api/core/workflow/nodes/agent_v2/event_adapter.py
Normal file
@@ -0,0 +1,86 @@
|
||||
"""Event adapter for Agent V2 Node.
|
||||
|
||||
Converts AgentPattern outputs (LLMResultChunk | AgentLog) into
|
||||
graphon NodeEventBase events consumable by the workflow engine.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
from graphon.model_runtime.entities import LLMResultChunk
|
||||
from graphon.node_events import (
|
||||
AgentLogEvent,
|
||||
NodeEventBase,
|
||||
StreamChunkEvent,
|
||||
)
|
||||
|
||||
from core.agent.entities import AgentLog, AgentResult
|
||||
|
||||
|
||||
class AgentV2EventAdapter:
|
||||
"""Converts agent strategy outputs into workflow node events."""
|
||||
|
||||
def process_strategy_outputs(
|
||||
self,
|
||||
outputs: Generator[LLMResultChunk | AgentLog, None, AgentResult],
|
||||
*,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
) -> Generator[NodeEventBase, None, AgentResult]:
|
||||
"""Process strategy generator outputs, yielding node events.
|
||||
|
||||
Returns the final AgentResult from the strategy.
|
||||
"""
|
||||
try:
|
||||
while True:
|
||||
item = next(outputs)
|
||||
if isinstance(item, AgentLog):
|
||||
yield self._convert_agent_log(item, node_id=node_id, node_execution_id=node_execution_id)
|
||||
elif isinstance(item, LLMResultChunk):
|
||||
pass
|
||||
except StopIteration as e:
|
||||
result: AgentResult = e.value
|
||||
return result
|
||||
|
||||
def _convert_agent_log(
|
||||
self,
|
||||
log: AgentLog,
|
||||
*,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
) -> AgentLogEvent:
|
||||
return AgentLogEvent(
|
||||
message_id=log.id,
|
||||
label=log.label,
|
||||
node_execution_id=node_execution_id,
|
||||
parent_id=log.parent_id,
|
||||
error=log.error,
|
||||
status=log.status.value,
|
||||
data=dict(log.data),
|
||||
metadata={k.value if hasattr(k, "value") else str(k): v for k, v in log.metadata.items()},
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
def _convert_llm_chunk(
|
||||
self,
|
||||
chunk: LLMResultChunk,
|
||||
*,
|
||||
node_id: str,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
content = ""
|
||||
if chunk.delta.message and chunk.delta.message.content:
|
||||
if isinstance(chunk.delta.message.content, str):
|
||||
content = chunk.delta.message.content
|
||||
elif isinstance(chunk.delta.message.content, list):
|
||||
from graphon.model_runtime.entities.message_entities import TextPromptMessageContent
|
||||
|
||||
for item in chunk.delta.message.content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
content += item.data
|
||||
|
||||
if content:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=content,
|
||||
)
|
||||
549
api/core/workflow/nodes/agent_v2/node.py
Normal file
549
api/core/workflow/nodes/agent_v2/node.py
Normal file
@@ -0,0 +1,549 @@
|
||||
"""Agent V2 Workflow Node.
|
||||
|
||||
A unified workflow node that combines LLM capabilities with agent tool-calling.
|
||||
When tools are configured, runs an FC/ReAct loop via StrategyFactory.
|
||||
When no tools are present, behaves as a single-shot LLM invocation.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal
|
||||
|
||||
from graphon.enums import WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from graphon.model_runtime.entities import (
|
||||
AssistantPromptMessage,
|
||||
LLMResultChunk,
|
||||
PromptMessage,
|
||||
SystemPromptMessage,
|
||||
TextPromptMessageContent,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from graphon.node_events import (
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
|
||||
from core.agent.entities import AgentEntity, ExecutionContext
|
||||
from core.agent.patterns import StrategyFactory
|
||||
from core.app.entities.app_invoke_entities import DIFY_RUN_CONTEXT_KEY, DifyRunContext
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.workflow.system_variables import SystemVariableKey, get_system_text
|
||||
|
||||
from .entities import AGENT_V2_NODE_TYPE, AgentV2NodeData
|
||||
from .event_adapter import AgentV2EventAdapter
|
||||
from .tool_manager import AgentV2ToolManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from graphon.entities import GraphInitParams
|
||||
from graphon.entities.graph_config import NodeConfigDict
|
||||
from graphon.runtime import GraphRuntimeState
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_THINK_PATTERN = re.compile(r"<think[^>]*>(.*?)</think>", re.IGNORECASE | re.DOTALL)
|
||||
|
||||
|
||||
class AgentV2Node(Node[AgentV2NodeData]):
|
||||
node_type = AGENT_V2_NODE_TYPE
|
||||
|
||||
_tool_manager: AgentV2ToolManager
|
||||
_event_adapter: AgentV2EventAdapter
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: NodeConfigDict,
|
||||
graph_init_params: GraphInitParams,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
*,
|
||||
tool_manager: AgentV2ToolManager,
|
||||
event_adapter: AgentV2EventAdapter,
|
||||
memory: Any | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
self._tool_manager = tool_manager
|
||||
self._event_adapter = event_adapter
|
||||
self._memory = memory
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
return {
|
||||
"type": AGENT_V2_NODE_TYPE,
|
||||
"config": {
|
||||
"prompt_templates": {
|
||||
"chat_model": {
|
||||
"prompts": [
|
||||
{
|
||||
"role": "system",
|
||||
"text": "You are a helpful AI assistant.",
|
||||
"edition_type": "basic",
|
||||
}
|
||||
]
|
||||
},
|
||||
"completion_model": {
|
||||
"conversation_histories_role": {
|
||||
"user_prefix": "Human",
|
||||
"assistant_prefix": "Assistant",
|
||||
},
|
||||
"prompt": {
|
||||
"text": "{{#sys.query#}}",
|
||||
"edition_type": "basic",
|
||||
},
|
||||
},
|
||||
},
|
||||
"agent_strategy": "auto",
|
||||
"max_iterations": 10,
|
||||
},
|
||||
}
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
dify_ctx = DifyRunContext.model_validate(self.require_run_context_value(DIFY_RUN_CONTEXT_KEY))
|
||||
|
||||
try:
|
||||
model_instance = self._fetch_model_instance(dify_ctx)
|
||||
except Exception as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error=f"Failed to load model: {e}",
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
prompt_messages = self._build_prompt_messages(dify_ctx)
|
||||
|
||||
if self.node_data.tool_call_enabled:
|
||||
yield from self._run_with_tools(model_instance, prompt_messages, dify_ctx)
|
||||
else:
|
||||
yield from self._run_without_tools(model_instance, prompt_messages, dify_ctx)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# No-tools path: single LLM invocation (LLM Node equivalent)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_without_tools(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
dify_ctx: DifyRunContext,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
try:
|
||||
result_chunks: Generator[LLMResultChunk, None, None] = model_instance.invoke_llm(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=self.node_data.model.completion_params,
|
||||
tools=[],
|
||||
stop=[],
|
||||
stream=True,
|
||||
callbacks=[],
|
||||
)
|
||||
|
||||
full_text = ""
|
||||
reasoning_content = ""
|
||||
usage: LLMUsage | None = None
|
||||
finish_reason: str | None = None
|
||||
|
||||
for chunk in result_chunks:
|
||||
chunk_text = self._extract_chunk_text(chunk)
|
||||
if chunk_text:
|
||||
full_text += chunk_text
|
||||
|
||||
if chunk.delta.usage:
|
||||
usage = chunk.delta.usage
|
||||
if chunk.delta.finish_reason:
|
||||
finish_reason = chunk.delta.finish_reason
|
||||
|
||||
if self.node_data.reasoning_format == "separated":
|
||||
full_text, reasoning_content = self._separate_reasoning(full_text)
|
||||
|
||||
metadata = {}
|
||||
if usage:
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
|
||||
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency
|
||||
self.graph_runtime_state.add_tokens(usage.total_tokens)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={},
|
||||
outputs={
|
||||
"text": full_text,
|
||||
"reasoning_content": reasoning_content,
|
||||
"finish_reason": finish_reason or "stop",
|
||||
},
|
||||
metadata=metadata,
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Agent V2 LLM invocation failed")
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Tools path: agent loop via StrategyFactory
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _run_with_tools(
|
||||
self,
|
||||
model_instance: ModelInstance,
|
||||
prompt_messages: list[PromptMessage],
|
||||
dify_ctx: DifyRunContext,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
try:
|
||||
tool_instances = self._tool_manager.prepare_tool_instances(
|
||||
list(self.node_data.tools),
|
||||
)
|
||||
|
||||
model_features = self._get_model_features(model_instance)
|
||||
|
||||
context = ExecutionContext(
|
||||
user_id=dify_ctx.user_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
conversation_id=get_system_text(
|
||||
self.graph_runtime_state.variable_pool,
|
||||
SystemVariableKey.CONVERSATION_ID,
|
||||
),
|
||||
)
|
||||
|
||||
agent_strategy_enum = self._map_strategy_config(self.node_data.agent_strategy)
|
||||
|
||||
strategy = StrategyFactory.create_strategy(
|
||||
model_features=model_features,
|
||||
model_instance=model_instance,
|
||||
tools=tool_instances,
|
||||
files=[],
|
||||
max_iterations=self.node_data.max_iterations,
|
||||
context=context,
|
||||
agent_strategy=agent_strategy_enum,
|
||||
tool_invoke_hook=self._tool_manager.create_workflow_tool_invoke_hook(context),
|
||||
)
|
||||
|
||||
outputs_gen = strategy.run(
|
||||
prompt_messages=prompt_messages,
|
||||
model_parameters=self.node_data.model.completion_params,
|
||||
stop=[],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
result = yield from self._event_adapter.process_strategy_outputs(
|
||||
outputs_gen,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
)
|
||||
|
||||
if result.usage and hasattr(result.usage, "total_tokens"):
|
||||
self.graph_runtime_state.add_tokens(result.usage.total_tokens)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
inputs={},
|
||||
outputs={
|
||||
"text": result.text,
|
||||
"finish_reason": result.finish_reason or "stop",
|
||||
},
|
||||
metadata=self._build_usage_metadata(result.usage),
|
||||
)
|
||||
)
|
||||
except Exception as e:
|
||||
logger.exception("Agent V2 tool execution failed")
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error=str(e),
|
||||
)
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _fetch_model_instance(self, dify_ctx: DifyRunContext) -> ModelInstance:
|
||||
model_config = self.node_data.model
|
||||
model_manager = ModelManager.for_tenant(tenant_id=dify_ctx.tenant_id)
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
provider=model_config.provider,
|
||||
model_type=ModelType.LLM,
|
||||
model=model_config.name,
|
||||
)
|
||||
return model_instance
|
||||
|
||||
def _build_prompt_messages(self, dify_ctx: DifyRunContext) -> list[PromptMessage]:
|
||||
"""Build prompt messages from the node's prompt_template, resolving variables.
|
||||
|
||||
Handles: variable references ({{#node.var#}}), context injection ({{#context#}}),
|
||||
Jinja2 templates, and memory (conversation history).
|
||||
"""
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
messages: list[PromptMessage] = []
|
||||
|
||||
context_str = self._build_context_string(variable_pool)
|
||||
|
||||
template = self.node_data.prompt_template
|
||||
if isinstance(template, Sequence) and not isinstance(template, str):
|
||||
for msg_template in template:
|
||||
role = msg_template.role.value if hasattr(msg_template.role, "value") else str(msg_template.role)
|
||||
text = msg_template.text or ""
|
||||
jinja2_text = getattr(msg_template, "jinja2_text", None)
|
||||
|
||||
if jinja2_text:
|
||||
content = self._render_jinja2(jinja2_text, variable_pool, context_str)
|
||||
else:
|
||||
content = self._resolve_variable_template(text, variable_pool)
|
||||
if context_str:
|
||||
content = content.replace("{{#context#}}", context_str)
|
||||
|
||||
if role == "system":
|
||||
messages.append(SystemPromptMessage(content=content))
|
||||
elif role == "user":
|
||||
messages.append(UserPromptMessage(content=content))
|
||||
elif role == "assistant":
|
||||
messages.append(AssistantPromptMessage(content=content))
|
||||
else:
|
||||
text_content = getattr(template, "text", "") or ""
|
||||
resolved = self._resolve_variable_template(text_content, variable_pool)
|
||||
if context_str:
|
||||
resolved = resolved.replace("{{#context#}}", context_str)
|
||||
messages.append(UserPromptMessage(content=resolved))
|
||||
|
||||
if self._memory is not None:
|
||||
try:
|
||||
window_size = None
|
||||
if self.node_data.memory and hasattr(self.node_data.memory, "window"):
|
||||
w = self.node_data.memory.window
|
||||
if w and w.enabled:
|
||||
window_size = w.size
|
||||
|
||||
history = self._memory.get_history_prompt_messages(
|
||||
max_token_limit=2000,
|
||||
message_limit=window_size or 50,
|
||||
)
|
||||
history_list = list(history)
|
||||
logger.info("[AGENT_V2_MEMORY] Loaded %d history messages from memory", len(history_list))
|
||||
if history_list:
|
||||
system_msgs = [m for m in messages if isinstance(m, SystemPromptMessage)]
|
||||
other_msgs = [m for m in messages if not isinstance(m, SystemPromptMessage)]
|
||||
messages = system_msgs + history_list + other_msgs
|
||||
logger.info("[AGENT_V2_MEMORY] Total prompt messages after memory injection: %d", len(messages))
|
||||
except Exception:
|
||||
logger.warning("Failed to load memory for agent-v2 node", exc_info=True)
|
||||
else:
|
||||
logger.info("[AGENT_V2_MEMORY] No memory injected (self._memory is None)")
|
||||
|
||||
return messages
|
||||
|
||||
def _load_memory_messages(self, dify_ctx: DifyRunContext) -> list[PromptMessage]:
|
||||
"""Load conversation history from memory."""
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from models.model import Conversation
|
||||
|
||||
conversation_id = get_system_text(
|
||||
self.graph_runtime_state.variable_pool,
|
||||
SystemVariableKey.CONVERSATION_ID,
|
||||
)
|
||||
if not conversation_id:
|
||||
return []
|
||||
|
||||
try:
|
||||
from sqlalchemy import select
|
||||
from extensions.ext_database import db
|
||||
|
||||
stmt = select(Conversation).where(Conversation.id == conversation_id)
|
||||
conversation = db.session.scalar(stmt)
|
||||
if not conversation:
|
||||
return []
|
||||
|
||||
model_instance = self._fetch_model_instance(dify_ctx)
|
||||
memory = TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
window_size = None
|
||||
if self.node_data.memory and hasattr(self.node_data.memory, "window"):
|
||||
window = self.node_data.memory.window
|
||||
if window and window.enabled:
|
||||
window_size = window.size
|
||||
|
||||
history = memory.get_history_prompt_messages(
|
||||
max_token_limit=2000,
|
||||
message_limit=window_size or 50,
|
||||
)
|
||||
return list(history)
|
||||
except Exception:
|
||||
logger.warning("Failed to load memory for agent-v2 node", exc_info=True)
|
||||
return []
|
||||
|
||||
def _build_context_string(self, variable_pool: Any) -> str:
|
||||
"""Build context string from knowledge retrieval node output."""
|
||||
ctx_config = self.node_data.context
|
||||
if not ctx_config or not ctx_config.enabled:
|
||||
return ""
|
||||
selector = getattr(ctx_config, "variable_selector", None)
|
||||
if not selector:
|
||||
return ""
|
||||
try:
|
||||
value = variable_pool.get(selector)
|
||||
if value is None:
|
||||
return ""
|
||||
raw = value.value if hasattr(value, "value") else value
|
||||
if isinstance(raw, str):
|
||||
return raw
|
||||
if isinstance(raw, list):
|
||||
parts = []
|
||||
for item in raw:
|
||||
if isinstance(item, str):
|
||||
parts.append(item)
|
||||
elif isinstance(item, dict):
|
||||
if "content" in item:
|
||||
parts.append(item["content"])
|
||||
elif "text" in item:
|
||||
parts.append(item["text"])
|
||||
return "\n".join(parts)
|
||||
return str(raw)
|
||||
except Exception:
|
||||
logger.warning("Failed to build context string", exc_info=True)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _render_jinja2(template: str, variable_pool: Any, context_str: str = "") -> str:
|
||||
"""Render a Jinja2 template with variables from the pool."""
|
||||
try:
|
||||
from jinja2 import Environment, BaseLoader
|
||||
env = Environment(loader=BaseLoader(), autoescape=False)
|
||||
tpl = env.from_string(template)
|
||||
|
||||
parser = VariableTemplateParser(template)
|
||||
selectors = parser.extract_variable_selectors()
|
||||
variables: dict[str, Any] = {}
|
||||
for selector in selectors:
|
||||
value = variable_pool.get(selector.value_selector)
|
||||
if value is not None:
|
||||
variables[selector.variable] = value.text if hasattr(value, "text") else str(value)
|
||||
else:
|
||||
variables[selector.variable] = ""
|
||||
variables["context"] = context_str
|
||||
return tpl.render(**variables)
|
||||
except Exception:
|
||||
logger.warning("Jinja2 rendering failed, falling back to plain text", exc_info=True)
|
||||
return template
|
||||
|
||||
@staticmethod
|
||||
def _resolve_variable_template(template: str, variable_pool: Any) -> str:
|
||||
"""Resolve {{#node.var#}} references in a template string using the variable pool."""
|
||||
parser = VariableTemplateParser(template)
|
||||
selectors = parser.extract_variable_selectors()
|
||||
if not selectors:
|
||||
return template
|
||||
|
||||
inputs: dict[str, Any] = {}
|
||||
for selector in selectors:
|
||||
value = variable_pool.get(selector.value_selector)
|
||||
if value is not None:
|
||||
inputs[selector.variable] = value.text if hasattr(value, "text") else str(value)
|
||||
else:
|
||||
inputs[selector.variable] = ""
|
||||
|
||||
return parser.format(inputs)
|
||||
|
||||
def _get_model_features(self, model_instance: ModelInstance) -> list[ModelFeature]:
|
||||
try:
|
||||
model_schema = model_instance.model_type_instance.get_model_schema(
|
||||
model_instance.model_name,
|
||||
model_instance.credentials,
|
||||
)
|
||||
return list(model_schema.features) if model_schema and model_schema.features else []
|
||||
except Exception:
|
||||
logger.warning("Failed to get model features, assuming none")
|
||||
return []
|
||||
|
||||
@staticmethod
|
||||
def _build_usage_metadata(usage: Any) -> dict:
|
||||
metadata: dict = {}
|
||||
if usage and hasattr(usage, "total_tokens"):
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
|
||||
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
|
||||
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = getattr(usage, "currency", "USD")
|
||||
return metadata
|
||||
|
||||
@staticmethod
|
||||
def _map_strategy_config(
|
||||
config_value: Literal["auto", "function-calling", "chain-of-thought"],
|
||||
) -> AgentEntity.Strategy | None:
|
||||
mapping = {
|
||||
"function-calling": AgentEntity.Strategy.FUNCTION_CALLING,
|
||||
"chain-of-thought": AgentEntity.Strategy.CHAIN_OF_THOUGHT,
|
||||
}
|
||||
return mapping.get(config_value)
|
||||
|
||||
@staticmethod
|
||||
def _extract_chunk_text(chunk: LLMResultChunk) -> str:
|
||||
if not chunk.delta.message or not chunk.delta.message.content:
|
||||
return ""
|
||||
content = chunk.delta.message.content
|
||||
if isinstance(content, str):
|
||||
return content
|
||||
if isinstance(content, list):
|
||||
parts = []
|
||||
for item in content:
|
||||
if isinstance(item, TextPromptMessageContent):
|
||||
parts.append(item.data)
|
||||
return "".join(parts)
|
||||
return ""
|
||||
|
||||
@staticmethod
|
||||
def _separate_reasoning(text: str) -> tuple[str, str]:
|
||||
"""Extract <think> blocks from text, return (clean_text, reasoning_content)."""
|
||||
reasoning_parts = _THINK_PATTERN.findall(text)
|
||||
reasoning_content = "\n".join(reasoning_parts)
|
||||
clean_text = _THINK_PATTERN.sub("", text).strip()
|
||||
return clean_text, reasoning_content
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: AgentV2NodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
result: dict[str, list[str]] = {}
|
||||
|
||||
if isinstance(node_data.prompt_template, Sequence) and not isinstance(node_data.prompt_template, str):
|
||||
for msg in node_data.prompt_template:
|
||||
text = msg.text or ""
|
||||
jinja2_text = getattr(msg, "jinja2_text", None)
|
||||
content = jinja2_text or text
|
||||
selectors = VariableTemplateParser(content).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = list(selector.value_selector)
|
||||
else:
|
||||
text_content = getattr(node_data.prompt_template, "text", "") or ""
|
||||
selectors = VariableTemplateParser(text_content).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = list(selector.value_selector)
|
||||
|
||||
return {f"{node_id}.{key}": value for key, value in result.items()}
|
||||
129
api/core/workflow/nodes/agent_v2/tool_manager.py
Normal file
129
api/core/workflow/nodes/agent_v2/tool_manager.py
Normal file
@@ -0,0 +1,129 @@
|
||||
"""Tool management for Agent V2 Node.
|
||||
|
||||
Handles tool instance preparation, conversion to LLM-consumable format,
|
||||
and creation of workflow-compatible tool invoke hooks.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.agent.entities import AgentToolEntity, ExecutionContext
|
||||
from core.agent.patterns.base import ToolInvokeHook
|
||||
from core.callback_handler.workflow_tool_callback_handler import DifyWorkflowCallbackHandler
|
||||
from core.tools.__base.tool import Tool
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta, ToolInvokeMessage
|
||||
from core.tools.tool_engine import ToolEngine
|
||||
from core.tools.tool_manager import ToolManager
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .entities import ToolMetadata
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AgentV2ToolManager:
|
||||
"""Manages tool lifecycle for Agent V2 node execution."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
) -> None:
|
||||
self._tenant_id = tenant_id
|
||||
self._app_id = app_id
|
||||
|
||||
def prepare_tool_instances(
|
||||
self,
|
||||
tools_config: list[ToolMetadata],
|
||||
) -> list[Tool]:
|
||||
"""Convert tool metadata configs into runtime Tool instances."""
|
||||
tool_instances: list[Tool] = []
|
||||
for tool_meta in tools_config:
|
||||
if not tool_meta.enabled:
|
||||
continue
|
||||
try:
|
||||
processed_settings = {}
|
||||
for key, value in tool_meta.settings.items():
|
||||
if isinstance(value, dict) and "value" in value and isinstance(value["value"], dict):
|
||||
if "type" in value["value"] and "value" in value["value"]:
|
||||
processed_settings[key] = value["value"]
|
||||
else:
|
||||
processed_settings[key] = value
|
||||
else:
|
||||
processed_settings[key] = value
|
||||
|
||||
merged_parameters = {**tool_meta.parameters, **processed_settings}
|
||||
|
||||
agent_tool = AgentToolEntity(
|
||||
provider_id=tool_meta.provider_name,
|
||||
provider_type=tool_meta.type,
|
||||
tool_name=tool_meta.tool_name,
|
||||
tool_parameters=merged_parameters,
|
||||
plugin_unique_identifier=tool_meta.plugin_unique_identifier,
|
||||
credential_id=tool_meta.credential_id,
|
||||
)
|
||||
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=self._app_id,
|
||||
agent_tool=agent_tool,
|
||||
)
|
||||
tool_instances.append(tool_runtime)
|
||||
except Exception:
|
||||
logger.warning("Failed to prepare tool %s/%s, skipping", tool_meta.provider_name, tool_meta.tool_name, exc_info=True)
|
||||
continue
|
||||
|
||||
return tool_instances
|
||||
|
||||
def create_workflow_tool_invoke_hook(
|
||||
self,
|
||||
context: ExecutionContext,
|
||||
workflow_call_depth: int = 0,
|
||||
) -> ToolInvokeHook:
|
||||
"""Create a ToolInvokeHook for workflow context."""
|
||||
|
||||
def hook(
|
||||
tool: Tool,
|
||||
tool_args: dict[str, Any],
|
||||
tool_name: str,
|
||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||
return self._invoke_tool_directly(tool, tool_args, tool_name, context, workflow_call_depth)
|
||||
|
||||
return hook
|
||||
|
||||
def _invoke_tool_directly(
|
||||
self,
|
||||
tool: Tool,
|
||||
tool_args: dict[str, Any],
|
||||
tool_name: str,
|
||||
context: ExecutionContext,
|
||||
workflow_call_depth: int,
|
||||
) -> tuple[str, list[str], ToolInvokeMeta]:
|
||||
"""Invoke tool directly via ToolEngine."""
|
||||
tool_response = ToolEngine.generic_invoke(
|
||||
tool=tool,
|
||||
tool_parameters=tool_args,
|
||||
user_id=context.user_id or "",
|
||||
workflow_tool_callback=DifyWorkflowCallbackHandler(),
|
||||
workflow_call_depth=workflow_call_depth,
|
||||
app_id=context.app_id,
|
||||
conversation_id=context.conversation_id,
|
||||
)
|
||||
|
||||
response_content = ""
|
||||
for response in tool_response:
|
||||
if response.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(response.message, ToolInvokeMessage.TextMessage)
|
||||
response_content += response.message.text
|
||||
elif response.type == ToolInvokeMessage.MessageType.JSON:
|
||||
if isinstance(response.message, ToolInvokeMessage.JsonMessage):
|
||||
response_content += json.dumps(response.message.json_object, ensure_ascii=False)
|
||||
elif response.type == ToolInvokeMessage.MessageType.LINK:
|
||||
if isinstance(response.message, ToolInvokeMessage.TextMessage):
|
||||
response_content += f"[Link: {response.message.text}]"
|
||||
|
||||
return response_content, [], ToolInvokeMeta.empty()
|
||||
41
api/dify_graph/entities/tool_entities.py
Normal file
41
api/dify_graph/entities/tool_entities.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from graphon.file import File
|
||||
|
||||
|
||||
class ToolResultStatus(StrEnum):
|
||||
SUCCESS = "success"
|
||||
ERROR = "error"
|
||||
|
||||
|
||||
class ToolCall(BaseModel):
|
||||
id: str | None = Field(default=None, description="Unique identifier for this tool call")
|
||||
name: str | None = Field(default=None, description="Name of the tool being called")
|
||||
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
|
||||
icon: str | dict | None = Field(default=None, description="Icon of the tool")
|
||||
icon_dark: str | dict | None = Field(default=None, description="Dark theme icon of the tool")
|
||||
|
||||
|
||||
class ToolResult(BaseModel):
|
||||
id: str | None = Field(default=None, description="Identifier of the tool call this result belongs to")
|
||||
name: str | None = Field(default=None, description="Name of the tool")
|
||||
output: str | None = Field(default=None, description="Tool output text, error or success message")
|
||||
files: list[str] = Field(default_factory=list, description="File produced by tool")
|
||||
status: ToolResultStatus | None = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
|
||||
elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool")
|
||||
icon: str | dict[str, Any] | None = Field(default=None, description="Icon of the tool")
|
||||
icon_dark: str | dict[str, Any] | None = Field(default=None, description="Dark theme icon of the tool")
|
||||
provider: str | None = Field(default=None, description="Tool provider identifier")
|
||||
|
||||
|
||||
class ToolCallResult(BaseModel):
|
||||
id: str | None = Field(default=None, description="Identifier for the tool call")
|
||||
name: str | None = Field(default=None, description="Name of the tool")
|
||||
arguments: str | None = Field(default=None, description="Accumulated tool arguments JSON")
|
||||
output: str | None = Field(default=None, description="Tool output text, error or success message")
|
||||
files: list[File] = Field(default_factory=list, description="File produced by tool")
|
||||
status: ToolResultStatus = Field(default=ToolResultStatus.SUCCESS, description="Tool execution status")
|
||||
elapsed_time: float | None = Field(default=None, description="Elapsed seconds spent executing the tool")
|
||||
935
api/dify_graph/nodes/agent/agent_node.py
Normal file
935
api/dify_graph/nodes/agent/agent_node.py
Normal file
@@ -0,0 +1,935 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.memory.base import BaseMemory
|
||||
from core.memory.node_token_buffer_memory import NodeTokenBufferMemory
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
from core.prompt.entities.advanced_prompt_entities import MemoryMode
|
||||
from core.plugin.impl.model_runtime_factory import create_plugin_provider_manager
|
||||
from core.tools.entities.tool_entities import (
|
||||
ToolIdentity,
|
||||
ToolInvokeMessage,
|
||||
ToolParameter,
|
||||
ToolProviderType,
|
||||
)
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.tools.utils.message_transformer import ToolFileMessageTransformer
|
||||
from core.workflow.nodes.agent.entities import AgentNodeData, AgentOldVersionModelFeatures, ParamsAutoGenerated
|
||||
from core.workflow.nodes.agent.exceptions import (
|
||||
AgentInputTypeError,
|
||||
AgentInvocationError,
|
||||
AgentMessageTransformError,
|
||||
AgentNodeError,
|
||||
AgentVariableNotFoundError,
|
||||
AgentVariableTypeError,
|
||||
ToolFileNotFoundError,
|
||||
)
|
||||
from graphon.enums import (
|
||||
BuiltinNodeTypes,
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
WorkflowNodeExecutionMetadataKey,
|
||||
WorkflowNodeExecutionStatus,
|
||||
)
|
||||
from graphon.file import File, FileTransferMethod
|
||||
from graphon.model_runtime.entities.llm_entities import LLMUsage, LLMUsageMetadata
|
||||
from graphon.model_runtime.entities.message_entities import (
|
||||
AssistantPromptMessage,
|
||||
PromptMessage,
|
||||
ToolPromptMessage,
|
||||
UserPromptMessage,
|
||||
)
|
||||
from graphon.model_runtime.entities.model_entities import AIModelEntity, ModelType
|
||||
from graphon.model_runtime.utils.encoders import jsonable_encoder
|
||||
from graphon.node_events import (
|
||||
AgentLogEvent,
|
||||
NodeEventBase,
|
||||
NodeRunResult,
|
||||
StreamChunkEvent,
|
||||
StreamCompletedEvent,
|
||||
)
|
||||
from graphon.nodes.base.node import Node
|
||||
from graphon.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from graphon.runtime import VariablePool
|
||||
from graphon.variables.segments import ArrayFileSegment, StringSegment
|
||||
from core.app.file_access.controller import DatabaseFileAccessController
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
|
||||
_file_access_controller = DatabaseFileAccessController()
|
||||
from models import ToolFile
|
||||
from models.model import Conversation
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.agent.strategy.plugin import PluginAgentStrategy
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
|
||||
class AgentNode(Node[AgentNodeData]):
|
||||
"""
|
||||
Agent Node
|
||||
"""
|
||||
|
||||
node_type = BuiltinNodeTypes.AGENT
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
|
||||
try:
|
||||
strategy = get_plugin_agent_strategy(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
agent_strategy_provider_name=self.node_data.agent_strategy_provider_name,
|
||||
agent_strategy_name=self.node_data.agent_strategy_name,
|
||||
)
|
||||
except Exception as e:
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs={},
|
||||
error=f"Failed to get agent strategy: {str(e)}",
|
||||
),
|
||||
)
|
||||
return
|
||||
|
||||
agent_parameters = strategy.get_parameters()
|
||||
|
||||
# get parameters
|
||||
parameters = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
strategy=strategy,
|
||||
)
|
||||
parameters_for_log = self._generate_agent_parameters(
|
||||
agent_parameters=agent_parameters,
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
node_data=self.node_data,
|
||||
for_log=True,
|
||||
strategy=strategy,
|
||||
)
|
||||
credentials = self._generate_credentials(parameters=parameters)
|
||||
|
||||
# get conversation id
|
||||
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
|
||||
try:
|
||||
message_stream = strategy.invoke(
|
||||
params=parameters,
|
||||
user_id=dify_ctx.user_id,
|
||||
app_id=dify_ctx.app_id,
|
||||
conversation_id=conversation_id.text if conversation_id else None,
|
||||
credentials=credentials,
|
||||
)
|
||||
except Exception as e:
|
||||
error = AgentInvocationError(f"Failed to invoke agent: {str(e)}", original_error=e)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(error),
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
# Fetch memory for node memory saving
|
||||
memory = self._fetch_memory_for_save()
|
||||
|
||||
try:
|
||||
yield from self._transform_message(
|
||||
messages=message_stream,
|
||||
tool_info={
|
||||
"icon": self.agent_strategy_icon,
|
||||
"agent_strategy": self.node_data.agent_strategy_name,
|
||||
},
|
||||
parameters_for_log=parameters_for_log,
|
||||
user_id=dify_ctx.user_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
node_type=self.node_type,
|
||||
node_id=self._node_id,
|
||||
node_execution_id=self.id,
|
||||
memory=memory,
|
||||
)
|
||||
except PluginDaemonClientSideError as e:
|
||||
transform_error = AgentMessageTransformError(
|
||||
f"Failed to transform agent message: {str(e)}", original_error=e
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
inputs=parameters_for_log,
|
||||
error=str(transform_error),
|
||||
)
|
||||
)
|
||||
|
||||
def _generate_agent_parameters(
|
||||
self,
|
||||
*,
|
||||
agent_parameters: Sequence[AgentStrategyParameter],
|
||||
variable_pool: VariablePool,
|
||||
node_data: AgentNodeData,
|
||||
for_log: bool = False,
|
||||
strategy: PluginAgentStrategy,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Generate parameters based on the given tool parameters, variable pool, and node data.
|
||||
|
||||
Args:
|
||||
agent_parameters (Sequence[AgentParameter]): The list of agent parameters.
|
||||
variable_pool (VariablePool): The variable pool containing the variables.
|
||||
node_data (AgentNodeData): The data associated with the agent node.
|
||||
|
||||
Returns:
|
||||
Mapping[str, Any]: A dictionary containing the generated parameters.
|
||||
|
||||
"""
|
||||
agent_parameters_dictionary = {parameter.name: parameter for parameter in agent_parameters}
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in node_data.agent_parameters:
|
||||
parameter = agent_parameters_dictionary.get(parameter_name)
|
||||
if not parameter:
|
||||
result[parameter_name] = None
|
||||
continue
|
||||
agent_input = node_data.agent_parameters[parameter_name]
|
||||
match agent_input.type:
|
||||
case "variable":
|
||||
variable = variable_pool.get(agent_input.value) # type: ignore
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(agent_input.value))
|
||||
parameter_value = variable.value
|
||||
case "mixed" | "constant":
|
||||
# variable_pool.convert_template expects a string template,
|
||||
# but if passing a dict, convert to JSON string first before rendering
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.dumps(agent_input.value, ensure_ascii=False)
|
||||
else:
|
||||
parameter_value = str(agent_input.value)
|
||||
except TypeError:
|
||||
parameter_value = str(agent_input.value)
|
||||
segment_group = variable_pool.convert_template(parameter_value)
|
||||
parameter_value = segment_group.log if for_log else segment_group.text
|
||||
# variable_pool.convert_template returns a string,
|
||||
# so we need to convert it back to a dictionary
|
||||
try:
|
||||
if not isinstance(agent_input.value, str):
|
||||
parameter_value = json.loads(parameter_value)
|
||||
except json.JSONDecodeError:
|
||||
parameter_value = parameter_value
|
||||
case _:
|
||||
raise AgentInputTypeError(agent_input.type)
|
||||
value = parameter_value
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
value = [tool for tool in value if tool.get("enabled", False)]
|
||||
value = self._filter_mcp_type_tool(strategy, value)
|
||||
for tool in value:
|
||||
if "schemas" in tool:
|
||||
tool.pop("schemas")
|
||||
parameters = tool.get("parameters", {})
|
||||
if all(isinstance(v, dict) for _, v in parameters.items()):
|
||||
params = {}
|
||||
for key, param in parameters.items():
|
||||
if param.get("auto", ParamsAutoGenerated.OPEN) in (
|
||||
ParamsAutoGenerated.CLOSE,
|
||||
0,
|
||||
):
|
||||
value_param = param.get("value", {})
|
||||
if value_param and value_param.get("type", "") == "variable":
|
||||
variable_selector = value_param.get("value")
|
||||
if not variable_selector:
|
||||
raise ValueError("Variable selector is missing for a variable-type parameter.")
|
||||
|
||||
variable = variable_pool.get(variable_selector)
|
||||
if variable is None:
|
||||
raise AgentVariableNotFoundError(str(variable_selector))
|
||||
|
||||
params[key] = variable.value
|
||||
else:
|
||||
params[key] = value_param.get("value", "") if value_param is not None else None
|
||||
else:
|
||||
params[key] = None
|
||||
parameters = params
|
||||
tool["settings"] = {k: v.get("value", None) for k, v in tool.get("settings", {}).items()}
|
||||
tool["parameters"] = parameters
|
||||
|
||||
if not for_log:
|
||||
if parameter.type == "array[tools]":
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
|
||||
parameters = {**parameters, **setting_params}
|
||||
entity = AgentToolEntity(
|
||||
provider_id=tool.get("provider_name", ""),
|
||||
provider_type=provider_type,
|
||||
tool_name=tool.get("tool_name", ""),
|
||||
tool_parameters=parameters,
|
||||
plugin_unique_identifier=tool.get("plugin_unique_identifier", None),
|
||||
credential_id=tool.get("credential_id", None),
|
||||
)
|
||||
|
||||
extra = tool.get("extra", {})
|
||||
|
||||
# This is an issue that caused problems before.
|
||||
# Logically, we shouldn't use the node_data.version field for judgment
|
||||
# But for backward compatibility with historical data
|
||||
# this version field judgment is still preserved here.
|
||||
runtime_variable_pool: VariablePool | None = None
|
||||
if node_data.version != "1" or node_data.tool_node_version is not None:
|
||||
runtime_variable_pool = variable_pool
|
||||
dify_ctx = self.require_dify_context()
|
||||
tool_runtime = ToolManager.get_agent_tool_runtime(
|
||||
dify_ctx.tenant_id,
|
||||
dify_ctx.app_id,
|
||||
entity,
|
||||
dify_ctx.user_id,
|
||||
dify_ctx.invoke_from,
|
||||
runtime_variable_pool,
|
||||
)
|
||||
if tool_runtime.entity.description:
|
||||
tool_runtime.entity.description.llm = (
|
||||
extra.get("description", "") or tool_runtime.entity.description.llm
|
||||
)
|
||||
for tool_runtime_params in tool_runtime.entity.parameters:
|
||||
tool_runtime_params.form = (
|
||||
ToolParameter.ToolParameterForm.FORM
|
||||
if tool_runtime_params.name in manual_input_params
|
||||
else tool_runtime_params.form
|
||||
)
|
||||
manual_input_value = {}
|
||||
if tool_runtime.entity.parameters:
|
||||
manual_input_value = {
|
||||
key: value for key, value in parameters.items() if key in manual_input_params
|
||||
}
|
||||
runtime_parameters = {
|
||||
**tool_runtime.runtime.runtime_parameters,
|
||||
**manual_input_value,
|
||||
}
|
||||
tool_value.append(
|
||||
{
|
||||
**tool_runtime.entity.model_dump(mode="json"),
|
||||
"runtime_parameters": runtime_parameters,
|
||||
"credential_id": tool.get("credential_id", None),
|
||||
"provider_type": provider_type.value,
|
||||
}
|
||||
)
|
||||
value = tool_value
|
||||
if parameter.type == AgentStrategyParameter.AgentStrategyParameterType.MODEL_SELECTOR:
|
||||
value = cast(dict[str, Any], value)
|
||||
model_instance, model_schema = self._fetch_model(value)
|
||||
# memory config
|
||||
history_prompt_messages = []
|
||||
if node_data.memory:
|
||||
memory = self._fetch_memory(model_instance)
|
||||
if memory:
|
||||
prompt_messages = memory.get_history_prompt_messages(
|
||||
message_limit=node_data.memory.window.size or None
|
||||
)
|
||||
history_prompt_messages = [
|
||||
prompt_message.model_dump(mode="json") for prompt_message in prompt_messages
|
||||
]
|
||||
value["history_prompt_messages"] = history_prompt_messages
|
||||
if model_schema:
|
||||
# remove structured output feature to support old version agent plugin
|
||||
model_schema = self._remove_unsupported_model_features_for_old_version(model_schema)
|
||||
value["entity"] = model_schema.model_dump(mode="json")
|
||||
else:
|
||||
value["entity"] = None
|
||||
result[parameter_name] = value
|
||||
|
||||
return result
|
||||
|
||||
def _generate_credentials(
|
||||
self,
|
||||
parameters: dict[str, Any],
|
||||
) -> InvokeCredentials:
|
||||
"""
|
||||
Generate credentials based on the given agent parameters.
|
||||
"""
|
||||
from core.plugin.entities.request import InvokeCredentials
|
||||
|
||||
credentials = InvokeCredentials()
|
||||
|
||||
# generate credentials for tools selector
|
||||
credentials.tool_credentials = {}
|
||||
for tool in parameters.get("tools", []):
|
||||
if tool.get("credential_id"):
|
||||
try:
|
||||
identity = ToolIdentity.model_validate(tool.get("identity", {}))
|
||||
credentials.tool_credentials[identity.provider] = tool.get("credential_id", None)
|
||||
except ValidationError:
|
||||
continue
|
||||
return credentials
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: AgentNodeData,
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
typed_node_data = node_data
|
||||
|
||||
result: dict[str, Any] = {}
|
||||
for parameter_name in typed_node_data.agent_parameters:
|
||||
input = typed_node_data.agent_parameters[parameter_name]
|
||||
match input.type:
|
||||
case "mixed" | "constant":
|
||||
selectors = VariableTemplateParser(str(input.value)).extract_variable_selectors()
|
||||
for selector in selectors:
|
||||
result[selector.variable] = selector.value_selector
|
||||
case "variable":
|
||||
result[parameter_name] = input.value
|
||||
|
||||
result = {node_id + "." + key: value for key, value in result.items()}
|
||||
|
||||
return result
|
||||
|
||||
@property
|
||||
def agent_strategy_icon(self) -> str | None:
|
||||
"""
|
||||
Get agent strategy icon
|
||||
:return:
|
||||
"""
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
manager = PluginInstaller()
|
||||
dify_ctx = self.require_dify_context()
|
||||
plugins = manager.list_plugins(dify_ctx.tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == self.node_data.agent_strategy_provider_name
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
icon = None
|
||||
return icon
|
||||
|
||||
def _fetch_memory(self, model_instance: ModelInstance) -> BaseMemory | TokenBufferMemory | None:
|
||||
"""
|
||||
Fetch memory based on configuration mode.
|
||||
|
||||
Returns TokenBufferMemory for conversation mode (default),
|
||||
or NodeTokenBufferMemory for node mode (Chatflow only).
|
||||
"""
|
||||
node_data = self.node_data
|
||||
memory_config = node_data.memory
|
||||
|
||||
if not memory_config:
|
||||
return None
|
||||
|
||||
# get conversation id (required for both modes in Chatflow)
|
||||
conversation_id_variable = self.graph_runtime_state.variable_pool.get(
|
||||
["sys", SystemVariableKey.CONVERSATION_ID]
|
||||
)
|
||||
if not isinstance(conversation_id_variable, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_variable.value
|
||||
|
||||
dify_ctx = self.require_dify_context()
|
||||
if memory_config.mode == MemoryMode.NODE:
|
||||
return NodeTokenBufferMemory(
|
||||
app_id=dify_ctx.app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=self._node_id,
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
else:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
stmt = select(Conversation).where(
|
||||
Conversation.app_id == dify_ctx.app_id, Conversation.id == conversation_id
|
||||
)
|
||||
conversation = session.scalar(stmt)
|
||||
if not conversation:
|
||||
return None
|
||||
return TokenBufferMemory(conversation=conversation, model_instance=model_instance)
|
||||
|
||||
def _fetch_model(self, value: dict[str, Any]) -> tuple[ModelInstance, AIModelEntity | None]:
|
||||
dify_ctx = self.require_dify_context()
|
||||
provider_manager = create_plugin_provider_manager(
|
||||
tenant_id=dify_ctx.tenant_id, user_id=dify_ctx.user_id
|
||||
)
|
||||
provider_model_bundle = provider_manager.get_provider_model_bundle(
|
||||
tenant_id=dify_ctx.tenant_id, provider=value.get("provider", ""), model_type=ModelType.LLM
|
||||
)
|
||||
model_name = value.get("model", "")
|
||||
model_credentials = provider_model_bundle.configuration.get_current_credentials(
|
||||
model_type=ModelType.LLM, model=model_name
|
||||
)
|
||||
provider_name = provider_model_bundle.configuration.provider.provider
|
||||
model_type_instance = provider_model_bundle.model_type_instance
|
||||
model_instance = ModelManager(provider_manager).get_model_instance(
|
||||
tenant_id=dify_ctx.tenant_id,
|
||||
provider=provider_name,
|
||||
model_type=ModelType(value.get("model_type", "")),
|
||||
model=model_name,
|
||||
)
|
||||
model_schema = model_type_instance.get_model_schema(model_name, model_credentials)
|
||||
return model_instance, model_schema
|
||||
|
||||
def _remove_unsupported_model_features_for_old_version(self, model_schema: AIModelEntity) -> AIModelEntity:
|
||||
if model_schema.features:
|
||||
for feature in model_schema.features[:]: # Create a copy to safely modify during iteration
|
||||
try:
|
||||
AgentOldVersionModelFeatures(feature.value) # Try to create enum member from value
|
||||
except ValueError:
|
||||
model_schema.features.remove(feature)
|
||||
return model_schema
|
||||
|
||||
def _filter_mcp_type_tool(self, strategy: PluginAgentStrategy, tools: list[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Filter MCP type tool
|
||||
:param strategy: plugin agent strategy
|
||||
:param tool: tool
|
||||
:return: filtered tool dict
|
||||
"""
|
||||
meta_version = strategy.meta_version
|
||||
if meta_version and Version(meta_version) > Version("0.0.1"):
|
||||
return tools
|
||||
else:
|
||||
return [tool for tool in tools if tool.get("type") != ToolProviderType.MCP]
|
||||
|
||||
def _fetch_memory_for_save(self) -> BaseMemory | None:
|
||||
"""
|
||||
Fetch memory instance for saving node memory.
|
||||
This is a simplified version that doesn't require model_instance.
|
||||
"""
|
||||
from graphon.model_runtime.entities.model_entities import ModelType
|
||||
|
||||
node_data = self.node_data
|
||||
if not node_data.memory:
|
||||
return None
|
||||
|
||||
# Get conversation_id
|
||||
conversation_id_var = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
|
||||
if not isinstance(conversation_id_var, StringSegment):
|
||||
return None
|
||||
conversation_id = conversation_id_var.value
|
||||
|
||||
# Return appropriate memory type based on mode
|
||||
if node_data.memory.mode == MemoryMode.NODE:
|
||||
try:
|
||||
provider_manager = create_plugin_provider_manager(tenant_id=self.tenant_id)
|
||||
model_instance = ModelManager(provider_manager).get_default_model_instance(
|
||||
tenant_id=self.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
return NodeTokenBufferMemory(
|
||||
app_id=self.app_id,
|
||||
conversation_id=conversation_id,
|
||||
node_id=self._node_id,
|
||||
tenant_id=self.tenant_id,
|
||||
model_instance=model_instance,
|
||||
)
|
||||
else:
|
||||
# Conversation-level memory doesn't need saving here
|
||||
return None
|
||||
|
||||
def _build_context(
|
||||
self,
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_query: str,
|
||||
assistant_response: str,
|
||||
agent_logs: list[AgentLogEvent],
|
||||
) -> list[PromptMessage]:
|
||||
"""
|
||||
Build context from user query, tool calls, and assistant response.
|
||||
Format: user -> assistant(with tool_calls) -> tool -> assistant
|
||||
|
||||
The context includes:
|
||||
- Current user query (always present, may be empty)
|
||||
- Assistant message with tool_calls (if tools were called)
|
||||
- Tool results
|
||||
- Assistant's final response
|
||||
"""
|
||||
context_messages: list[PromptMessage] = []
|
||||
|
||||
# Always add user query (even if empty, to maintain conversation structure)
|
||||
context_messages.append(UserPromptMessage(content=user_query or ""))
|
||||
|
||||
# Extract actual tool calls from agent logs
|
||||
# Only include logs with label starting with "CALL " - these are real tool invocations
|
||||
tool_calls: list[AssistantPromptMessage.ToolCall] = []
|
||||
tool_results: list[tuple[str, str, str]] = [] # (tool_call_id, tool_name, result)
|
||||
|
||||
for log in agent_logs:
|
||||
if log.status == "success" and log.label and log.label.startswith("CALL "):
|
||||
# Extract tool name from label (format: "CALL tool_name")
|
||||
tool_name = log.label[5:] # Remove "CALL " prefix
|
||||
tool_call_id = log.message_id
|
||||
|
||||
# Parse tool response from data
|
||||
data = log.data or {}
|
||||
tool_response = ""
|
||||
|
||||
# Try to extract the actual tool response
|
||||
if "tool_response" in data:
|
||||
tool_response = data["tool_response"]
|
||||
elif "output" in data:
|
||||
tool_response = data["output"]
|
||||
elif "result" in data:
|
||||
tool_response = data["result"]
|
||||
|
||||
if isinstance(tool_response, dict):
|
||||
tool_response = str(tool_response)
|
||||
|
||||
# Get tool input for arguments
|
||||
tool_input = data.get("tool_call_input", {}) or data.get("input", {})
|
||||
if isinstance(tool_input, dict):
|
||||
import json
|
||||
|
||||
tool_input_str = json.dumps(tool_input, ensure_ascii=False)
|
||||
else:
|
||||
tool_input_str = str(tool_input) if tool_input else ""
|
||||
|
||||
if tool_response:
|
||||
tool_calls.append(
|
||||
AssistantPromptMessage.ToolCall(
|
||||
id=tool_call_id,
|
||||
type="function",
|
||||
function=AssistantPromptMessage.ToolCall.ToolCallFunction(
|
||||
name=tool_name,
|
||||
arguments=tool_input_str,
|
||||
),
|
||||
)
|
||||
)
|
||||
tool_results.append((tool_call_id, tool_name, str(tool_response)))
|
||||
|
||||
# Add assistant message with tool_calls if there were tool calls
|
||||
if tool_calls:
|
||||
context_messages.append(AssistantPromptMessage(content="", tool_calls=tool_calls))
|
||||
|
||||
# Add tool result messages
|
||||
for tool_call_id, tool_name, result in tool_results:
|
||||
context_messages.append(
|
||||
ToolPromptMessage(
|
||||
content=result,
|
||||
tool_call_id=tool_call_id,
|
||||
name=tool_name,
|
||||
)
|
||||
)
|
||||
|
||||
# Add final assistant response
|
||||
context_messages.append(AssistantPromptMessage(content=assistant_response))
|
||||
|
||||
return context_messages
|
||||
|
||||
def _transform_message(
|
||||
self,
|
||||
messages: Generator[ToolInvokeMessage, None, None],
|
||||
tool_info: Mapping[str, Any],
|
||||
parameters_for_log: dict[str, Any],
|
||||
user_id: str,
|
||||
tenant_id: str,
|
||||
node_type: NodeType,
|
||||
node_id: str,
|
||||
node_execution_id: str,
|
||||
memory: BaseMemory | None = None,
|
||||
) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Convert ToolInvokeMessages into tuple[plain_text, files]
|
||||
"""
|
||||
# transform message and handle file storage
|
||||
from core.plugin.impl.plugin import PluginInstaller
|
||||
|
||||
message_stream = ToolFileMessageTransformer.transform_tool_invoke_messages(
|
||||
messages=messages,
|
||||
user_id=user_id,
|
||||
tenant_id=tenant_id,
|
||||
conversation_id=None,
|
||||
)
|
||||
|
||||
text = ""
|
||||
files: list[File] = []
|
||||
json_list: list[dict | list] = []
|
||||
|
||||
agent_logs: list[AgentLogEvent] = []
|
||||
agent_execution_metadata: Mapping[WorkflowNodeExecutionMetadataKey, Any] = {}
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
variables: dict[str, Any] = {}
|
||||
|
||||
for message in message_stream:
|
||||
if message.type in {
|
||||
ToolInvokeMessage.MessageType.IMAGE_LINK,
|
||||
ToolInvokeMessage.MessageType.BINARY_LINK,
|
||||
ToolInvokeMessage.MessageType.IMAGE,
|
||||
}:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
|
||||
url = message.message.text
|
||||
if message.meta:
|
||||
transfer_method = message.meta.get("transfer_method", FileTransferMethod.TOOL_FILE)
|
||||
else:
|
||||
transfer_method = FileTransferMethod.TOOL_FILE
|
||||
|
||||
tool_file_id = str(url).split("/")[-1].split(".")[0]
|
||||
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"type": file_factory.get_file_type_by_mime_type(tool_file.mimetype),
|
||||
"transfer_method": transfer_method,
|
||||
"url": url,
|
||||
}
|
||||
file = file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
files.append(file)
|
||||
elif message.type == ToolInvokeMessage.MessageType.BLOB:
|
||||
# get tool file id
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
assert message.meta
|
||||
|
||||
tool_file_id = message.message.text.split("/")[-1].split(".")[0]
|
||||
with Session(db.engine) as session:
|
||||
stmt = select(ToolFile).where(ToolFile.id == tool_file_id)
|
||||
tool_file = session.scalar(stmt)
|
||||
if tool_file is None:
|
||||
raise ToolFileNotFoundError(tool_file_id)
|
||||
|
||||
mapping = {
|
||||
"tool_file_id": tool_file_id,
|
||||
"transfer_method": FileTransferMethod.TOOL_FILE,
|
||||
}
|
||||
|
||||
files.append(
|
||||
file_factory.build_from_mapping(
|
||||
mapping=mapping,
|
||||
tenant_id=tenant_id,
|
||||
access_controller=_file_access_controller,
|
||||
)
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.TEXT:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
text += message.message.text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=message.message.text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.JSON:
|
||||
assert isinstance(message.message, ToolInvokeMessage.JsonMessage)
|
||||
if node_type == BuiltinNodeTypes.AGENT:
|
||||
if isinstance(message.message.json_object, dict):
|
||||
msg_metadata: dict[str, Any] = message.message.json_object.pop("execution_metadata", {})
|
||||
llm_usage = LLMUsage.from_metadata(cast(LLMUsageMetadata, msg_metadata))
|
||||
agent_execution_metadata = {
|
||||
WorkflowNodeExecutionMetadataKey(key): value
|
||||
for key, value in msg_metadata.items()
|
||||
if key in WorkflowNodeExecutionMetadataKey.__members__.values()
|
||||
}
|
||||
else:
|
||||
msg_metadata = {}
|
||||
llm_usage = LLMUsage.empty_usage()
|
||||
agent_execution_metadata = {}
|
||||
if message.message.json_object:
|
||||
json_list.append(message.message.json_object)
|
||||
elif message.type == ToolInvokeMessage.MessageType.LINK:
|
||||
assert isinstance(message.message, ToolInvokeMessage.TextMessage)
|
||||
stream_text = f"Link: {message.message.text}\n"
|
||||
text += stream_text
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk=stream_text,
|
||||
is_final=False,
|
||||
)
|
||||
elif message.type == ToolInvokeMessage.MessageType.VARIABLE:
|
||||
assert isinstance(message.message, ToolInvokeMessage.VariableMessage)
|
||||
variable_name = message.message.variable_name
|
||||
variable_value = message.message.variable_value
|
||||
if message.message.stream:
|
||||
if not isinstance(variable_value, str):
|
||||
raise AgentVariableTypeError(
|
||||
"When 'stream' is True, 'variable_value' must be a string.",
|
||||
variable_name=variable_name,
|
||||
expected_type="str",
|
||||
actual_type=type(variable_value).__name__,
|
||||
)
|
||||
if variable_name not in variables:
|
||||
variables[variable_name] = ""
|
||||
variables[variable_name] += variable_value
|
||||
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, variable_name],
|
||||
chunk=variable_value,
|
||||
is_final=False,
|
||||
)
|
||||
else:
|
||||
variables[variable_name] = variable_value
|
||||
elif message.type == ToolInvokeMessage.MessageType.FILE:
|
||||
assert message.meta is not None
|
||||
assert isinstance(message.meta, dict)
|
||||
# Validate that meta contains a 'file' key
|
||||
if "file" not in message.meta:
|
||||
raise AgentNodeError("File message is missing 'file' key in meta")
|
||||
|
||||
# Validate that the file is an instance of File
|
||||
if not isinstance(message.meta["file"], File):
|
||||
raise AgentNodeError(f"Expected File object but got {type(message.meta['file']).__name__}")
|
||||
files.append(message.meta["file"])
|
||||
elif message.type == ToolInvokeMessage.MessageType.LOG:
|
||||
assert isinstance(message.message, ToolInvokeMessage.LogMessage)
|
||||
if message.message.metadata:
|
||||
icon = tool_info.get("icon", "")
|
||||
dict_metadata = dict(message.message.metadata)
|
||||
if dict_metadata.get("provider"):
|
||||
manager = PluginInstaller()
|
||||
plugins = manager.list_plugins(tenant_id)
|
||||
try:
|
||||
current_plugin = next(
|
||||
plugin
|
||||
for plugin in plugins
|
||||
if f"{plugin.plugin_id}/{plugin.name}" == dict_metadata["provider"]
|
||||
)
|
||||
icon = current_plugin.declaration.icon
|
||||
except StopIteration:
|
||||
pass
|
||||
icon_dark = None
|
||||
try:
|
||||
builtin_tool = next(
|
||||
provider
|
||||
for provider in BuiltinToolManageService.list_builtin_tools(
|
||||
user_id,
|
||||
tenant_id,
|
||||
)
|
||||
if provider.name == dict_metadata["provider"]
|
||||
)
|
||||
icon = builtin_tool.icon
|
||||
icon_dark = builtin_tool.icon_dark
|
||||
except StopIteration:
|
||||
pass
|
||||
|
||||
dict_metadata["icon"] = icon
|
||||
dict_metadata["icon_dark"] = icon_dark
|
||||
message.message.metadata = dict_metadata
|
||||
agent_log = AgentLogEvent(
|
||||
message_id=message.message.id,
|
||||
node_execution_id=node_execution_id,
|
||||
parent_id=message.message.parent_id,
|
||||
error=message.message.error,
|
||||
status=message.message.status.value,
|
||||
data=message.message.data,
|
||||
label=message.message.label,
|
||||
metadata=message.message.metadata,
|
||||
node_id=node_id,
|
||||
)
|
||||
|
||||
# check if the agent log is already in the list
|
||||
for log in agent_logs:
|
||||
if log.message_id == agent_log.message_id:
|
||||
# update the log
|
||||
log.data = agent_log.data
|
||||
log.status = agent_log.status
|
||||
log.error = agent_log.error
|
||||
log.label = agent_log.label
|
||||
log.metadata = agent_log.metadata
|
||||
break
|
||||
else:
|
||||
agent_logs.append(agent_log)
|
||||
|
||||
yield agent_log
|
||||
|
||||
# Add agent_logs to outputs['json'] to ensure frontend can access thinking process
|
||||
json_output: list[dict[str, Any] | list[Any]] = []
|
||||
|
||||
# Step 1: append each agent log as its own dict.
|
||||
if agent_logs:
|
||||
for log in agent_logs:
|
||||
json_output.append(
|
||||
{
|
||||
"id": log.message_id,
|
||||
"parent_id": log.parent_id,
|
||||
"error": log.error,
|
||||
"status": log.status,
|
||||
"data": log.data,
|
||||
"label": log.label,
|
||||
"metadata": log.metadata,
|
||||
"node_id": log.node_id,
|
||||
}
|
||||
)
|
||||
# Step 2: normalize JSON into {"data": [...]}.change json to list[dict]
|
||||
if json_list:
|
||||
json_output.extend(json_list)
|
||||
else:
|
||||
json_output.append({"data": []})
|
||||
|
||||
# Send final chunk events for all streamed outputs
|
||||
# Final chunk for text stream
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, "text"],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Final chunks for any streamed variables
|
||||
for var_name in variables:
|
||||
yield StreamChunkEvent(
|
||||
selector=[node_id, var_name],
|
||||
chunk="",
|
||||
is_final=True,
|
||||
)
|
||||
|
||||
# Get user query from parameters for building context
|
||||
user_query = parameters_for_log.get("query", "")
|
||||
|
||||
# Build context from history, user query, tool calls and assistant response
|
||||
context = self._build_context(parameters_for_log, user_query, text, agent_logs)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={
|
||||
"text": text,
|
||||
"usage": jsonable_encoder(llm_usage),
|
||||
"files": ArrayFileSegment(value=files),
|
||||
"json": json_output,
|
||||
"context": context,
|
||||
**variables,
|
||||
},
|
||||
metadata={
|
||||
**agent_execution_metadata,
|
||||
WorkflowNodeExecutionMetadataKey.TOOL_INFO: tool_info,
|
||||
WorkflowNodeExecutionMetadataKey.AGENT_LOG: agent_logs,
|
||||
},
|
||||
inputs=parameters_for_log,
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
)
|
||||
5
api/extensions/ext_socketio.py
Normal file
5
api/extensions/ext_socketio.py
Normal file
@@ -0,0 +1,5 @@
|
||||
import socketio # type: ignore[reportMissingTypeStubs]
|
||||
|
||||
from configs import dify_config
|
||||
|
||||
sio = socketio.Server(async_mode="gevent", cors_allowed_origins=dify_config.CONSOLE_CORS_ALLOW_ORIGINS)
|
||||
73
api/extensions/storage/file_presign_storage.py
Normal file
73
api/extensions/storage/file_presign_storage.py
Normal file
@@ -0,0 +1,73 @@
|
||||
"""Storage wrapper that provides presigned URL support with fallback to ticket-based URLs.
|
||||
|
||||
This is the unified presign wrapper for all storage operations. When the underlying
|
||||
storage backend doesn't support presigned URLs (raises NotImplementedError), it falls
|
||||
back to generating ticket-based URLs that route through Dify's file proxy endpoints.
|
||||
|
||||
Usage:
|
||||
from extensions.storage.file_presign_storage import FilePresignStorage
|
||||
|
||||
# Wrap any BaseStorage to add presign support
|
||||
presign_storage = FilePresignStorage(base_storage)
|
||||
download_url = presign_storage.get_download_url("path/to/file.txt", expires_in=3600)
|
||||
upload_url = presign_storage.get_upload_url("path/to/file.txt", expires_in=3600)
|
||||
|
||||
When the underlying storage doesn't support presigned URLs, the fallback URLs follow the format:
|
||||
{FILES_API_URL}/files/storage-files/{token} (falls back to FILES_URL)
|
||||
|
||||
The token is a UUID that maps to the real storage key in Redis.
|
||||
"""
|
||||
|
||||
from extensions.storage.storage_wrapper import StorageWrapper
|
||||
|
||||
|
||||
class FilePresignStorage(StorageWrapper):
|
||||
"""Storage wrapper that provides presigned URL support with ticket fallback.
|
||||
|
||||
If the wrapped storage supports presigned URLs, delegates to it.
|
||||
Otherwise, generates ticket-based URLs for both download and upload operations.
|
||||
"""
|
||||
|
||||
def get_download_url(
|
||||
self,
|
||||
filename: str,
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filename: str | None = None,
|
||||
) -> str:
|
||||
"""Get a presigned download URL, falling back to ticket URL if not supported."""
|
||||
try:
|
||||
return self._storage.get_download_url(filename, expires_in, download_filename=download_filename)
|
||||
except NotImplementedError:
|
||||
from services.storage_ticket_service import StorageTicketService
|
||||
|
||||
return StorageTicketService.create_download_url(filename, expires_in=expires_in, filename=download_filename)
|
||||
|
||||
def get_download_urls(
|
||||
self,
|
||||
filenames: list[str],
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filenames: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
"""Get presigned download URLs for multiple files."""
|
||||
try:
|
||||
return self._storage.get_download_urls(filenames, expires_in, download_filenames=download_filenames)
|
||||
except NotImplementedError:
|
||||
from services.storage_ticket_service import StorageTicketService
|
||||
|
||||
if download_filenames is None:
|
||||
return [StorageTicketService.create_download_url(f, expires_in=expires_in) for f in filenames]
|
||||
return [
|
||||
StorageTicketService.create_download_url(f, expires_in=expires_in, filename=df)
|
||||
for f, df in zip(filenames, download_filenames, strict=True)
|
||||
]
|
||||
|
||||
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
"""Get a presigned upload URL, falling back to ticket URL if not supported."""
|
||||
try:
|
||||
return self._storage.get_upload_url(filename, expires_in)
|
||||
except NotImplementedError:
|
||||
from services.storage_ticket_service import StorageTicketService
|
||||
|
||||
return StorageTicketService.create_upload_url(filename, expires_in=expires_in)
|
||||
66
api/extensions/storage/storage_wrapper.py
Normal file
66
api/extensions/storage/storage_wrapper.py
Normal file
@@ -0,0 +1,66 @@
|
||||
"""Base class for storage wrappers that delegate to an inner storage."""
|
||||
|
||||
from collections.abc import Generator
|
||||
|
||||
from extensions.storage.base_storage import BaseStorage
|
||||
|
||||
|
||||
class StorageWrapper(BaseStorage):
|
||||
"""Base class for storage wrappers using the decorator pattern.
|
||||
|
||||
Forwards all BaseStorage methods to the wrapped storage by default.
|
||||
Subclasses can override specific methods to customize behavior.
|
||||
|
||||
Example:
|
||||
class MyCustomStorage(StorageWrapper):
|
||||
def save(self, filename: str, data: bytes):
|
||||
# Custom logic before save
|
||||
super().save(filename, data)
|
||||
# Custom logic after save
|
||||
"""
|
||||
|
||||
def __init__(self, storage: BaseStorage):
|
||||
super().__init__()
|
||||
self._storage = storage
|
||||
|
||||
def save(self, filename: str, data: bytes):
|
||||
self._storage.save(filename, data)
|
||||
|
||||
def load_once(self, filename: str) -> bytes:
|
||||
return self._storage.load_once(filename)
|
||||
|
||||
def load_stream(self, filename: str) -> Generator:
|
||||
return self._storage.load_stream(filename)
|
||||
|
||||
def download(self, filename: str, target_filepath: str):
|
||||
self._storage.download(filename, target_filepath)
|
||||
|
||||
def exists(self, filename: str) -> bool:
|
||||
return self._storage.exists(filename)
|
||||
|
||||
def delete(self, filename: str):
|
||||
self._storage.delete(filename)
|
||||
|
||||
def scan(self, path: str, files: bool = True, directories: bool = False) -> list[str]:
|
||||
return self._storage.scan(path, files=files, directories=directories)
|
||||
|
||||
def get_download_url(
|
||||
self,
|
||||
filename: str,
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filename: str | None = None,
|
||||
) -> str:
|
||||
return self._storage.get_download_url(filename, expires_in, download_filename=download_filename)
|
||||
|
||||
def get_download_urls(
|
||||
self,
|
||||
filenames: list[str],
|
||||
expires_in: int = 3600,
|
||||
*,
|
||||
download_filenames: list[str] | None = None,
|
||||
) -> list[str]:
|
||||
return self._storage.get_download_urls(filenames, expires_in, download_filenames=download_filenames)
|
||||
|
||||
def get_upload_url(self, filename: str, expires_in: int = 3600) -> str:
|
||||
return self._storage.get_upload_url(filename, expires_in)
|
||||
17
api/fields/online_user_fields.py
Normal file
17
api/fields/online_user_fields.py
Normal file
@@ -0,0 +1,17 @@
|
||||
from flask_restx import fields
|
||||
|
||||
online_user_partial_fields = {
|
||||
"user_id": fields.String,
|
||||
"username": fields.String,
|
||||
"avatar": fields.String,
|
||||
"sid": fields.String,
|
||||
}
|
||||
|
||||
workflow_online_users_fields = {
|
||||
"workflow_id": fields.String,
|
||||
"users": fields.List(fields.Nested(online_user_partial_fields)),
|
||||
}
|
||||
|
||||
online_user_list_fields = {
|
||||
"data": fields.List(fields.Nested(workflow_online_users_fields)),
|
||||
}
|
||||
96
api/fields/workflow_comment_fields.py
Normal file
96
api/fields/workflow_comment_fields.py
Normal file
@@ -0,0 +1,96 @@
|
||||
from flask_restx import fields
|
||||
|
||||
from libs.helper import AvatarUrlField, TimestampField
|
||||
|
||||
# basic account fields for comments
|
||||
account_fields = {
|
||||
"id": fields.String,
|
||||
"name": fields.String,
|
||||
"email": fields.String,
|
||||
"avatar_url": AvatarUrlField,
|
||||
}
|
||||
|
||||
# Comment mention fields
|
||||
workflow_comment_mention_fields = {
|
||||
"mentioned_user_id": fields.String,
|
||||
"mentioned_user_account": fields.Nested(account_fields, allow_null=True),
|
||||
"reply_id": fields.String,
|
||||
}
|
||||
|
||||
# Comment reply fields
|
||||
workflow_comment_reply_fields = {
|
||||
"id": fields.String,
|
||||
"content": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
# Basic comment fields (for list views)
|
||||
workflow_comment_basic_fields = {
|
||||
"id": fields.String,
|
||||
"position_x": fields.Float,
|
||||
"position_y": fields.Float,
|
||||
"content": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"resolved": fields.Boolean,
|
||||
"resolved_at": TimestampField,
|
||||
"resolved_by": fields.String,
|
||||
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"reply_count": fields.Integer,
|
||||
"mention_count": fields.Integer,
|
||||
"participants": fields.List(fields.Nested(account_fields)),
|
||||
}
|
||||
|
||||
# Detailed comment fields (for single comment view)
|
||||
workflow_comment_detail_fields = {
|
||||
"id": fields.String,
|
||||
"position_x": fields.Float,
|
||||
"position_y": fields.Float,
|
||||
"content": fields.String,
|
||||
"created_by": fields.String,
|
||||
"created_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"created_at": TimestampField,
|
||||
"updated_at": TimestampField,
|
||||
"resolved": fields.Boolean,
|
||||
"resolved_at": TimestampField,
|
||||
"resolved_by": fields.String,
|
||||
"resolved_by_account": fields.Nested(account_fields, allow_null=True),
|
||||
"replies": fields.List(fields.Nested(workflow_comment_reply_fields)),
|
||||
"mentions": fields.List(fields.Nested(workflow_comment_mention_fields)),
|
||||
}
|
||||
|
||||
# Comment creation response fields (simplified)
|
||||
workflow_comment_create_fields = {
|
||||
"id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
# Comment update response fields (simplified)
|
||||
workflow_comment_update_fields = {
|
||||
"id": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
|
||||
# Comment resolve response fields
|
||||
workflow_comment_resolve_fields = {
|
||||
"id": fields.String,
|
||||
"resolved": fields.Boolean,
|
||||
"resolved_at": TimestampField,
|
||||
"resolved_by": fields.String,
|
||||
}
|
||||
|
||||
# Reply creation response fields (simplified)
|
||||
workflow_comment_reply_create_fields = {
|
||||
"id": fields.String,
|
||||
"created_at": TimestampField,
|
||||
}
|
||||
|
||||
# Reply update response fields
|
||||
workflow_comment_reply_update_fields = {
|
||||
"id": fields.String,
|
||||
"updated_at": TimestampField,
|
||||
}
|
||||
163
api/libs/attr_map.py
Normal file
163
api/libs/attr_map.py
Normal file
@@ -0,0 +1,163 @@
|
||||
"""
|
||||
Type-safe attribute storage inspired by Netty's AttributeKey/AttributeMap pattern.
|
||||
|
||||
Provides loosely-coupled typed attribute storage where only code with access
|
||||
to the same AttrKey instance can read/write the corresponding attribute.
|
||||
|
||||
SESSION_KEY: AttrKey[Session] = AttrKey("session", Session)
|
||||
attrs = AttrMap()
|
||||
attrs.set(SESSION_KEY, session)
|
||||
session = attrs.get(SESSION_KEY) # -> Session (raises if not set)
|
||||
session = attrs.get_or_none(SESSION_KEY) # -> Session | None
|
||||
|
||||
Note: AttrMap is NOT thread-safe. Each instance should be confined to a single
|
||||
thread/context (e.g., one AttrMap per Sandbox/VirtualEnvironment instance).
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Generic, TypeVar, cast, final, overload
|
||||
|
||||
T = TypeVar("T")
|
||||
D = TypeVar("D")
|
||||
|
||||
|
||||
@final
|
||||
class AttrKey(Generic[T]):
|
||||
"""
|
||||
A type-safe key for attribute storage.
|
||||
|
||||
Identity-based: different AttrKey instances with same name are distinct keys.
|
||||
This enables different modules to define keys independently without collision.
|
||||
"""
|
||||
|
||||
__slots__ = ("_name", "_type")
|
||||
|
||||
def __init__(self, name: str, type_: type[T]) -> None:
|
||||
self._name = name
|
||||
self._type = type_
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def type_(self) -> type[T]:
|
||||
return self._type
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"AttrKey({self._name!r}, {self._type.__name__})"
|
||||
|
||||
def __hash__(self) -> int:
|
||||
return id(self)
|
||||
|
||||
def __eq__(self, other: object) -> bool:
|
||||
return self is other
|
||||
|
||||
|
||||
class AttrMapKeyError(KeyError):
|
||||
"""Raised when a required attribute is not set."""
|
||||
|
||||
key: AttrKey[Any]
|
||||
|
||||
def __init__(self, key: AttrKey[Any]) -> None:
|
||||
self.key = key
|
||||
super().__init__(f"Required attribute '{key.name}' (type: {key.type_.__name__}) is not set")
|
||||
|
||||
|
||||
class AttrMapTypeError(TypeError):
|
||||
"""Raised when attribute value type doesn't match the key's declared type."""
|
||||
|
||||
key: AttrKey[Any]
|
||||
expected_type: type[Any]
|
||||
actual_type: type[Any]
|
||||
|
||||
def __init__(self, key: AttrKey[Any], expected_type: type[Any], actual_type: type[Any]) -> None:
|
||||
self.key = key
|
||||
self.expected_type = expected_type
|
||||
self.actual_type = actual_type
|
||||
super().__init__(
|
||||
f"Attribute '{key.name}' expects type '{expected_type.__name__}', got '{actual_type.__name__}'"
|
||||
)
|
||||
|
||||
|
||||
@final
|
||||
class AttrMap:
|
||||
"""
|
||||
Thread-confined container for storing typed attributes using AttrKey instances.
|
||||
|
||||
NOT thread-safe. Each instance should be owned by a single context
|
||||
(e.g., one AttrMap per Sandbox/VirtualEnvironment instance).
|
||||
"""
|
||||
|
||||
__slots__ = ("_data",)
|
||||
|
||||
def __init__(self) -> None:
|
||||
self._data: dict[AttrKey[Any], Any] = {}
|
||||
|
||||
def set(self, key: AttrKey[T], value: T, *, validate: bool = True) -> None:
|
||||
"""
|
||||
Store an attribute. Raises AttrMapTypeError if validate=True and type mismatches.
|
||||
|
||||
Note: Runtime validation only checks outer type (e.g., `list` not `list[str]`).
|
||||
"""
|
||||
if validate and not isinstance(value, key.type_):
|
||||
raise AttrMapTypeError(key, key.type_, type(value))
|
||||
self._data[key] = value
|
||||
|
||||
def get(self, key: AttrKey[T]) -> T:
|
||||
"""Retrieve an attribute. Raises AttrMapKeyError if not set."""
|
||||
if key not in self._data:
|
||||
raise AttrMapKeyError(key)
|
||||
return cast(T, self._data[key])
|
||||
|
||||
def get_or_none(self, key: AttrKey[T]) -> T | None:
|
||||
"""Retrieve an attribute, returning None if not set."""
|
||||
return cast(T | None, self._data.get(key))
|
||||
|
||||
@overload
|
||||
def get_or_default(self, key: AttrKey[T], default: T) -> T: ...
|
||||
|
||||
@overload
|
||||
def get_or_default(self, key: AttrKey[T], default: D) -> T | D: ...
|
||||
|
||||
def get_or_default(self, key: AttrKey[T], default: T | D) -> T | D:
|
||||
"""Retrieve an attribute, returning default if not set."""
|
||||
if key in self._data:
|
||||
return cast(T, self._data[key])
|
||||
return default
|
||||
|
||||
def has(self, key: AttrKey[Any]) -> bool:
|
||||
"""Check if an attribute is set."""
|
||||
return key in self._data
|
||||
|
||||
def remove(self, key: AttrKey[Any]) -> bool:
|
||||
"""Remove an attribute. Returns True if it was present."""
|
||||
if key in self._data:
|
||||
del self._data[key]
|
||||
return True
|
||||
return False
|
||||
|
||||
def set_if_absent(self, key: AttrKey[T], value: T, *, validate: bool = True) -> T:
|
||||
"""
|
||||
Set attribute only if not already set. Returns existing or newly set value.
|
||||
|
||||
Raises AttrMapTypeError if validate=True and type mismatches.
|
||||
"""
|
||||
if key in self._data:
|
||||
return cast(T, self._data[key])
|
||||
if validate and not isinstance(value, key.type_):
|
||||
raise AttrMapTypeError(key, key.type_, type(value))
|
||||
self._data[key] = value
|
||||
return value
|
||||
|
||||
def clear(self) -> None:
|
||||
"""Remove all attributes."""
|
||||
self._data.clear()
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._data)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
keys = [k.name for k in self._data]
|
||||
return f"AttrMap({keys})"
|
||||
@@ -0,0 +1,109 @@
|
||||
"""Add workflow comments table
|
||||
|
||||
Revision ID: 227822d22895
|
||||
Revises: 6b5f9f8b1a2c
|
||||
Create Date: 2026-02-09 17:26:15.255980
|
||||
|
||||
"""
|
||||
|
||||
from alembic import op
|
||||
import models as models
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "227822d22895"
|
||||
down_revision = "6b5f9f8b1a2c"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"workflow_comments",
|
||||
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
|
||||
sa.Column("tenant_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("app_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("position_x", sa.Float(), nullable=False),
|
||||
sa.Column("position_y", sa.Float(), nullable=False),
|
||||
sa.Column("content", sa.Text(), nullable=False),
|
||||
sa.Column("created_by", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("resolved", sa.Boolean(), server_default=sa.text("false"), nullable=False),
|
||||
sa.Column("resolved_at", sa.DateTime(), nullable=True),
|
||||
sa.Column("resolved_by", models.types.StringUUID(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
|
||||
)
|
||||
with op.batch_alter_table("workflow_comments", schema=None) as batch_op:
|
||||
batch_op.create_index("workflow_comments_app_idx", ["tenant_id", "app_id"], unique=False)
|
||||
batch_op.create_index("workflow_comments_created_at_idx", ["created_at"], unique=False)
|
||||
|
||||
op.create_table(
|
||||
"workflow_comment_replies",
|
||||
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
|
||||
sa.Column("comment_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("content", sa.Text(), nullable=False),
|
||||
sa.Column("created_by", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.Column("updated_at", sa.DateTime(), server_default=sa.text("CURRENT_TIMESTAMP"), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["comment_id"],
|
||||
["workflow_comments.id"],
|
||||
name=op.f("workflow_comment_replies_comment_id_fkey"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
|
||||
)
|
||||
with op.batch_alter_table("workflow_comment_replies", schema=None) as batch_op:
|
||||
batch_op.create_index("comment_replies_comment_idx", ["comment_id"], unique=False)
|
||||
batch_op.create_index("comment_replies_created_at_idx", ["created_at"], unique=False)
|
||||
|
||||
op.create_table(
|
||||
"workflow_comment_mentions",
|
||||
sa.Column("id", models.types.StringUUID(), server_default=sa.text("uuidv7()"), nullable=False),
|
||||
sa.Column("comment_id", models.types.StringUUID(), nullable=False),
|
||||
sa.Column("reply_id", models.types.StringUUID(), nullable=True),
|
||||
sa.Column("mentioned_user_id", models.types.StringUUID(), nullable=False),
|
||||
sa.ForeignKeyConstraint(
|
||||
["comment_id"],
|
||||
["workflow_comments.id"],
|
||||
name=op.f("workflow_comment_mentions_comment_id_fkey"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["reply_id"],
|
||||
["workflow_comment_replies.id"],
|
||||
name=op.f("workflow_comment_mentions_reply_id_fkey"),
|
||||
ondelete="CASCADE",
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
|
||||
)
|
||||
with op.batch_alter_table("workflow_comment_mentions", schema=None) as batch_op:
|
||||
batch_op.create_index("comment_mentions_comment_idx", ["comment_id"], unique=False)
|
||||
batch_op.create_index("comment_mentions_reply_idx", ["reply_id"], unique=False)
|
||||
batch_op.create_index("comment_mentions_user_idx", ["mentioned_user_id"], unique=False)
|
||||
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
with op.batch_alter_table("workflow_comment_mentions", schema=None) as batch_op:
|
||||
batch_op.drop_index("comment_mentions_user_idx")
|
||||
batch_op.drop_index("comment_mentions_reply_idx")
|
||||
batch_op.drop_index("comment_mentions_comment_idx")
|
||||
|
||||
op.drop_table("workflow_comment_mentions")
|
||||
with op.batch_alter_table("workflow_comment_replies", schema=None) as batch_op:
|
||||
batch_op.drop_index("comment_replies_created_at_idx")
|
||||
batch_op.drop_index("comment_replies_comment_idx")
|
||||
|
||||
op.drop_table("workflow_comment_replies")
|
||||
with op.batch_alter_table("workflow_comments", schema=None) as batch_op:
|
||||
batch_op.drop_index("workflow_comments_created_at_idx")
|
||||
batch_op.drop_index("workflow_comments_app_idx")
|
||||
|
||||
op.drop_table("workflow_comments")
|
||||
# ### end Alembic commands ###
|
||||
@@ -98,6 +98,7 @@ from .trigger import (
|
||||
TriggerSubscription,
|
||||
WorkflowSchedulePlan,
|
||||
)
|
||||
from .comment import WorkflowComment, WorkflowCommentMention, WorkflowCommentReply
|
||||
from .web import PinnedConversation, SavedMessage
|
||||
from .workflow import (
|
||||
ConversationVariable,
|
||||
@@ -205,6 +206,9 @@ __all__ = [
|
||||
"UploadFile",
|
||||
"Whitelist",
|
||||
"Workflow",
|
||||
"WorkflowComment",
|
||||
"WorkflowCommentMention",
|
||||
"WorkflowCommentReply",
|
||||
"WorkflowAppLog",
|
||||
"WorkflowAppLogCreatedFrom",
|
||||
"WorkflowArchiveLog",
|
||||
|
||||
210
api/models/comment.py
Normal file
210
api/models/comment.py
Normal file
@@ -0,0 +1,210 @@
|
||||
"""Workflow comment models."""
|
||||
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from sqlalchemy import Index, func
|
||||
from sqlalchemy.orm import Mapped, mapped_column, relationship
|
||||
|
||||
from .account import Account
|
||||
from .base import Base
|
||||
from .engine import db
|
||||
from .types import StringUUID
|
||||
|
||||
|
||||
class WorkflowComment(Base):
|
||||
"""Workflow comment model for canvas commenting functionality.
|
||||
|
||||
Comments are associated with apps rather than specific workflow versions,
|
||||
since an app has only one draft workflow at a time and comments should persist
|
||||
across workflow version changes.
|
||||
|
||||
Attributes:
|
||||
id: Comment ID
|
||||
tenant_id: Workspace ID
|
||||
app_id: App ID (primary association, comments belong to apps)
|
||||
position_x: X coordinate on canvas
|
||||
position_y: Y coordinate on canvas
|
||||
content: Comment content
|
||||
created_by: Creator account ID
|
||||
created_at: Creation time
|
||||
updated_at: Last update time
|
||||
resolved: Whether comment is resolved
|
||||
resolved_at: Resolution time
|
||||
resolved_by: Resolver account ID
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_comments"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_comments_pkey"),
|
||||
Index("workflow_comments_app_idx", "tenant_id", "app_id"),
|
||||
Index("workflow_comments_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
position_x: Mapped[float] = mapped_column(db.Float)
|
||||
position_y: Mapped[float] = mapped_column(db.Float)
|
||||
content: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
resolved: Mapped[bool] = mapped_column(db.Boolean, nullable=False, server_default=db.text("false"))
|
||||
resolved_at: Mapped[datetime | None] = mapped_column(db.DateTime)
|
||||
resolved_by: Mapped[str | None] = mapped_column(StringUUID)
|
||||
|
||||
# Relationships
|
||||
replies: Mapped[list["WorkflowCommentReply"]] = relationship(
|
||||
"WorkflowCommentReply", back_populates="comment", cascade="all, delete-orphan"
|
||||
)
|
||||
mentions: Mapped[list["WorkflowCommentMention"]] = relationship(
|
||||
"WorkflowCommentMention", back_populates="comment", cascade="all, delete-orphan"
|
||||
)
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
"""Get creator account."""
|
||||
if hasattr(self, "_created_by_account_cache"):
|
||||
return self._created_by_account_cache
|
||||
return db.session.get(Account, self.created_by)
|
||||
|
||||
def cache_created_by_account(self, account: Account | None) -> None:
|
||||
"""Cache creator account to avoid extra queries."""
|
||||
self._created_by_account_cache = account
|
||||
|
||||
@property
|
||||
def resolved_by_account(self):
|
||||
"""Get resolver account."""
|
||||
if hasattr(self, "_resolved_by_account_cache"):
|
||||
return self._resolved_by_account_cache
|
||||
if self.resolved_by:
|
||||
return db.session.get(Account, self.resolved_by)
|
||||
return None
|
||||
|
||||
def cache_resolved_by_account(self, account: Account | None) -> None:
|
||||
"""Cache resolver account to avoid extra queries."""
|
||||
self._resolved_by_account_cache = account
|
||||
|
||||
@property
|
||||
def reply_count(self):
|
||||
"""Get reply count."""
|
||||
return len(self.replies)
|
||||
|
||||
@property
|
||||
def mention_count(self):
|
||||
"""Get mention count."""
|
||||
return len(self.mentions)
|
||||
|
||||
@property
|
||||
def participants(self):
|
||||
"""Get all participants (creator + repliers + mentioned users)."""
|
||||
participant_ids = set()
|
||||
|
||||
# Add comment creator
|
||||
participant_ids.add(self.created_by)
|
||||
|
||||
# Add reply creators
|
||||
participant_ids.update(reply.created_by for reply in self.replies)
|
||||
|
||||
# Add mentioned users
|
||||
participant_ids.update(mention.mentioned_user_id for mention in self.mentions)
|
||||
|
||||
# Get account objects
|
||||
participants = []
|
||||
for user_id in participant_ids:
|
||||
account = db.session.get(Account, user_id)
|
||||
if account:
|
||||
participants.append(account)
|
||||
|
||||
return participants
|
||||
|
||||
|
||||
class WorkflowCommentReply(Base):
|
||||
"""Workflow comment reply model.
|
||||
|
||||
Attributes:
|
||||
id: Reply ID
|
||||
comment_id: Parent comment ID
|
||||
content: Reply content
|
||||
created_by: Creator account ID
|
||||
created_at: Creation time
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_comment_replies"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_comment_replies_pkey"),
|
||||
Index("comment_replies_comment_idx", "comment_id"),
|
||||
Index("comment_replies_created_at_idx", "created_at"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
comment_id: Mapped[str] = mapped_column(
|
||||
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
content: Mapped[str] = mapped_column(db.Text, nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
created_at: Mapped[datetime] = mapped_column(db.DateTime, nullable=False, server_default=func.current_timestamp())
|
||||
updated_at: Mapped[datetime] = mapped_column(
|
||||
db.DateTime, nullable=False, server_default=func.current_timestamp(), onupdate=func.current_timestamp()
|
||||
)
|
||||
# Relationships
|
||||
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="replies")
|
||||
|
||||
@property
|
||||
def created_by_account(self):
|
||||
"""Get creator account."""
|
||||
if hasattr(self, "_created_by_account_cache"):
|
||||
return self._created_by_account_cache
|
||||
return db.session.get(Account, self.created_by)
|
||||
|
||||
def cache_created_by_account(self, account: Account | None) -> None:
|
||||
"""Cache creator account to avoid extra queries."""
|
||||
self._created_by_account_cache = account
|
||||
|
||||
|
||||
class WorkflowCommentMention(Base):
|
||||
"""Workflow comment mention model.
|
||||
|
||||
Mentions are only for internal accounts since end users
|
||||
cannot access workflow canvas and commenting features.
|
||||
|
||||
Attributes:
|
||||
id: Mention ID
|
||||
comment_id: Parent comment ID
|
||||
mentioned_user_id: Mentioned account ID
|
||||
"""
|
||||
|
||||
__tablename__ = "workflow_comment_mentions"
|
||||
__table_args__ = (
|
||||
db.PrimaryKeyConstraint("id", name="workflow_comment_mentions_pkey"),
|
||||
Index("comment_mentions_comment_idx", "comment_id"),
|
||||
Index("comment_mentions_reply_idx", "reply_id"),
|
||||
Index("comment_mentions_user_idx", "mentioned_user_id"),
|
||||
)
|
||||
|
||||
id: Mapped[str] = mapped_column(StringUUID, server_default=db.text("uuidv7()"))
|
||||
comment_id: Mapped[str] = mapped_column(
|
||||
StringUUID, db.ForeignKey("workflow_comments.id", ondelete="CASCADE"), nullable=False
|
||||
)
|
||||
reply_id: Mapped[str | None] = mapped_column(
|
||||
StringUUID, db.ForeignKey("workflow_comment_replies.id", ondelete="CASCADE"), nullable=True
|
||||
)
|
||||
mentioned_user_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
|
||||
# Relationships
|
||||
comment: Mapped["WorkflowComment"] = relationship("WorkflowComment", back_populates="mentions")
|
||||
reply: Mapped[Optional["WorkflowCommentReply"]] = relationship("WorkflowCommentReply")
|
||||
|
||||
@property
|
||||
def mentioned_user_account(self):
|
||||
"""Get mentioned account."""
|
||||
if hasattr(self, "_mentioned_user_account_cache"):
|
||||
return self._mentioned_user_account_cache
|
||||
return db.session.get(Account, self.mentioned_user_id)
|
||||
|
||||
def cache_mentioned_user_account(self, account: Account | None) -> None:
|
||||
"""Cache mentioned account to avoid extra queries."""
|
||||
self._mentioned_user_account_cache = account
|
||||
@@ -352,6 +352,7 @@ class AppMode(StrEnum):
|
||||
CHAT = "chat"
|
||||
ADVANCED_CHAT = "advanced-chat"
|
||||
AGENT_CHAT = "agent-chat"
|
||||
AGENT = "agent"
|
||||
CHANNEL = "channel"
|
||||
RAG_PIPELINE = "rag-pipeline"
|
||||
|
||||
|
||||
0
api/models/workflow_comment.py
Normal file
0
api/models/workflow_comment.py
Normal file
26
api/models/workflow_features.py
Normal file
26
api/models/workflow_features.py
Normal file
@@ -0,0 +1,26 @@
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from enum import StrEnum
|
||||
from typing import Any
|
||||
|
||||
|
||||
class WorkflowFeatures(StrEnum):
|
||||
SANDBOX = "sandbox"
|
||||
SPEECH_TO_TEXT = "speech_to_text"
|
||||
TEXT_TO_SPEECH = "text_to_speech"
|
||||
RETRIEVER_RESOURCE = "retriever_resource"
|
||||
SENSITIVE_WORD_AVOIDANCE = "sensitive_word_avoidance"
|
||||
FILE_UPLOAD = "file_upload"
|
||||
SUGGESTED_QUESTIONS_AFTER_ANSWER = "suggested_questions_after_answer"
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class WorkflowFeature:
|
||||
enabled: bool
|
||||
config: Mapping[str, Any]
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: Mapping[str, Any] | None) -> "WorkflowFeature":
|
||||
if data is None or not isinstance(data, dict):
|
||||
return cls(enabled=False, config={})
|
||||
return cls(enabled=bool(data.get("enabled", False)), config=data)
|
||||
226
api/repositories/workflow_collaboration_repository.py
Normal file
226
api/repositories/workflow_collaboration_repository.py
Normal file
@@ -0,0 +1,226 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TypedDict
|
||||
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
SESSION_STATE_TTL_SECONDS = 3600
|
||||
WORKFLOW_ONLINE_USERS_PREFIX = "workflow_online_users:"
|
||||
WORKFLOW_LEADER_PREFIX = "workflow_leader:"
|
||||
WORKFLOW_SKILL_LEADER_PREFIX = "workflow_skill_leader:"
|
||||
WS_SID_MAP_PREFIX = "ws_sid_map:"
|
||||
|
||||
|
||||
class WorkflowSessionInfo(TypedDict):
|
||||
user_id: str
|
||||
username: str
|
||||
avatar: str | None
|
||||
sid: str
|
||||
connected_at: int
|
||||
graph_active: bool
|
||||
active_skill_file_id: str | None
|
||||
|
||||
|
||||
class SidMapping(TypedDict):
|
||||
workflow_id: str
|
||||
user_id: str
|
||||
|
||||
|
||||
class WorkflowCollaborationRepository:
|
||||
def __init__(self) -> None:
|
||||
self._redis = redis_client
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(redis_client={self._redis})"
|
||||
|
||||
@staticmethod
|
||||
def workflow_key(workflow_id: str) -> str:
|
||||
return f"{WORKFLOW_ONLINE_USERS_PREFIX}{workflow_id}"
|
||||
|
||||
@staticmethod
|
||||
def leader_key(workflow_id: str) -> str:
|
||||
return f"{WORKFLOW_LEADER_PREFIX}{workflow_id}"
|
||||
|
||||
@staticmethod
|
||||
def skill_leader_key(workflow_id: str, file_id: str) -> str:
|
||||
return f"{WORKFLOW_SKILL_LEADER_PREFIX}{workflow_id}:{file_id}"
|
||||
|
||||
@staticmethod
|
||||
def sid_key(sid: str) -> str:
|
||||
return f"{WS_SID_MAP_PREFIX}{sid}"
|
||||
|
||||
@staticmethod
|
||||
def _decode(value: str | bytes | None) -> str | None:
|
||||
if value is None:
|
||||
return None
|
||||
if isinstance(value, bytes):
|
||||
return value.decode("utf-8")
|
||||
return value
|
||||
|
||||
def refresh_session_state(self, workflow_id: str, sid: str) -> None:
|
||||
workflow_key = self.workflow_key(workflow_id)
|
||||
sid_key = self.sid_key(sid)
|
||||
if self._redis.exists(workflow_key):
|
||||
self._redis.expire(workflow_key, SESSION_STATE_TTL_SECONDS)
|
||||
if self._redis.exists(sid_key):
|
||||
self._redis.expire(sid_key, SESSION_STATE_TTL_SECONDS)
|
||||
|
||||
def set_session_info(self, workflow_id: str, session_info: WorkflowSessionInfo) -> None:
|
||||
workflow_key = self.workflow_key(workflow_id)
|
||||
self._redis.hset(workflow_key, session_info["sid"], json.dumps(session_info))
|
||||
self._redis.set(
|
||||
self.sid_key(session_info["sid"]),
|
||||
json.dumps({"workflow_id": workflow_id, "user_id": session_info["user_id"]}),
|
||||
ex=SESSION_STATE_TTL_SECONDS,
|
||||
)
|
||||
self.refresh_session_state(workflow_id, session_info["sid"])
|
||||
|
||||
def get_session_info(self, workflow_id: str, sid: str) -> WorkflowSessionInfo | None:
|
||||
raw = self._redis.hget(self.workflow_key(workflow_id), sid)
|
||||
value = self._decode(raw)
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
session_info = json.loads(value)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
if not isinstance(session_info, dict):
|
||||
return None
|
||||
if "user_id" not in session_info or "username" not in session_info or "sid" not in session_info:
|
||||
return None
|
||||
|
||||
return {
|
||||
"user_id": str(session_info["user_id"]),
|
||||
"username": str(session_info["username"]),
|
||||
"avatar": session_info.get("avatar"),
|
||||
"sid": str(session_info["sid"]),
|
||||
"connected_at": int(session_info.get("connected_at") or 0),
|
||||
"graph_active": bool(session_info.get("graph_active")),
|
||||
"active_skill_file_id": session_info.get("active_skill_file_id"),
|
||||
}
|
||||
|
||||
def set_graph_active(self, workflow_id: str, sid: str, active: bool) -> None:
|
||||
session_info = self.get_session_info(workflow_id, sid)
|
||||
if not session_info:
|
||||
return
|
||||
session_info["graph_active"] = bool(active)
|
||||
self._redis.hset(self.workflow_key(workflow_id), sid, json.dumps(session_info))
|
||||
self.refresh_session_state(workflow_id, sid)
|
||||
|
||||
def is_graph_active(self, workflow_id: str, sid: str) -> bool:
|
||||
session_info = self.get_session_info(workflow_id, sid)
|
||||
if not session_info:
|
||||
return False
|
||||
return bool(session_info.get("graph_active") or False)
|
||||
|
||||
def set_active_skill_file(self, workflow_id: str, sid: str, file_id: str | None) -> None:
|
||||
session_info = self.get_session_info(workflow_id, sid)
|
||||
if not session_info:
|
||||
return
|
||||
session_info["active_skill_file_id"] = file_id
|
||||
self._redis.hset(self.workflow_key(workflow_id), sid, json.dumps(session_info))
|
||||
self.refresh_session_state(workflow_id, sid)
|
||||
|
||||
def get_active_skill_file_id(self, workflow_id: str, sid: str) -> str | None:
|
||||
session_info = self.get_session_info(workflow_id, sid)
|
||||
if not session_info:
|
||||
return None
|
||||
return session_info.get("active_skill_file_id")
|
||||
|
||||
def get_sid_mapping(self, sid: str) -> SidMapping | None:
|
||||
raw = self._redis.get(self.sid_key(sid))
|
||||
if not raw:
|
||||
return None
|
||||
value = self._decode(raw)
|
||||
if not value:
|
||||
return None
|
||||
try:
|
||||
return json.loads(value)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
def delete_session(self, workflow_id: str, sid: str) -> None:
|
||||
self._redis.hdel(self.workflow_key(workflow_id), sid)
|
||||
self._redis.delete(self.sid_key(sid))
|
||||
|
||||
def session_exists(self, workflow_id: str, sid: str) -> bool:
|
||||
return bool(self._redis.hexists(self.workflow_key(workflow_id), sid))
|
||||
|
||||
def sid_mapping_exists(self, sid: str) -> bool:
|
||||
return bool(self._redis.exists(self.sid_key(sid)))
|
||||
|
||||
def get_session_sids(self, workflow_id: str) -> list[str]:
|
||||
raw_sids = self._redis.hkeys(self.workflow_key(workflow_id))
|
||||
decoded_sids: list[str] = []
|
||||
for sid in raw_sids:
|
||||
decoded = self._decode(sid)
|
||||
if decoded:
|
||||
decoded_sids.append(decoded)
|
||||
return decoded_sids
|
||||
|
||||
def list_sessions(self, workflow_id: str) -> list[WorkflowSessionInfo]:
|
||||
sessions_json = self._redis.hgetall(self.workflow_key(workflow_id))
|
||||
users: list[WorkflowSessionInfo] = []
|
||||
|
||||
for session_info_json in sessions_json.values():
|
||||
value = self._decode(session_info_json)
|
||||
if not value:
|
||||
continue
|
||||
try:
|
||||
session_info = json.loads(value)
|
||||
except (TypeError, json.JSONDecodeError):
|
||||
continue
|
||||
|
||||
if not isinstance(session_info, dict):
|
||||
continue
|
||||
if "user_id" not in session_info or "username" not in session_info or "sid" not in session_info:
|
||||
continue
|
||||
|
||||
users.append(
|
||||
{
|
||||
"user_id": str(session_info["user_id"]),
|
||||
"username": str(session_info["username"]),
|
||||
"avatar": session_info.get("avatar"),
|
||||
"sid": str(session_info["sid"]),
|
||||
"connected_at": int(session_info.get("connected_at") or 0),
|
||||
"graph_active": bool(session_info.get("graph_active")),
|
||||
"active_skill_file_id": session_info.get("active_skill_file_id"),
|
||||
}
|
||||
)
|
||||
|
||||
return users
|
||||
|
||||
def get_current_leader(self, workflow_id: str) -> str | None:
|
||||
raw = self._redis.get(self.leader_key(workflow_id))
|
||||
return self._decode(raw)
|
||||
|
||||
def get_skill_leader(self, workflow_id: str, file_id: str) -> str | None:
|
||||
raw = self._redis.get(self.skill_leader_key(workflow_id, file_id))
|
||||
return self._decode(raw)
|
||||
|
||||
def set_leader_if_absent(self, workflow_id: str, sid: str) -> bool:
|
||||
return bool(self._redis.set(self.leader_key(workflow_id), sid, nx=True, ex=SESSION_STATE_TTL_SECONDS))
|
||||
|
||||
def set_leader(self, workflow_id: str, sid: str) -> None:
|
||||
self._redis.set(self.leader_key(workflow_id), sid, ex=SESSION_STATE_TTL_SECONDS)
|
||||
|
||||
def set_skill_leader(self, workflow_id: str, file_id: str, sid: str) -> None:
|
||||
self._redis.set(self.skill_leader_key(workflow_id, file_id), sid, ex=SESSION_STATE_TTL_SECONDS)
|
||||
|
||||
def delete_leader(self, workflow_id: str) -> None:
|
||||
self._redis.delete(self.leader_key(workflow_id))
|
||||
|
||||
def delete_skill_leader(self, workflow_id: str, file_id: str) -> None:
|
||||
self._redis.delete(self.skill_leader_key(workflow_id, file_id))
|
||||
|
||||
def expire_leader(self, workflow_id: str) -> None:
|
||||
self._redis.expire(self.leader_key(workflow_id), SESSION_STATE_TTL_SECONDS)
|
||||
|
||||
def expire_skill_leader(self, workflow_id: str, file_id: str) -> None:
|
||||
self._redis.expire(self.skill_leader_key(workflow_id, file_id), SESSION_STATE_TTL_SECONDS)
|
||||
|
||||
def get_active_skill_session_sids(self, workflow_id: str, file_id: str) -> list[str]:
|
||||
sessions = self.list_sessions(workflow_id)
|
||||
return [session["sid"] for session in sessions if session.get("active_skill_file_id") == file_id]
|
||||
@@ -455,7 +455,7 @@ class AppDslService:
|
||||
app.updated_by = account.id
|
||||
|
||||
self._session.add(app)
|
||||
self._session.commit()
|
||||
self._session.flush()
|
||||
app_was_created.send(app, account=account)
|
||||
|
||||
# save dependencies
|
||||
@@ -468,7 +468,7 @@ class AppDslService:
|
||||
|
||||
# Initialize app based on mode
|
||||
match app_mode:
|
||||
case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW:
|
||||
case AppMode.ADVANCED_CHAT | AppMode.WORKFLOW | AppMode.AGENT:
|
||||
workflow_data = data.get("workflow")
|
||||
if not workflow_data or not isinstance(workflow_data, dict):
|
||||
raise ValueError("Missing workflow data for workflow/advanced chat app")
|
||||
@@ -556,7 +556,7 @@ class AppDslService:
|
||||
},
|
||||
}
|
||||
|
||||
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW}:
|
||||
if app_mode in {AppMode.ADVANCED_CHAT, AppMode.WORKFLOW, AppMode.AGENT}:
|
||||
cls._append_workflow_export_data(
|
||||
export_data=export_data, app_model=app_model, include_secret=include_secret, workflow_id=workflow_id
|
||||
)
|
||||
|
||||
@@ -56,7 +56,6 @@ class AppGenerateService:
|
||||
try:
|
||||
start_task()
|
||||
except Exception:
|
||||
logger.exception("Failed to enqueue streaming task")
|
||||
return False
|
||||
started = True
|
||||
return True
|
||||
@@ -117,8 +116,84 @@ class AppGenerateService:
|
||||
try:
|
||||
request_id = rate_limit.enter(request_id)
|
||||
effective_mode = (
|
||||
AppMode.AGENT_CHAT if app_model.is_agent and app_model.mode != AppMode.AGENT_CHAT else app_model.mode
|
||||
AppMode.AGENT_CHAT
|
||||
if app_model.is_agent and app_model.mode not in {AppMode.AGENT_CHAT, AppMode.AGENT}
|
||||
else app_model.mode
|
||||
)
|
||||
|
||||
if (
|
||||
effective_mode in {AppMode.COMPLETION, AppMode.CHAT, AppMode.AGENT_CHAT}
|
||||
and dify_config.AGENT_V2_TRANSPARENT_UPGRADE
|
||||
):
|
||||
from services.workflow.virtual_workflow import VirtualWorkflowSynthesizer
|
||||
|
||||
try:
|
||||
workflow = VirtualWorkflowSynthesizer.ensure_workflow(app_model)
|
||||
logger.info(
|
||||
"[AGENT_V2_UPGRADE] Transparent upgrade for app %s (mode=%s), wf=%s",
|
||||
app_model.id,
|
||||
effective_mode,
|
||||
workflow.id,
|
||||
)
|
||||
|
||||
upgraded_args = dict(args)
|
||||
if "query" not in upgraded_args or not upgraded_args.get("query"):
|
||||
inputs = upgraded_args.get("inputs", {})
|
||||
upgraded_args["query"] = inputs.get("query", "") or inputs.get("input", "") or str(inputs)
|
||||
args = upgraded_args
|
||||
|
||||
if streaming:
|
||||
with rate_limit_context(rate_limit, request_id):
|
||||
payload = AppExecutionParams.new(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=True,
|
||||
call_depth=0,
|
||||
)
|
||||
payload_json = payload.model_dump_json()
|
||||
|
||||
def on_subscribe():
|
||||
workflow_based_app_execution_task.delay(payload_json)
|
||||
|
||||
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
|
||||
generator = AdvancedChatAppGenerator()
|
||||
subscribe_mode = AppMode.value_of(app_model.mode)
|
||||
return rate_limit.generate(
|
||||
generator.convert_to_event_stream(
|
||||
generator.retrieve_events(
|
||||
subscribe_mode,
|
||||
payload.workflow_run_id,
|
||||
on_subscribe=on_subscribe,
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
else:
|
||||
advanced_generator = AdvancedChatAppGenerator()
|
||||
return rate_limit.generate(
|
||||
advanced_generator.convert_to_event_stream(
|
||||
advanced_generator.generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
streaming=False,
|
||||
)
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
except Exception:
|
||||
logger.warning(
|
||||
"[AGENT_V2_UPGRADE] Transparent upgrade failed for app %s, falling back to legacy",
|
||||
app_model.id,
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
match effective_mode:
|
||||
case AppMode.COMPLETION:
|
||||
return rate_limit.generate(
|
||||
@@ -147,6 +222,54 @@ class AppGenerateService:
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
case AppMode.AGENT:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
|
||||
if streaming:
|
||||
with rate_limit_context(rate_limit, request_id):
|
||||
payload = AppExecutionParams.new(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
streaming=True,
|
||||
call_depth=0,
|
||||
)
|
||||
payload_json = payload.model_dump_json()
|
||||
|
||||
def on_subscribe():
|
||||
workflow_based_app_execution_task.delay(payload_json)
|
||||
|
||||
on_subscribe = cls._build_streaming_task_on_subscribe(on_subscribe)
|
||||
generator = AdvancedChatAppGenerator()
|
||||
return rate_limit.generate(
|
||||
generator.convert_to_event_stream(
|
||||
generator.retrieve_events(
|
||||
AppMode.AGENT,
|
||||
payload.workflow_run_id,
|
||||
on_subscribe=on_subscribe,
|
||||
),
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
else:
|
||||
advanced_generator = AdvancedChatAppGenerator()
|
||||
return rate_limit.generate(
|
||||
advanced_generator.convert_to_event_stream(
|
||||
advanced_generator.generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
args=args,
|
||||
invoke_from=invoke_from,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
streaming=False,
|
||||
)
|
||||
),
|
||||
request_id=request_id,
|
||||
)
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
workflow_id = args.get("workflow_id")
|
||||
workflow = cls._get_workflow(app_model, invoke_from, workflow_id)
|
||||
|
||||
@@ -14,5 +14,5 @@ class AppModelConfigService:
|
||||
return AgentChatAppConfigManager.config_validate(tenant_id, config)
|
||||
case AppMode.COMPLETION:
|
||||
return CompletionAppConfigManager.config_validate(tenant_id, config)
|
||||
case AppMode.WORKFLOW | AppMode.ADVANCED_CHAT | AppMode.CHANNEL | AppMode.RAG_PIPELINE:
|
||||
case AppMode.WORKFLOW | AppMode.ADVANCED_CHAT | AppMode.AGENT | AppMode.CHANNEL | AppMode.RAG_PIPELINE:
|
||||
raise ValueError(f"Invalid app mode: {app_mode}")
|
||||
|
||||
443
api/services/app_runtime_upgrade_service.py
Normal file
443
api/services/app_runtime_upgrade_service.py
Normal file
@@ -0,0 +1,443 @@
|
||||
"""Service for upgrading Classic runtime apps to Sandboxed runtime via clone-and-convert.
|
||||
|
||||
The upgrade flow:
|
||||
1. Clone the source app via DSL export/import
|
||||
2. On the cloned app's draft workflow, convert Agent nodes to LLM nodes
|
||||
3. Rewrite variable references for all LLM nodes (old output names → new generation-based names)
|
||||
4. Enable sandbox feature flag
|
||||
|
||||
The original app is never modified; the user gets a new sandboxed copy.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models import App, Workflow
|
||||
from models.workflow_features import WorkflowFeatures
|
||||
from services.app_dsl_service import AppDslService, ImportMode
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_VAR_REWRITES: dict[str, list[str]] = {
|
||||
"text": ["generation", "content"],
|
||||
"reasoning_content": ["generation", "reasoning_content"],
|
||||
}
|
||||
|
||||
_PASSTHROUGH_KEYS = (
|
||||
"version",
|
||||
"error_strategy",
|
||||
"default_value",
|
||||
"retry_config",
|
||||
"parent_node_id",
|
||||
"isInLoop",
|
||||
"loop_id",
|
||||
"isInIteration",
|
||||
"iteration_id",
|
||||
)
|
||||
|
||||
|
||||
class AppRuntimeUpgradeService:
|
||||
"""Upgrades a Classic-runtime app to Sandboxed runtime by cloning and converting.
|
||||
|
||||
Holds an active SQLAlchemy session; the caller is responsible for commit/rollback.
|
||||
"""
|
||||
|
||||
session: Session
|
||||
|
||||
def __init__(self, session: Session) -> None:
|
||||
self.session = session
|
||||
|
||||
def upgrade(self, app_model: App, account: Any) -> dict[str, Any]:
|
||||
"""Clone *app_model* and upgrade the clone to sandboxed runtime.
|
||||
|
||||
Returns:
|
||||
dict with keys: result, new_app_id, converted_agents, skipped_agents.
|
||||
"""
|
||||
workflow = self._get_draft_workflow(app_model)
|
||||
if not workflow:
|
||||
return {"result": "no_draft"}
|
||||
|
||||
if workflow.get_feature(WorkflowFeatures.SANDBOX).enabled:
|
||||
return {"result": "already_sandboxed"}
|
||||
|
||||
new_app = self._clone_app(app_model, account)
|
||||
new_workflow = self._get_draft_workflow(new_app)
|
||||
if not new_workflow:
|
||||
return {"result": "no_draft"}
|
||||
|
||||
graph = json.loads(new_workflow.graph) if new_workflow.graph else {}
|
||||
nodes = graph.get("nodes", [])
|
||||
|
||||
converted, skipped = _convert_agent_nodes(nodes)
|
||||
_enable_computer_use_for_existing_llm_nodes(nodes)
|
||||
|
||||
llm_node_ids = {n["id"] for n in nodes if n.get("data", {}).get("type") == "llm"}
|
||||
_rewrite_variable_references(nodes, llm_node_ids)
|
||||
|
||||
new_workflow.graph = json.dumps(graph)
|
||||
|
||||
features = json.loads(new_workflow.features) if new_workflow.features else {}
|
||||
features.setdefault("sandbox", {})["enabled"] = True
|
||||
new_workflow.features = json.dumps(features)
|
||||
|
||||
return {
|
||||
"result": "success",
|
||||
"new_app_id": str(new_app.id),
|
||||
"converted_agents": converted,
|
||||
"skipped_agents": skipped,
|
||||
}
|
||||
|
||||
def _get_draft_workflow(self, app_model: App) -> Workflow | None:
|
||||
stmt = select(Workflow).where(
|
||||
Workflow.tenant_id == app_model.tenant_id,
|
||||
Workflow.app_id == app_model.id,
|
||||
Workflow.version == "draft",
|
||||
)
|
||||
return self.session.scalar(stmt)
|
||||
|
||||
def _clone_app(self, app_model: App, account: Any) -> App:
|
||||
dsl_service = AppDslService(self.session)
|
||||
yaml_content = dsl_service.export_dsl(app_model=app_model, include_secret=True)
|
||||
result = dsl_service.import_app(
|
||||
account=account,
|
||||
import_mode=ImportMode.YAML_CONTENT,
|
||||
yaml_content=yaml_content,
|
||||
name=f"{app_model.name} (Sandboxed)",
|
||||
)
|
||||
stmt = select(App).where(App.id == result.app_id)
|
||||
new_app = self.session.scalar(stmt)
|
||||
if not new_app:
|
||||
raise RuntimeError(f"Cloned app not found: {result.app_id}")
|
||||
return new_app
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Pure conversion functions (no DB access)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _convert_agent_nodes(nodes: list[dict[str, Any]]) -> tuple[int, int]:
|
||||
"""Convert Agent nodes to LLM nodes in-place. Returns (converted_count, skipped_count)."""
|
||||
converted = 0
|
||||
|
||||
for node in nodes:
|
||||
data = node.get("data", {})
|
||||
if data.get("type") != "agent":
|
||||
continue
|
||||
|
||||
node_id = node.get("id", "?")
|
||||
node["data"] = _agent_data_to_llm_data(data)
|
||||
logger.info("Converted agent node %s to LLM", node_id)
|
||||
converted += 1
|
||||
|
||||
return converted, 0
|
||||
|
||||
|
||||
def _agent_data_to_llm_data(agent_data: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Map an Agent node's data dict to an LLM node's data dict.
|
||||
|
||||
Always returns a valid LLM data dict. If the agent has no model selected,
|
||||
produces an empty LLM node with agent mode (computer_use) enabled.
|
||||
"""
|
||||
params = agent_data.get("agent_parameters") or {}
|
||||
|
||||
model_param = params.get("model", {}) if isinstance(params, dict) else {}
|
||||
model_value = model_param.get("value") if isinstance(model_param, dict) else None
|
||||
|
||||
if isinstance(model_value, dict) and model_value.get("provider") and model_value.get("model"):
|
||||
model_config = {
|
||||
"provider": model_value["provider"],
|
||||
"name": model_value["model"],
|
||||
"mode": model_value.get("mode", "chat"),
|
||||
"completion_params": model_value.get("completion_params", {}),
|
||||
}
|
||||
else:
|
||||
model_config = {"provider": "", "name": "", "mode": "chat", "completion_params": {}}
|
||||
|
||||
tools_param = params.get("tools", {})
|
||||
tools_value = tools_param.get("value", []) if isinstance(tools_param, dict) else []
|
||||
tools_meta, tool_settings = _convert_tools(tools_value if isinstance(tools_value, list) else [])
|
||||
|
||||
instruction_param = params.get("instruction", {})
|
||||
instruction = instruction_param.get("value", "") if isinstance(instruction_param, dict) else ""
|
||||
|
||||
query_param = params.get("query", {})
|
||||
query_value = query_param.get("value", "") if isinstance(query_param, dict) else ""
|
||||
|
||||
has_tools = bool(tools_meta)
|
||||
prompt_template = _build_prompt_template(
|
||||
instruction,
|
||||
query_value,
|
||||
skill=has_tools,
|
||||
tools=tools_value if has_tools else None,
|
||||
)
|
||||
|
||||
max_iter_param = params.get("maximum_iterations", {})
|
||||
max_iterations = max_iter_param.get("value", 100) if isinstance(max_iter_param, dict) else 100
|
||||
|
||||
context_config = _extract_context(params)
|
||||
vision_config = _extract_vision(params)
|
||||
|
||||
llm_data: dict[str, Any] = {
|
||||
"type": "llm",
|
||||
"title": agent_data.get("title", "LLM"),
|
||||
"desc": agent_data.get("desc", ""),
|
||||
"model": model_config,
|
||||
"prompt_template": prompt_template,
|
||||
"prompt_config": {"jinja2_variables": []},
|
||||
"memory": agent_data.get("memory"),
|
||||
"context": context_config,
|
||||
"vision": vision_config,
|
||||
"computer_use": True,
|
||||
"structured_output_switch_on": False,
|
||||
"reasoning_format": "separated",
|
||||
"tools": tools_meta,
|
||||
"tool_settings": tool_settings,
|
||||
"max_iterations": max_iterations,
|
||||
}
|
||||
|
||||
for key in _PASSTHROUGH_KEYS:
|
||||
if key in agent_data:
|
||||
llm_data[key] = agent_data[key]
|
||||
|
||||
return llm_data
|
||||
|
||||
|
||||
def _extract_context(params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract context config from agent_parameters for LLM node format.
|
||||
|
||||
Agent stores context as a variable selector in agent_parameters.context.value,
|
||||
e.g. ["knowledge_retrieval_node_id", "result"]. Maps to LLM ContextConfig.
|
||||
"""
|
||||
if not isinstance(params, dict):
|
||||
return {"enabled": False}
|
||||
|
||||
ctx_param = params.get("context", {})
|
||||
ctx_value = ctx_param.get("value") if isinstance(ctx_param, dict) else None
|
||||
|
||||
if isinstance(ctx_value, list) and len(ctx_value) >= 2 and all(isinstance(s, str) for s in ctx_value):
|
||||
return {"enabled": True, "variable_selector": ctx_value}
|
||||
|
||||
return {"enabled": False}
|
||||
|
||||
|
||||
def _extract_vision(params: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Extract vision config from agent_parameters for LLM node format."""
|
||||
if not isinstance(params, dict):
|
||||
return {"enabled": False}
|
||||
|
||||
vision_param = params.get("vision", {})
|
||||
vision_value = vision_param.get("value") if isinstance(vision_param, dict) else None
|
||||
|
||||
if isinstance(vision_value, dict) and vision_value.get("enabled"):
|
||||
return vision_value
|
||||
|
||||
if isinstance(vision_value, bool) and vision_value:
|
||||
return {"enabled": True}
|
||||
|
||||
return {"enabled": False}
|
||||
|
||||
|
||||
def _enable_computer_use_for_existing_llm_nodes(nodes: list[dict[str, Any]]) -> None:
|
||||
"""Enable computer_use for existing LLM nodes that have tools configured.
|
||||
|
||||
After upgrade, the sandbox runtime requires computer_use=true for tool calling.
|
||||
Existing LLM nodes from classic mode may have tools but computer_use=false.
|
||||
"""
|
||||
for node in nodes:
|
||||
data = node.get("data", {})
|
||||
if data.get("type") != "llm":
|
||||
continue
|
||||
|
||||
tools = data.get("tools", [])
|
||||
if tools and not data.get("computer_use"):
|
||||
data["computer_use"] = True
|
||||
logger.info("Enabled computer_use for LLM node %s with %d tools", node.get("id", "?"), len(tools))
|
||||
|
||||
|
||||
def _convert_tools(
|
||||
tools_input: list[dict[str, Any]],
|
||||
) -> tuple[list[dict[str, Any]], list[dict[str, Any]]]:
|
||||
"""Convert agent tool dicts to (ToolMetadata[], ToolSetting[]).
|
||||
|
||||
Agent tools in graph JSON already use provider_name/settings/parameters —
|
||||
the same field names as LLM ToolMetadata. We pass them through with defaults
|
||||
for any missing fields.
|
||||
"""
|
||||
tools_meta: list[dict[str, Any]] = []
|
||||
tool_settings: list[dict[str, Any]] = []
|
||||
|
||||
for ts in tools_input:
|
||||
if not isinstance(ts, dict):
|
||||
continue
|
||||
|
||||
provider_name = ts.get("provider_name", "")
|
||||
tool_name = ts.get("tool_name", "")
|
||||
tool_type = ts.get("type", "builtin")
|
||||
|
||||
tools_meta.append(
|
||||
{
|
||||
"enabled": True,
|
||||
"type": tool_type,
|
||||
"provider_name": provider_name,
|
||||
"tool_name": tool_name,
|
||||
"plugin_unique_identifier": ts.get("plugin_unique_identifier"),
|
||||
"credential_id": ts.get("credential_id"),
|
||||
"parameters": ts.get("parameters", {}),
|
||||
"settings": ts.get("settings", {}) or ts.get("tool_configuration", {}),
|
||||
"extra": ts.get("extra", {}),
|
||||
}
|
||||
)
|
||||
|
||||
tool_settings.append(
|
||||
{
|
||||
"type": tool_type,
|
||||
"provider": provider_name,
|
||||
"tool_name": tool_name,
|
||||
"enabled": True,
|
||||
}
|
||||
)
|
||||
|
||||
return tools_meta, tool_settings
|
||||
|
||||
|
||||
def _build_prompt_template(
|
||||
instruction: Any,
|
||||
query: Any,
|
||||
*,
|
||||
skill: bool = False,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build LLM prompt_template from Agent instruction and query values.
|
||||
|
||||
When *skill* is True each message gets ``"skill": True`` so the sandbox
|
||||
engine treats the prompt as a skill document.
|
||||
|
||||
When *tools* is provided, tool reference placeholders
|
||||
(``§[tool].[provider].[name].[uuid]§``) are appended to the system
|
||||
message and the corresponding ``ToolReference`` entries are placed in the
|
||||
message's ``metadata.tools`` dict so the skill assembler can resolve them.
|
||||
Tools from the same provider are grouped into a single token list.
|
||||
"""
|
||||
messages: list[dict[str, Any]] = []
|
||||
|
||||
system_text = instruction if isinstance(instruction, str) else (str(instruction) if instruction else "")
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
if tools:
|
||||
tool_refs: dict[str, dict[str, Any]] = {}
|
||||
provider_groups: dict[str, list[str]] = {}
|
||||
for ts in tools:
|
||||
if not isinstance(ts, dict):
|
||||
continue
|
||||
tool_uuid = str(uuid.uuid4())
|
||||
provider_id = ts.get("provider_name", "")
|
||||
tool_name = ts.get("tool_name", "")
|
||||
tool_type = ts.get("type", "builtin")
|
||||
|
||||
token = f"§[tool].[{provider_id}].[{tool_name}].[{tool_uuid}]§"
|
||||
provider_groups.setdefault(provider_id, []).append(token)
|
||||
tool_refs[tool_uuid] = {
|
||||
"type": tool_type,
|
||||
"configuration": {"fields": []},
|
||||
"enabled": True,
|
||||
**({"credential_id": ts.get("credential_id")} if ts.get("credential_id") else {}),
|
||||
}
|
||||
|
||||
if provider_groups:
|
||||
group_texts: list[str] = []
|
||||
for tokens in provider_groups.values():
|
||||
if len(tokens) == 1:
|
||||
group_texts.append(tokens[0])
|
||||
else:
|
||||
group_texts.append("[" + ",".join(tokens) + "]")
|
||||
all_tools_text = " ".join(group_texts)
|
||||
system_text = f"{system_text}\n\n{all_tools_text}" if system_text else all_tools_text
|
||||
metadata = {"tools": tool_refs, "files": []}
|
||||
|
||||
if system_text:
|
||||
msg: dict[str, Any] = {"role": "system", "text": system_text, "skill": skill}
|
||||
if metadata:
|
||||
msg["metadata"] = metadata
|
||||
messages.append(msg)
|
||||
|
||||
if isinstance(query, list) and len(query) >= 2:
|
||||
template_ref = "{{#" + ".".join(str(s) for s in query) + "#}}"
|
||||
messages.append({"role": "user", "text": template_ref, "skill": skill})
|
||||
elif query:
|
||||
messages.append({"role": "user", "text": str(query), "skill": skill})
|
||||
|
||||
if not messages:
|
||||
messages.append({"role": "user", "text": "", "skill": skill})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _rewrite_variable_references(nodes: list[dict[str, Any]], llm_ids: set[str]) -> None:
|
||||
"""Recursively walk all node data and rewrite variable references for LLM nodes.
|
||||
|
||||
Handles two forms:
|
||||
- Structured selectors: [node_id, "text"] → [node_id, "generation", "content"]
|
||||
- Template strings: {{#node_id.text#}} → {{#node_id.generation.content#}}
|
||||
"""
|
||||
if not llm_ids:
|
||||
return
|
||||
|
||||
escaped_ids = [re.escape(nid) for nid in llm_ids]
|
||||
patterns: list[tuple[re.Pattern[str], str]] = []
|
||||
for old_name, new_path in _VAR_REWRITES.items():
|
||||
pattern = re.compile(r"\{\{#(" + "|".join(escaped_ids) + r")\." + re.escape(old_name) + r"#\}\}")
|
||||
replacement = r"{{#\1." + ".".join(new_path) + r"#}}"
|
||||
patterns.append((pattern, replacement))
|
||||
|
||||
for node in nodes:
|
||||
data = node.get("data", {})
|
||||
_walk_and_rewrite(data, llm_ids, patterns)
|
||||
|
||||
|
||||
def _walk_and_rewrite(
|
||||
obj: Any,
|
||||
llm_ids: set[str],
|
||||
template_patterns: list[tuple[re.Pattern[str], str]],
|
||||
) -> Any:
|
||||
"""Recursively rewrite variable references in a nested data structure."""
|
||||
if isinstance(obj, dict):
|
||||
for key, value in obj.items():
|
||||
obj[key] = _walk_and_rewrite(value, llm_ids, template_patterns)
|
||||
return obj
|
||||
|
||||
if isinstance(obj, list):
|
||||
if _is_variable_selector(obj, llm_ids):
|
||||
return _rewrite_selector(obj)
|
||||
for i, item in enumerate(obj):
|
||||
obj[i] = _walk_and_rewrite(item, llm_ids, template_patterns)
|
||||
return obj
|
||||
|
||||
if isinstance(obj, str):
|
||||
for pattern, replacement in template_patterns:
|
||||
obj = pattern.sub(replacement, obj)
|
||||
return obj
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
def _is_variable_selector(lst: list, llm_ids: set[str]) -> bool:
|
||||
"""Check if a list is a structured variable selector pointing to an LLM node output."""
|
||||
if len(lst) < 2:
|
||||
return False
|
||||
if not all(isinstance(s, str) for s in lst):
|
||||
return False
|
||||
return lst[0] in llm_ids and lst[1] in _VAR_REWRITES
|
||||
|
||||
|
||||
def _rewrite_selector(selector: list[str]) -> list[str]:
|
||||
"""Rewrite [node_id, "text"] → [node_id, "generation", "content"]."""
|
||||
old_field = selector[1]
|
||||
new_path = _VAR_REWRITES[old_field]
|
||||
return [selector[0]] + new_path + selector[2:]
|
||||
@@ -10,6 +10,7 @@ from sqlalchemy import select
|
||||
|
||||
from configs import dify_config
|
||||
from constants.model_template import default_app_templates
|
||||
from services.workflow.graph_factory import WorkflowGraphFactory
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.errors.error import LLMBadRequestError, ProviderTokenNotInitError
|
||||
from core.model_manager import ModelManager
|
||||
@@ -52,6 +53,8 @@ class AppService:
|
||||
filters.append(App.mode == AppMode.ADVANCED_CHAT)
|
||||
elif args["mode"] == "agent-chat":
|
||||
filters.append(App.mode == AppMode.AGENT_CHAT)
|
||||
elif args["mode"] == "agent":
|
||||
filters.append(App.mode == AppMode.AGENT)
|
||||
|
||||
if args.get("is_created_by_me", False):
|
||||
filters.append(App.created_by == user_id)
|
||||
@@ -169,6 +172,10 @@ class AppService:
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if app_mode == AppMode.AGENT:
|
||||
model_dict = default_model_config.get("model") if default_model_config else None
|
||||
self._init_agent_workflow(app, account, model_dict)
|
||||
|
||||
app_was_created.send(app, account=account)
|
||||
|
||||
if FeatureService.get_system_features().webapp_auth.enabled:
|
||||
@@ -180,6 +187,34 @@ class AppService:
|
||||
|
||||
return app
|
||||
|
||||
@staticmethod
|
||||
def _init_agent_workflow(app: App, account: Any, model_dict: dict | None) -> None:
|
||||
"""Create the default single-agent-node workflow for a new Agent app."""
|
||||
from services.workflow_service import WorkflowService
|
||||
|
||||
model_config = model_dict or {
|
||||
"provider": "openai",
|
||||
"name": "gpt-4o",
|
||||
"mode": "chat",
|
||||
"completion_params": {},
|
||||
}
|
||||
|
||||
graph = WorkflowGraphFactory.create_single_agent_graph(
|
||||
model_config=model_config,
|
||||
is_chat=True,
|
||||
)
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
workflow_service.sync_draft_workflow(
|
||||
app_model=app,
|
||||
graph=graph,
|
||||
features={},
|
||||
unique_hash=None,
|
||||
account=account,
|
||||
environment_variables=[],
|
||||
conversation_variables=[],
|
||||
)
|
||||
|
||||
def get_app(self, app: App) -> App:
|
||||
"""
|
||||
Get App
|
||||
|
||||
37
api/services/llm_generation_service.py
Normal file
37
api/services/llm_generation_service.py
Normal file
@@ -0,0 +1,37 @@
|
||||
"""
|
||||
LLM Generation Detail Service.
|
||||
|
||||
Provides methods to query and attach generation details to workflow node executions
|
||||
and messages, avoiding N+1 query problems.
|
||||
"""
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.llm_generation_entities import LLMGenerationDetailData
|
||||
from models import LLMGenerationDetail
|
||||
|
||||
|
||||
class LLMGenerationService:
|
||||
"""Service for handling LLM generation details."""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self._session = session
|
||||
|
||||
def get_generation_detail_for_message(self, message_id: str) -> LLMGenerationDetailData | None:
|
||||
"""Query generation detail for a specific message."""
|
||||
stmt = select(LLMGenerationDetail).where(LLMGenerationDetail.message_id == message_id)
|
||||
detail = self._session.scalars(stmt).first()
|
||||
return detail.to_domain_model() if detail else None
|
||||
|
||||
def get_generation_details_for_messages(
|
||||
self,
|
||||
message_ids: list[str],
|
||||
) -> dict[str, LLMGenerationDetailData]:
|
||||
"""Batch query generation details for multiple messages."""
|
||||
if not message_ids:
|
||||
return {}
|
||||
|
||||
stmt = select(LLMGenerationDetail).where(LLMGenerationDetail.message_id.in_(message_ids))
|
||||
details = self._session.scalars(stmt).all()
|
||||
return {detail.message_id: detail.to_domain_model() for detail in details if detail.message_id}
|
||||
153
api/services/storage_ticket_service.py
Normal file
153
api/services/storage_ticket_service.py
Normal file
@@ -0,0 +1,153 @@
|
||||
"""Storage ticket service for generating opaque download/upload URLs.
|
||||
|
||||
This service provides a ticket-based approach for file access. Instead of exposing
|
||||
the real storage key in URLs, it generates a random UUID token and stores the mapping
|
||||
in Redis with a TTL.
|
||||
|
||||
Usage:
|
||||
from services.storage_ticket_service import StorageTicketService
|
||||
|
||||
# Generate a download ticket
|
||||
url = StorageTicketService.create_download_url("path/to/file.txt", expires_in=300)
|
||||
|
||||
# Generate an upload ticket
|
||||
url = StorageTicketService.create_upload_url("path/to/file.txt", expires_in=300, max_bytes=10*1024*1024)
|
||||
|
||||
URL format:
|
||||
{FILES_API_URL}/files/storage-files/{token}
|
||||
|
||||
The token is validated by looking up the Redis key, which contains:
|
||||
- op: "download" or "upload"
|
||||
- storage_key: the real storage path
|
||||
- max_bytes: (upload only) maximum allowed upload size
|
||||
- filename: suggested filename for Content-Disposition header
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
TICKET_KEY_PREFIX = "storage_files"
|
||||
DEFAULT_DOWNLOAD_TTL = 300 # 5 minutes
|
||||
DEFAULT_UPLOAD_TTL = 300 # 5 minutes
|
||||
DEFAULT_MAX_UPLOAD_BYTES = 100 * 1024 * 1024 # 100MB
|
||||
|
||||
|
||||
class StorageTicket(BaseModel):
|
||||
"""Represents a storage access ticket."""
|
||||
|
||||
op: Literal["download", "upload"]
|
||||
storage_key: str
|
||||
max_bytes: int | None = None # upload only
|
||||
filename: str | None = None # suggested filename for download
|
||||
|
||||
|
||||
class StorageTicketService:
|
||||
"""Service for creating and validating storage access tickets."""
|
||||
|
||||
@classmethod
|
||||
def create_download_url(
|
||||
cls,
|
||||
storage_key: str,
|
||||
*,
|
||||
expires_in: int = DEFAULT_DOWNLOAD_TTL,
|
||||
filename: str | None = None,
|
||||
) -> str:
|
||||
"""Create a download ticket and return the URL.
|
||||
|
||||
Args:
|
||||
storage_key: The real storage path
|
||||
expires_in: TTL in seconds (default 300)
|
||||
filename: Suggested filename for Content-Disposition header
|
||||
|
||||
Returns:
|
||||
Full URL with token
|
||||
"""
|
||||
if filename is None:
|
||||
filename = storage_key.rsplit("/", 1)[-1]
|
||||
|
||||
ticket = StorageTicket(op="download", storage_key=storage_key, filename=filename)
|
||||
token = cls._store_ticket(ticket, expires_in)
|
||||
return cls._build_url(token)
|
||||
|
||||
@classmethod
|
||||
def create_upload_url(
|
||||
cls,
|
||||
storage_key: str,
|
||||
*,
|
||||
expires_in: int = DEFAULT_UPLOAD_TTL,
|
||||
max_bytes: int = DEFAULT_MAX_UPLOAD_BYTES,
|
||||
) -> str:
|
||||
"""Create an upload ticket and return the URL.
|
||||
|
||||
Args:
|
||||
storage_key: The real storage path
|
||||
expires_in: TTL in seconds (default 300)
|
||||
max_bytes: Maximum allowed upload size in bytes
|
||||
|
||||
Returns:
|
||||
Full URL with token
|
||||
"""
|
||||
ticket = StorageTicket(op="upload", storage_key=storage_key, max_bytes=max_bytes)
|
||||
token = cls._store_ticket(ticket, expires_in)
|
||||
return cls._build_url(token)
|
||||
|
||||
@classmethod
|
||||
def get_ticket(cls, token: str) -> StorageTicket | None:
|
||||
"""Retrieve a ticket by token.
|
||||
|
||||
Args:
|
||||
token: The UUID token from the URL
|
||||
|
||||
Returns:
|
||||
StorageTicket if found and valid, None otherwise
|
||||
"""
|
||||
key = cls._ticket_key(token)
|
||||
try:
|
||||
data = redis_client.get(key)
|
||||
if data is None:
|
||||
return None
|
||||
if isinstance(data, bytes):
|
||||
data = data.decode("utf-8")
|
||||
return StorageTicket.model_validate_json(data)
|
||||
except Exception:
|
||||
logger.warning("Failed to retrieve storage ticket: %s", token, exc_info=True)
|
||||
return None
|
||||
|
||||
@classmethod
|
||||
def _store_ticket(cls, ticket: StorageTicket, ttl: int) -> str:
|
||||
"""Store a ticket in Redis and return the token."""
|
||||
token = str(uuid4())
|
||||
key = cls._ticket_key(token)
|
||||
value = ticket.model_dump_json()
|
||||
redis_client.setex(key, ttl, value)
|
||||
return token
|
||||
|
||||
@classmethod
|
||||
def _ticket_key(cls, token: str) -> str:
|
||||
"""Generate Redis key for a token."""
|
||||
return f"{TICKET_KEY_PREFIX}:{token}"
|
||||
|
||||
@classmethod
|
||||
def _build_url(cls, token: str) -> str:
|
||||
"""Build the full URL for a token.
|
||||
|
||||
FILES_API_URL is dedicated to sandbox runtime file access (agentbox/e2b/etc.).
|
||||
This endpoint must be routable from the runtime environment.
|
||||
"""
|
||||
base_url = dify_config.FILES_API_URL.strip()
|
||||
if not base_url:
|
||||
raise ValueError(
|
||||
"FILES_API_URL is required for sandbox runtime file access. "
|
||||
"Set FILES_API_URL to a URL reachable by your sandbox runtime. "
|
||||
"For public sandbox environments (e.g. e2b), use a public domain or IP."
|
||||
)
|
||||
base_url = base_url.rstrip("/")
|
||||
return f"{base_url}/files/storage-files/{token}"
|
||||
@@ -152,6 +152,29 @@ class TriggerLogResponse(BaseModel):
|
||||
model_config = ConfigDict(use_enum_values=True)
|
||||
|
||||
|
||||
class NestedNodeParameterSchema(BaseModel):
|
||||
"""Schema for a single parameter in a nested node."""
|
||||
|
||||
name: str
|
||||
type: str = "string"
|
||||
description: str = ""
|
||||
|
||||
|
||||
class NestedNodeGraphRequest(BaseModel):
|
||||
"""Request for generating a nested node graph."""
|
||||
|
||||
parent_node_id: str
|
||||
parameter_key: str
|
||||
context_source: list[str] = Field(default_factory=list)
|
||||
parameter_schema: NestedNodeParameterSchema
|
||||
|
||||
|
||||
class NestedNodeGraphResponse(BaseModel):
|
||||
"""Response containing the generated nested node graph."""
|
||||
|
||||
graph: dict[str, Any]
|
||||
|
||||
|
||||
class WorkflowScheduleCFSPlanEntity(BaseModel):
|
||||
"""
|
||||
CFS plan entity.
|
||||
|
||||
113
api/services/workflow/graph_factory.py
Normal file
113
api/services/workflow/graph_factory.py
Normal file
@@ -0,0 +1,113 @@
|
||||
"""Factory for programmatically building workflow graphs.
|
||||
|
||||
Used by AppService to auto-generate single-node workflow graphs when
|
||||
creating a new Agent app (AppMode.AGENT).
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from core.workflow.nodes.agent_v2.entities import AGENT_V2_NODE_TYPE
|
||||
|
||||
|
||||
class WorkflowGraphFactory:
|
||||
"""Builds workflow graph dicts for special app creation flows."""
|
||||
|
||||
@staticmethod
|
||||
def create_single_agent_graph(
|
||||
model_config: dict[str, Any],
|
||||
is_chat: bool = True,
|
||||
) -> dict[str, Any]:
|
||||
"""Create a minimal start -> agent_v2 -> answer/end graph.
|
||||
|
||||
Args:
|
||||
model_config: Model configuration dict with provider, name, mode, completion_params.
|
||||
is_chat: If True, creates chatflow (with answer node); otherwise workflow (with end node).
|
||||
|
||||
Returns:
|
||||
Graph dict with nodes and edges, ready for WorkflowService.sync_draft_workflow().
|
||||
"""
|
||||
agent_node_data: dict[str, Any] = {
|
||||
"type": AGENT_V2_NODE_TYPE,
|
||||
"title": "Agent",
|
||||
"model": model_config,
|
||||
"prompt_template": [
|
||||
{"role": "system", "text": "You are a helpful assistant."},
|
||||
{"role": "user", "text": "{{#sys.query#}}"},
|
||||
],
|
||||
"tools": [],
|
||||
"max_iterations": 10,
|
||||
"agent_strategy": "auto",
|
||||
"context": {"enabled": False},
|
||||
"vision": {"enabled": False},
|
||||
}
|
||||
|
||||
if is_chat:
|
||||
agent_node_data["memory"] = {"window": {"enabled": True, "size": 50}}
|
||||
|
||||
nodes: list[dict[str, Any]] = [
|
||||
{
|
||||
"id": "start",
|
||||
"type": "custom",
|
||||
"data": {"type": "start", "title": "Start", "variables": []},
|
||||
"position": {"x": 80, "y": 282},
|
||||
},
|
||||
{
|
||||
"id": "agent",
|
||||
"type": "custom",
|
||||
"data": agent_node_data,
|
||||
"position": {"x": 400, "y": 282},
|
||||
},
|
||||
]
|
||||
|
||||
if is_chat:
|
||||
nodes.append(
|
||||
{
|
||||
"id": "answer",
|
||||
"type": "custom",
|
||||
"data": {
|
||||
"type": "answer",
|
||||
"title": "Answer",
|
||||
"answer": "{{#agent.text#}}",
|
||||
},
|
||||
"position": {"x": 720, "y": 282},
|
||||
}
|
||||
)
|
||||
end_node_id = "answer"
|
||||
else:
|
||||
nodes.append(
|
||||
{
|
||||
"id": "end",
|
||||
"type": "custom",
|
||||
"data": {
|
||||
"type": "end",
|
||||
"title": "End",
|
||||
"outputs": [
|
||||
{
|
||||
"value_selector": ["agent", "text"],
|
||||
"variable": "result",
|
||||
}
|
||||
],
|
||||
},
|
||||
"position": {"x": 720, "y": 282},
|
||||
}
|
||||
)
|
||||
end_node_id = "end"
|
||||
|
||||
edges: list[dict[str, str]] = [
|
||||
{
|
||||
"id": "start-agent",
|
||||
"source": "start",
|
||||
"target": "agent",
|
||||
"sourceHandle": "source",
|
||||
"targetHandle": "target",
|
||||
},
|
||||
{
|
||||
"id": f"agent-{end_node_id}",
|
||||
"source": "agent",
|
||||
"target": end_node_id,
|
||||
"sourceHandle": "source",
|
||||
"targetHandle": "target",
|
||||
},
|
||||
]
|
||||
|
||||
return {"nodes": nodes, "edges": edges}
|
||||
157
api/services/workflow/nested_node_graph_service.py
Normal file
157
api/services/workflow/nested_node_graph_service.py
Normal file
@@ -0,0 +1,157 @@
|
||||
"""
|
||||
Service for generating Nested Node LLM graph structures.
|
||||
|
||||
This service creates graph structures containing LLM nodes configured for
|
||||
extracting values from list[PromptMessage] variables.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from graphon.enums import BuiltinNodeTypes
|
||||
from graphon.model_runtime.entities import LLMMode
|
||||
from services.model_provider_service import ModelProviderService
|
||||
from services.workflow.entities import NestedNodeGraphRequest, NestedNodeGraphResponse, NestedNodeParameterSchema
|
||||
|
||||
|
||||
class NestedNodeGraphService:
|
||||
"""Service for generating Nested Node LLM graph structures."""
|
||||
|
||||
def __init__(self, session: Session):
|
||||
self._session = session
|
||||
|
||||
def generate_nested_node_id(self, node_id: str, parameter_name: str) -> str:
|
||||
"""Generate nested node ID following the naming convention.
|
||||
|
||||
Format: {node_id}_ext_{parameter_name}
|
||||
"""
|
||||
return f"{node_id}_ext_{parameter_name}"
|
||||
|
||||
def generate_nested_node_graph(self, tenant_id: str, request: NestedNodeGraphRequest) -> NestedNodeGraphResponse:
|
||||
"""Generate a complete graph structure containing a Nested Node LLM node.
|
||||
|
||||
Args:
|
||||
tenant_id: The tenant ID for fetching default model config
|
||||
request: The nested node graph generation request
|
||||
|
||||
Returns:
|
||||
Complete graph structure with nodes, edges, and viewport
|
||||
"""
|
||||
node_id = self.generate_nested_node_id(request.parent_node_id, request.parameter_key)
|
||||
model_config = self._get_default_model_config(tenant_id)
|
||||
node = self._build_nested_node_llm_node(
|
||||
node_id=node_id,
|
||||
parent_node_id=request.parent_node_id,
|
||||
context_source=request.context_source,
|
||||
parameter_schema=request.parameter_schema,
|
||||
model_config=model_config,
|
||||
)
|
||||
|
||||
graph = {
|
||||
"nodes": [node],
|
||||
"edges": [],
|
||||
"viewport": {},
|
||||
}
|
||||
|
||||
return NestedNodeGraphResponse(graph=graph)
|
||||
|
||||
def _get_default_model_config(self, tenant_id: str) -> dict[str, Any]:
|
||||
"""Get the default LLM model configuration for the tenant."""
|
||||
model_provider_service = ModelProviderService()
|
||||
default_model = model_provider_service.get_default_model_of_model_type(
|
||||
tenant_id=tenant_id,
|
||||
model_type="llm",
|
||||
)
|
||||
|
||||
if default_model:
|
||||
return {
|
||||
"provider": default_model.provider.provider,
|
||||
"name": default_model.model,
|
||||
"mode": LLMMode.CHAT.value,
|
||||
"completion_params": {},
|
||||
}
|
||||
|
||||
# Fallback to empty config if no default model is configured
|
||||
return {
|
||||
"provider": "",
|
||||
"name": "",
|
||||
"mode": LLMMode.CHAT.value,
|
||||
"completion_params": {},
|
||||
}
|
||||
|
||||
def _build_nested_node_llm_node(
|
||||
self,
|
||||
*,
|
||||
node_id: str,
|
||||
parent_node_id: str,
|
||||
context_source: list[str],
|
||||
parameter_schema: NestedNodeParameterSchema,
|
||||
model_config: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
"""Build the Nested Node LLM node structure.
|
||||
|
||||
The node uses:
|
||||
- $context in prompt_template to reference the PromptMessage list
|
||||
- structured_output for extracting the specific parameter
|
||||
- parent_node_id to associate with the parent node
|
||||
"""
|
||||
prompt_template = [
|
||||
{
|
||||
"role": "system",
|
||||
"text": "Extract the required parameter value from the conversation context above.",
|
||||
"skill": False,
|
||||
},
|
||||
{"$context": context_source},
|
||||
{"role": "user", "text": "", "skill": False},
|
||||
]
|
||||
|
||||
structured_output = {
|
||||
"schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
parameter_schema.name: {
|
||||
"type": parameter_schema.type,
|
||||
"description": parameter_schema.description,
|
||||
}
|
||||
},
|
||||
"required": [parameter_schema.name],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
}
|
||||
|
||||
return {
|
||||
"id": node_id,
|
||||
"position": {"x": 0, "y": 0},
|
||||
"data": {
|
||||
"type": BuiltinNodeTypes.LLM,
|
||||
# BaseNodeData fields
|
||||
"title": f"NestedNode: {parameter_schema.name}",
|
||||
"desc": f"Extract {parameter_schema.name} from conversation context",
|
||||
"version": "1",
|
||||
"error_strategy": None,
|
||||
"default_value": None,
|
||||
"retry_config": {"max_retries": 0},
|
||||
"parent_node_id": parent_node_id,
|
||||
# LLMNodeData fields
|
||||
"model": model_config,
|
||||
"prompt_template": prompt_template,
|
||||
"prompt_config": {"jinja2_variables": []},
|
||||
"memory": None,
|
||||
"context": {
|
||||
"enabled": False,
|
||||
"variable_selector": None,
|
||||
},
|
||||
"vision": {
|
||||
"enabled": False,
|
||||
"configs": {
|
||||
"variable_selector": ["sys", "files"],
|
||||
"detail": "high",
|
||||
},
|
||||
},
|
||||
"structured_output_enabled": True,
|
||||
"structured_output": structured_output,
|
||||
"computer_use": False,
|
||||
"tool_settings": [],
|
||||
},
|
||||
}
|
||||
328
api/services/workflow/virtual_workflow.py
Normal file
328
api/services/workflow/virtual_workflow.py
Normal file
@@ -0,0 +1,328 @@
|
||||
"""Virtual Workflow Synthesizer for transparent old-app upgrade.
|
||||
|
||||
Converts an old App's AppModelConfig into an in-memory Workflow object
|
||||
with a single agent-v2 node, without persisting to the database.
|
||||
This allows legacy apps (chat/completion/agent-chat) to run through
|
||||
the Agent V2 workflow engine transparently.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from typing import Any
|
||||
from uuid import uuid4
|
||||
|
||||
from core.workflow.nodes.agent_v2.entities import AGENT_V2_NODE_TYPE
|
||||
from models.model import App, AppMode, AppModelConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class VirtualWorkflowSynthesizer:
|
||||
"""Synthesize in-memory Workflow from legacy AppModelConfig."""
|
||||
|
||||
@staticmethod
|
||||
def synthesize(app: App) -> Any:
|
||||
"""Convert old app config to a virtual Workflow object.
|
||||
|
||||
Returns a Workflow-like object (not persisted to DB) that can be
|
||||
passed to AdvancedChatAppGenerator.generate().
|
||||
"""
|
||||
from models.workflow import Workflow, WorkflowType
|
||||
|
||||
config = app.app_model_config
|
||||
if not config:
|
||||
raise ValueError("App has no model config")
|
||||
|
||||
model_dict = _extract_model_config(config)
|
||||
prompt_template = _build_prompt_template(config, app.mode)
|
||||
tools = _extract_tools(config)
|
||||
agent_strategy = _extract_strategy(config)
|
||||
max_iterations = _extract_max_iterations(config)
|
||||
context = _build_context_config(config)
|
||||
vision = _build_vision_config(config)
|
||||
is_chat = app.mode != AppMode.COMPLETION
|
||||
|
||||
agent_node_data: dict[str, Any] = {
|
||||
"type": AGENT_V2_NODE_TYPE,
|
||||
"title": "Agent",
|
||||
"model": model_dict,
|
||||
"prompt_template": prompt_template,
|
||||
"tools": tools,
|
||||
"max_iterations": max_iterations,
|
||||
"agent_strategy": agent_strategy,
|
||||
"context": context,
|
||||
"vision": vision,
|
||||
}
|
||||
if is_chat:
|
||||
agent_node_data["memory"] = {"window": {"enabled": True, "size": 50}}
|
||||
|
||||
graph = _build_graph(agent_node_data, is_chat)
|
||||
|
||||
workflow = Workflow()
|
||||
workflow.id = str(uuid4())
|
||||
workflow.tenant_id = app.tenant_id
|
||||
workflow.app_id = app.id
|
||||
workflow.type = WorkflowType.CHAT if is_chat else WorkflowType.WORKFLOW
|
||||
workflow.version = "virtual"
|
||||
workflow.graph = json.dumps(graph)
|
||||
workflow.features = json.dumps(_build_features(config))
|
||||
workflow.created_by = app.created_by
|
||||
workflow.updated_by = app.updated_by
|
||||
|
||||
return workflow
|
||||
|
||||
@staticmethod
|
||||
def ensure_workflow(app: App) -> Any:
|
||||
"""Ensure the old app has a workflow, creating one if needed.
|
||||
|
||||
On first call for a legacy app, synthesizes a workflow from its
|
||||
AppModelConfig and persists it as a draft. On subsequent calls,
|
||||
returns the existing draft. This is a one-time lazy upgrade:
|
||||
the app gets a real workflow that can be edited in the workflow editor.
|
||||
|
||||
The app's workflow_id is NOT updated (preserving its legacy state),
|
||||
but the workflow is findable via app_id + version="draft".
|
||||
"""
|
||||
from models.workflow import Workflow
|
||||
|
||||
from extensions.ext_database import db
|
||||
|
||||
existing = db.session.query(Workflow).filter_by(
|
||||
app_id=app.id, version="draft"
|
||||
).first()
|
||||
if existing:
|
||||
return existing
|
||||
|
||||
workflow = VirtualWorkflowSynthesizer.synthesize(app)
|
||||
workflow.version = "draft"
|
||||
|
||||
db.session.add(workflow)
|
||||
db.session.commit()
|
||||
logger.info("Created draft workflow %s for legacy app %s", workflow.id, app.id)
|
||||
return workflow
|
||||
|
||||
|
||||
def _extract_model_config(config: AppModelConfig) -> dict[str, Any]:
|
||||
if config.model:
|
||||
try:
|
||||
return json.loads(config.model)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
return {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}
|
||||
|
||||
|
||||
def _build_prompt_template(config: AppModelConfig, mode: str) -> list[dict[str, str]]:
|
||||
messages: list[dict[str, str]] = []
|
||||
|
||||
if config.prompt_type and config.prompt_type.value == "advanced":
|
||||
if config.chat_prompt_config:
|
||||
try:
|
||||
chat_config = json.loads(config.chat_prompt_config)
|
||||
if isinstance(chat_config, dict) and "prompt" in chat_config:
|
||||
prompts = chat_config["prompt"]
|
||||
if isinstance(prompts, list):
|
||||
for p in prompts:
|
||||
if isinstance(p, dict) and "role" in p and "text" in p:
|
||||
messages.append({"role": p["role"], "text": p["text"]})
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
if not messages:
|
||||
pre_prompt = config.pre_prompt or ""
|
||||
if pre_prompt:
|
||||
messages.append({"role": "system", "text": pre_prompt})
|
||||
|
||||
if mode == AppMode.COMPLETION:
|
||||
messages.append({"role": "user", "text": "{{#sys.query#}}"})
|
||||
else:
|
||||
messages.append({"role": "user", "text": "{{#sys.query#}}"})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def _extract_tools(config: AppModelConfig) -> list[dict[str, Any]]:
|
||||
if not config.agent_mode:
|
||||
return []
|
||||
try:
|
||||
agent_mode = json.loads(config.agent_mode) if isinstance(config.agent_mode, str) else config.agent_mode
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return []
|
||||
|
||||
if not isinstance(agent_mode, dict) or not agent_mode.get("enabled"):
|
||||
return []
|
||||
|
||||
tools_config = agent_mode.get("tools", [])
|
||||
result: list[dict[str, Any]] = []
|
||||
|
||||
for tool in tools_config:
|
||||
if not isinstance(tool, dict):
|
||||
continue
|
||||
if not tool.get("enabled", True):
|
||||
continue
|
||||
|
||||
provider_type = tool.get("provider_type", "builtin")
|
||||
provider_id = tool.get("provider_id", "")
|
||||
tool_name = tool.get("tool_name", "")
|
||||
|
||||
if not tool_name:
|
||||
continue
|
||||
|
||||
result.append({
|
||||
"enabled": True,
|
||||
"type": provider_type,
|
||||
"provider_name": provider_id,
|
||||
"tool_name": tool_name,
|
||||
"parameters": tool.get("tool_parameters", {}),
|
||||
"settings": {},
|
||||
})
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def _extract_strategy(config: AppModelConfig) -> str:
|
||||
if not config.agent_mode:
|
||||
return "auto"
|
||||
try:
|
||||
agent_mode = json.loads(config.agent_mode) if isinstance(config.agent_mode, str) else config.agent_mode
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return "auto"
|
||||
|
||||
strategy = agent_mode.get("strategy", "")
|
||||
mapping = {
|
||||
"function_call": "function-calling",
|
||||
"react": "chain-of-thought",
|
||||
}
|
||||
return mapping.get(strategy, "auto")
|
||||
|
||||
|
||||
def _extract_max_iterations(config: AppModelConfig) -> int:
|
||||
if not config.agent_mode:
|
||||
return 10
|
||||
try:
|
||||
agent_mode = json.loads(config.agent_mode) if isinstance(config.agent_mode, str) else config.agent_mode
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
return 10
|
||||
return agent_mode.get("max_iteration", 10)
|
||||
|
||||
|
||||
def _build_context_config(config: AppModelConfig) -> dict[str, Any]:
|
||||
if config.dataset_configs:
|
||||
try:
|
||||
dc = json.loads(config.dataset_configs) if isinstance(config.dataset_configs, str) else config.dataset_configs
|
||||
if isinstance(dc, dict) and dc.get("datasets", {}).get("datasets", []):
|
||||
return {"enabled": True}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
return {"enabled": False}
|
||||
|
||||
|
||||
def _build_vision_config(config: AppModelConfig) -> dict[str, Any]:
|
||||
if config.file_upload:
|
||||
try:
|
||||
fu = json.loads(config.file_upload) if isinstance(config.file_upload, str) else config.file_upload
|
||||
if isinstance(fu, dict) and fu.get("image", {}).get("enabled"):
|
||||
return {"enabled": True}
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
return {"enabled": False}
|
||||
|
||||
|
||||
def _build_graph(agent_data: dict[str, Any], is_chat: bool) -> dict[str, Any]:
|
||||
nodes: list[dict[str, Any]] = [
|
||||
{
|
||||
"id": "start",
|
||||
"type": "custom",
|
||||
"data": {"type": "start", "title": "Start", "variables": []},
|
||||
"position": {"x": 80, "y": 282},
|
||||
},
|
||||
{
|
||||
"id": "agent",
|
||||
"type": "custom",
|
||||
"data": agent_data,
|
||||
"position": {"x": 400, "y": 282},
|
||||
},
|
||||
]
|
||||
|
||||
if is_chat:
|
||||
nodes.append({
|
||||
"id": "answer",
|
||||
"type": "custom",
|
||||
"data": {"type": "answer", "title": "Answer", "answer": "{{#agent.text#}}"},
|
||||
"position": {"x": 720, "y": 282},
|
||||
})
|
||||
end_id = "answer"
|
||||
else:
|
||||
nodes.append({
|
||||
"id": "end",
|
||||
"type": "custom",
|
||||
"data": {"type": "end", "title": "End", "outputs": [{"value_selector": ["agent", "text"], "variable": "result"}]},
|
||||
"position": {"x": 720, "y": 282},
|
||||
})
|
||||
end_id = "end"
|
||||
|
||||
edges = [
|
||||
{"id": "start-agent", "source": "start", "target": "agent", "sourceHandle": "source", "targetHandle": "target"},
|
||||
{"id": f"agent-{end_id}", "source": "agent", "target": end_id, "sourceHandle": "source", "targetHandle": "target"},
|
||||
]
|
||||
|
||||
return {"nodes": nodes, "edges": edges}
|
||||
|
||||
|
||||
def _build_features(config: AppModelConfig) -> dict[str, Any]:
|
||||
"""Extract app-level features from AppModelConfig for the synthesized workflow."""
|
||||
features: dict[str, Any] = {}
|
||||
|
||||
if config.opening_statement:
|
||||
features["opening_statement"] = config.opening_statement
|
||||
|
||||
if config.suggested_questions:
|
||||
try:
|
||||
sq = json.loads(config.suggested_questions) if isinstance(config.suggested_questions, str) else config.suggested_questions
|
||||
if sq:
|
||||
features["suggested_questions"] = sq
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
if config.sensitive_word_avoidance:
|
||||
try:
|
||||
swa = json.loads(config.sensitive_word_avoidance) if isinstance(config.sensitive_word_avoidance, str) else config.sensitive_word_avoidance
|
||||
if swa and swa.get("enabled"):
|
||||
features["sensitive_word_avoidance"] = swa
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
if config.more_like_this:
|
||||
try:
|
||||
mlt = json.loads(config.more_like_this) if isinstance(config.more_like_this, str) else config.more_like_this
|
||||
if mlt and mlt.get("enabled"):
|
||||
features["more_like_this"] = mlt
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
if config.speech_to_text:
|
||||
try:
|
||||
stt = json.loads(config.speech_to_text) if isinstance(config.speech_to_text, str) else config.speech_to_text
|
||||
if stt and stt.get("enabled"):
|
||||
features["speech_to_text"] = stt
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
if config.text_to_speech:
|
||||
try:
|
||||
tts = json.loads(config.text_to_speech) if isinstance(config.text_to_speech, str) else config.text_to_speech
|
||||
if tts and tts.get("enabled"):
|
||||
features["text_to_speech"] = tts
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
if config.retriever_resource:
|
||||
try:
|
||||
rr = json.loads(config.retriever_resource) if isinstance(config.retriever_resource, str) else config.retriever_resource
|
||||
if rr and rr.get("enabled"):
|
||||
features["retriever_resource"] = rr
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
pass
|
||||
|
||||
return features
|
||||
391
api/services/workflow_collaboration_service.py
Normal file
391
api/services/workflow_collaboration_service.py
Normal file
@@ -0,0 +1,391 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
|
||||
from models.account import Account
|
||||
from repositories.workflow_collaboration_repository import WorkflowCollaborationRepository, WorkflowSessionInfo
|
||||
|
||||
|
||||
class WorkflowCollaborationService:
|
||||
def __init__(self, repository: WorkflowCollaborationRepository, socketio) -> None:
|
||||
self._repository = repository
|
||||
self._socketio = socketio
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"{self.__class__.__name__}(repository={self._repository})"
|
||||
|
||||
def save_session(self, sid: str, user: Account) -> None:
|
||||
self._socketio.save_session(
|
||||
sid,
|
||||
{
|
||||
"user_id": user.id,
|
||||
"username": user.name,
|
||||
"avatar": user.avatar,
|
||||
},
|
||||
)
|
||||
|
||||
def register_session(self, workflow_id: str, sid: str) -> tuple[str, bool] | None:
|
||||
session = self._socketio.get_session(sid)
|
||||
user_id = session.get("user_id")
|
||||
if not user_id:
|
||||
return None
|
||||
|
||||
session_info: WorkflowSessionInfo = {
|
||||
"user_id": str(user_id),
|
||||
"username": str(session.get("username", "Unknown")),
|
||||
"avatar": session.get("avatar"),
|
||||
"sid": sid,
|
||||
"connected_at": int(time.time()),
|
||||
"graph_active": True,
|
||||
"active_skill_file_id": None,
|
||||
}
|
||||
|
||||
self._repository.set_session_info(workflow_id, session_info)
|
||||
|
||||
leader_sid = self.get_or_set_leader(workflow_id, sid)
|
||||
is_leader = leader_sid == sid if leader_sid else False
|
||||
|
||||
self._socketio.enter_room(sid, workflow_id)
|
||||
self.broadcast_online_users(workflow_id)
|
||||
|
||||
self._socketio.emit("status", {"isLeader": is_leader}, room=sid)
|
||||
|
||||
return str(user_id), is_leader
|
||||
|
||||
def disconnect_session(self, sid: str) -> None:
|
||||
mapping = self._repository.get_sid_mapping(sid)
|
||||
if not mapping:
|
||||
return
|
||||
|
||||
workflow_id = mapping["workflow_id"]
|
||||
active_skill_file_id = self._repository.get_active_skill_file_id(workflow_id, sid)
|
||||
self._repository.delete_session(workflow_id, sid)
|
||||
|
||||
self.handle_leader_disconnect(workflow_id, sid)
|
||||
if active_skill_file_id:
|
||||
self.handle_skill_leader_disconnect(workflow_id, active_skill_file_id, sid)
|
||||
self.broadcast_online_users(workflow_id)
|
||||
|
||||
def relay_collaboration_event(self, sid: str, data: Mapping[str, object]) -> tuple[dict[str, str], int]:
|
||||
mapping = self._repository.get_sid_mapping(sid)
|
||||
if not mapping:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
workflow_id = mapping["workflow_id"]
|
||||
user_id = mapping["user_id"]
|
||||
self.refresh_session_state(workflow_id, sid)
|
||||
|
||||
event_type = data.get("type")
|
||||
event_data = data.get("data")
|
||||
timestamp = data.get("timestamp", int(time.time()))
|
||||
|
||||
if not event_type:
|
||||
return {"msg": "invalid event type"}, 400
|
||||
|
||||
if event_type == "graph_view_active":
|
||||
is_active = False
|
||||
if isinstance(event_data, dict):
|
||||
is_active = bool(event_data.get("active") or False)
|
||||
self._repository.set_graph_active(workflow_id, sid, is_active)
|
||||
self.refresh_session_state(workflow_id, sid)
|
||||
self.broadcast_online_users(workflow_id)
|
||||
return {"msg": "graph_view_active_updated"}, 200
|
||||
|
||||
if event_type == "skill_file_active":
|
||||
file_id = None
|
||||
is_active = False
|
||||
if isinstance(event_data, dict):
|
||||
file_id = event_data.get("file_id")
|
||||
is_active = bool(event_data.get("active") or False)
|
||||
|
||||
if not file_id or not isinstance(file_id, str):
|
||||
return {"msg": "invalid skill_file_active payload"}, 400
|
||||
|
||||
previous_file_id = self._repository.get_active_skill_file_id(workflow_id, sid)
|
||||
next_file_id = file_id if is_active else None
|
||||
|
||||
if previous_file_id == next_file_id:
|
||||
self.refresh_session_state(workflow_id, sid)
|
||||
return {"msg": "skill_file_active_unchanged"}, 200
|
||||
|
||||
self._repository.set_active_skill_file(workflow_id, sid, next_file_id)
|
||||
self.refresh_session_state(workflow_id, sid)
|
||||
|
||||
if previous_file_id:
|
||||
self._ensure_skill_leader(workflow_id, previous_file_id)
|
||||
if next_file_id:
|
||||
self._ensure_skill_leader(workflow_id, next_file_id, preferred_sid=sid)
|
||||
|
||||
return {"msg": "skill_file_active_updated"}, 200
|
||||
|
||||
if event_type == "sync_request":
|
||||
leader_sid = self._repository.get_current_leader(workflow_id)
|
||||
if leader_sid and (
|
||||
self.is_session_active(workflow_id, leader_sid)
|
||||
and self._repository.is_graph_active(workflow_id, leader_sid)
|
||||
):
|
||||
target_sid = leader_sid
|
||||
else:
|
||||
if leader_sid:
|
||||
self._repository.delete_leader(workflow_id)
|
||||
target_sid = self._select_graph_leader(workflow_id, preferred_sid=sid)
|
||||
if target_sid:
|
||||
self._repository.set_leader(workflow_id, target_sid)
|
||||
self.broadcast_leader_change(workflow_id, target_sid)
|
||||
if not target_sid:
|
||||
return {"msg": "no_active_leader"}, 200
|
||||
|
||||
self._socketio.emit(
|
||||
"collaboration_update",
|
||||
{"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp},
|
||||
room=target_sid,
|
||||
)
|
||||
return {"msg": "sync_request_forwarded"}, 200
|
||||
|
||||
self._socketio.emit(
|
||||
"collaboration_update",
|
||||
{"type": event_type, "userId": user_id, "data": event_data, "timestamp": timestamp},
|
||||
room=workflow_id,
|
||||
skip_sid=sid,
|
||||
)
|
||||
|
||||
return {"msg": "event_broadcasted"}, 200
|
||||
|
||||
def relay_graph_event(self, sid: str, data: object) -> tuple[dict[str, str], int]:
|
||||
mapping = self._repository.get_sid_mapping(sid)
|
||||
if not mapping:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
workflow_id = mapping["workflow_id"]
|
||||
self.refresh_session_state(workflow_id, sid)
|
||||
|
||||
self._socketio.emit("graph_update", data, room=workflow_id, skip_sid=sid)
|
||||
|
||||
return {"msg": "graph_update_broadcasted"}, 200
|
||||
|
||||
def relay_skill_event(self, sid: str, data: object) -> tuple[dict[str, str], int]:
|
||||
mapping = self._repository.get_sid_mapping(sid)
|
||||
if not mapping:
|
||||
return {"msg": "unauthorized"}, 401
|
||||
|
||||
workflow_id = mapping["workflow_id"]
|
||||
self.refresh_session_state(workflow_id, sid)
|
||||
|
||||
self._socketio.emit("skill_update", data, room=workflow_id, skip_sid=sid)
|
||||
|
||||
return {"msg": "skill_update_broadcasted"}, 200
|
||||
|
||||
def get_or_set_leader(self, workflow_id: str, sid: str) -> str | None:
|
||||
current_leader = self._repository.get_current_leader(workflow_id)
|
||||
|
||||
if current_leader:
|
||||
if self.is_session_active(workflow_id, current_leader) and self._repository.is_graph_active(
|
||||
workflow_id, current_leader
|
||||
):
|
||||
return current_leader
|
||||
self._repository.delete_session(workflow_id, current_leader)
|
||||
self._repository.delete_leader(workflow_id)
|
||||
|
||||
new_leader_sid = self._select_graph_leader(workflow_id, preferred_sid=sid)
|
||||
if not new_leader_sid:
|
||||
return None
|
||||
|
||||
was_set = self._repository.set_leader_if_absent(workflow_id, new_leader_sid)
|
||||
|
||||
if was_set:
|
||||
if current_leader:
|
||||
self.broadcast_leader_change(workflow_id, new_leader_sid)
|
||||
return new_leader_sid
|
||||
|
||||
current_leader = self._repository.get_current_leader(workflow_id)
|
||||
if current_leader:
|
||||
return current_leader
|
||||
|
||||
return new_leader_sid
|
||||
|
||||
def handle_leader_disconnect(self, workflow_id: str, disconnected_sid: str) -> None:
|
||||
current_leader = self._repository.get_current_leader(workflow_id)
|
||||
if not current_leader:
|
||||
return
|
||||
|
||||
if current_leader != disconnected_sid:
|
||||
return
|
||||
|
||||
new_leader_sid = self._select_graph_leader(workflow_id)
|
||||
if new_leader_sid:
|
||||
self._repository.set_leader(workflow_id, new_leader_sid)
|
||||
self.broadcast_leader_change(workflow_id, new_leader_sid)
|
||||
else:
|
||||
self._repository.delete_leader(workflow_id)
|
||||
self.broadcast_leader_change(workflow_id, None)
|
||||
|
||||
def handle_skill_leader_disconnect(self, workflow_id: str, file_id: str, disconnected_sid: str) -> None:
|
||||
current_leader = self._repository.get_skill_leader(workflow_id, file_id)
|
||||
if not current_leader:
|
||||
return
|
||||
|
||||
if current_leader != disconnected_sid:
|
||||
return
|
||||
|
||||
new_leader_sid = self._select_skill_leader(workflow_id, file_id)
|
||||
if new_leader_sid:
|
||||
self._repository.set_skill_leader(workflow_id, file_id, new_leader_sid)
|
||||
self.broadcast_skill_leader_change(workflow_id, file_id, new_leader_sid)
|
||||
else:
|
||||
self._repository.delete_skill_leader(workflow_id, file_id)
|
||||
self.broadcast_skill_leader_change(workflow_id, file_id, None)
|
||||
|
||||
def broadcast_leader_change(self, workflow_id: str, new_leader_sid: str | None) -> None:
|
||||
for sid in self._repository.get_session_sids(workflow_id):
|
||||
try:
|
||||
is_leader = new_leader_sid is not None and sid == new_leader_sid
|
||||
self._socketio.emit("status", {"isLeader": is_leader}, room=sid)
|
||||
except Exception:
|
||||
logging.exception("Failed to emit leader status to session %s", sid)
|
||||
|
||||
def broadcast_skill_leader_change(self, workflow_id: str, file_id: str, new_leader_sid: str | None) -> None:
|
||||
for sid in self._repository.get_session_sids(workflow_id):
|
||||
try:
|
||||
is_leader = new_leader_sid is not None and sid == new_leader_sid
|
||||
self._socketio.emit("skill_status", {"file_id": file_id, "isLeader": is_leader}, room=sid)
|
||||
except Exception:
|
||||
logging.exception("Failed to emit skill leader status to session %s", sid)
|
||||
|
||||
def get_current_leader(self, workflow_id: str) -> str | None:
|
||||
return self._repository.get_current_leader(workflow_id)
|
||||
|
||||
def _prune_inactive_sessions(self, workflow_id: str) -> list[WorkflowSessionInfo]:
|
||||
"""Remove inactive sessions from storage and return active sessions only."""
|
||||
sessions = self._repository.list_sessions(workflow_id)
|
||||
if not sessions:
|
||||
return []
|
||||
|
||||
active_sessions: list[WorkflowSessionInfo] = []
|
||||
stale_sids: list[str] = []
|
||||
for session in sessions:
|
||||
sid = session["sid"]
|
||||
if self.is_session_active(workflow_id, sid):
|
||||
active_sessions.append(session)
|
||||
else:
|
||||
stale_sids.append(sid)
|
||||
|
||||
for sid in stale_sids:
|
||||
self._repository.delete_session(workflow_id, sid)
|
||||
|
||||
return active_sessions
|
||||
|
||||
def broadcast_online_users(self, workflow_id: str) -> None:
|
||||
users = self._prune_inactive_sessions(workflow_id)
|
||||
users.sort(key=lambda x: x.get("connected_at") or 0)
|
||||
|
||||
leader_sid = self.get_current_leader(workflow_id)
|
||||
previous_leader = leader_sid
|
||||
active_sids = {user["sid"] for user in users}
|
||||
if leader_sid and leader_sid not in active_sids:
|
||||
self._repository.delete_leader(workflow_id)
|
||||
leader_sid = None
|
||||
|
||||
if not leader_sid and users:
|
||||
leader_sid = self._select_graph_leader(workflow_id)
|
||||
if leader_sid:
|
||||
self._repository.set_leader(workflow_id, leader_sid)
|
||||
|
||||
if leader_sid != previous_leader:
|
||||
self.broadcast_leader_change(workflow_id, leader_sid)
|
||||
|
||||
self._socketio.emit(
|
||||
"online_users",
|
||||
{"workflow_id": workflow_id, "users": users, "leader": leader_sid},
|
||||
room=workflow_id,
|
||||
)
|
||||
|
||||
def refresh_session_state(self, workflow_id: str, sid: str) -> None:
|
||||
self._repository.refresh_session_state(workflow_id, sid)
|
||||
self._ensure_leader(workflow_id, sid)
|
||||
active_skill_file_id = self._repository.get_active_skill_file_id(workflow_id, sid)
|
||||
if active_skill_file_id:
|
||||
self._ensure_skill_leader(workflow_id, active_skill_file_id, preferred_sid=sid)
|
||||
|
||||
def _ensure_leader(self, workflow_id: str, sid: str) -> None:
|
||||
current_leader = self._repository.get_current_leader(workflow_id)
|
||||
if (
|
||||
current_leader
|
||||
and self.is_session_active(workflow_id, current_leader)
|
||||
and self._repository.is_graph_active(workflow_id, current_leader)
|
||||
):
|
||||
self._repository.expire_leader(workflow_id)
|
||||
return
|
||||
|
||||
if current_leader:
|
||||
self._repository.delete_leader(workflow_id)
|
||||
|
||||
new_leader_sid = self._select_graph_leader(workflow_id, preferred_sid=sid)
|
||||
if not new_leader_sid:
|
||||
self.broadcast_leader_change(workflow_id, None)
|
||||
return
|
||||
|
||||
self._repository.set_leader(workflow_id, new_leader_sid)
|
||||
self.broadcast_leader_change(workflow_id, new_leader_sid)
|
||||
|
||||
def _ensure_skill_leader(self, workflow_id: str, file_id: str, preferred_sid: str | None = None) -> None:
|
||||
current_leader = self._repository.get_skill_leader(workflow_id, file_id)
|
||||
active_sids = self._repository.get_active_skill_session_sids(workflow_id, file_id)
|
||||
if current_leader and self.is_session_active(workflow_id, current_leader):
|
||||
if current_leader in active_sids or not active_sids:
|
||||
self._repository.expire_skill_leader(workflow_id, file_id)
|
||||
return
|
||||
|
||||
if current_leader:
|
||||
self._repository.delete_skill_leader(workflow_id, file_id)
|
||||
|
||||
new_leader_sid = self._select_skill_leader(workflow_id, file_id, preferred_sid=preferred_sid)
|
||||
if not new_leader_sid:
|
||||
self.broadcast_skill_leader_change(workflow_id, file_id, None)
|
||||
return
|
||||
|
||||
self._repository.set_skill_leader(workflow_id, file_id, new_leader_sid)
|
||||
self.broadcast_skill_leader_change(workflow_id, file_id, new_leader_sid)
|
||||
|
||||
def _select_graph_leader(self, workflow_id: str, preferred_sid: str | None = None) -> str | None:
|
||||
session_sids = [
|
||||
session["sid"]
|
||||
for session in self._repository.list_sessions(workflow_id)
|
||||
if session.get("graph_active") and self.is_session_active(workflow_id, session["sid"])
|
||||
]
|
||||
if not session_sids:
|
||||
return None
|
||||
if preferred_sid and preferred_sid in session_sids:
|
||||
return preferred_sid
|
||||
return session_sids[0]
|
||||
|
||||
def _select_skill_leader(self, workflow_id: str, file_id: str, preferred_sid: str | None = None) -> str | None:
|
||||
session_sids = [
|
||||
sid
|
||||
for sid in self._repository.get_active_skill_session_sids(workflow_id, file_id)
|
||||
if self.is_session_active(workflow_id, sid)
|
||||
]
|
||||
if not session_sids:
|
||||
return None
|
||||
if preferred_sid and preferred_sid in session_sids:
|
||||
return preferred_sid
|
||||
return session_sids[0]
|
||||
|
||||
def is_session_active(self, workflow_id: str, sid: str) -> bool:
|
||||
if not sid:
|
||||
return False
|
||||
|
||||
try:
|
||||
if not self._socketio.manager.is_connected(sid, "/"):
|
||||
return False
|
||||
except AttributeError:
|
||||
return False
|
||||
|
||||
if not self._repository.session_exists(workflow_id, sid):
|
||||
return False
|
||||
|
||||
if not self._repository.sid_mapping_exists(sid):
|
||||
return False
|
||||
|
||||
return True
|
||||
468
api/services/workflow_comment_service.py
Normal file
468
api/services/workflow_comment_service.py
Normal file
@@ -0,0 +1,468 @@
|
||||
import logging
|
||||
from collections.abc import Sequence
|
||||
|
||||
from sqlalchemy import desc, select
|
||||
from sqlalchemy.orm import Session, selectinload
|
||||
from werkzeug.exceptions import Forbidden, NotFound
|
||||
|
||||
from configs import dify_config
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.helper import uuid_value
|
||||
from models import App, TenantAccountJoin, WorkflowComment, WorkflowCommentMention, WorkflowCommentReply
|
||||
from models.account import Account
|
||||
from tasks.mail_workflow_comment_task import send_workflow_comment_mention_email_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowCommentService:
|
||||
"""Service for managing workflow comments."""
|
||||
|
||||
@staticmethod
|
||||
def _validate_content(content: str) -> None:
|
||||
if len(content.strip()) == 0:
|
||||
raise ValueError("Comment content cannot be empty")
|
||||
|
||||
if len(content) > 1000:
|
||||
raise ValueError("Comment content cannot exceed 1000 characters")
|
||||
|
||||
@staticmethod
|
||||
def _filter_valid_mentioned_user_ids(mentioned_user_ids: Sequence[str]) -> list[str]:
|
||||
"""Return deduplicated UUID user IDs in the order provided."""
|
||||
unique_user_ids: list[str] = []
|
||||
seen: set[str] = set()
|
||||
for user_id in mentioned_user_ids:
|
||||
if not isinstance(user_id, str):
|
||||
continue
|
||||
if not uuid_value(user_id):
|
||||
continue
|
||||
if user_id in seen:
|
||||
continue
|
||||
seen.add(user_id)
|
||||
unique_user_ids.append(user_id)
|
||||
return unique_user_ids
|
||||
|
||||
@staticmethod
|
||||
def _format_comment_excerpt(content: str, max_length: int = 200) -> str:
|
||||
"""Trim comment content for email display."""
|
||||
trimmed = content.strip()
|
||||
if len(trimmed) <= max_length:
|
||||
return trimmed
|
||||
if max_length <= 3:
|
||||
return trimmed[:max_length]
|
||||
return f"{trimmed[: max_length - 3].rstrip()}..."
|
||||
|
||||
@staticmethod
|
||||
def _build_mention_email_payloads(
|
||||
session: Session,
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
mentioner_id: str,
|
||||
mentioned_user_ids: Sequence[str],
|
||||
content: str,
|
||||
) -> list[dict[str, str]]:
|
||||
"""Prepare email payloads for mentioned users, including the workflow app link."""
|
||||
if not mentioned_user_ids:
|
||||
return []
|
||||
|
||||
candidate_user_ids = [user_id for user_id in mentioned_user_ids if user_id != mentioner_id]
|
||||
if not candidate_user_ids:
|
||||
return []
|
||||
|
||||
app_name = session.scalar(select(App.name).where(App.id == app_id, App.tenant_id == tenant_id)) or "Dify app"
|
||||
commenter_name = session.scalar(select(Account.name).where(Account.id == mentioner_id)) or "Dify user"
|
||||
comment_excerpt = WorkflowCommentService._format_comment_excerpt(content)
|
||||
base_url = dify_config.CONSOLE_WEB_URL.rstrip("/")
|
||||
app_url = f"{base_url}/app/{app_id}/workflow"
|
||||
|
||||
accounts = session.scalars(
|
||||
select(Account)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
|
||||
.where(TenantAccountJoin.tenant_id == tenant_id, Account.id.in_(candidate_user_ids))
|
||||
).all()
|
||||
|
||||
payloads: list[dict[str, str]] = []
|
||||
for account in accounts:
|
||||
payloads.append(
|
||||
{
|
||||
"language": account.interface_language or "en-US",
|
||||
"to": account.email,
|
||||
"mentioned_name": account.name or account.email,
|
||||
"commenter_name": commenter_name,
|
||||
"app_name": app_name,
|
||||
"comment_content": comment_excerpt,
|
||||
"app_url": app_url,
|
||||
}
|
||||
)
|
||||
return payloads
|
||||
|
||||
@staticmethod
|
||||
def _dispatch_mention_emails(payloads: Sequence[dict[str, str]]) -> None:
|
||||
"""Enqueue mention notification emails."""
|
||||
for payload in payloads:
|
||||
send_workflow_comment_mention_email_task.delay(**payload)
|
||||
|
||||
@staticmethod
|
||||
def get_comments(tenant_id: str, app_id: str) -> Sequence[WorkflowComment]:
|
||||
"""Get all comments for a workflow."""
|
||||
with Session(db.engine) as session:
|
||||
# Get all comments with eager loading
|
||||
stmt = (
|
||||
select(WorkflowComment)
|
||||
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
|
||||
.where(WorkflowComment.tenant_id == tenant_id, WorkflowComment.app_id == app_id)
|
||||
.order_by(desc(WorkflowComment.created_at))
|
||||
)
|
||||
|
||||
comments = session.scalars(stmt).all()
|
||||
|
||||
# Batch preload all Account objects to avoid N+1 queries
|
||||
WorkflowCommentService._preload_accounts(session, comments)
|
||||
|
||||
return comments
|
||||
|
||||
@staticmethod
|
||||
def _preload_accounts(session: Session, comments: Sequence[WorkflowComment]) -> None:
|
||||
"""Batch preload Account objects for comments, replies, and mentions."""
|
||||
# Collect all user IDs
|
||||
user_ids: set[str] = set()
|
||||
for comment in comments:
|
||||
user_ids.add(comment.created_by)
|
||||
if comment.resolved_by:
|
||||
user_ids.add(comment.resolved_by)
|
||||
user_ids.update(reply.created_by for reply in comment.replies)
|
||||
user_ids.update(mention.mentioned_user_id for mention in comment.mentions)
|
||||
|
||||
if not user_ids:
|
||||
return
|
||||
|
||||
# Batch query all accounts
|
||||
accounts = session.scalars(select(Account).where(Account.id.in_(user_ids))).all()
|
||||
account_map = {str(account.id): account for account in accounts}
|
||||
|
||||
# Cache accounts on objects
|
||||
for comment in comments:
|
||||
comment.cache_created_by_account(account_map.get(comment.created_by))
|
||||
comment.cache_resolved_by_account(account_map.get(comment.resolved_by) if comment.resolved_by else None)
|
||||
for reply in comment.replies:
|
||||
reply.cache_created_by_account(account_map.get(reply.created_by))
|
||||
for mention in comment.mentions:
|
||||
mention.cache_mentioned_user_account(account_map.get(mention.mentioned_user_id))
|
||||
|
||||
@staticmethod
|
||||
def get_comment(tenant_id: str, app_id: str, comment_id: str, session: Session | None = None) -> WorkflowComment:
|
||||
"""Get a specific comment."""
|
||||
|
||||
def _get_comment(session: Session) -> WorkflowComment:
|
||||
stmt = (
|
||||
select(WorkflowComment)
|
||||
.options(selectinload(WorkflowComment.replies), selectinload(WorkflowComment.mentions))
|
||||
.where(
|
||||
WorkflowComment.id == comment_id,
|
||||
WorkflowComment.tenant_id == tenant_id,
|
||||
WorkflowComment.app_id == app_id,
|
||||
)
|
||||
)
|
||||
comment = session.scalar(stmt)
|
||||
|
||||
if not comment:
|
||||
raise NotFound("Comment not found")
|
||||
|
||||
# Preload accounts to avoid N+1 queries
|
||||
WorkflowCommentService._preload_accounts(session, [comment])
|
||||
|
||||
return comment
|
||||
|
||||
if session is not None:
|
||||
return _get_comment(session)
|
||||
else:
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
return _get_comment(session)
|
||||
|
||||
@staticmethod
|
||||
def create_comment(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
created_by: str,
|
||||
content: str,
|
||||
position_x: float,
|
||||
position_y: float,
|
||||
mentioned_user_ids: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""Create a new workflow comment and send mention notification emails."""
|
||||
WorkflowCommentService._validate_content(content)
|
||||
|
||||
with Session(db.engine) as session:
|
||||
comment = WorkflowComment(
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
position_x=position_x,
|
||||
position_y=position_y,
|
||||
content=content,
|
||||
created_by=created_by,
|
||||
)
|
||||
|
||||
session.add(comment)
|
||||
session.flush() # Get the comment ID for mentions
|
||||
|
||||
# Create mentions if specified
|
||||
mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids(mentioned_user_ids or [])
|
||||
for user_id in mentioned_user_ids:
|
||||
mention = WorkflowCommentMention(
|
||||
comment_id=comment.id,
|
||||
reply_id=None, # This is a comment mention, not reply mention
|
||||
mentioned_user_id=user_id,
|
||||
)
|
||||
session.add(mention)
|
||||
|
||||
mention_email_payloads = WorkflowCommentService._build_mention_email_payloads(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
mentioner_id=created_by,
|
||||
mentioned_user_ids=mentioned_user_ids,
|
||||
content=content,
|
||||
)
|
||||
|
||||
session.commit()
|
||||
WorkflowCommentService._dispatch_mention_emails(mention_email_payloads)
|
||||
|
||||
# Return only what we need - id and created_at
|
||||
return {"id": comment.id, "created_at": comment.created_at}
|
||||
|
||||
@staticmethod
|
||||
def update_comment(
|
||||
tenant_id: str,
|
||||
app_id: str,
|
||||
comment_id: str,
|
||||
user_id: str,
|
||||
content: str,
|
||||
position_x: float | None = None,
|
||||
position_y: float | None = None,
|
||||
mentioned_user_ids: list[str] | None = None,
|
||||
) -> dict:
|
||||
"""Update a workflow comment and notify newly mentioned users."""
|
||||
WorkflowCommentService._validate_content(content)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Get comment with validation
|
||||
stmt = select(WorkflowComment).where(
|
||||
WorkflowComment.id == comment_id,
|
||||
WorkflowComment.tenant_id == tenant_id,
|
||||
WorkflowComment.app_id == app_id,
|
||||
)
|
||||
comment = session.scalar(stmt)
|
||||
|
||||
if not comment:
|
||||
raise NotFound("Comment not found")
|
||||
|
||||
# Only the creator can update the comment
|
||||
if comment.created_by != user_id:
|
||||
raise Forbidden("Only the comment creator can update it")
|
||||
|
||||
# Update comment fields
|
||||
comment.content = content
|
||||
if position_x is not None:
|
||||
comment.position_x = position_x
|
||||
if position_y is not None:
|
||||
comment.position_y = position_y
|
||||
|
||||
# Update mentions - first remove existing mentions for this comment only (not replies)
|
||||
existing_mentions = session.scalars(
|
||||
select(WorkflowCommentMention).where(
|
||||
WorkflowCommentMention.comment_id == comment.id,
|
||||
WorkflowCommentMention.reply_id.is_(None), # Only comment mentions, not reply mentions
|
||||
)
|
||||
).all()
|
||||
existing_mentioned_user_ids = {mention.mentioned_user_id for mention in existing_mentions}
|
||||
for mention in existing_mentions:
|
||||
session.delete(mention)
|
||||
|
||||
# Add new mentions
|
||||
mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids(mentioned_user_ids or [])
|
||||
new_mentioned_user_ids = [
|
||||
user_id for user_id in mentioned_user_ids if user_id not in existing_mentioned_user_ids
|
||||
]
|
||||
for user_id_str in mentioned_user_ids:
|
||||
mention = WorkflowCommentMention(
|
||||
comment_id=comment.id,
|
||||
reply_id=None, # This is a comment mention
|
||||
mentioned_user_id=user_id_str,
|
||||
)
|
||||
session.add(mention)
|
||||
|
||||
mention_email_payloads = WorkflowCommentService._build_mention_email_payloads(
|
||||
session=session,
|
||||
tenant_id=tenant_id,
|
||||
app_id=app_id,
|
||||
mentioner_id=user_id,
|
||||
mentioned_user_ids=new_mentioned_user_ids,
|
||||
content=content,
|
||||
)
|
||||
|
||||
session.commit()
|
||||
WorkflowCommentService._dispatch_mention_emails(mention_email_payloads)
|
||||
|
||||
return {"id": comment.id, "updated_at": comment.updated_at}
|
||||
|
||||
@staticmethod
|
||||
def delete_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> None:
|
||||
"""Delete a workflow comment."""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
|
||||
|
||||
# Only the creator can delete the comment
|
||||
if comment.created_by != user_id:
|
||||
raise Forbidden("Only the comment creator can delete it")
|
||||
|
||||
# Delete associated mentions (both comment and reply mentions)
|
||||
mentions = session.scalars(
|
||||
select(WorkflowCommentMention).where(WorkflowCommentMention.comment_id == comment_id)
|
||||
).all()
|
||||
for mention in mentions:
|
||||
session.delete(mention)
|
||||
|
||||
# Delete associated replies
|
||||
replies = session.scalars(
|
||||
select(WorkflowCommentReply).where(WorkflowCommentReply.comment_id == comment_id)
|
||||
).all()
|
||||
for reply in replies:
|
||||
session.delete(reply)
|
||||
|
||||
session.delete(comment)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def resolve_comment(tenant_id: str, app_id: str, comment_id: str, user_id: str) -> WorkflowComment:
|
||||
"""Resolve a workflow comment."""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
comment = WorkflowCommentService.get_comment(tenant_id, app_id, comment_id, session)
|
||||
if comment.resolved:
|
||||
return comment
|
||||
|
||||
comment.resolved = True
|
||||
comment.resolved_at = naive_utc_now()
|
||||
comment.resolved_by = user_id
|
||||
session.commit()
|
||||
|
||||
return comment
|
||||
|
||||
@staticmethod
|
||||
def create_reply(
|
||||
comment_id: str, content: str, created_by: str, mentioned_user_ids: list[str] | None = None
|
||||
) -> dict:
|
||||
"""Add a reply to a workflow comment and notify mentioned users."""
|
||||
WorkflowCommentService._validate_content(content)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
# Check if comment exists
|
||||
comment = session.get(WorkflowComment, comment_id)
|
||||
if not comment:
|
||||
raise NotFound("Comment not found")
|
||||
|
||||
reply = WorkflowCommentReply(comment_id=comment_id, content=content, created_by=created_by)
|
||||
|
||||
session.add(reply)
|
||||
session.flush() # Get the reply ID for mentions
|
||||
|
||||
# Create mentions if specified
|
||||
mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids(mentioned_user_ids or [])
|
||||
for user_id in mentioned_user_ids:
|
||||
# Create mention linking to specific reply
|
||||
mention = WorkflowCommentMention(comment_id=comment_id, reply_id=reply.id, mentioned_user_id=user_id)
|
||||
session.add(mention)
|
||||
|
||||
mention_email_payloads = WorkflowCommentService._build_mention_email_payloads(
|
||||
session=session,
|
||||
tenant_id=comment.tenant_id,
|
||||
app_id=comment.app_id,
|
||||
mentioner_id=created_by,
|
||||
mentioned_user_ids=mentioned_user_ids,
|
||||
content=content,
|
||||
)
|
||||
|
||||
session.commit()
|
||||
WorkflowCommentService._dispatch_mention_emails(mention_email_payloads)
|
||||
|
||||
return {"id": reply.id, "created_at": reply.created_at}
|
||||
|
||||
@staticmethod
|
||||
def update_reply(reply_id: str, user_id: str, content: str, mentioned_user_ids: list[str] | None = None) -> dict:
|
||||
"""Update a comment reply and notify newly mentioned users."""
|
||||
WorkflowCommentService._validate_content(content)
|
||||
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
reply = session.get(WorkflowCommentReply, reply_id)
|
||||
if not reply:
|
||||
raise NotFound("Reply not found")
|
||||
|
||||
# Only the creator can update the reply
|
||||
if reply.created_by != user_id:
|
||||
raise Forbidden("Only the reply creator can update it")
|
||||
|
||||
reply.content = content
|
||||
|
||||
# Update mentions - first remove existing mentions for this reply
|
||||
existing_mentions = session.scalars(
|
||||
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply.id)
|
||||
).all()
|
||||
existing_mentioned_user_ids = {mention.mentioned_user_id for mention in existing_mentions}
|
||||
for mention in existing_mentions:
|
||||
session.delete(mention)
|
||||
|
||||
# Add mentions
|
||||
mentioned_user_ids = WorkflowCommentService._filter_valid_mentioned_user_ids(mentioned_user_ids or [])
|
||||
new_mentioned_user_ids = [
|
||||
user_id for user_id in mentioned_user_ids if user_id not in existing_mentioned_user_ids
|
||||
]
|
||||
for user_id_str in mentioned_user_ids:
|
||||
mention = WorkflowCommentMention(
|
||||
comment_id=reply.comment_id, reply_id=reply.id, mentioned_user_id=user_id_str
|
||||
)
|
||||
session.add(mention)
|
||||
|
||||
mention_email_payloads: list[dict[str, str]] = []
|
||||
comment = session.get(WorkflowComment, reply.comment_id)
|
||||
if comment:
|
||||
mention_email_payloads = WorkflowCommentService._build_mention_email_payloads(
|
||||
session=session,
|
||||
tenant_id=comment.tenant_id,
|
||||
app_id=comment.app_id,
|
||||
mentioner_id=user_id,
|
||||
mentioned_user_ids=new_mentioned_user_ids,
|
||||
content=content,
|
||||
)
|
||||
|
||||
session.commit()
|
||||
session.refresh(reply) # Refresh to get updated timestamp
|
||||
WorkflowCommentService._dispatch_mention_emails(mention_email_payloads)
|
||||
|
||||
return {"id": reply.id, "updated_at": reply.updated_at}
|
||||
|
||||
@staticmethod
|
||||
def delete_reply(reply_id: str, user_id: str) -> None:
|
||||
"""Delete a comment reply."""
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
reply = session.get(WorkflowCommentReply, reply_id)
|
||||
if not reply:
|
||||
raise NotFound("Reply not found")
|
||||
|
||||
# Only the creator can delete the reply
|
||||
if reply.created_by != user_id:
|
||||
raise Forbidden("Only the reply creator can delete it")
|
||||
|
||||
# Delete associated mentions first
|
||||
mentions = session.scalars(
|
||||
select(WorkflowCommentMention).where(WorkflowCommentMention.reply_id == reply_id)
|
||||
).all()
|
||||
for mention in mentions:
|
||||
session.delete(mention)
|
||||
|
||||
session.delete(reply)
|
||||
session.commit()
|
||||
|
||||
@staticmethod
|
||||
def validate_comment_access(comment_id: str, tenant_id: str, app_id: str) -> WorkflowComment:
|
||||
"""Validate that a comment belongs to the specified tenant and app."""
|
||||
return WorkflowCommentService.get_comment(tenant_id, app_id, comment_id)
|
||||
@@ -1423,7 +1423,7 @@ class WorkflowService:
|
||||
|
||||
def validate_features_structure(self, app_model: App, features: dict):
|
||||
match app_model.mode:
|
||||
case AppMode.ADVANCED_CHAT:
|
||||
case AppMode.ADVANCED_CHAT | AppMode.AGENT:
|
||||
return AdvancedChatAppConfigManager.config_validate(
|
||||
tenant_id=app_model.tenant_id, config=features, only_structure_validate=True
|
||||
)
|
||||
|
||||
@@ -183,7 +183,13 @@ class _AppRunner:
|
||||
pause_state_config: PauseStateLayerConfig,
|
||||
):
|
||||
exec_params = self._exec_params
|
||||
if exec_params.app_mode == AppMode.ADVANCED_CHAT:
|
||||
if exec_params.app_mode in {
|
||||
AppMode.ADVANCED_CHAT,
|
||||
AppMode.AGENT,
|
||||
AppMode.CHAT,
|
||||
AppMode.AGENT_CHAT,
|
||||
AppMode.COMPLETION,
|
||||
}:
|
||||
return AdvancedChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
|
||||
65
api/tasks/mail_workflow_comment_task.py
Normal file
65
api/tasks/mail_workflow_comment_task.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
import click
|
||||
from celery import shared_task
|
||||
|
||||
from extensions.ext_mail import mail
|
||||
from libs.email_i18n import EmailType, get_email_i18n_service
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@shared_task(queue="mail")
|
||||
def send_workflow_comment_mention_email_task(
|
||||
language: str,
|
||||
to: str,
|
||||
mentioned_name: str,
|
||||
commenter_name: str,
|
||||
app_name: str,
|
||||
comment_content: str,
|
||||
app_url: str,
|
||||
):
|
||||
"""
|
||||
Send workflow comment mention email with internationalization support.
|
||||
|
||||
Args:
|
||||
language: Language code for email localization
|
||||
to: Recipient email address
|
||||
mentioned_name: Name of the mentioned user
|
||||
commenter_name: Name of the comment author
|
||||
app_name: Name of the app where the comment was made
|
||||
comment_content: Comment content excerpt
|
||||
app_url: Link to the app workflow page
|
||||
"""
|
||||
if not mail.is_inited():
|
||||
return
|
||||
|
||||
logger.info(click.style(f"Start workflow comment mention mail to {to}", fg="green"))
|
||||
start_at = time.perf_counter()
|
||||
|
||||
try:
|
||||
email_service = get_email_i18n_service()
|
||||
email_service.send_email(
|
||||
email_type=EmailType.WORKFLOW_COMMENT_MENTION,
|
||||
language_code=language,
|
||||
to=to,
|
||||
template_context={
|
||||
"to": to,
|
||||
"mentioned_name": mentioned_name,
|
||||
"commenter_name": commenter_name,
|
||||
"app_name": app_name,
|
||||
"comment_content": comment_content,
|
||||
"app_url": app_url,
|
||||
},
|
||||
)
|
||||
|
||||
end_at = time.perf_counter()
|
||||
logger.info(
|
||||
click.style(
|
||||
f"Send workflow comment mention mail to {to} succeeded: latency: {end_at - start_at}",
|
||||
fg="green",
|
||||
)
|
||||
)
|
||||
except Exception:
|
||||
logger.exception("workflow comment mention email to %s failed", to)
|
||||
@@ -100,6 +100,8 @@ class TestAppGenerateService:
|
||||
mock_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
|
||||
mock_dify_config.APP_DEFAULT_ACTIVE_REQUESTS = 100
|
||||
mock_dify_config.APP_DAILY_RATE_LIMIT = 1000
|
||||
mock_dify_config.AGENT_V2_TRANSPARENT_UPGRADE = False
|
||||
mock_dify_config.AGENT_V2_REPLACES_LLM = False
|
||||
|
||||
mock_global_dify_config.BILLING_ENABLED = False
|
||||
mock_global_dify_config.APP_MAX_ACTIVE_REQUESTS = 100
|
||||
|
||||
@@ -146,7 +146,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.advanced_chat.app_generator.db",
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)),
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None, scalar=lambda *a, **kw: None)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.advanced_chat.app_generator.sessionmaker", lambda **kwargs: SimpleNamespace()
|
||||
@@ -576,7 +576,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.advanced_chat.app_generator.db",
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)),
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None, scalar=lambda *a, **kw: None)),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Workflow not found"):
|
||||
@@ -640,7 +640,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.advanced_chat.app_generator.db",
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)),
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None, scalar=lambda *a, **kw: None)),
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="App not found"):
|
||||
@@ -713,7 +713,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.AdvancedChatAppRunner", _Runner)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.advanced_chat.app_generator.db",
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)),
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None, scalar=lambda *a, **kw: None)),
|
||||
)
|
||||
|
||||
generator._generate_worker(
|
||||
@@ -797,7 +797,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.AdvancedChatAppRunner", _Runner)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.advanced_chat.app_generator.db",
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)),
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None, scalar=lambda *a, **kw: None)),
|
||||
)
|
||||
|
||||
generator._generate_worker(
|
||||
@@ -878,7 +878,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.dify_config", SimpleNamespace(DEBUG=True))
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.advanced_chat.app_generator.db",
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)),
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None, scalar=lambda *a, **kw: None)),
|
||||
)
|
||||
|
||||
generator._generate_worker(
|
||||
@@ -1069,7 +1069,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
monkeypatch.setattr("core.app.apps.advanced_chat.app_generator.Session", _Session)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.advanced_chat.app_generator.db",
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)),
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None, scalar=lambda *a, **kw: None)),
|
||||
)
|
||||
|
||||
generator._generate_worker(
|
||||
@@ -1131,7 +1131,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.advanced_chat.app_generator.db",
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)),
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None, scalar=lambda *a, **kw: None)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.advanced_chat.app_generator.sessionmaker",
|
||||
@@ -1210,7 +1210,7 @@ class TestAdvancedChatAppGeneratorInternals:
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.advanced_chat.app_generator.db",
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None)),
|
||||
SimpleNamespace(engine=object(), session=SimpleNamespace(close=lambda: None, scalar=lambda *a, **kw: None)),
|
||||
)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.advanced_chat.app_generator.sessionmaker",
|
||||
|
||||
@@ -136,8 +136,8 @@ class TestAgentChatAppRunnerRun:
|
||||
@pytest.mark.parametrize(
|
||||
("mode", "expected_runner"),
|
||||
[
|
||||
(LLMMode.CHAT, "CotChatAgentRunner"),
|
||||
(LLMMode.COMPLETION, "CotCompletionAgentRunner"),
|
||||
(LLMMode.CHAT, "AgentAppRunner"),
|
||||
(LLMMode.COMPLETION, "AgentAppRunner"),
|
||||
],
|
||||
)
|
||||
def test_run_chain_of_thought_modes(self, runner, mocker, mode, expected_runner):
|
||||
@@ -196,7 +196,8 @@ class TestAgentChatAppRunnerRun:
|
||||
runner_instance.run.assert_called_once()
|
||||
runner._handle_invoke_result.assert_called_once()
|
||||
|
||||
def test_run_invalid_llm_mode_raises(self, runner, mocker):
|
||||
def test_run_invalid_llm_mode_proceeds(self, runner, mocker):
|
||||
"""With unified AgentAppRunner, invalid LLM mode no longer raises ValueError."""
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = AgentEntity(provider="p", model="m", strategy=AgentEntity.Strategy.CHAIN_OF_THOUGHT)
|
||||
@@ -239,8 +240,16 @@ class TestAgentChatAppRunnerRun:
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
runner_cls = mocker.MagicMock()
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.AgentAppRunner", runner_cls)
|
||||
runner_instance = mocker.MagicMock()
|
||||
runner_cls.return_value = runner_instance
|
||||
runner_instance.run.return_value = []
|
||||
mocker.patch.object(runner, "_handle_invoke_result")
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
||||
runner_instance.run.assert_called_once()
|
||||
|
||||
def test_run_function_calling_strategy_selected_by_features(self, runner, mocker):
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
@@ -286,7 +295,7 @@ class TestAgentChatAppRunnerRun:
|
||||
)
|
||||
|
||||
runner_cls = mocker.MagicMock()
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.FunctionCallAgentRunner", runner_cls)
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.AgentAppRunner", runner_cls)
|
||||
|
||||
runner_instance = mocker.MagicMock()
|
||||
runner_cls.return_value = runner_instance
|
||||
@@ -366,7 +375,8 @@ class TestAgentChatAppRunnerRun:
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), mocker.MagicMock(id="conv"), mocker.MagicMock(id="msg"))
|
||||
|
||||
def test_run_invalid_agent_strategy_raises(self, runner, mocker):
|
||||
def test_run_invalid_agent_strategy_defaults_to_react(self, runner, mocker):
|
||||
"""With StrategyFactory, invalid strategy defaults to ReAct instead of raising ValueError."""
|
||||
app_record = mocker.MagicMock(id="app1", tenant_id="tenant")
|
||||
app_config = mocker.MagicMock(app_id="app1", tenant_id="tenant", prompt_template=mocker.MagicMock())
|
||||
app_config.agent = mocker.MagicMock(strategy="invalid", provider="p", model="m")
|
||||
@@ -409,5 +419,13 @@ class TestAgentChatAppRunnerRun:
|
||||
side_effect=[app_record, conversation, message],
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
runner_cls = mocker.MagicMock()
|
||||
mocker.patch("core.app.apps.agent_chat.app_runner.AgentAppRunner", runner_cls)
|
||||
runner_instance = mocker.MagicMock()
|
||||
runner_cls.return_value = runner_instance
|
||||
runner_instance.run.return_value = []
|
||||
mocker.patch.object(runner, "_handle_invoke_result")
|
||||
|
||||
runner.run(generate_entity, mocker.MagicMock(), conversation, message)
|
||||
|
||||
runner_instance.run.assert_called_once()
|
||||
|
||||
@@ -0,0 +1,332 @@
|
||||
"""Basic tests for Agent V2 node — Phase 1 + 2 validation.
|
||||
|
||||
Tests:
|
||||
1. Module imports resolve without errors
|
||||
2. AgentV2Node self-registers in the graphon Node registry
|
||||
3. DifyNodeFactory kwargs mapping includes agent-v2
|
||||
4. StrategyFactory selects correct strategy based on model features
|
||||
5. AgentV2NodeData validates with and without tools
|
||||
"""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestPhase1Imports:
|
||||
"""Verify Phase 1 (Agent Patterns) modules import correctly."""
|
||||
|
||||
def test_entities_import(self):
|
||||
from core.agent.entities import AgentLog, AgentResult, ExecutionContext
|
||||
|
||||
assert ExecutionContext is not None
|
||||
assert AgentLog is not None
|
||||
assert AgentResult is not None
|
||||
|
||||
def test_entities_backward_compatible(self):
|
||||
from core.agent.entities import (
|
||||
AgentEntity,
|
||||
AgentInvokeMessage,
|
||||
AgentPromptEntity,
|
||||
AgentScratchpadUnit,
|
||||
AgentToolEntity,
|
||||
)
|
||||
|
||||
assert AgentEntity is not None
|
||||
assert AgentToolEntity is not None
|
||||
assert AgentPromptEntity is not None
|
||||
assert AgentScratchpadUnit is not None
|
||||
assert AgentInvokeMessage is not None
|
||||
|
||||
def test_patterns_module_import(self):
|
||||
from core.agent.patterns import (
|
||||
AgentPattern,
|
||||
FunctionCallStrategy,
|
||||
ReActStrategy,
|
||||
StrategyFactory,
|
||||
)
|
||||
|
||||
assert AgentPattern is not None
|
||||
assert FunctionCallStrategy is not None
|
||||
assert ReActStrategy is not None
|
||||
assert StrategyFactory is not None
|
||||
|
||||
def test_patterns_inheritance(self):
|
||||
from core.agent.patterns import AgentPattern, FunctionCallStrategy, ReActStrategy
|
||||
|
||||
assert issubclass(FunctionCallStrategy, AgentPattern)
|
||||
assert issubclass(ReActStrategy, AgentPattern)
|
||||
|
||||
|
||||
class TestPhase2Imports:
|
||||
"""Verify Phase 2 (Agent V2 Node) modules import correctly."""
|
||||
|
||||
def test_entities_import(self):
|
||||
from core.workflow.nodes.agent_v2.entities import (
|
||||
AGENT_V2_NODE_TYPE,
|
||||
AgentV2NodeData,
|
||||
ContextConfig,
|
||||
ToolMetadata,
|
||||
VisionConfig,
|
||||
)
|
||||
|
||||
assert AGENT_V2_NODE_TYPE == "agent-v2"
|
||||
assert AgentV2NodeData is not None
|
||||
assert ToolMetadata is not None
|
||||
|
||||
def test_node_import(self):
|
||||
from core.workflow.nodes.agent_v2.node import AgentV2Node
|
||||
|
||||
assert AgentV2Node is not None
|
||||
assert AgentV2Node.node_type == "agent-v2"
|
||||
|
||||
def test_tool_manager_import(self):
|
||||
from core.workflow.nodes.agent_v2.tool_manager import AgentV2ToolManager
|
||||
|
||||
assert AgentV2ToolManager is not None
|
||||
|
||||
def test_event_adapter_import(self):
|
||||
from core.workflow.nodes.agent_v2.event_adapter import AgentV2EventAdapter
|
||||
|
||||
assert AgentV2EventAdapter is not None
|
||||
|
||||
|
||||
class TestNodeRegistration:
|
||||
"""Verify AgentV2Node self-registers in the graphon Node registry."""
|
||||
|
||||
def test_agent_v2_in_registry(self):
|
||||
from core.workflow.node_factory import register_nodes
|
||||
|
||||
register_nodes()
|
||||
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
registry = Node.get_node_type_classes_mapping()
|
||||
assert "agent-v2" in registry, f"agent-v2 not found in registry. Available: {list(registry.keys())}"
|
||||
|
||||
def test_agent_v2_latest_version(self):
|
||||
from core.workflow.node_factory import register_nodes
|
||||
|
||||
register_nodes()
|
||||
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
registry = Node.get_node_type_classes_mapping()
|
||||
agent_v2_versions = registry.get("agent-v2", {})
|
||||
assert "latest" in agent_v2_versions
|
||||
assert "1" in agent_v2_versions
|
||||
|
||||
from core.workflow.nodes.agent_v2.node import AgentV2Node
|
||||
|
||||
assert agent_v2_versions["latest"] is AgentV2Node
|
||||
assert agent_v2_versions["1"] is AgentV2Node
|
||||
|
||||
def test_old_agent_still_registered(self):
|
||||
"""Old Agent node must not be affected by Agent V2."""
|
||||
from core.workflow.node_factory import register_nodes
|
||||
|
||||
register_nodes()
|
||||
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
registry = Node.get_node_type_classes_mapping()
|
||||
assert "agent" in registry, "Old agent node must still be registered"
|
||||
|
||||
def test_resolve_workflow_node_class(self):
|
||||
from core.workflow.node_factory import register_nodes, resolve_workflow_node_class
|
||||
from core.workflow.nodes.agent_v2.node import AgentV2Node
|
||||
|
||||
register_nodes()
|
||||
|
||||
resolved = resolve_workflow_node_class(node_type="agent-v2", node_version="1")
|
||||
assert resolved is AgentV2Node
|
||||
|
||||
resolved_latest = resolve_workflow_node_class(node_type="agent-v2", node_version="latest")
|
||||
assert resolved_latest is AgentV2Node
|
||||
|
||||
|
||||
class TestNodeFactoryKwargs:
|
||||
"""Verify DifyNodeFactory includes agent-v2 in kwargs mapping."""
|
||||
|
||||
def test_agent_v2_node_type_in_factory(self):
|
||||
from core.workflow.node_factory import AGENT_V2_NODE_TYPE
|
||||
|
||||
assert AGENT_V2_NODE_TYPE == "agent-v2"
|
||||
|
||||
|
||||
class TestStrategyFactory:
|
||||
"""Verify StrategyFactory selects correct strategy."""
|
||||
|
||||
def test_fc_selected_for_tool_call_model(self):
|
||||
from graphon.model_runtime.entities.model_entities import ModelFeature
|
||||
|
||||
from core.agent.patterns import FunctionCallStrategy, StrategyFactory
|
||||
|
||||
assert ModelFeature.TOOL_CALL in StrategyFactory.TOOL_CALL_FEATURES
|
||||
assert ModelFeature.MULTI_TOOL_CALL in StrategyFactory.TOOL_CALL_FEATURES
|
||||
|
||||
def test_factory_has_create_strategy(self):
|
||||
from core.agent.patterns import StrategyFactory
|
||||
|
||||
assert callable(getattr(StrategyFactory, "create_strategy", None))
|
||||
|
||||
|
||||
class TestAgentV2NodeData:
|
||||
"""Verify AgentV2NodeData validation."""
|
||||
|
||||
def test_minimal_data(self):
|
||||
from core.workflow.nodes.agent_v2.entities import AgentV2NodeData
|
||||
|
||||
data = AgentV2NodeData(
|
||||
title="Test Agent",
|
||||
model={"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}},
|
||||
prompt_template=[{"role": "system", "text": "You are helpful."}, {"role": "user", "text": "Hello"}],
|
||||
context={"enabled": False},
|
||||
)
|
||||
assert data.type == "agent-v2"
|
||||
assert data.tool_call_enabled is False
|
||||
assert data.max_iterations == 10
|
||||
assert data.agent_strategy == "auto"
|
||||
|
||||
def test_data_with_tools(self):
|
||||
from core.workflow.nodes.agent_v2.entities import AgentV2NodeData
|
||||
|
||||
data = AgentV2NodeData(
|
||||
title="Test Agent with Tools",
|
||||
model={"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}},
|
||||
prompt_template=[{"role": "user", "text": "Search for {{query}}"}],
|
||||
context={"enabled": False},
|
||||
tools=[
|
||||
{
|
||||
"enabled": True,
|
||||
"type": "builtin",
|
||||
"provider_name": "google",
|
||||
"tool_name": "google_search",
|
||||
}
|
||||
],
|
||||
max_iterations=5,
|
||||
agent_strategy="function-calling",
|
||||
)
|
||||
assert data.tool_call_enabled is True
|
||||
assert data.max_iterations == 5
|
||||
assert data.agent_strategy == "function-calling"
|
||||
assert len(data.tools) == 1
|
||||
|
||||
def test_data_with_disabled_tools(self):
|
||||
from core.workflow.nodes.agent_v2.entities import AgentV2NodeData
|
||||
|
||||
data = AgentV2NodeData(
|
||||
title="Test Agent",
|
||||
model={"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}},
|
||||
prompt_template=[{"role": "user", "text": "Hello"}],
|
||||
context={"enabled": False},
|
||||
tools=[
|
||||
{
|
||||
"enabled": False,
|
||||
"type": "builtin",
|
||||
"provider_name": "google",
|
||||
"tool_name": "google_search",
|
||||
}
|
||||
],
|
||||
)
|
||||
assert data.tool_call_enabled is False
|
||||
|
||||
def test_data_with_memory(self):
|
||||
from core.workflow.nodes.agent_v2.entities import AgentV2NodeData
|
||||
|
||||
data = AgentV2NodeData(
|
||||
title="Test Agent",
|
||||
model={"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}},
|
||||
prompt_template=[{"role": "user", "text": "Hello"}],
|
||||
context={"enabled": False},
|
||||
memory={"window": {"enabled": True, "size": 50}},
|
||||
)
|
||||
assert data.memory is not None
|
||||
assert data.memory.window.enabled is True
|
||||
assert data.memory.window.size == 50
|
||||
|
||||
def test_data_with_vision(self):
|
||||
from core.workflow.nodes.agent_v2.entities import AgentV2NodeData
|
||||
|
||||
data = AgentV2NodeData(
|
||||
title="Test Agent",
|
||||
model={"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}},
|
||||
prompt_template=[{"role": "user", "text": "Hello"}],
|
||||
context={"enabled": False},
|
||||
vision={"enabled": True},
|
||||
)
|
||||
assert data.vision.enabled is True
|
||||
|
||||
|
||||
class TestExecutionContext:
|
||||
"""Verify ExecutionContext entity."""
|
||||
|
||||
def test_create_minimal(self):
|
||||
from core.agent.entities import ExecutionContext
|
||||
|
||||
ctx = ExecutionContext.create_minimal(user_id="user-123")
|
||||
assert ctx.user_id == "user-123"
|
||||
assert ctx.app_id is None
|
||||
|
||||
def test_to_dict(self):
|
||||
from core.agent.entities import ExecutionContext
|
||||
|
||||
ctx = ExecutionContext(user_id="u1", app_id="a1", tenant_id="t1")
|
||||
d = ctx.to_dict()
|
||||
assert d["user_id"] == "u1"
|
||||
assert d["app_id"] == "a1"
|
||||
assert d["tenant_id"] == "t1"
|
||||
assert d["conversation_id"] is None
|
||||
|
||||
def test_with_updates(self):
|
||||
from core.agent.entities import ExecutionContext
|
||||
|
||||
ctx = ExecutionContext(user_id="u1")
|
||||
ctx2 = ctx.with_updates(app_id="a1", conversation_id="c1")
|
||||
assert ctx2.user_id == "u1"
|
||||
assert ctx2.app_id == "a1"
|
||||
assert ctx2.conversation_id == "c1"
|
||||
|
||||
|
||||
class TestAgentLog:
|
||||
"""Verify AgentLog entity."""
|
||||
|
||||
def test_create_log(self):
|
||||
from core.agent.entities import AgentLog
|
||||
|
||||
log = AgentLog(
|
||||
label="Round 1",
|
||||
log_type=AgentLog.LogType.ROUND,
|
||||
status=AgentLog.LogStatus.START,
|
||||
data={"key": "value"},
|
||||
)
|
||||
assert log.id is not None
|
||||
assert log.label == "Round 1"
|
||||
assert log.log_type == "round"
|
||||
assert log.status == "start"
|
||||
assert log.parent_id is None
|
||||
|
||||
def test_log_types(self):
|
||||
from core.agent.entities import AgentLog
|
||||
|
||||
assert AgentLog.LogType.ROUND == "round"
|
||||
assert AgentLog.LogType.THOUGHT == "thought"
|
||||
assert AgentLog.LogType.TOOL_CALL == "tool_call"
|
||||
|
||||
|
||||
class TestAgentResult:
|
||||
"""Verify AgentResult entity."""
|
||||
|
||||
def test_default_result(self):
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
result = AgentResult()
|
||||
assert result.text == ""
|
||||
assert result.files == []
|
||||
assert result.usage is None
|
||||
assert result.finish_reason is None
|
||||
|
||||
def test_result_with_data(self):
|
||||
from core.agent.entities import AgentResult
|
||||
|
||||
result = AgentResult(text="Hello world", finish_reason="stop")
|
||||
assert result.text == "Hello world"
|
||||
assert result.finish_reason == "stop"
|
||||
@@ -0,0 +1,132 @@
|
||||
"""Tests for Phase 3 — Agent app type support."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestAppModeAgent:
|
||||
"""Verify AppMode.AGENT is properly defined."""
|
||||
|
||||
def test_agent_mode_exists(self):
|
||||
from models.model import AppMode
|
||||
|
||||
assert hasattr(AppMode, "AGENT")
|
||||
assert AppMode.AGENT == "agent"
|
||||
|
||||
def test_agent_mode_value_of(self):
|
||||
from models.model import AppMode
|
||||
|
||||
mode = AppMode.value_of("agent")
|
||||
assert mode == AppMode.AGENT
|
||||
|
||||
def test_all_original_modes_still_work(self):
|
||||
from models.model import AppMode
|
||||
|
||||
for val in ["completion", "workflow", "chat", "advanced-chat", "agent-chat", "channel", "rag-pipeline"]:
|
||||
mode = AppMode.value_of(val)
|
||||
assert mode.value == val
|
||||
|
||||
|
||||
class TestDefaultAppTemplate:
|
||||
"""Verify AGENT template is defined."""
|
||||
|
||||
def test_agent_template_exists(self):
|
||||
from constants.model_template import default_app_templates
|
||||
from models.model import AppMode
|
||||
|
||||
assert AppMode.AGENT in default_app_templates
|
||||
template = default_app_templates[AppMode.AGENT]
|
||||
assert template["app"]["mode"] == AppMode.AGENT
|
||||
assert template["app"]["enable_site"] is True
|
||||
assert "model_config" in template
|
||||
|
||||
def test_all_original_templates_exist(self):
|
||||
from constants.model_template import default_app_templates
|
||||
from models.model import AppMode
|
||||
|
||||
for mode in [AppMode.WORKFLOW, AppMode.COMPLETION, AppMode.CHAT, AppMode.ADVANCED_CHAT, AppMode.AGENT_CHAT]:
|
||||
assert mode in default_app_templates
|
||||
|
||||
|
||||
class TestWorkflowGraphFactory:
|
||||
"""Verify WorkflowGraphFactory creates valid graphs."""
|
||||
|
||||
def test_create_chat_graph(self):
|
||||
from services.workflow.graph_factory import WorkflowGraphFactory
|
||||
|
||||
model_config = {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}
|
||||
graph = WorkflowGraphFactory.create_single_agent_graph(model_config, is_chat=True)
|
||||
|
||||
assert "nodes" in graph
|
||||
assert "edges" in graph
|
||||
assert len(graph["nodes"]) == 3
|
||||
assert len(graph["edges"]) == 2
|
||||
|
||||
node_types = [n["data"]["type"] for n in graph["nodes"]]
|
||||
assert "start" in node_types
|
||||
assert "agent-v2" in node_types
|
||||
assert "answer" in node_types
|
||||
|
||||
agent_node = next(n for n in graph["nodes"] if n["data"]["type"] == "agent-v2")
|
||||
assert agent_node["data"]["model"] == model_config
|
||||
assert agent_node["data"]["memory"] is not None
|
||||
|
||||
def test_create_workflow_graph(self):
|
||||
from services.workflow.graph_factory import WorkflowGraphFactory
|
||||
|
||||
model_config = {"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}}
|
||||
graph = WorkflowGraphFactory.create_single_agent_graph(model_config, is_chat=False)
|
||||
|
||||
node_types = [n["data"]["type"] for n in graph["nodes"]]
|
||||
assert "end" in node_types
|
||||
assert "answer" not in node_types
|
||||
|
||||
agent_node = next(n for n in graph["nodes"] if n["data"]["type"] == "agent-v2")
|
||||
assert agent_node["data"].get("memory") is None
|
||||
|
||||
def test_edge_connectivity(self):
|
||||
from services.workflow.graph_factory import WorkflowGraphFactory
|
||||
|
||||
graph = WorkflowGraphFactory.create_single_agent_graph(
|
||||
{"provider": "openai", "name": "gpt-4o", "mode": "chat", "completion_params": {}},
|
||||
is_chat=True,
|
||||
)
|
||||
|
||||
edges = graph["edges"]
|
||||
sources = {e["source"] for e in edges}
|
||||
targets = {e["target"] for e in edges}
|
||||
assert "start" in sources
|
||||
assert "agent" in sources
|
||||
assert "agent" in targets
|
||||
assert "answer" in targets
|
||||
|
||||
|
||||
class TestConsoleAppController:
|
||||
"""Verify Console API allows 'agent' mode."""
|
||||
|
||||
def test_allow_create_app_modes(self):
|
||||
from controllers.console.app.app import ALLOW_CREATE_APP_MODES
|
||||
|
||||
assert "agent" in ALLOW_CREATE_APP_MODES
|
||||
assert "chat" in ALLOW_CREATE_APP_MODES
|
||||
assert "agent-chat" in ALLOW_CREATE_APP_MODES
|
||||
|
||||
|
||||
class TestAppGenerateServiceHasAgentCase:
|
||||
"""Verify the generate() method has an AppMode.AGENT case."""
|
||||
|
||||
def test_generate_method_exists(self):
|
||||
from services.app_generate_service import AppGenerateService
|
||||
|
||||
assert hasattr(AppGenerateService, "generate")
|
||||
|
||||
def test_agent_mode_import(self):
|
||||
"""Verify AppMode.AGENT can be used in match statement context."""
|
||||
from models.model import AppMode
|
||||
|
||||
mode = AppMode.AGENT
|
||||
match mode:
|
||||
case AppMode.AGENT:
|
||||
result = "agent"
|
||||
case _:
|
||||
result = "other"
|
||||
assert result == "agent"
|
||||
@@ -0,0 +1,115 @@
|
||||
"""Tests for Phase 7 — New/old agent node parallel compatibility."""
|
||||
|
||||
import pytest
|
||||
|
||||
|
||||
class TestAgentV2DefaultConfig:
|
||||
"""Verify Agent V2 node provides default block configuration."""
|
||||
|
||||
def test_has_default_config(self):
|
||||
from core.workflow.node_factory import register_nodes
|
||||
|
||||
register_nodes()
|
||||
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
registry = Node.get_node_type_classes_mapping()
|
||||
agent_v2_cls = registry["agent-v2"]["latest"]
|
||||
config = agent_v2_cls.get_default_config()
|
||||
|
||||
assert config, "Agent V2 should have a default config"
|
||||
assert config["type"] == "agent-v2"
|
||||
assert "config" in config
|
||||
assert "prompt_templates" in config["config"]
|
||||
assert "agent_strategy" in config["config"]
|
||||
assert config["config"]["agent_strategy"] == "auto"
|
||||
assert config["config"]["max_iterations"] == 10
|
||||
|
||||
def test_old_agent_no_default_config(self):
|
||||
from core.workflow.node_factory import register_nodes
|
||||
|
||||
register_nodes()
|
||||
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
registry = Node.get_node_type_classes_mapping()
|
||||
agent_cls = registry["agent"]["latest"]
|
||||
config = agent_cls.get_default_config()
|
||||
assert config == {} or config is None or not config
|
||||
|
||||
|
||||
class TestParallelNodeRegistration:
|
||||
"""Verify both agent and agent-v2 coexist in the registry."""
|
||||
|
||||
def test_both_registered(self):
|
||||
from core.workflow.node_factory import register_nodes
|
||||
|
||||
register_nodes()
|
||||
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
registry = Node.get_node_type_classes_mapping()
|
||||
assert "agent" in registry
|
||||
assert "agent-v2" in registry
|
||||
|
||||
def test_different_classes(self):
|
||||
from core.workflow.node_factory import register_nodes
|
||||
|
||||
register_nodes()
|
||||
|
||||
from graphon.nodes.base.node import Node
|
||||
|
||||
registry = Node.get_node_type_classes_mapping()
|
||||
old_cls = registry["agent"]["latest"]
|
||||
new_cls = registry["agent-v2"]["latest"]
|
||||
assert old_cls is not new_cls
|
||||
|
||||
def test_default_configs_list_contains_agent_v2(self):
|
||||
"""Verify agent-v2 appears in the full default block configs list.
|
||||
|
||||
Instead of instantiating WorkflowService (which requires Flask/DB),
|
||||
we replicate the same iteration logic over the node registry.
|
||||
"""
|
||||
from core.workflow.node_factory import LATEST_VERSION, get_node_type_classes_mapping, register_nodes
|
||||
|
||||
register_nodes()
|
||||
|
||||
types_with_config: set[str] = set()
|
||||
for node_type, mapping in get_node_type_classes_mapping().items():
|
||||
node_cls = mapping.get(LATEST_VERSION)
|
||||
if node_cls:
|
||||
cfg = node_cls.get_default_config()
|
||||
if cfg and isinstance(cfg, dict):
|
||||
types_with_config.add(cfg.get("type", ""))
|
||||
|
||||
assert "agent-v2" in types_with_config
|
||||
|
||||
|
||||
class TestAgentModeWorkflowAccess:
|
||||
"""Verify AGENT mode is allowed in workflow-related API mode checks."""
|
||||
|
||||
def test_workflow_controller_allows_agent(self):
|
||||
"""Check that the workflow.py source allows AppMode.AGENT."""
|
||||
import inspect
|
||||
|
||||
from controllers.console.app import workflow
|
||||
|
||||
source = inspect.getsource(workflow)
|
||||
assert "AppMode.AGENT" in source
|
||||
|
||||
def test_service_api_chat_allows_agent(self):
|
||||
"""Check that service API chat endpoint allows AGENT mode."""
|
||||
import inspect
|
||||
|
||||
from controllers.service_api.app import completion
|
||||
|
||||
source = inspect.getsource(completion)
|
||||
assert "AppMode.AGENT" in source
|
||||
|
||||
def test_service_api_conversation_allows_agent(self):
|
||||
import inspect
|
||||
|
||||
from controllers.service_api.app import conversation
|
||||
|
||||
source = inspect.getsource(conversation)
|
||||
assert "AppMode.AGENT" in source
|
||||
@@ -97,6 +97,7 @@ class TestAppModelValidation:
|
||||
"workflow",
|
||||
"advanced-chat",
|
||||
"agent-chat",
|
||||
"agent",
|
||||
"channel",
|
||||
"rag-pipeline",
|
||||
}
|
||||
|
||||
@@ -140,10 +140,10 @@ const AppDetailLayout: FC<IAppDetailLayoutProps> = (props) => {
|
||||
router.replace(`/app/${appId}/overview`)
|
||||
return
|
||||
}
|
||||
if ((res.mode === AppModeEnum.WORKFLOW || res.mode === AppModeEnum.ADVANCED_CHAT) && (pathname).endsWith('configuration')) {
|
||||
if ((res.mode === AppModeEnum.WORKFLOW || res.mode === AppModeEnum.ADVANCED_CHAT || res.mode === AppModeEnum.AGENT) && (pathname).endsWith('configuration')) {
|
||||
router.replace(`/app/${appId}/workflow`)
|
||||
}
|
||||
else if ((res.mode !== AppModeEnum.WORKFLOW && res.mode !== AppModeEnum.ADVANCED_CHAT) && (pathname).endsWith('workflow')) {
|
||||
else if ((res.mode !== AppModeEnum.WORKFLOW && res.mode !== AppModeEnum.ADVANCED_CHAT && res.mode !== AppModeEnum.AGENT) && (pathname).endsWith('workflow')) {
|
||||
router.replace(`/app/${appId}/configuration`)
|
||||
}
|
||||
else {
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
'use client'
|
||||
|
||||
import type { AppIconSelection } from '../../base/app-icon-picker'
|
||||
import { RiArrowRightLine, RiArrowRightSLine, RiExchange2Fill } from '@remixicon/react'
|
||||
import { RiArrowRightLine, RiArrowRightSLine, RiExchange2Fill, RiRobot2Fill } from '@remixicon/react'
|
||||
|
||||
import { useDebounceFn, useKeyPress } from 'ahooks'
|
||||
import { useCallback, useEffect, useRef, useState } from 'react'
|
||||
@@ -145,6 +145,19 @@ function CreateApp({ onClose, onSuccess, onCreateFromTemplate, defaultAppMode }:
|
||||
setAppMode(AppModeEnum.ADVANCED_CHAT)
|
||||
}}
|
||||
/>
|
||||
<AppTypeCard
|
||||
active={appMode === AppModeEnum.AGENT}
|
||||
title={t('types.agent', { ns: 'app' })}
|
||||
description={t('newApp.agentV2ShortDescription', { ns: 'app' })}
|
||||
icon={(
|
||||
<div className="flex h-6 w-6 items-center justify-center rounded-md bg-components-icon-bg-violet-solid">
|
||||
<RiRobot2Fill className="h-4 w-4 text-components-avatar-shape-fill-stop-100" />
|
||||
</div>
|
||||
)}
|
||||
onClick={() => {
|
||||
setAppMode(AppModeEnum.AGENT)
|
||||
}}
|
||||
/>
|
||||
</div>
|
||||
</div>
|
||||
<div>
|
||||
@@ -357,6 +370,10 @@ function AppPreview({ mode }: { mode: AppModeEnum }) {
|
||||
title: t('types.workflow', { ns: 'app' }),
|
||||
description: t('newApp.workflowUserDescription', { ns: 'app' }),
|
||||
},
|
||||
[AppModeEnum.AGENT]: {
|
||||
title: t('types.agent', { ns: 'app' }),
|
||||
description: t('newApp.agentV2ShortDescription', { ns: 'app' }),
|
||||
},
|
||||
}
|
||||
const previewInfo = modeToPreviewInfoMap[mode]
|
||||
return (
|
||||
@@ -377,6 +394,7 @@ function AppScreenShot({ mode, show }: { mode: AppModeEnum, show: boolean }) {
|
||||
[AppModeEnum.AGENT_CHAT]: 'Agent',
|
||||
[AppModeEnum.COMPLETION]: 'TextGenerator',
|
||||
[AppModeEnum.WORKFLOW]: 'Workflow',
|
||||
[AppModeEnum.AGENT]: 'Agent',
|
||||
}
|
||||
return (
|
||||
<picture>
|
||||
|
||||
@@ -67,6 +67,7 @@ const DEFAULT_ICON_MAP: Record<BlockEnum, React.ComponentType<{ className: strin
|
||||
[BlockEnum.DocExtractor]: DocsExtractor,
|
||||
[BlockEnum.ListFilter]: ListFilter,
|
||||
[BlockEnum.Agent]: Agent,
|
||||
[BlockEnum.AgentV2]: Agent,
|
||||
[BlockEnum.KnowledgeBase]: KnowledgeBase,
|
||||
[BlockEnum.DataSource]: Datasource,
|
||||
[BlockEnum.DataSourceEmpty]: () => null,
|
||||
@@ -116,6 +117,7 @@ const ICON_CONTAINER_BG_COLOR_MAP: Record<string, string> = {
|
||||
[BlockEnum.DocExtractor]: 'bg-util-colors-green-green-500',
|
||||
[BlockEnum.ListFilter]: 'bg-util-colors-cyan-cyan-500',
|
||||
[BlockEnum.Agent]: 'bg-util-colors-indigo-indigo-500',
|
||||
[BlockEnum.AgentV2]: 'bg-util-colors-violet-violet-500',
|
||||
[BlockEnum.HumanInput]: 'bg-util-colors-cyan-cyan-500',
|
||||
[BlockEnum.KnowledgeBase]: 'bg-util-colors-warning-warning-500',
|
||||
[BlockEnum.DataSource]: 'bg-components-icon-bg-midnight-solid',
|
||||
|
||||
@@ -51,6 +51,7 @@ const singleRunFormParamsHooks: Record<BlockEnum, any> = {
|
||||
[BlockEnum.ParameterExtractor]: useParameterExtractorSingleRunFormParams,
|
||||
[BlockEnum.Iteration]: useIterationSingleRunFormParams,
|
||||
[BlockEnum.Agent]: useAgentSingleRunFormParams,
|
||||
[BlockEnum.AgentV2]: undefined,
|
||||
[BlockEnum.DocExtractor]: useDocExtractorSingleRunFormParams,
|
||||
[BlockEnum.Loop]: useLoopSingleRunFormParams,
|
||||
[BlockEnum.Start]: useStartSingleRunFormParams,
|
||||
@@ -90,6 +91,7 @@ const getDataForCheckMoreHooks: Record<BlockEnum, any> = {
|
||||
[BlockEnum.ParameterExtractor]: undefined,
|
||||
[BlockEnum.Iteration]: undefined,
|
||||
[BlockEnum.Agent]: undefined,
|
||||
[BlockEnum.AgentV2]: undefined,
|
||||
[BlockEnum.DocExtractor]: undefined,
|
||||
[BlockEnum.Loop]: undefined,
|
||||
[BlockEnum.Start]: undefined,
|
||||
|
||||
61
web/app/components/workflow/nodes/agent-v2/node.tsx
Normal file
61
web/app/components/workflow/nodes/agent-v2/node.tsx
Normal file
@@ -0,0 +1,61 @@
|
||||
import type { FC } from 'react'
|
||||
import type { NodeProps } from '../../types'
|
||||
import type { AgentV2NodeType } from './types'
|
||||
import { memo, useMemo } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import { RiRobot2Line, RiToolsFill } from '@remixicon/react'
|
||||
import { Group, GroupLabel } from '../_base/components/group'
|
||||
import { SettingItem } from '../_base/components/setting-item'
|
||||
|
||||
const strategyLabels: Record<string, string> = {
|
||||
auto: 'Auto',
|
||||
'function-calling': 'Function Calling',
|
||||
'chain-of-thought': 'ReAct (Chain of Thought)',
|
||||
}
|
||||
|
||||
const AgentV2Node: FC<NodeProps<AgentV2NodeType>> = ({ id, data }) => {
|
||||
const { t } = useTranslation()
|
||||
|
||||
const modelName = data.model?.name || ''
|
||||
const modelProvider = data.model?.provider || ''
|
||||
const strategy = data.agent_strategy || 'auto'
|
||||
const enabledTools = useMemo(() => (data.tools || []).filter(t => t.enabled), [data.tools])
|
||||
const maxIter = data.max_iterations || 10
|
||||
|
||||
return (
|
||||
<div className="mb-1 space-y-1 px-3">
|
||||
<SettingItem label={t('workflow.nodes.llm.model')}>
|
||||
<span className="system-xs-medium text-text-secondary truncate">
|
||||
{modelName || 'Not configured'}
|
||||
</span>
|
||||
</SettingItem>
|
||||
<SettingItem label="Strategy">
|
||||
<span className="system-xs-medium text-text-secondary">
|
||||
{strategyLabels[strategy] || strategy}
|
||||
</span>
|
||||
</SettingItem>
|
||||
{enabledTools.length > 0 && (
|
||||
<Group label={<GroupLabel className="mt-1"><RiToolsFill className="mr-1 inline h-3 w-3" />Tools ({enabledTools.length})</GroupLabel>}>
|
||||
<div className="flex flex-wrap gap-1">
|
||||
{enabledTools.slice(0, 6).map((tool, i) => (
|
||||
<span key={i} className="inline-flex items-center rounded bg-components-badge-bg-gray px-1.5 py-0.5 text-[11px] text-text-tertiary">
|
||||
{tool.tool_name}
|
||||
</span>
|
||||
))}
|
||||
{enabledTools.length > 6 && (
|
||||
<span className="text-[11px] text-text-quaternary">+{enabledTools.length - 6}</span>
|
||||
)}
|
||||
</div>
|
||||
</Group>
|
||||
)}
|
||||
{maxIter !== 10 && (
|
||||
<SettingItem label="Max Iterations">
|
||||
<span className="system-xs-medium text-text-secondary">{maxIter}</span>
|
||||
</SettingItem>
|
||||
)}
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
AgentV2Node.displayName = 'AgentV2Node'
|
||||
export default memo(AgentV2Node)
|
||||
139
web/app/components/workflow/nodes/agent-v2/panel.tsx
Normal file
139
web/app/components/workflow/nodes/agent-v2/panel.tsx
Normal file
@@ -0,0 +1,139 @@
|
||||
import type { FC } from 'react'
|
||||
import type { AgentV2NodeType } from './types'
|
||||
import type { NodePanelProps } from '@/app/components/workflow/types'
|
||||
import { memo, useCallback } from 'react'
|
||||
import { useTranslation } from 'react-i18next'
|
||||
import Field from '@/app/components/workflow/nodes/_base/components/field'
|
||||
import Split from '@/app/components/workflow/nodes/_base/components/split'
|
||||
import { useNodeDataUpdate } from '../../hooks/use-node-data-update'
|
||||
|
||||
const strategyOptions = [
|
||||
{ value: 'auto', label: 'Auto (based on model capability)' },
|
||||
{ value: 'function-calling', label: 'Function Calling' },
|
||||
{ value: 'chain-of-thought', label: 'ReAct (Chain of Thought)' },
|
||||
]
|
||||
|
||||
const Panel: FC<NodePanelProps<AgentV2NodeType>> = ({ id, data }) => {
|
||||
const { t } = useTranslation()
|
||||
const { handleNodeDataUpdate } = useNodeDataUpdate()
|
||||
|
||||
const updateData = useCallback((patch: Partial<AgentV2NodeType>) => {
|
||||
handleNodeDataUpdate({ id, data: patch as any })
|
||||
}, [id, handleNodeDataUpdate])
|
||||
|
||||
const inputs = data as AgentV2NodeType
|
||||
|
||||
return (
|
||||
<div className="space-y-4 px-4 pb-4 pt-2">
|
||||
{/* Model */}
|
||||
<Field title={t('workflow.nodes.llm.model')}>
|
||||
<div className="rounded-lg border border-divider-subtle px-3 py-2 text-[13px] text-text-secondary">
|
||||
{inputs.model?.name
|
||||
? `${inputs.model.provider?.split('/').pop()} / ${inputs.model.name}`
|
||||
: 'Not configured'}
|
||||
</div>
|
||||
</Field>
|
||||
|
||||
<Split />
|
||||
|
||||
{/* Strategy */}
|
||||
<Field title="Agent Strategy">
|
||||
<select
|
||||
className="w-full rounded-lg border border-components-input-border-active bg-transparent px-3 py-1.5 text-[13px] text-text-secondary"
|
||||
value={inputs.agent_strategy || 'auto'}
|
||||
onChange={e => updateData({ agent_strategy: e.target.value as any })}
|
||||
>
|
||||
{strategyOptions.map(opt => (
|
||||
<option key={opt.value} value={opt.value}>{opt.label}</option>
|
||||
))}
|
||||
</select>
|
||||
</Field>
|
||||
|
||||
{/* Max Iterations */}
|
||||
<Field title="Max Iterations">
|
||||
<input
|
||||
type="number"
|
||||
min={1}
|
||||
max={99}
|
||||
className="w-full rounded-lg border border-components-input-border-active bg-transparent px-3 py-1.5 text-[13px] text-text-secondary"
|
||||
value={inputs.max_iterations || 10}
|
||||
onChange={e => updateData({ max_iterations: parseInt(e.target.value) || 10 })}
|
||||
/>
|
||||
</Field>
|
||||
|
||||
<Split />
|
||||
|
||||
{/* Tools */}
|
||||
<Field title={`Tools (${(inputs.tools || []).filter(t => t.enabled).length})`}>
|
||||
<div className="space-y-2">
|
||||
{(inputs.tools || []).map((tool, idx) => (
|
||||
<div key={idx} className="flex items-center justify-between rounded-lg border border-divider-subtle px-3 py-2">
|
||||
<div className="flex items-center gap-2">
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={tool.enabled}
|
||||
onChange={e => {
|
||||
const tools = [...(inputs.tools || [])]
|
||||
tools[idx] = { ...tools[idx], enabled: e.target.checked }
|
||||
updateData({ tools })
|
||||
}}
|
||||
className="h-4 w-4"
|
||||
/>
|
||||
<span className="text-[13px] text-text-secondary">{tool.tool_name}</span>
|
||||
</div>
|
||||
<span className="text-[11px] text-text-quaternary">{tool.provider_name?.split('/').pop()}</span>
|
||||
</div>
|
||||
))}
|
||||
{(inputs.tools || []).length === 0 && (
|
||||
<div className="py-3 text-center text-[13px] text-text-quaternary">
|
||||
No tools configured
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</Field>
|
||||
|
||||
<Split />
|
||||
|
||||
{/* Memory */}
|
||||
<Field title="Memory">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-[13px] text-text-secondary">Window Size</span>
|
||||
<input
|
||||
type="number"
|
||||
min={1}
|
||||
max={200}
|
||||
className="w-20 rounded-lg border border-components-input-border-active bg-transparent px-2 py-1 text-center text-[13px] text-text-secondary"
|
||||
value={inputs.memory?.window?.size || 50}
|
||||
onChange={e => updateData({
|
||||
memory: {
|
||||
role_prefix: inputs.memory?.role_prefix,
|
||||
query_prompt_template: inputs.memory?.query_prompt_template,
|
||||
window: { enabled: true, size: parseInt(e.target.value) || 50 },
|
||||
},
|
||||
})}
|
||||
/>
|
||||
</div>
|
||||
</Field>
|
||||
|
||||
<Split />
|
||||
|
||||
{/* Vision */}
|
||||
<Field title="Vision">
|
||||
<div className="flex items-center justify-between">
|
||||
<span className="text-[13px] text-text-secondary">Enable image understanding</span>
|
||||
<input
|
||||
type="checkbox"
|
||||
checked={inputs.vision?.enabled || false}
|
||||
onChange={e => updateData({
|
||||
vision: { ...inputs.vision, enabled: e.target.checked },
|
||||
})}
|
||||
className="h-4 w-4"
|
||||
/>
|
||||
</div>
|
||||
</Field>
|
||||
</div>
|
||||
)
|
||||
}
|
||||
|
||||
Panel.displayName = 'AgentV2Panel'
|
||||
export default memo(Panel)
|
||||
32
web/app/components/workflow/nodes/agent-v2/types.ts
Normal file
32
web/app/components/workflow/nodes/agent-v2/types.ts
Normal file
@@ -0,0 +1,32 @@
|
||||
import type { CommonNodeType, Memory, ModelConfig, PromptItem, ValueSelector, VisionSetting } from '@/app/components/workflow/types'
|
||||
|
||||
export type ToolMetadata = {
|
||||
enabled: boolean
|
||||
type: string
|
||||
provider_name: string
|
||||
tool_name: string
|
||||
plugin_unique_identifier?: string
|
||||
credential_id?: string
|
||||
parameters: Record<string, any>
|
||||
settings: Record<string, any>
|
||||
extra: Record<string, any>
|
||||
}
|
||||
|
||||
export type AgentV2NodeType = CommonNodeType & {
|
||||
model: ModelConfig
|
||||
prompt_template: PromptItem[] | PromptItem
|
||||
tools: ToolMetadata[]
|
||||
max_iterations: number
|
||||
agent_strategy: 'auto' | 'function-calling' | 'chain-of-thought'
|
||||
memory?: Memory
|
||||
context: {
|
||||
enabled: boolean
|
||||
variable_selector?: ValueSelector
|
||||
}
|
||||
vision: {
|
||||
enabled: boolean
|
||||
configs?: VisionSetting
|
||||
}
|
||||
structured_output_enabled?: boolean
|
||||
structured_output?: Record<string, any>
|
||||
}
|
||||
@@ -2,6 +2,8 @@ import type { ComponentType } from 'react'
|
||||
import { BlockEnum } from '../types'
|
||||
import AgentNode from './agent/node'
|
||||
import AgentPanel from './agent/panel'
|
||||
import AgentV2Node from './agent-v2/node'
|
||||
import AgentV2Panel from './agent-v2/panel'
|
||||
import AnswerNode from './answer/node'
|
||||
import AnswerPanel from './answer/panel'
|
||||
import AssignerNode from './assigner/node'
|
||||
@@ -72,6 +74,7 @@ export const NodeComponentMap: Record<string, ComponentType<any>> = {
|
||||
[BlockEnum.DocExtractor]: DocExtractorNode,
|
||||
[BlockEnum.ListFilter]: ListFilterNode,
|
||||
[BlockEnum.Agent]: AgentNode,
|
||||
[BlockEnum.AgentV2]: AgentV2Node,
|
||||
[BlockEnum.DataSource]: DataSourceNode,
|
||||
[BlockEnum.KnowledgeBase]: KnowledgeBaseNode,
|
||||
[BlockEnum.HumanInput]: HumanInputNode,
|
||||
@@ -101,6 +104,7 @@ export const PanelComponentMap: Record<string, ComponentType<any>> = {
|
||||
[BlockEnum.DocExtractor]: DocExtractorPanel,
|
||||
[BlockEnum.ListFilter]: ListFilterPanel,
|
||||
[BlockEnum.Agent]: AgentPanel,
|
||||
[BlockEnum.AgentV2]: AgentV2Panel,
|
||||
[BlockEnum.DataSource]: DataSourcePanel,
|
||||
[BlockEnum.KnowledgeBase]: KnowledgeBasePanel,
|
||||
[BlockEnum.HumanInput]: HumanInputPanel,
|
||||
|
||||
@@ -46,6 +46,7 @@ export enum BlockEnum {
|
||||
IterationStart = 'iteration-start',
|
||||
Assigner = 'assigner', // is now named as VariableAssigner
|
||||
Agent = 'agent',
|
||||
AgentV2 = 'agent-v2',
|
||||
Loop = 'loop',
|
||||
LoopStart = 'loop-start',
|
||||
LoopEnd = 'loop-end',
|
||||
|
||||
@@ -135,6 +135,7 @@
|
||||
"newApp.advancedUserDescription": "Workflow with additional memory features and a chatbot interface.",
|
||||
"newApp.agentAssistant": "New Agent Assistant",
|
||||
"newApp.agentShortDescription": "Intelligent agent with reasoning and autonomous tool use",
|
||||
"newApp.agentV2ShortDescription": "Next-gen agent with tools, sandbox, and workflow integration",
|
||||
"newApp.agentUserDescription": "An intelligent agent capable of iterative reasoning and autonomous tool use to achieve task goals.",
|
||||
"newApp.appCreateDSLErrorPart1": "A significant difference in DSL versions has been detected. Forcing the import may cause the application to malfunction.",
|
||||
"newApp.appCreateDSLErrorPart2": "Do you want to continue?",
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
{
|
||||
"blocks.agent": "Agent",
|
||||
"blocks.agent-v2": "Agent V2",
|
||||
"blocks.answer": "Answer",
|
||||
"blocks.assigner": "Variable Assigner",
|
||||
"blocks.code": "Code",
|
||||
@@ -31,6 +32,7 @@
|
||||
"blocks.variable-aggregator": "Variable Aggregator",
|
||||
"blocks.variable-assigner": "Variable Aggregator",
|
||||
"blocksAbout.agent": "Invoking large language models to answer questions or process natural language",
|
||||
"blocksAbout.agent-v2": "Next-gen agent with LLM, tools, sandbox execution, and configurable strategies",
|
||||
"blocksAbout.answer": "Define the reply content of a chat conversation",
|
||||
"blocksAbout.assigner": "The variable assignment node is used for assigning values to writable variables(like conversation variables).",
|
||||
"blocksAbout.code": "Execute a piece of Python or NodeJS code to implement custom logic",
|
||||
|
||||
@@ -135,6 +135,7 @@
|
||||
"newApp.advancedUserDescription": "基于工作流编排,适用于定义等复杂流程的多轮对话场景,具有记忆功能。",
|
||||
"newApp.agentAssistant": "新的智能助手",
|
||||
"newApp.agentShortDescription": "具备推理与自主工具调用的智能助手",
|
||||
"newApp.agentV2ShortDescription": "新一代 Agent,支持工具调用、沙箱执行和 Workflow 集成",
|
||||
"newApp.agentUserDescription": "能够迭代式的规划推理、自主工具调用,直至完成任务目标的智能助手。",
|
||||
"newApp.appCreateDSLErrorPart1": "检测到 DSL 版本差异较大,强制导入应用可能无法正常运行。",
|
||||
"newApp.appCreateDSLErrorPart2": "是否继续?",
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
{
|
||||
"blocks.agent": "Agent",
|
||||
"blocks.agent-v2": "Agent V2",
|
||||
"blocks.answer": "直接回复",
|
||||
"blocks.assigner": "变量赋值",
|
||||
"blocks.code": "代码执行",
|
||||
@@ -31,6 +32,7 @@
|
||||
"blocks.variable-aggregator": "变量聚合器",
|
||||
"blocks.variable-assigner": "变量赋值器",
|
||||
"blocksAbout.agent": "调用大型语言模型回答问题或处理自然语言",
|
||||
"blocksAbout.agent-v2": "新一代 Agent,支持 LLM、工具调用、沙箱执行和可配置策略",
|
||||
"blocksAbout.answer": "定义一个聊天对话的回复内容",
|
||||
"blocksAbout.assigner": "变量赋值节点用于向可写入变量(例如会话变量)进行变量赋值。",
|
||||
"blocksAbout.code": "执行一段 Python 或 NodeJS 代码实现自定义逻辑",
|
||||
|
||||
@@ -44,8 +44,9 @@ export enum AppModeEnum {
|
||||
CHAT = 'chat',
|
||||
ADVANCED_CHAT = 'advanced-chat',
|
||||
AGENT_CHAT = 'agent-chat',
|
||||
AGENT = 'agent',
|
||||
}
|
||||
export const AppModes = [AppModeEnum.COMPLETION, AppModeEnum.WORKFLOW, AppModeEnum.CHAT, AppModeEnum.ADVANCED_CHAT, AppModeEnum.AGENT_CHAT] as const
|
||||
export const AppModes = [AppModeEnum.COMPLETION, AppModeEnum.WORKFLOW, AppModeEnum.CHAT, AppModeEnum.ADVANCED_CHAT, AppModeEnum.AGENT_CHAT, AppModeEnum.AGENT] as const
|
||||
|
||||
/**
|
||||
* Variable type
|
||||
|
||||
@@ -8,7 +8,7 @@ export const getRedirectionPath = (
|
||||
return `/app/${app.id}/overview`
|
||||
}
|
||||
else {
|
||||
if (app.mode === AppModeEnum.WORKFLOW || app.mode === AppModeEnum.ADVANCED_CHAT)
|
||||
if (app.mode === AppModeEnum.WORKFLOW || app.mode === AppModeEnum.ADVANCED_CHAT || app.mode === AppModeEnum.AGENT)
|
||||
return `/app/${app.id}/workflow`
|
||||
else
|
||||
return `/app/${app.id}/configuration`
|
||||
|
||||
Reference in New Issue
Block a user