mirror of
https://github.com/langgenius/dify.git
synced 2026-02-06 08:08:57 +00:00
Compare commits
39 Commits
deploy/dev
...
feat/hitl
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3d0ff9463f | ||
|
|
b893d2df82 | ||
|
|
79b6117d80 | ||
|
|
d2ef434dec | ||
|
|
d9530f7bb7 | ||
|
|
b24e6edada | ||
|
|
59a9cbbf78 | ||
|
|
45164ce33e | ||
|
|
095b3ee234 | ||
|
|
cb970e54da | ||
|
|
e04f2a0786 | ||
|
|
7202a24bcf | ||
|
|
be8f265e43 | ||
|
|
aaf83c2b4c | ||
|
|
9e54f086dc | ||
|
|
d898bcff90 | ||
|
|
b4cf146c85 | ||
|
|
8c31b69c8e | ||
|
|
b886b3f6c8 | ||
|
|
ef0d18bb61 | ||
|
|
f21782a9a3 | ||
|
|
e4455987e7 | ||
|
|
b2ceb41dd6 | ||
|
|
c56ad8e323 | ||
|
|
365f749ed5 | ||
|
|
f686197589 | ||
|
|
f584be9cf0 | ||
|
|
3bd228ddb7 | ||
|
|
0dfa59b1db | ||
|
|
1e344f773b | ||
|
|
bba2040a05 | ||
|
|
ad3be1e4d0 | ||
|
|
297dd832aa | ||
|
|
cc5705cb71 | ||
|
|
74b027c41a | ||
|
|
5f69470ebf | ||
|
|
ec7ccd800c | ||
|
|
f614153f30 | ||
|
|
8ca020e179 |
@@ -1 +0,0 @@
|
||||
../../.agents/skills/component-refactoring
|
||||
@@ -1 +0,0 @@
|
||||
../../.agents/skills/frontend-code-review
|
||||
@@ -1 +0,0 @@
|
||||
../../.agents/skills/frontend-testing
|
||||
@@ -1 +0,0 @@
|
||||
../../.agents/skills/orpc-contract-first
|
||||
7
.github/CODEOWNERS
vendored
7
.github/CODEOWNERS
vendored
@@ -24,6 +24,10 @@
|
||||
/api/services/tools/mcp_tools_manage_service.py @Nov1c444
|
||||
/api/controllers/mcp/ @Nov1c444
|
||||
/api/controllers/console/app/mcp_server.py @Nov1c444
|
||||
|
||||
# Backend - Tests
|
||||
/api/tests/ @laipz8200 @QuantumGhost
|
||||
|
||||
/api/tests/**/*mcp* @Nov1c444
|
||||
|
||||
# Backend - Workflow - Engine (Core graph execution engine)
|
||||
@@ -234,6 +238,9 @@
|
||||
# Frontend - Base Components
|
||||
/web/app/components/base/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Base Components Tests
|
||||
/web/app/components/base/**/*.spec.tsx @hyoban @CodingOnStar
|
||||
|
||||
# Frontend - Utils and Hooks
|
||||
/web/utils/classnames.ts @iamjoel @zxhlyh
|
||||
/web/utils/time.ts @iamjoel @zxhlyh
|
||||
|
||||
23
.github/workflows/autofix.yml
vendored
23
.github/workflows/autofix.yml
vendored
@@ -79,29 +79,6 @@ jobs:
|
||||
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
|
||||
find . -name "*.py.bak" -type f -delete
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install web dependencies
|
||||
run: |
|
||||
cd web
|
||||
pnpm install --frozen-lockfile
|
||||
|
||||
- name: ESLint autofix
|
||||
run: |
|
||||
cd web
|
||||
pnpm lint:fix || true
|
||||
|
||||
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
|
||||
- name: mdformat
|
||||
run: |
|
||||
|
||||
1
.github/workflows/build-push.yml
vendored
1
.github/workflows/build-push.yml
vendored
@@ -8,6 +8,7 @@ on:
|
||||
- "build/**"
|
||||
- "release/e-*"
|
||||
- "hotfix/**"
|
||||
- "feat/hitl-backend"
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
||||
2
.github/workflows/web-tests.yml
vendored
2
.github/workflows/web-tests.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run tests
|
||||
run: pnpm test:coverage
|
||||
run: pnpm test:ci
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
|
||||
@@ -717,3 +717,28 @@ SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000
|
||||
|
||||
|
||||
# Redis URL used for PubSub between API and
|
||||
# celery worker
|
||||
# defaults to url constructed from `REDIS_*`
|
||||
# configurations
|
||||
PUBSUB_REDIS_URL=
|
||||
# Pub/sub channel type for streaming events.
|
||||
# valid options are:
|
||||
#
|
||||
# - pubsub: for normal Pub/Sub
|
||||
# - sharded: for sharded Pub/Sub
|
||||
#
|
||||
# It's highly recommended to use sharded Pub/Sub AND redis cluster
|
||||
# for large deployments.
|
||||
PUBSUB_REDIS_CHANNEL_TYPE=pubsub
|
||||
# Whether to use Redis cluster mode while running
|
||||
# PubSub.
|
||||
# It's highly recommended to enable this for large deployments.
|
||||
PUBSUB_REDIS_USE_CLUSTERS=false
|
||||
|
||||
# Whether to Enable human input timeout check task
|
||||
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true
|
||||
# Human input timeout check interval in minutes
|
||||
HUMAN_INPUT_TIMEOUT_TASK_INTERVAL=1
|
||||
|
||||
@@ -36,6 +36,8 @@ ignore_imports =
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
||||
# TODO(QuantumGhost): fix the import violation later
|
||||
core.workflow.entities.pause_reason -> core.workflow.nodes.human_input.entities
|
||||
|
||||
[importlinter:contract:workflow-infrastructure-dependencies]
|
||||
name = Workflow Infrastructure Dependencies
|
||||
@@ -58,6 +60,8 @@ ignore_imports =
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||
# TODO(QuantumGhost): use DI to avoid depending on global DB.
|
||||
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
|
||||
|
||||
[importlinter:contract:workflow-external-imports]
|
||||
name = Workflow External Imports
|
||||
@@ -102,6 +106,8 @@ forbidden_modules =
|
||||
core.trigger
|
||||
core.variables
|
||||
ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> core.db.session_factory
|
||||
core.workflow.nodes.agent.agent_node -> models.tools
|
||||
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.workflow_entry -> core.app.workflow.layers.observability
|
||||
@@ -136,7 +142,6 @@ ignore_imports =
|
||||
core.workflow.nodes.llm.llm_utils -> models.provider
|
||||
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
|
||||
core.workflow.nodes.llm.node -> core.tools.signature
|
||||
core.workflow.nodes.template_transform.template_transform_node -> configs
|
||||
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.tool_manager
|
||||
@@ -145,6 +150,7 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> core.agent.entities
|
||||
core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities
|
||||
core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities
|
||||
@@ -248,6 +254,7 @@ ignore_imports =
|
||||
core.workflow.nodes.document_extractor.node -> core.variables.segments
|
||||
core.workflow.nodes.http_request.executor -> core.variables.segments
|
||||
core.workflow.nodes.http_request.node -> core.variables.segments
|
||||
core.workflow.nodes.human_input.entities -> core.variables.consts
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables.segments
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables.variables
|
||||
@@ -294,6 +301,8 @@ ignore_imports =
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
|
||||
core.workflow.nodes.human_input.human_input_node -> core.repositories.human_input_repository
|
||||
core.workflow.workflow_entry -> extensions.otel.runtime
|
||||
core.workflow.nodes.agent.agent_node -> models
|
||||
core.workflow.nodes.base.node -> models.enums
|
||||
|
||||
@@ -739,8 +739,10 @@ def upgrade_db():
|
||||
|
||||
click.echo(click.style("Database migration successful!", fg="green"))
|
||||
|
||||
except Exception:
|
||||
except Exception as e:
|
||||
logger.exception("Failed to execute database migration")
|
||||
click.echo(click.style(f"Database migration failed: {e}", fg="red"))
|
||||
raise SystemExit(1)
|
||||
finally:
|
||||
lock.release()
|
||||
else:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from datetime import timedelta
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
@@ -48,6 +49,16 @@ class SecurityConfig(BaseSettings):
|
||||
default=5,
|
||||
)
|
||||
|
||||
WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS: PositiveInt = Field(
|
||||
description="Maximum number of web form submissions allowed per IP within the rate limit window",
|
||||
default=30,
|
||||
)
|
||||
|
||||
WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS: PositiveInt = Field(
|
||||
description="Time window in seconds for web form submission rate limiting",
|
||||
default=60,
|
||||
)
|
||||
|
||||
LOGIN_DISABLED: bool = Field(
|
||||
description="Whether to disable login checks",
|
||||
default=False,
|
||||
@@ -82,6 +93,12 @@ class AppExecutionConfig(BaseSettings):
|
||||
default=0,
|
||||
)
|
||||
|
||||
HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS: PositiveInt = Field(
|
||||
description="Maximum seconds a workflow run can stay paused waiting for human input before global timeout.",
|
||||
default=int(timedelta(days=7).total_seconds()),
|
||||
ge=1,
|
||||
)
|
||||
|
||||
|
||||
class CodeExecutionSandboxConfig(BaseSettings):
|
||||
"""
|
||||
@@ -1134,6 +1151,14 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable queue monitor task",
|
||||
default=False,
|
||||
)
|
||||
ENABLE_HUMAN_INPUT_TIMEOUT_TASK: bool = Field(
|
||||
description="Enable human input timeout check task",
|
||||
default=True,
|
||||
)
|
||||
HUMAN_INPUT_TIMEOUT_TASK_INTERVAL: PositiveInt = Field(
|
||||
description="Human input timeout check interval in minutes",
|
||||
default=1,
|
||||
)
|
||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: bool = Field(
|
||||
description="Enable check upgradable plugin task",
|
||||
default=True,
|
||||
|
||||
@@ -6,6 +6,7 @@ from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, Pos
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from .cache.redis_config import RedisConfig
|
||||
from .cache.redis_pubsub_config import RedisPubSubConfig
|
||||
from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
|
||||
from .storage.amazon_s3_storage_config import S3StorageConfig
|
||||
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||
@@ -317,6 +318,7 @@ class MiddlewareConfig(
|
||||
CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig
|
||||
KeywordStoreConfig,
|
||||
RedisConfig,
|
||||
RedisPubSubConfig,
|
||||
# configs of storage and storage providers
|
||||
StorageConfig,
|
||||
AliyunOSSStorageConfig,
|
||||
|
||||
96
api/configs/middleware/cache/redis_pubsub_config.py
vendored
Normal file
96
api/configs/middleware/cache/redis_pubsub_config.py
vendored
Normal file
@@ -0,0 +1,96 @@
|
||||
from typing import Literal, Protocol
|
||||
from urllib.parse import quote_plus, urlunparse
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class RedisConfigDefaults(Protocol):
|
||||
REDIS_HOST: str
|
||||
REDIS_PORT: int
|
||||
REDIS_USERNAME: str | None
|
||||
REDIS_PASSWORD: str | None
|
||||
REDIS_DB: int
|
||||
REDIS_USE_SSL: bool
|
||||
REDIS_USE_SENTINEL: bool | None
|
||||
REDIS_USE_CLUSTERS: bool
|
||||
|
||||
|
||||
class RedisConfigDefaultsMixin:
|
||||
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
|
||||
return self
|
||||
|
||||
|
||||
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||
"""
|
||||
Configuration settings for Redis pub/sub streaming.
|
||||
"""
|
||||
|
||||
PUBSUB_REDIS_URL: str | None = Field(
|
||||
alias="PUBSUB_REDIS_URL",
|
||||
description=(
|
||||
"Redis connection URL for pub/sub streaming events between API "
|
||||
"and celery worker, defaults to url constructed from "
|
||||
"`REDIS_*` configurations"
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
|
||||
PUBSUB_REDIS_USE_CLUSTERS: bool = Field(
|
||||
description=(
|
||||
"Enable Redis Cluster mode for pub/sub streaming. It's highly "
|
||||
"recommended to enable this for large deployments."
|
||||
),
|
||||
default=False,
|
||||
)
|
||||
|
||||
PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded"] = Field(
|
||||
description=(
|
||||
"Pub/sub channel type for streaming events. "
|
||||
"Valid options are:\n"
|
||||
"\n"
|
||||
" - pubsub: for normal Pub/Sub\n"
|
||||
" - sharded: for sharded Pub/Sub\n"
|
||||
"\n"
|
||||
"It's highly recommended to use sharded Pub/Sub AND redis cluster "
|
||||
"for large deployments."
|
||||
),
|
||||
default="pubsub",
|
||||
)
|
||||
|
||||
def _build_default_pubsub_url(self) -> str:
|
||||
defaults = self._redis_defaults()
|
||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
|
||||
|
||||
scheme = "rediss" if defaults.REDIS_USE_SSL else "redis"
|
||||
username = defaults.REDIS_USERNAME or None
|
||||
password = defaults.REDIS_PASSWORD or None
|
||||
|
||||
userinfo = ""
|
||||
if username:
|
||||
userinfo = quote_plus(username)
|
||||
if password:
|
||||
password_part = quote_plus(password)
|
||||
userinfo = f"{userinfo}:{password_part}" if userinfo else f":{password_part}"
|
||||
if userinfo:
|
||||
userinfo = f"{userinfo}@"
|
||||
|
||||
host = defaults.REDIS_HOST
|
||||
port = defaults.REDIS_PORT
|
||||
db = defaults.REDIS_DB
|
||||
|
||||
netloc = f"{userinfo}{host}:{port}"
|
||||
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
|
||||
|
||||
@property
|
||||
def normalized_pubsub_redis_url(self) -> str:
|
||||
pubsub_redis_url = self.PUBSUB_REDIS_URL
|
||||
if pubsub_redis_url:
|
||||
cleaned = pubsub_redis_url.strip()
|
||||
pubsub_redis_url = cleaned or None
|
||||
|
||||
if pubsub_redis_url:
|
||||
return pubsub_redis_url
|
||||
|
||||
return self._build_default_pubsub_url()
|
||||
@@ -37,6 +37,7 @@ from . import (
|
||||
apikey,
|
||||
extension,
|
||||
feature,
|
||||
human_input_form,
|
||||
init_validate,
|
||||
ping,
|
||||
setup,
|
||||
@@ -171,6 +172,7 @@ __all__ = [
|
||||
"forgot_password",
|
||||
"generator",
|
||||
"hit_testing",
|
||||
"human_input_form",
|
||||
"init_validate",
|
||||
"installed_app",
|
||||
"load_balancing_config",
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Any, Literal, TypeAlias
|
||||
@@ -54,6 +55,8 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co
|
||||
|
||||
register_enum_models(console_ns, IconType)
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class AppListQuery(BaseModel):
|
||||
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
|
||||
@@ -499,6 +502,7 @@ class AppListApi(Resource):
|
||||
select(Workflow).where(
|
||||
Workflow.version == Workflow.VERSION_DRAFT,
|
||||
Workflow.app_id.in_(workflow_capable_app_ids),
|
||||
Workflow.tenant_id == current_tenant_id,
|
||||
)
|
||||
)
|
||||
.scalars()
|
||||
@@ -510,12 +514,14 @@ class AppListApi(Resource):
|
||||
NodeType.TRIGGER_PLUGIN,
|
||||
}
|
||||
for workflow in draft_workflows:
|
||||
node_id = None
|
||||
try:
|
||||
for _, node_data in workflow.walk_nodes():
|
||||
for node_id, node_data in workflow.walk_nodes():
|
||||
if node_data.get("type") in trigger_node_types:
|
||||
draft_trigger_app_ids.add(str(workflow.app_id))
|
||||
break
|
||||
except Exception:
|
||||
_logger.exception("error while walking nodes, workflow_id=%s, node_id=%s", workflow.id, node_id)
|
||||
continue
|
||||
|
||||
for app in app_pagination.items:
|
||||
|
||||
@@ -89,6 +89,7 @@ status_count_model = console_ns.model(
|
||||
"success": fields.Integer,
|
||||
"failed": fields.Integer,
|
||||
"partial_success": fields.Integer,
|
||||
"paused": fields.Integer,
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -33,7 +33,7 @@ from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService
|
||||
from services.message_service import MessageService, attach_message_extra_contents
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -207,6 +207,7 @@ message_detail_model = console_ns.model(
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||
"extra_contents": fields.List(fields.Raw),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
@@ -299,6 +300,7 @@ class ChatMessageListApi(Resource):
|
||||
has_more = False
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
attach_message_extra_contents(history_messages)
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
|
||||
|
||||
@@ -481,4 +483,5 @@ class MessageApi(Resource):
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
attach_message_extra_contents([message])
|
||||
return message
|
||||
|
||||
@@ -507,6 +507,179 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class HumanInputFormPreviewPayload(BaseModel):
|
||||
inputs: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Values used to fill missing upstream variables referenced in form_content",
|
||||
)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
form_inputs: dict[str, Any] = Field(..., description="Values the user provides for the form's own fields")
|
||||
inputs: dict[str, Any] = Field(
|
||||
...,
|
||||
description="Values used to fill missing upstream variables referenced in form_content",
|
||||
)
|
||||
action: str = Field(..., description="Selected action ID")
|
||||
|
||||
|
||||
class HumanInputDeliveryTestPayload(BaseModel):
|
||||
delivery_method_id: str = Field(..., description="Delivery method ID")
|
||||
inputs: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Values used to fill missing upstream variables referenced in form_content",
|
||||
)
|
||||
|
||||
|
||||
reg(HumanInputFormPreviewPayload)
|
||||
reg(HumanInputFormSubmitPayload)
|
||||
reg(HumanInputDeliveryTestPayload)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/human-input/nodes/<string:node_id>/form/preview")
|
||||
class AdvancedChatDraftHumanInputFormPreviewApi(Resource):
|
||||
@console_ns.doc("get_advanced_chat_draft_human_input_form")
|
||||
@console_ns.doc(description="Get human input form preview for advanced chat workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Preview human input form content and placeholders
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
|
||||
inputs = args.inputs
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
preview = workflow_service.get_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
node_id=node_id,
|
||||
inputs=inputs,
|
||||
)
|
||||
return jsonable_encoder(preview)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/human-input/nodes/<string:node_id>/form/run")
|
||||
class AdvancedChatDraftHumanInputFormRunApi(Resource):
|
||||
@console_ns.doc("submit_advanced_chat_draft_human_input_form")
|
||||
@console_ns.doc(description="Submit human input form preview for advanced chat workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Submit human input form preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
|
||||
workflow_service = WorkflowService()
|
||||
result = workflow_service.submit_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
node_id=node_id,
|
||||
form_inputs=args.form_inputs,
|
||||
inputs=args.inputs,
|
||||
action=args.action,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/form/preview")
|
||||
class WorkflowDraftHumanInputFormPreviewApi(Resource):
|
||||
@console_ns.doc("get_workflow_draft_human_input_form")
|
||||
@console_ns.doc(description="Get human input form preview for workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Preview human input form content and placeholders
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
|
||||
inputs = args.inputs
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
preview = workflow_service.get_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
node_id=node_id,
|
||||
inputs=inputs,
|
||||
)
|
||||
return jsonable_encoder(preview)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/form/run")
|
||||
class WorkflowDraftHumanInputFormRunApi(Resource):
|
||||
@console_ns.doc("submit_workflow_draft_human_input_form")
|
||||
@console_ns.doc(description="Submit human input form preview for workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Submit human input form preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
|
||||
result = workflow_service.submit_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
node_id=node_id,
|
||||
form_inputs=args.form_inputs,
|
||||
inputs=args.inputs,
|
||||
action=args.action,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/delivery-test")
|
||||
class WorkflowDraftHumanInputDeliveryTestApi(Resource):
|
||||
@console_ns.doc("test_workflow_draft_human_input_delivery")
|
||||
@console_ns.doc(description="Test human input delivery for workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[HumanInputDeliveryTestPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Test human input delivery
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
args = HumanInputDeliveryTestPayload.model_validate(console_ns.payload or {})
|
||||
workflow_service.test_human_input_delivery(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
node_id=node_id,
|
||||
delivery_method_id=args.delivery_method_id,
|
||||
inputs=args.inputs,
|
||||
)
|
||||
return jsonable_encoder({})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/run")
|
||||
class DraftWorkflowRunApi(Resource):
|
||||
@console_ns.doc("run_draft_workflow")
|
||||
|
||||
@@ -5,10 +5,15 @@ from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
from fields.member_fields import simple_account_fields
|
||||
@@ -27,9 +32,21 @@ from libs.custom_inputs import time_duration
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
|
||||
def _build_backstage_input_url(form_token: str | None) -> str | None:
|
||||
if not form_token:
|
||||
return None
|
||||
base_url = dify_config.APP_WEB_URL
|
||||
if not base_url:
|
||||
return None
|
||||
return f"{base_url.rstrip('/')}/form/{form_token}"
|
||||
|
||||
|
||||
# Workflow run status choices for filtering
|
||||
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
|
||||
EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600
|
||||
@@ -440,3 +457,63 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
||||
)
|
||||
|
||||
return {"data": node_executions}
|
||||
|
||||
|
||||
@console_ns.route("/workflow/<string:workflow_run_id>/pause-details")
|
||||
class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
"""Console API for getting workflow pause details."""
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def get(self, workflow_run_id: str):
|
||||
"""
|
||||
Get workflow pause details.
|
||||
|
||||
GET /console/api/workflow/<workflow_run_id>/pause-details
|
||||
|
||||
Returns information about why and where the workflow is paused.
|
||||
"""
|
||||
|
||||
# Query WorkflowRun to determine if workflow is suspended
|
||||
session_maker = sessionmaker(bind=db.engine)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_maker)
|
||||
workflow_run = db.session.get(WorkflowRun, workflow_run_id)
|
||||
if not workflow_run:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
# Check if workflow is suspended
|
||||
is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
if not is_paused:
|
||||
return {
|
||||
"paused_at": None,
|
||||
"paused_nodes": [],
|
||||
}, 200
|
||||
|
||||
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
|
||||
pause_reasons = pause_entity.get_pause_reasons() if pause_entity else []
|
||||
|
||||
# Build response
|
||||
paused_at = pause_entity.paused_at if pause_entity else None
|
||||
paused_nodes = []
|
||||
response = {
|
||||
"paused_at": paused_at.isoformat() + "Z" if paused_at else None,
|
||||
"paused_nodes": paused_nodes,
|
||||
}
|
||||
|
||||
for reason in pause_reasons:
|
||||
if isinstance(reason, HumanInputRequired):
|
||||
paused_nodes.append(
|
||||
{
|
||||
"node_id": reason.node_id,
|
||||
"node_title": reason.node_title,
|
||||
"pause_type": {
|
||||
"type": "human_input",
|
||||
"form_id": reason.form_id,
|
||||
"backstage_input_url": _build_backstage_input_url(reason.form_token),
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise AssertionError("unimplemented.")
|
||||
|
||||
return response, 200
|
||||
|
||||
217
api/controllers/console/human_input_form.py
Normal file
217
api/controllers/console/human_input_form.py
Normal file
@@ -0,0 +1,217 @@
|
||||
"""
|
||||
Console/Studio Human Input Form APIs.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, jsonify, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.enums import CreatorUserRole
|
||||
from models.human_input import RecipientType
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.human_input_service import Form, HumanInputService
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
payload = form.get_definition().model_dump()
|
||||
payload["expiration_time"] = int(form.expiration_time.timestamp())
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
@console_ns.route("/form/human_input/<string:form_token>")
|
||||
class ConsoleHumanInputFormApi(Resource):
|
||||
"""Console API for getting human input form definition."""
|
||||
|
||||
@staticmethod
|
||||
def _ensure_console_access(form: Form):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if form.tenant_id != current_tenant_id:
|
||||
raise NotFoundError("App not found")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, form_token: str):
|
||||
"""
|
||||
Get human input form definition by form token.
|
||||
|
||||
GET /console/api/form/human_input/<form_token>
|
||||
"""
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_definition_by_token_for_console(form_token)
|
||||
if form is None:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def post(self, form_token: str):
|
||||
"""
|
||||
Submit human input form by form token.
|
||||
|
||||
POST /console/api/form/human_input/<form_token>
|
||||
|
||||
Request body:
|
||||
{
|
||||
"inputs": {
|
||||
"content": "User input content"
|
||||
},
|
||||
"action": "Approve"
|
||||
}
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("action", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
|
||||
recipient_type = form.recipient_type
|
||||
if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
# The type checker is not smart enought to validate the following invariant.
|
||||
# So we need to assert it manually.
|
||||
assert recipient_type is not None, "recipient_type cannot be None here."
|
||||
|
||||
service.submit_form_by_token(
|
||||
recipient_type=recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=args["action"],
|
||||
form_data=args["inputs"],
|
||||
submission_user_id=current_user.id,
|
||||
)
|
||||
|
||||
return jsonify({})
|
||||
|
||||
|
||||
@console_ns.route("/workflow/<string:workflow_run_id>/events")
|
||||
class ConsoleWorkflowEventsApi(Resource):
|
||||
"""Console API for getting workflow execution events after resume."""
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def get(self, workflow_run_id: str):
|
||||
"""
|
||||
Get workflow execution events stream after resume.
|
||||
|
||||
GET /console/api/workflow/<workflow_run_id>/events
|
||||
|
||||
Returns Server-Sent Events stream.
|
||||
"""
|
||||
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
tenant_id=tenant_id,
|
||||
run_id=workflow_run_id,
|
||||
)
|
||||
if workflow_run is None:
|
||||
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.created_by_role != CreatorUserRole.ACCOUNT:
|
||||
raise NotFoundError(f"WorkflowRun not created by account, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.created_by != user.id:
|
||||
raise NotFoundError(f"WorkflowRun not created by the current account, id={workflow_run_id}")
|
||||
|
||||
with Session(expire_on_commit=False, bind=db.engine) as session:
|
||||
app = _retrieve_app_for_workflow_run(session, workflow_run)
|
||||
|
||||
if workflow_run.finished_at is not None:
|
||||
# TODO(QuantumGhost): should we modify the handling for finished workflow run here?
|
||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||
task_id=workflow_run.id,
|
||||
workflow_run=workflow_run,
|
||||
creator_user=user,
|
||||
)
|
||||
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
|
||||
def _generate_finished_events() -> Generator[str, None, None]:
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
event_generator = _generate_finished_events
|
||||
|
||||
else:
|
||||
msg_generator = MessageGenerator()
|
||||
if app.mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
elif app.mode == AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
else:
|
||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
|
||||
def _generate_stream_events():
|
||||
if include_state_snapshot:
|
||||
return generator.convert_to_event_stream(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode(app.mode),
|
||||
workflow_run=workflow_run,
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
session_maker=session_maker,
|
||||
)
|
||||
)
|
||||
return generator.convert_to_event_stream(
|
||||
msg_generator.retrieve_events(AppMode(app.mode), workflow_run.id),
|
||||
)
|
||||
|
||||
event_generator = _generate_stream_events
|
||||
|
||||
return Response(
|
||||
event_generator(),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun):
|
||||
query = select(App).where(
|
||||
App.id == workflow_run.app_id,
|
||||
App.tenant_id == workflow_run.tenant_id,
|
||||
)
|
||||
app = session.scalars(query).first()
|
||||
if app is None:
|
||||
raise AssertionError(
|
||||
f"App not found for WorkflowRun, workflow_run_id={workflow_run.id}, "
|
||||
f"app_id={workflow_run.app_id}, tenant_id={workflow_run.tenant_id}"
|
||||
)
|
||||
|
||||
return app
|
||||
@@ -33,8 +33,9 @@ from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from libs import helper
|
||||
from libs.helper import TimestampField
|
||||
from libs.helper import OptionalTimestampField, TimestampField
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
@@ -63,17 +64,32 @@ class WorkflowLogQuery(BaseModel):
|
||||
|
||||
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
|
||||
|
||||
|
||||
class WorkflowRunStatusField(fields.Raw):
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
return obj.status.value
|
||||
|
||||
|
||||
class WorkflowRunOutputsField(fields.Raw):
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
if obj.status == WorkflowExecutionStatus.PAUSED:
|
||||
return {}
|
||||
|
||||
outputs = obj.outputs_dict
|
||||
return outputs or {}
|
||||
|
||||
|
||||
workflow_run_fields = {
|
||||
"id": fields.String,
|
||||
"workflow_id": fields.String,
|
||||
"status": fields.String,
|
||||
"status": WorkflowRunStatusField,
|
||||
"inputs": fields.Raw,
|
||||
"outputs": fields.Raw,
|
||||
"outputs": WorkflowRunOutputsField,
|
||||
"error": fields.String,
|
||||
"total_steps": fields.Integer,
|
||||
"total_tokens": fields.Integer,
|
||||
"created_at": TimestampField,
|
||||
"finished_at": TimestampField,
|
||||
"finished_at": OptionalTimestampField,
|
||||
"elapsed_time": fields.Float,
|
||||
}
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from . import (
|
||||
feature,
|
||||
files,
|
||||
forgot_password,
|
||||
human_input_form,
|
||||
login,
|
||||
message,
|
||||
passport,
|
||||
@@ -30,6 +31,7 @@ from . import (
|
||||
saved_message,
|
||||
site,
|
||||
workflow,
|
||||
workflow_events,
|
||||
)
|
||||
|
||||
api.add_namespace(web_ns)
|
||||
@@ -44,6 +46,7 @@ __all__ = [
|
||||
"feature",
|
||||
"files",
|
||||
"forgot_password",
|
||||
"human_input_form",
|
||||
"login",
|
||||
"message",
|
||||
"passport",
|
||||
@@ -52,4 +55,5 @@ __all__ = [
|
||||
"site",
|
||||
"web_ns",
|
||||
"workflow",
|
||||
"workflow_events",
|
||||
]
|
||||
|
||||
@@ -117,6 +117,12 @@ class InvokeRateLimitError(BaseHTTPException):
|
||||
code = 429
|
||||
|
||||
|
||||
class WebFormRateLimitExceededError(BaseHTTPException):
|
||||
error_code = "web_form_rate_limit_exceeded"
|
||||
description = "Too many form requests. Please try again later."
|
||||
code = 429
|
||||
|
||||
|
||||
class NotFoundError(BaseHTTPException):
|
||||
error_code = "not_found"
|
||||
code = 404
|
||||
|
||||
164
api/controllers/web/human_input_form.py
Normal file
164
api/controllers/web/human_input_form.py
Normal file
@@ -0,0 +1,164 @@
|
||||
"""
|
||||
Web App Human Input Form APIs.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
|
||||
from controllers.web.site import serialize_app_site_payload
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import RateLimiter, extract_remote_ip
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, Site
|
||||
from services.human_input_service import Form, FormNotFoundError, HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_submit_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
|
||||
)
|
||||
_FORM_ACCESS_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_access_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
|
||||
result: dict[str, str] = {}
|
||||
for key, value in values.items():
|
||||
if value is None:
|
||||
result[key] = ""
|
||||
elif isinstance(value, (dict, list)):
|
||||
result[key] = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
result[key] = str(value)
|
||||
return result
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime) -> int:
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
|
||||
"""Return the form payload (optionally with site) as a JSON response."""
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
payload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
|
||||
"user_actions": definition_payload["user_actions"],
|
||||
"expiration_time": _to_timestamp(form.expiration_time),
|
||||
}
|
||||
if site_payload is not None:
|
||||
payload["site"] = site_payload
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
# TODO(QuantumGhost): disable authorization for web app
|
||||
# form api temporarily
|
||||
|
||||
|
||||
@web_ns.route("/form/human_input/<string:form_token>")
|
||||
# class HumanInputFormApi(WebApiResource):
|
||||
class HumanInputFormApi(Resource):
|
||||
"""API for getting and submitting human input forms via the web app."""
|
||||
|
||||
# def get(self, _app_model: App, _end_user: EndUser, form_token: str):
|
||||
def get(self, form_token: str):
|
||||
"""
|
||||
Get human input form definition by token.
|
||||
|
||||
GET /api/form/human_input/<form_token>
|
||||
"""
|
||||
ip_address = extract_remote_ip(request)
|
||||
if _FORM_ACCESS_RATE_LIMITER.is_rate_limited(ip_address):
|
||||
raise WebFormRateLimitExceededError()
|
||||
_FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address)
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
# TODO(QuantumGhost): forbid submision for form tokens
|
||||
# that are only for console.
|
||||
form = service.get_form_by_token(form_token)
|
||||
|
||||
if form is None:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
service.ensure_form_active(form)
|
||||
app_model, site = _get_app_site_from_form(form)
|
||||
|
||||
return _jsonify_form_definition(form, site_payload=serialize_app_site_payload(app_model, site, None))
|
||||
|
||||
# def post(self, _app_model: App, _end_user: EndUser, form_token: str):
|
||||
def post(self, form_token: str):
|
||||
"""
|
||||
Submit human input form by token.
|
||||
|
||||
POST /api/form/human_input/<form_token>
|
||||
|
||||
Request body:
|
||||
{
|
||||
"inputs": {
|
||||
"content": "User input content"
|
||||
},
|
||||
"action": "Approve"
|
||||
}
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("action", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address):
|
||||
raise WebFormRateLimitExceededError()
|
||||
_FORM_SUBMIT_RATE_LIMITER.increment_rate_limit(ip_address)
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
if (recipient_type := form.recipient_type) is None:
|
||||
logger.warning("Recipient type is None for form, form_id=%", form.id)
|
||||
raise AssertionError("Recipient type is None")
|
||||
|
||||
try:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=args["action"],
|
||||
form_data=args["inputs"],
|
||||
submission_end_user_id=None,
|
||||
# submission_end_user_id=_end_user.id,
|
||||
)
|
||||
except FormNotFoundError:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
return {}, 200
|
||||
|
||||
|
||||
def _get_app_site_from_form(form: Form) -> tuple[App, Site]:
|
||||
"""Resolve App/Site for the form's app and validate tenant status."""
|
||||
app_model = db.session.query(App).where(App.id == form.app_id).first()
|
||||
if app_model is None or app_model.tenant_id != form.tenant_id:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
if site is None:
|
||||
raise Forbidden()
|
||||
|
||||
if app_model.tenant and app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
return app_model, site
|
||||
@@ -1,4 +1,6 @@
|
||||
from flask_restx import fields, marshal_with
|
||||
from typing import cast
|
||||
|
||||
from flask_restx import fields, marshal, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
@@ -7,7 +9,7 @@ from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import AppIconUrlField
|
||||
from models.account import TenantStatus
|
||||
from models.model import Site
|
||||
from models.model import App, Site
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@@ -108,3 +110,14 @@ class AppSiteInfo:
|
||||
"remove_webapp_brand": remove_webapp_brand,
|
||||
"replace_webapp_logo": replace_webapp_logo,
|
||||
}
|
||||
|
||||
|
||||
def serialize_site(site: Site) -> dict:
|
||||
"""Serialize Site model using the same schema as AppSiteApi."""
|
||||
return cast(dict, marshal(site, AppSiteApi.site_fields))
|
||||
|
||||
|
||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict:
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
||||
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
|
||||
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields))
|
||||
|
||||
112
api/controllers/web/workflow_events.py
Normal file
112
api/controllers/web/workflow_events.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""
|
||||
Web App Workflow Resume APIs.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, request
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, EndUser
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
|
||||
class WorkflowEventsApi(WebApiResource):
|
||||
"""API for getting workflow execution events after resume."""
|
||||
|
||||
def get(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
"""
|
||||
Get workflow execution events stream after resume.
|
||||
|
||||
GET /api/workflow/<task_id>/events
|
||||
|
||||
Returns Server-Sent Events stream.
|
||||
"""
|
||||
workflow_run_id = task_id
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
tenant_id=app_model.tenant_id,
|
||||
run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
if workflow_run is None:
|
||||
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.app_id != app_model.id:
|
||||
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.created_by_role != CreatorUserRole.END_USER:
|
||||
raise NotFoundError(f"WorkflowRun not created by end user, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.created_by != end_user.id:
|
||||
raise NotFoundError(f"WorkflowRun not created by the current end user, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.finished_at is not None:
|
||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||
task_id=workflow_run.id,
|
||||
workflow_run=workflow_run,
|
||||
creator_user=end_user,
|
||||
)
|
||||
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
|
||||
def _generate_finished_events() -> Generator[str, None, None]:
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
event_generator = _generate_finished_events
|
||||
else:
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
msg_generator = MessageGenerator()
|
||||
generator: BaseAppGenerator
|
||||
if app_mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
elif app_mode == AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
else:
|
||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
|
||||
def _generate_stream_events():
|
||||
if include_state_snapshot:
|
||||
return generator.convert_to_event_stream(
|
||||
build_workflow_event_stream(
|
||||
app_mode=app_mode,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
session_maker=session_maker,
|
||||
)
|
||||
)
|
||||
return generator.convert_to_event_stream(
|
||||
msg_generator.retrieve_events(app_mode, workflow_run.id),
|
||||
)
|
||||
|
||||
event_generator = _generate_stream_events
|
||||
|
||||
return Response(
|
||||
event_generator(),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Register the APIs
|
||||
api.add_resource(WorkflowEventsApi, "/workflow/<string:task_id>/events")
|
||||
@@ -4,8 +4,8 @@ import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -29,21 +29,25 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaverFactory,
|
||||
)
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.base import Base
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.conversation_service import ConversationService
|
||||
from services.workflow_draft_variable_service import (
|
||||
@@ -65,7 +69,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
streaming: Literal[False],
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
@@ -74,9 +80,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
streaming: Literal[True],
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@@ -85,9 +93,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
streaming: bool,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...
|
||||
|
||||
def generate(
|
||||
@@ -95,9 +105,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping,
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
streaming: bool = True,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
@@ -161,7 +173,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# always enable retriever resource in debugger mode
|
||||
app_config.additional_features.show_retrieve_source = True # type: ignore
|
||||
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
# init application generate entity
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
@@ -179,7 +190,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager,
|
||||
workflow_run_id=workflow_run_id,
|
||||
workflow_run_id=str(workflow_run_id),
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
@@ -216,6 +227,38 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
conversation=conversation,
|
||||
stream=streaming,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
|
||||
def resume(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
):
|
||||
"""
|
||||
Resume a paused advanced chat execution.
|
||||
"""
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
stream=application_generate_entity.stream,
|
||||
pause_state_config=pause_state_config,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
def single_iteration_generate(
|
||||
@@ -396,8 +439,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
conversation: Conversation | None = None,
|
||||
message: Message | None = None,
|
||||
stream: bool = True,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
@@ -411,12 +458,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param conversation: conversation
|
||||
:param stream: is stream
|
||||
"""
|
||||
is_first_conversation = False
|
||||
if not conversation:
|
||||
is_first_conversation = True
|
||||
is_first_conversation = conversation is None
|
||||
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
if conversation is not None and message is not None:
|
||||
pass
|
||||
else:
|
||||
conversation, message = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
if is_first_conversation:
|
||||
# update conversation features
|
||||
@@ -439,6 +486,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
|
||||
if pause_state_config is not None:
|
||||
graph_layers.append(
|
||||
PauseStatePersistenceLayer(
|
||||
session_factory=pause_state_config.session_factory,
|
||||
generate_entity=application_generate_entity,
|
||||
state_owner_user_id=pause_state_config.state_owner_user_id,
|
||||
)
|
||||
)
|
||||
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
@@ -454,14 +511,25 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
"variable_loader": variable_loader,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"graph_engine_layers": tuple(graph_layers),
|
||||
"graph_runtime_state": graph_runtime_state,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
db.session.refresh(workflow)
|
||||
db.session.refresh(message)
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
workflow = _refresh_model(session, workflow)
|
||||
message = _refresh_model(session, message)
|
||||
# workflow_ = session.get(Workflow, workflow.id)
|
||||
# assert workflow_ is not None
|
||||
# workflow = workflow_
|
||||
# message_ = session.get(Message, message.id)
|
||||
# assert message_ is not None
|
||||
# message = message_
|
||||
# db.session.refresh(workflow)
|
||||
# db.session.refresh(message)
|
||||
# db.session.refresh(user)
|
||||
db.session.close()
|
||||
|
||||
@@ -490,6 +558,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
variable_loader: VariableLoader,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
):
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@@ -547,6 +617,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app=app,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -614,3 +686,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
else:
|
||||
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
|
||||
raise e
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=Base)
|
||||
|
||||
|
||||
def _refresh_model(session, model: _T) -> _T:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
detach_model = session.get(type(model), model.id)
|
||||
assert detach_model is not None
|
||||
return detach_model
|
||||
|
||||
@@ -66,6 +66,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
@@ -82,6 +83,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
self._app = app
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
self._resume_graph_runtime_state = graph_runtime_state
|
||||
|
||||
@trace_span(WorkflowAppRunnerHandler)
|
||||
def run(self):
|
||||
@@ -110,7 +112,21 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
invoke_from = InvokeFrom.DEBUGGER
|
||||
user_from = self._resolve_user_from(invoke_from)
|
||||
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
resume_state = self._resume_graph_runtime_state
|
||||
|
||||
if resume_state is not None:
|
||||
graph_runtime_state = resume_state
|
||||
variable_pool = graph_runtime_state.variable_pool
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
invoke_from=invoke_from,
|
||||
user_from=user_from,
|
||||
)
|
||||
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
# Handle single iteration or single loop run
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=self._workflow,
|
||||
|
||||
@@ -24,6 +24,8 @@ from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
QueueHumanInputFormTimeoutEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
@@ -42,6 +44,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
WorkflowQueueMessage,
|
||||
@@ -63,6 +66,8 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
@@ -71,7 +76,8 @@ from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, MessageStatus
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -128,6 +134,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
)
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._seed_task_state_from_message(message)
|
||||
self._message_cycle_manager = MessageCycleManager(
|
||||
application_generate_entity=application_generate_entity, task_state=self._task_state
|
||||
)
|
||||
@@ -135,6 +142,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
self._workflow_tenant_id = workflow.tenant_id
|
||||
self._conversation_id = conversation.id
|
||||
self._conversation_mode = conversation.mode
|
||||
self._message_id = message.id
|
||||
@@ -144,8 +152,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
self._workflow_run_id: str = ""
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
self._graph_runtime_state: GraphRuntimeState | None = None
|
||||
self._message_saved_on_pause = False
|
||||
self._seed_graph_runtime_state_from_queue_manager()
|
||||
|
||||
def _seed_task_state_from_message(self, message: Message) -> None:
|
||||
if message.status == MessageStatus.PAUSED and message.answer:
|
||||
self._task_state.answer = message.answer
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
@@ -308,6 +321,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=run_id,
|
||||
workflow_id=self._workflow_id,
|
||||
reason=event.reason,
|
||||
)
|
||||
|
||||
yield workflow_start_resp
|
||||
@@ -525,6 +539,35 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
)
|
||||
|
||||
yield workflow_finish_resp
|
||||
|
||||
def _handle_workflow_paused_event(
|
||||
self,
|
||||
event: QueueWorkflowPausedEvent,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow paused events."""
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
graph_runtime_state=validated_state,
|
||||
)
|
||||
for reason in event.reasons:
|
||||
if isinstance(reason, HumanInputRequired):
|
||||
self._persist_human_input_extra_content(form_id=reason.form_id, node_id=reason.node_id)
|
||||
yield from responses
|
||||
resolved_state: GraphRuntimeState | None = None
|
||||
try:
|
||||
resolved_state = self._ensure_graph_runtime_initialized()
|
||||
except ValueError:
|
||||
resolved_state = None
|
||||
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
message = self._get_message(session=session)
|
||||
if message is not None:
|
||||
message.status = MessageStatus.PAUSED
|
||||
self._message_saved_on_pause = True
|
||||
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
def _handle_workflow_failed_event(
|
||||
@@ -614,9 +657,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
|
||||
)
|
||||
|
||||
# Save message
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
# Save message unless it has already been persisted on pause.
|
||||
if not self._message_saved_on_pause:
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
@@ -642,6 +686,65 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
"""Handle message replace events."""
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text, reason=event.reason)
|
||||
|
||||
def _handle_human_input_form_filled_event(
|
||||
self, event: QueueHumanInputFormFilledEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle human input form filled events."""
|
||||
self._persist_human_input_extra_content(node_id=event.node_id)
|
||||
yield self._workflow_response_converter.human_input_form_filled_to_stream_response(
|
||||
event=event, task_id=self._application_generate_entity.task_id
|
||||
)
|
||||
|
||||
def _handle_human_input_form_timeout_event(
|
||||
self, event: QueueHumanInputFormTimeoutEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle human input form timeout events."""
|
||||
yield self._workflow_response_converter.human_input_form_timeout_to_stream_response(
|
||||
event=event, task_id=self._application_generate_entity.task_id
|
||||
)
|
||||
|
||||
def _persist_human_input_extra_content(self, *, node_id: str | None = None, form_id: str | None = None) -> None:
|
||||
if not self._workflow_run_id or not self._message_id:
|
||||
return
|
||||
|
||||
if form_id is None:
|
||||
if node_id is None:
|
||||
return
|
||||
form_id = self._load_human_input_form_id(node_id=node_id)
|
||||
if form_id is None:
|
||||
logger.warning(
|
||||
"HumanInput form not found for workflow run %s node %s",
|
||||
self._workflow_run_id,
|
||||
node_id,
|
||||
)
|
||||
return
|
||||
|
||||
with self._database_session() as session:
|
||||
exists_stmt = select(HumanInputContent).where(
|
||||
HumanInputContent.workflow_run_id == self._workflow_run_id,
|
||||
HumanInputContent.message_id == self._message_id,
|
||||
HumanInputContent.form_id == form_id,
|
||||
)
|
||||
if session.scalar(exists_stmt) is not None:
|
||||
return
|
||||
|
||||
content = HumanInputContent(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
message_id=self._message_id,
|
||||
form_id=form_id,
|
||||
)
|
||||
session.add(content)
|
||||
|
||||
def _load_human_input_form_id(self, *, node_id: str) -> str | None:
|
||||
form_repository = HumanInputFormRepositoryImpl(
|
||||
session_factory=db.engine,
|
||||
tenant_id=self._workflow_tenant_id,
|
||||
)
|
||||
form = form_repository.get_form(self._workflow_run_id, node_id)
|
||||
if form is None:
|
||||
return None
|
||||
return form.id
|
||||
|
||||
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle agent log events."""
|
||||
yield self._workflow_response_converter.handle_agent_log(
|
||||
@@ -659,6 +762,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
|
||||
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
|
||||
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
|
||||
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
|
||||
QueueWorkflowFailedEvent: self._handle_workflow_failed_event,
|
||||
# Node events
|
||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||
@@ -680,6 +784,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
QueueMessageReplaceEvent: self._handle_message_replace_event,
|
||||
QueueAdvancedChatMessageEndEvent: self._handle_advanced_chat_message_end_event,
|
||||
QueueAgentLogEvent: self._handle_agent_log_event,
|
||||
QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event,
|
||||
QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event,
|
||||
}
|
||||
|
||||
def _dispatch_event(
|
||||
@@ -747,6 +853,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
case QueueWorkflowFailedEvent():
|
||||
yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager)
|
||||
break
|
||||
case QueueWorkflowPausedEvent():
|
||||
yield from self._handle_workflow_paused_event(event)
|
||||
break
|
||||
|
||||
case QueueStopEvent():
|
||||
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
|
||||
@@ -772,6 +881,11 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
|
||||
message = self._get_message(session=session)
|
||||
if message is None:
|
||||
return
|
||||
|
||||
if message.status == MessageStatus.PAUSED:
|
||||
message.status = MessageStatus.NORMAL
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
answer_text = self._task_state.answer
|
||||
|
||||
@@ -5,9 +5,14 @@ from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, NewType, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
QueueHumanInputFormTimeoutEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
@@ -19,9 +24,13 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AgentLogStreamResponse,
|
||||
HumanInputFormFilledResponse,
|
||||
HumanInputFormTimeoutResponse,
|
||||
HumanInputRequiredResponse,
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
@@ -31,7 +40,9 @@ from core.app.entities.task_entities import (
|
||||
NodeFinishStreamResponse,
|
||||
NodeRetryStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
StreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
@@ -40,6 +51,8 @@ from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
@@ -51,8 +64,11 @@ from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, EndUser
|
||||
from models.human_input import HumanInputForm
|
||||
from models.workflow import WorkflowRun
|
||||
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
|
||||
|
||||
NodeExecutionId = NewType("NodeExecutionId", str)
|
||||
@@ -191,6 +207,7 @@ class WorkflowResponseConverter:
|
||||
task_id: str,
|
||||
workflow_run_id: str,
|
||||
workflow_id: str,
|
||||
reason: WorkflowStartReason,
|
||||
) -> WorkflowStartStreamResponse:
|
||||
run_id = self._ensure_workflow_run_id(workflow_run_id)
|
||||
started_at = naive_utc_now()
|
||||
@@ -204,6 +221,7 @@ class WorkflowResponseConverter:
|
||||
workflow_id=workflow_id,
|
||||
inputs=self._workflow_inputs,
|
||||
created_at=int(started_at.timestamp()),
|
||||
reason=reason,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -264,6 +282,160 @@ class WorkflowResponseConverter:
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_pause_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
event: QueueWorkflowPausedEvent,
|
||||
task_id: str,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> list[StreamResponse]:
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
started_at = self._workflow_started_at
|
||||
if started_at is None:
|
||||
raise ValueError(
|
||||
"workflow_pause_to_stream_response called before workflow_start_to_stream_response",
|
||||
)
|
||||
paused_at = naive_utc_now()
|
||||
elapsed_time = (paused_at - started_at).total_seconds()
|
||||
encoded_outputs = self._encode_outputs(event.outputs) or {}
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
|
||||
encoded_outputs = {}
|
||||
pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons]
|
||||
human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)]
|
||||
expiration_times_by_form_id: dict[str, datetime] = {}
|
||||
if human_input_form_ids:
|
||||
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where(
|
||||
HumanInputForm.id.in_(human_input_form_ids)
|
||||
)
|
||||
with Session(bind=db.engine) as session:
|
||||
for form_id, expiration_time in session.execute(stmt):
|
||||
expiration_times_by_form_id[str(form_id)] = expiration_time
|
||||
|
||||
responses: list[StreamResponse] = []
|
||||
|
||||
for reason in event.reasons:
|
||||
if isinstance(reason, HumanInputRequired):
|
||||
expiration_time = expiration_times_by_form_id.get(reason.form_id)
|
||||
if expiration_time is None:
|
||||
raise ValueError(f"HumanInputForm not found for pause reason, form_id={reason.form_id}")
|
||||
responses.append(
|
||||
HumanInputRequiredResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=HumanInputRequiredResponse.Data(
|
||||
form_id=reason.form_id,
|
||||
node_id=reason.node_id,
|
||||
node_title=reason.node_title,
|
||||
form_content=reason.form_content,
|
||||
inputs=reason.inputs,
|
||||
actions=reason.actions,
|
||||
display_in_ui=reason.display_in_ui,
|
||||
form_token=reason.form_token,
|
||||
resolved_default_values=reason.resolved_default_values,
|
||||
expiration_time=int(expiration_time.timestamp()),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
responses.append(
|
||||
WorkflowPauseStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=WorkflowPauseStreamResponse.Data(
|
||||
workflow_run_id=run_id,
|
||||
paused_nodes=list(event.paused_nodes),
|
||||
outputs=encoded_outputs,
|
||||
reasons=pause_reasons,
|
||||
status=WorkflowExecutionStatus.PAUSED.value,
|
||||
created_at=int(started_at.timestamp()),
|
||||
elapsed_time=elapsed_time,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return responses
|
||||
|
||||
def human_input_form_filled_to_stream_response(
|
||||
self, *, event: QueueHumanInputFormFilledEvent, task_id: str
|
||||
) -> HumanInputFormFilledResponse:
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
return HumanInputFormFilledResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=HumanInputFormFilledResponse.Data(
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
rendered_content=event.rendered_content,
|
||||
action_id=event.action_id,
|
||||
action_text=event.action_text,
|
||||
),
|
||||
)
|
||||
|
||||
def human_input_form_timeout_to_stream_response(
|
||||
self, *, event: QueueHumanInputFormTimeoutEvent, task_id: str
|
||||
) -> HumanInputFormTimeoutResponse:
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
return HumanInputFormTimeoutResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=HumanInputFormTimeoutResponse.Data(
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
expiration_time=int(event.expiration_time.timestamp()),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def workflow_run_result_to_finish_response(
|
||||
cls,
|
||||
*,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
creator_user: Account | EndUser,
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
run_id = workflow_run.id
|
||||
elapsed_time = workflow_run.elapsed_time
|
||||
|
||||
encoded_outputs = workflow_run.outputs_dict
|
||||
finished_at = workflow_run.finished_at
|
||||
assert finished_at is not None
|
||||
|
||||
created_by: Mapping[str, object]
|
||||
user = creator_user
|
||||
if isinstance(user, Account):
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
}
|
||||
else:
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"user": user.session_id,
|
||||
}
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id=run_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
status=workflow_run.status.value,
|
||||
outputs=encoded_outputs,
|
||||
error=workflow_run.error,
|
||||
elapsed_time=elapsed_time,
|
||||
total_tokens=workflow_run.total_tokens,
|
||||
total_steps=workflow_run.total_steps,
|
||||
created_by=created_by,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(finished_at.timestamp()),
|
||||
files=cls.fetch_files_from_node_outputs(encoded_outputs),
|
||||
exceptions_count=workflow_run.exceptions_count,
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_node_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
@@ -592,7 +764,8 @@ class WorkflowResponseConverter:
|
||||
),
|
||||
)
|
||||
|
||||
def fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
|
||||
@classmethod
|
||||
def fetch_files_from_node_outputs(cls, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
Fetch files from node outputs
|
||||
:param outputs_dict: node outputs dict
|
||||
@@ -601,7 +774,7 @@ class WorkflowResponseConverter:
|
||||
if not outputs_dict:
|
||||
return []
|
||||
|
||||
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
|
||||
files = [cls._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
|
||||
# Remove None
|
||||
files = [file for file in files if file]
|
||||
# Flatten list
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from typing import Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
@@ -10,12 +10,14 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.streaming_utils import stream_topic_events
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
ConversationAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
@@ -27,6 +29,8 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import get_pubsub_broadcast_channel
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.enums import CreatorUserRole
|
||||
@@ -156,6 +160,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
query = application_generate_entity.query or "New conversation"
|
||||
conversation_name = (query[:20] + "…") if len(query) > 20 else query
|
||||
|
||||
created_new_conversation = conversation is None
|
||||
try:
|
||||
if not conversation:
|
||||
conversation = Conversation(
|
||||
@@ -232,6 +237,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
db.session.add_all(message_files)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if isinstance(application_generate_entity, ConversationAppGenerateEntity):
|
||||
application_generate_entity.conversation_id = conversation.id
|
||||
application_generate_entity.is_new_conversation = created_new_conversation
|
||||
return conversation, message
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
@@ -284,3 +293,29 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def _make_channel_key(app_mode: AppMode, workflow_run_id: str):
|
||||
return f"channel:{app_mode}:{workflow_run_id}"
|
||||
|
||||
@classmethod
|
||||
def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic:
|
||||
key = cls._make_channel_key(app_mode, workflow_run_id)
|
||||
channel = get_pubsub_broadcast_channel()
|
||||
topic = channel.topic(key)
|
||||
return topic
|
||||
|
||||
@classmethod
|
||||
def retrieve_events(
|
||||
cls,
|
||||
app_mode: AppMode,
|
||||
workflow_run_id: str,
|
||||
idle_timeout=300,
|
||||
on_subscribe: Callable[[], None] | None = None,
|
||||
) -> Generator[Mapping | str, None, None]:
|
||||
topic = cls.get_response_topic(app_mode, workflow_run_id)
|
||||
return stream_topic_events(
|
||||
topic=topic,
|
||||
idle_timeout=idle_timeout,
|
||||
on_subscribe=on_subscribe,
|
||||
)
|
||||
|
||||
36
api/core/app/apps/message_generator.py
Normal file
36
api/core/app/apps/message_generator.py
Normal file
@@ -0,0 +1,36 @@
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
|
||||
from core.app.apps.streaming_utils import stream_topic_events
|
||||
from extensions.ext_redis import get_pubsub_broadcast_channel
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class MessageGenerator:
|
||||
@staticmethod
|
||||
def _make_channel_key(app_mode: AppMode, workflow_run_id: str):
|
||||
return f"channel:{app_mode}:{str(workflow_run_id)}"
|
||||
|
||||
@classmethod
|
||||
def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic:
|
||||
key = cls._make_channel_key(app_mode, workflow_run_id)
|
||||
channel = get_pubsub_broadcast_channel()
|
||||
topic = channel.topic(key)
|
||||
return topic
|
||||
|
||||
@classmethod
|
||||
def retrieve_events(
|
||||
cls,
|
||||
app_mode: AppMode,
|
||||
workflow_run_id: str,
|
||||
idle_timeout=300,
|
||||
ping_interval: float = 10.0,
|
||||
on_subscribe: Callable[[], None] | None = None,
|
||||
) -> Generator[Mapping | str, None, None]:
|
||||
topic = cls.get_response_topic(app_mode, workflow_run_id)
|
||||
return stream_topic_events(
|
||||
topic=topic,
|
||||
idle_timeout=idle_timeout,
|
||||
ping_interval=ping_interval,
|
||||
on_subscribe=on_subscribe,
|
||||
)
|
||||
70
api/core/app/apps/streaming_utils.py
Normal file
70
api/core/app/apps/streaming_utils.py
Normal file
@@ -0,0 +1,70 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
|
||||
|
||||
def stream_topic_events(
|
||||
*,
|
||||
topic: Topic,
|
||||
idle_timeout: float,
|
||||
ping_interval: float | None = None,
|
||||
on_subscribe: Callable[[], None] | None = None,
|
||||
terminal_events: Iterable[str | StreamEvent] | None = None,
|
||||
) -> Generator[Mapping[str, Any] | str, None, None]:
|
||||
# send a PING event immediately to prevent the connection staying in pending state for a long time.
|
||||
#
|
||||
# This simplify the debugging process as the DevTools in Chrome does not
|
||||
# provide complete curl command for pending connections.
|
||||
yield StreamEvent.PING.value
|
||||
|
||||
terminal_values = _normalize_terminal_events(terminal_events)
|
||||
last_msg_time = time.time()
|
||||
last_ping_time = last_msg_time
|
||||
with topic.subscribe() as sub:
|
||||
# on_subscribe fires only after the Redis subscription is active.
|
||||
# This is used to gate task start and reduce pub/sub race for the first event.
|
||||
if on_subscribe is not None:
|
||||
on_subscribe()
|
||||
while True:
|
||||
try:
|
||||
msg = sub.receive(timeout=0.1)
|
||||
except SubscriptionClosedError:
|
||||
return
|
||||
if msg is None:
|
||||
current_time = time.time()
|
||||
if current_time - last_msg_time > idle_timeout:
|
||||
return
|
||||
if ping_interval is not None and current_time - last_ping_time >= ping_interval:
|
||||
yield StreamEvent.PING.value
|
||||
last_ping_time = current_time
|
||||
continue
|
||||
|
||||
last_msg_time = time.time()
|
||||
last_ping_time = last_msg_time
|
||||
event = json.loads(msg)
|
||||
yield event
|
||||
if not isinstance(event, dict):
|
||||
continue
|
||||
|
||||
event_type = event.get("event")
|
||||
if event_type in terminal_values:
|
||||
return
|
||||
|
||||
|
||||
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
|
||||
if not terminal_events:
|
||||
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
|
||||
values: set[str] = set()
|
||||
for item in terminal_events:
|
||||
if isinstance(item, StreamEvent):
|
||||
values.add(item.value)
|
||||
else:
|
||||
values.add(str(item))
|
||||
return values
|
||||
@@ -25,6 +25,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
@@ -34,12 +35,15 @@ from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.account import Account
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -66,9 +70,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_run_id: str | uuid.UUID | None = None,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Generator[Mapping[str, Any] | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@@ -82,9 +88,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_run_id: str | uuid.UUID | None = None,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
@@ -98,9 +106,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_run_id: str | uuid.UUID | None = None,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
@@ -113,9 +123,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_run_id: str | uuid.UUID | None = None,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
@@ -150,7 +162,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
extras = {
|
||||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
workflow_run_id = str(workflow_run_id or uuid.uuid4())
|
||||
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
|
||||
# trigger shouldn't prepare user inputs
|
||||
if self._should_prepare_user_inputs(args):
|
||||
@@ -216,13 +228,40 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
streaming=streaming,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
|
||||
def resume(self, *, workflow_run_id: str) -> None:
|
||||
def resume(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
@TBD
|
||||
Resume a paused workflow execution using the persisted runtime state.
|
||||
"""
|
||||
pass
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=application_generate_entity.stream,
|
||||
variable_loader=variable_loader,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@@ -238,6 +277,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
@@ -251,6 +292,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param streaming: is stream
|
||||
"""
|
||||
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
@@ -259,6 +302,15 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
app_mode=app_model.mode,
|
||||
)
|
||||
|
||||
if pause_state_config is not None:
|
||||
graph_layers.append(
|
||||
PauseStatePersistenceLayer(
|
||||
session_factory=pause_state_config.session_factory,
|
||||
generate_entity=application_generate_entity,
|
||||
state_owner_user_id=pause_state_config.state_owner_user_id,
|
||||
)
|
||||
)
|
||||
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
@@ -276,7 +328,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
"root_node_id": root_node_id,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"graph_engine_layers": graph_engine_layers,
|
||||
"graph_engine_layers": tuple(graph_layers),
|
||||
"graph_runtime_state": graph_runtime_state,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -378,6 +431,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
pause_state_config=None,
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
@@ -459,6 +513,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
pause_state_config=None,
|
||||
)
|
||||
|
||||
def _generate_worker(
|
||||
@@ -472,6 +527,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@@ -517,6 +573,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -42,6 +42,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
@@ -55,6 +56,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
self._root_node_id = root_node_id
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
self._resume_graph_runtime_state = graph_runtime_state
|
||||
|
||||
@trace_span(WorkflowAppRunnerHandler)
|
||||
def run(self):
|
||||
@@ -63,23 +65,28 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
"""
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
system_inputs = SystemVariable(
|
||||
files=self.application_generate_entity.files,
|
||||
user_id=self._sys_user_id,
|
||||
app_id=app_config.app_id,
|
||||
timestamp=int(naive_utc_now().timestamp()),
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
)
|
||||
|
||||
invoke_from = self.application_generate_entity.invoke_from
|
||||
# if only single iteration or single loop run is requested
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
invoke_from = InvokeFrom.DEBUGGER
|
||||
user_from = self._resolve_user_from(invoke_from)
|
||||
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
resume_state = self._resume_graph_runtime_state
|
||||
|
||||
if resume_state is not None:
|
||||
graph_runtime_state = resume_state
|
||||
variable_pool = graph_runtime_state.variable_pool
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
root_node_id=self._root_node_id,
|
||||
)
|
||||
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=self._workflow,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
@@ -89,7 +96,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
||||
# Create a variable pool.
|
||||
|
||||
system_inputs = SystemVariable(
|
||||
files=self.application_generate_entity.files,
|
||||
user_id=self._sys_user_id,
|
||||
app_id=app_config.app_id,
|
||||
timestamp=int(naive_utc_now().timestamp()),
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
)
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
@@ -98,8 +112,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
7
api/core/app/apps/workflow/errors.py
Normal file
7
api/core/app/apps/workflow/errors.py
Normal file
@@ -0,0 +1,7 @@
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class WorkflowPausedInBlockingModeError(BaseHTTPException):
|
||||
error_code = "workflow_paused_in_blocking_mode"
|
||||
description = "Workflow execution paused for human input; blocking response mode is not supported."
|
||||
code = 400
|
||||
@@ -16,6 +16,8 @@ from core.app.entities.queue_entities import (
|
||||
MessageQueueMessage,
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
QueueHumanInputFormTimeoutEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
@@ -32,6 +34,7 @@ from core.app.entities.queue_entities import (
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
WorkflowQueueMessage,
|
||||
@@ -46,11 +49,13 @@ from core.app.entities.task_entities import (
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
@@ -132,6 +137,25 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
||||
response = WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=stream_response.data.workflow_run_id,
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
id=stream_response.data.workflow_run_id,
|
||||
workflow_id=self._workflow.id,
|
||||
status=stream_response.data.status,
|
||||
outputs=stream_response.data.outputs or {},
|
||||
error=None,
|
||||
elapsed_time=stream_response.data.elapsed_time,
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=stream_response.data.created_at,
|
||||
finished_at=None,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
elif isinstance(stream_response, WorkflowFinishStreamResponse):
|
||||
response = WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -146,7 +170,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=int(stream_response.data.created_at),
|
||||
finished_at=int(stream_response.data.finished_at),
|
||||
finished_at=int(stream_response.data.finished_at) if stream_response.data.finished_at else None,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -259,13 +283,15 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
run_id = self._extract_workflow_run_id(runtime_state)
|
||||
self._workflow_execution_id = run_id
|
||||
|
||||
with self._database_session() as session:
|
||||
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
|
||||
if event.reason == WorkflowStartReason.INITIAL:
|
||||
with self._database_session() as session:
|
||||
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
|
||||
|
||||
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=run_id,
|
||||
workflow_id=self._workflow.id,
|
||||
reason=event.reason,
|
||||
)
|
||||
yield start_resp
|
||||
|
||||
@@ -440,6 +466,21 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
)
|
||||
yield workflow_finish_resp
|
||||
|
||||
def _handle_workflow_paused_event(
|
||||
self,
|
||||
event: QueueWorkflowPausedEvent,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow paused events."""
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
graph_runtime_state=validated_state,
|
||||
)
|
||||
yield from responses
|
||||
|
||||
def _handle_workflow_failed_and_stop_events(
|
||||
self,
|
||||
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
|
||||
@@ -495,6 +536,22 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
task_id=self._application_generate_entity.task_id, event=event
|
||||
)
|
||||
|
||||
def _handle_human_input_form_filled_event(
|
||||
self, event: QueueHumanInputFormFilledEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle human input form filled events."""
|
||||
yield self._workflow_response_converter.human_input_form_filled_to_stream_response(
|
||||
event=event, task_id=self._application_generate_entity.task_id
|
||||
)
|
||||
|
||||
def _handle_human_input_form_timeout_event(
|
||||
self, event: QueueHumanInputFormTimeoutEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle human input form timeout events."""
|
||||
yield self._workflow_response_converter.human_input_form_timeout_to_stream_response(
|
||||
event=event, task_id=self._application_generate_entity.task_id
|
||||
)
|
||||
|
||||
def _get_event_handlers(self) -> dict[type, Callable]:
|
||||
"""Get mapping of event types to their handlers using fluent pattern."""
|
||||
return {
|
||||
@@ -506,6 +563,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
|
||||
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
|
||||
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
|
||||
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
|
||||
# Node events
|
||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||
@@ -520,6 +578,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
QueueLoopCompletedEvent: self._handle_loop_completed_event,
|
||||
# Agent events
|
||||
QueueAgentLogEvent: self._handle_agent_log_event,
|
||||
QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event,
|
||||
QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event,
|
||||
}
|
||||
|
||||
def _dispatch_event(
|
||||
@@ -602,6 +662,9 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
case QueueWorkflowFailedEvent():
|
||||
yield from self._handle_workflow_failed_and_stop_events(event)
|
||||
break
|
||||
case QueueWorkflowPausedEvent():
|
||||
yield from self._handle_workflow_paused_event(event)
|
||||
break
|
||||
|
||||
case QueueStopEvent():
|
||||
yield from self._handle_workflow_failed_and_stop_events(event)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
@@ -7,6 +8,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentLogEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
QueueHumanInputFormTimeoutEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
@@ -22,22 +25,27 @@ from core.app.entities.queue_entities import (
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunHumanInputFormFilledEvent,
|
||||
NodeRunHumanInputFormTimeoutEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
@@ -61,6 +69,9 @@ from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader,
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow
|
||||
from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowBasedAppRunner:
|
||||
@@ -327,7 +338,7 @@ class WorkflowBasedAppRunner:
|
||||
:param event: event
|
||||
"""
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._publish_event(QueueWorkflowStartedEvent())
|
||||
self._publish_event(QueueWorkflowStartedEvent(reason=event.reason))
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
@@ -338,6 +349,38 @@ class WorkflowBasedAppRunner:
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||
elif isinstance(event, GraphRunAbortedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
|
||||
elif isinstance(event, GraphRunPausedEvent):
|
||||
runtime_state = workflow_entry.graph_engine.graph_runtime_state
|
||||
paused_nodes = runtime_state.get_paused_nodes()
|
||||
self._enqueue_human_input_notifications(event.reasons)
|
||||
self._publish_event(
|
||||
QueueWorkflowPausedEvent(
|
||||
reasons=event.reasons,
|
||||
outputs=event.outputs,
|
||||
paused_nodes=paused_nodes,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunHumanInputFormFilledEvent):
|
||||
self._publish_event(
|
||||
QueueHumanInputFormFilledEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
rendered_content=event.rendered_content,
|
||||
action_id=event.action_id,
|
||||
action_text=event.action_text,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunHumanInputFormTimeoutEvent):
|
||||
self._publish_event(
|
||||
QueueHumanInputFormTimeoutEvent(
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
expiration_time=event.expiration_time,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
@@ -544,5 +587,19 @@ class WorkflowBasedAppRunner:
|
||||
)
|
||||
)
|
||||
|
||||
def _enqueue_human_input_notifications(self, reasons: Sequence[object]) -> None:
|
||||
for reason in reasons:
|
||||
if not isinstance(reason, HumanInputRequired):
|
||||
continue
|
||||
if not reason.form_id:
|
||||
continue
|
||||
try:
|
||||
dispatch_human_input_email_task.apply_async(
|
||||
kwargs={"form_id": reason.form_id, "node_title": reason.node_title},
|
||||
queue="mail",
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive logging
|
||||
logger.exception("Failed to enqueue human input email task for form %s", reason.form_id)
|
||||
|
||||
def _publish_event(self, event: AppQueueEvent):
|
||||
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
@@ -132,7 +132,7 @@ class AppGenerateEntity(BaseModel):
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# tracing instance
|
||||
trace_manager: Optional["TraceQueueManager"] = None
|
||||
trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False)
|
||||
|
||||
|
||||
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
@@ -156,6 +156,7 @@ class ConversationAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
|
||||
conversation_id: str | None = None
|
||||
is_new_conversation: bool = False
|
||||
parent_message_id: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
|
||||
@@ -8,6 +8,8 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes import NodeType
|
||||
|
||||
@@ -46,6 +48,9 @@ class QueueEvent(StrEnum):
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
RETRY = "retry"
|
||||
PAUSE = "pause"
|
||||
HUMAN_INPUT_FORM_FILLED = "human_input_form_filled"
|
||||
HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout"
|
||||
|
||||
|
||||
class AppQueueEvent(BaseModel):
|
||||
@@ -261,6 +266,8 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
|
||||
"""QueueWorkflowStartedEvent entity."""
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
|
||||
# Always present; mirrors GraphRunStartedEvent.reason for downstream consumers.
|
||||
reason: WorkflowStartReason = WorkflowStartReason.INITIAL
|
||||
|
||||
|
||||
class QueueWorkflowSucceededEvent(AppQueueEvent):
|
||||
@@ -484,6 +491,35 @@ class QueueStopEvent(AppQueueEvent):
|
||||
return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.")
|
||||
|
||||
|
||||
class QueueHumanInputFormFilledEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueHumanInputFormFilledEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_FILLED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
rendered_content: str
|
||||
action_id: str
|
||||
action_text: str
|
||||
|
||||
|
||||
class QueueHumanInputFormTimeoutEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueHumanInputFormTimeoutEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_TIMEOUT
|
||||
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
expiration_time: datetime
|
||||
|
||||
|
||||
class QueueMessage(BaseModel):
|
||||
"""
|
||||
QueueMessage abstract entity
|
||||
@@ -509,3 +545,14 @@ class WorkflowQueueMessage(QueueMessage):
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QueueWorkflowPausedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueWorkflowPausedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PAUSE
|
||||
reasons: Sequence[PauseReason] = Field(default_factory=list)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||
|
||||
@@ -7,7 +7,9 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
@@ -69,6 +71,7 @@ class StreamEvent(StrEnum):
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
AGENT_MESSAGE = "agent_message"
|
||||
WORKFLOW_STARTED = "workflow_started"
|
||||
WORKFLOW_PAUSED = "workflow_paused"
|
||||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
@@ -82,6 +85,9 @@ class StreamEvent(StrEnum):
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
TEXT_REPLACE = "text_replace"
|
||||
AGENT_LOG = "agent_log"
|
||||
HUMAN_INPUT_REQUIRED = "human_input_required"
|
||||
HUMAN_INPUT_FORM_FILLED = "human_input_form_filled"
|
||||
HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout"
|
||||
|
||||
|
||||
class StreamResponse(BaseModel):
|
||||
@@ -205,6 +211,8 @@ class WorkflowStartStreamResponse(StreamResponse):
|
||||
workflow_id: str
|
||||
inputs: Mapping[str, Any]
|
||||
created_at: int
|
||||
# Always present; mirrors QueueWorkflowStartedEvent.reason for SSE clients.
|
||||
reason: WorkflowStartReason = WorkflowStartReason.INITIAL
|
||||
|
||||
event: StreamEvent = StreamEvent.WORKFLOW_STARTED
|
||||
workflow_run_id: str
|
||||
@@ -231,7 +239,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
total_steps: int
|
||||
created_by: Mapping[str, object] = Field(default_factory=dict)
|
||||
created_at: int
|
||||
finished_at: int
|
||||
finished_at: int | None
|
||||
exceptions_count: int | None = 0
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
|
||||
@@ -240,6 +248,85 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class WorkflowPauseStreamResponse(StreamResponse):
|
||||
"""
|
||||
WorkflowPauseStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
workflow_run_id: str
|
||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||
outputs: Mapping[str, Any] = Field(default_factory=dict)
|
||||
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
|
||||
status: str
|
||||
created_at: int
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
total_steps: int
|
||||
|
||||
event: StreamEvent = StreamEvent.WORKFLOW_PAUSED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class HumanInputRequiredResponse(StreamResponse):
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
form_id: str
|
||||
node_id: str
|
||||
node_title: str
|
||||
form_content: str
|
||||
inputs: Sequence[FormInput] = Field(default_factory=list)
|
||||
actions: Sequence[UserAction] = Field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
form_token: str | None = None
|
||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
expiration_time: int = Field(..., description="Unix timestamp in seconds")
|
||||
|
||||
event: StreamEvent = StreamEvent.HUMAN_INPUT_REQUIRED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class HumanInputFormFilledResponse(StreamResponse):
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
node_title: str
|
||||
rendered_content: str
|
||||
action_id: str
|
||||
action_text: str
|
||||
|
||||
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_FILLED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class HumanInputFormTimeoutResponse(StreamResponse):
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
node_title: str
|
||||
expiration_time: int
|
||||
|
||||
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_TIMEOUT
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class NodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
@@ -726,7 +813,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
total_tokens: int
|
||||
total_steps: int
|
||||
created_at: int
|
||||
finished_at: int
|
||||
finished_at: int | None
|
||||
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
@@ -103,6 +104,14 @@ class RateLimit:
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def rate_limit_context(rate_limit: RateLimit, request_id: str | None):
|
||||
request_id = rate_limit.enter(request_id)
|
||||
yield
|
||||
if request_id is not None:
|
||||
rate_limit.exit(request_id)
|
||||
|
||||
|
||||
class RateLimitGenerator:
|
||||
def __init__(self, rate_limit: RateLimit, generator: Generator[str, None, None], request_id: str):
|
||||
self.rate_limit = rate_limit
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Literal, Self, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -52,6 +53,14 @@ class WorkflowResumptionContext(BaseModel):
|
||||
return self.generate_entity.entity
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PauseStateLayerConfig:
|
||||
"""Configuration container for instantiating pause persistence layers."""
|
||||
|
||||
session_factory: Engine | sessionmaker[Session]
|
||||
state_owner_user_id: str
|
||||
|
||||
|
||||
class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -82,10 +82,11 @@ class MessageCycleManager:
|
||||
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
|
||||
return None
|
||||
|
||||
is_first_message = self._application_generate_entity.conversation_id is None
|
||||
is_first_message = self._application_generate_entity.is_new_conversation
|
||||
extras = self._application_generate_entity.extras
|
||||
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
|
||||
|
||||
thread: Thread | None = None
|
||||
if auto_generate_conversation_name and is_first_message:
|
||||
# start generate thread
|
||||
# time.sleep not block other logic
|
||||
@@ -101,9 +102,10 @@ class MessageCycleManager:
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
return thread
|
||||
if is_first_message:
|
||||
self._application_generate_entity.is_new_conversation = False
|
||||
|
||||
return None
|
||||
return thread
|
||||
|
||||
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
|
||||
with flask_app.app_context():
|
||||
|
||||
@@ -47,6 +47,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits | None = None,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
template_transform_max_output_length: int | None = None,
|
||||
http_request_http_client: HttpClientProtocol | None = None,
|
||||
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
|
||||
http_request_file_manager: FileManagerProtocol | None = None,
|
||||
@@ -68,6 +69,9 @@ class DifyNodeFactory(NodeFactory):
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
self._template_transform_max_output_length = (
|
||||
template_transform_max_output_length or dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
)
|
||||
self._http_request_http_client = http_request_http_client or ssrf_proxy
|
||||
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
|
||||
self._http_request_file_manager = http_request_file_manager or file_manager
|
||||
@@ -122,6 +126,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
template_renderer=self._template_renderer,
|
||||
max_output_length=self._template_transform_max_output_length,
|
||||
)
|
||||
|
||||
if node_type == NodeType.HTTP_REQUEST:
|
||||
|
||||
54
api/core/entities/execution_extra_content.py
Normal file
54
api/core/entities/execution_extra_content.py
Normal file
@@ -0,0 +1,54 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
from models.execution_extra_content import ExecutionContentType
|
||||
|
||||
|
||||
class HumanInputFormDefinition(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
form_id: str
|
||||
node_id: str
|
||||
node_title: str
|
||||
form_content: str
|
||||
inputs: Sequence[FormInput] = Field(default_factory=list)
|
||||
actions: Sequence[UserAction] = Field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
form_token: str | None = None
|
||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
expiration_time: int
|
||||
|
||||
|
||||
class HumanInputFormSubmissionData(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
node_id: str
|
||||
node_title: str
|
||||
rendered_content: str
|
||||
action_id: str
|
||||
action_text: str
|
||||
|
||||
|
||||
class HumanInputContent(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
workflow_run_id: str
|
||||
submitted: bool
|
||||
form_definition: HumanInputFormDefinition | None = None
|
||||
form_submission_data: HumanInputFormSubmissionData | None = None
|
||||
type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT)
|
||||
|
||||
|
||||
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent
|
||||
|
||||
__all__ = [
|
||||
"ExecutionExtraContentDomainModel",
|
||||
"HumanInputContent",
|
||||
"HumanInputFormDefinition",
|
||||
"HumanInputFormSubmissionData",
|
||||
]
|
||||
@@ -28,8 +28,8 @@ from core.model_runtime.entities.provider_entities import (
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.engine import db
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
|
||||
@@ -6,7 +6,8 @@ from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.download import download_with_size_limit
|
||||
from core.plugin.entities.marketplace import MarketplacePluginDeclaration
|
||||
from core.plugin.entities.marketplace import MarketplacePluginDeclaration, MarketplacePluginSnapshot
|
||||
from extensions.ext_redis import redis_client
|
||||
|
||||
marketplace_api_url = URL(str(dify_config.MARKETPLACE_API_URL))
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -43,28 +44,37 @@ def batch_fetch_plugin_by_ids(plugin_ids: list[str]) -> list[dict]:
|
||||
return data.get("data", {}).get("plugins", [])
|
||||
|
||||
|
||||
def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
plugin_ids: list[str],
|
||||
) -> Sequence[MarketplacePluginDeclaration]:
|
||||
if len(plugin_ids) == 0:
|
||||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
|
||||
response.raise_for_status()
|
||||
result: list[MarketplacePluginDeclaration] = []
|
||||
for plugin in response.json()["data"]["plugins"]:
|
||||
try:
|
||||
result.append(MarketplacePluginDeclaration.model_validate(plugin))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to deserialize marketplace plugin manifest for %s", plugin.get("plugin_id", "unknown")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def record_install_plugin_event(plugin_unique_identifier: str):
|
||||
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
|
||||
response = httpx.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def fetch_global_plugin_manifest(cache_key_prefix: str, cache_ttl: int) -> None:
|
||||
"""
|
||||
Fetch all plugin manifests from marketplace and cache them in Redis.
|
||||
This should be called once per check cycle to populate the instance-level cache.
|
||||
|
||||
Args:
|
||||
cache_key_prefix: Redis key prefix for caching plugin manifests
|
||||
cache_ttl: Cache TTL in seconds
|
||||
|
||||
Raises:
|
||||
httpx.HTTPError: If the HTTP request fails
|
||||
Exception: If any other error occurs during fetching or caching
|
||||
"""
|
||||
url = str(marketplace_api_url / "api/v1/dist/plugins/manifest.json")
|
||||
response = httpx.get(url, headers={"X-Dify-Version": dify_config.project.version}, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
raw_json = response.json()
|
||||
plugins_data = raw_json.get("plugins", [])
|
||||
|
||||
# Parse and cache all plugin snapshots
|
||||
for plugin_data in plugins_data:
|
||||
plugin_snapshot = MarketplacePluginSnapshot.model_validate(plugin_data)
|
||||
redis_client.setex(
|
||||
name=f"{cache_key_prefix}{plugin_snapshot.plugin_id}",
|
||||
time=cache_ttl,
|
||||
value=plugin_snapshot.model_dump_json(),
|
||||
)
|
||||
|
||||
@@ -15,10 +15,7 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
|
||||
from core.ops.entities.config_entity import (
|
||||
OPS_FILE_PATH,
|
||||
TracingProviderEnum,
|
||||
)
|
||||
from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
GenerateNameTraceInfo,
|
||||
@@ -31,8 +28,8 @@ from core.ops.entities.trace_entity import (
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import get_message_data
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.engine import db
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||
from models.workflow import WorkflowAppLog
|
||||
from tasks.ops_trace_task import process_trace_tasks
|
||||
@@ -469,6 +466,8 @@ class TraceTask:
|
||||
|
||||
@classmethod
|
||||
def _get_workflow_run_repo(cls):
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
if cls._workflow_run_repo is None:
|
||||
with cls._repo_lock:
|
||||
if cls._workflow_run_repo is None:
|
||||
|
||||
@@ -5,7 +5,7 @@ from urllib.parse import urlparse
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.engine import db
|
||||
from models.model import Message
|
||||
|
||||
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Union
|
||||
|
||||
@@ -11,6 +12,7 @@ from core.app.apps.chat.app_generator import ChatAppGenerator
|
||||
from core.app.apps.completion.app_generator import CompletionAppGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
@@ -101,6 +103,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=db.engine,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
return AdvancedChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
@@ -112,7 +119,9 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
streaming=stream,
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
elif app.mode == AppMode.AGENT_CHAT:
|
||||
return AgentChatAppGenerator().generate(
|
||||
@@ -159,6 +168,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=db.engine,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
return WorkflowAppGenerator().generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
@@ -167,6 +181,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
call_depth=1,
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from pydantic import BaseModel, Field, computed_field, model_validator
|
||||
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
||||
@@ -48,3 +48,15 @@ class MarketplacePluginDeclaration(BaseModel):
|
||||
if "tool" in data and not data["tool"]:
|
||||
del data["tool"]
|
||||
return data
|
||||
|
||||
|
||||
class MarketplacePluginSnapshot(BaseModel):
|
||||
org: str
|
||||
name: str
|
||||
latest_version: str
|
||||
latest_package_identifier: str
|
||||
latest_package_url: str
|
||||
|
||||
@computed_field
|
||||
def plugin_id(self) -> str:
|
||||
return f"{self.org}/{self.name}"
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
"""
|
||||
Repository implementations for data access.
|
||||
"""Repository implementations for data access."""
|
||||
|
||||
This package contains concrete implementations of the repository interfaces
|
||||
defined in the core.workflow.repository package.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
||||
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
||||
from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
|
||||
from .factory import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
__all__ = [
|
||||
"CeleryWorkflowExecutionRepository",
|
||||
"CeleryWorkflowNodeExecutionRepository",
|
||||
"DifyCoreRepositoryFactory",
|
||||
"RepositoryImportError",
|
||||
"SQLAlchemyWorkflowExecutionRepository",
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
]
|
||||
|
||||
553
api/core/repositories/human_input_repository.py
Normal file
553
api/core/repositories/human_input_repository.py
Normal file
@@ -0,0 +1,553 @@
|
||||
import dataclasses
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
DeliveryChannelConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
FormDefinition,
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
WebAppDeliveryMethod,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import (
|
||||
DeliveryMethodType,
|
||||
HumanInputFormKind,
|
||||
HumanInputFormStatus,
|
||||
)
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
FormNotFoundError,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRecipientEntity,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.account import Account, TenantAccountJoin
|
||||
from models.human_input import (
|
||||
BackstageRecipientPayload,
|
||||
ConsoleDeliveryPayload,
|
||||
ConsoleRecipientPayload,
|
||||
EmailExternalRecipientPayload,
|
||||
EmailMemberRecipientPayload,
|
||||
HumanInputDelivery,
|
||||
HumanInputForm,
|
||||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
StandaloneWebAppRecipientPayload,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _DeliveryAndRecipients:
|
||||
delivery: HumanInputDelivery
|
||||
recipients: Sequence[HumanInputFormRecipient]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _WorkspaceMemberInfo:
|
||||
user_id: str
|
||||
email: str
|
||||
|
||||
|
||||
class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity):
|
||||
def __init__(self, recipient_model: HumanInputFormRecipient):
|
||||
self._recipient_model = recipient_model
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._recipient_model.id
|
||||
|
||||
@property
|
||||
def token(self) -> str:
|
||||
if self._recipient_model.access_token is None:
|
||||
raise AssertionError(f"access_token should not be None for recipient {self._recipient_model.id}")
|
||||
return self._recipient_model.access_token
|
||||
|
||||
|
||||
class _HumanInputFormEntityImpl(HumanInputFormEntity):
|
||||
def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]):
|
||||
self._form_model = form_model
|
||||
self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models]
|
||||
self._web_app_recipient = next(
|
||||
(
|
||||
recipient
|
||||
for recipient in recipient_models
|
||||
if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP
|
||||
),
|
||||
None,
|
||||
)
|
||||
self._console_recipient = next(
|
||||
(recipient for recipient in recipient_models if recipient.recipient_type == RecipientType.CONSOLE),
|
||||
None,
|
||||
)
|
||||
self._submitted_data: Mapping[str, Any] | None = (
|
||||
json.loads(form_model.submitted_data) if form_model.submitted_data is not None else None
|
||||
)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._form_model.id
|
||||
|
||||
@property
|
||||
def web_app_token(self):
|
||||
if self._console_recipient is not None:
|
||||
return self._console_recipient.access_token
|
||||
if self._web_app_recipient is None:
|
||||
return None
|
||||
return self._web_app_recipient.access_token
|
||||
|
||||
@property
|
||||
def recipients(self) -> list[HumanInputFormRecipientEntity]:
|
||||
return list(self._recipients)
|
||||
|
||||
@property
|
||||
def rendered_content(self) -> str:
|
||||
return self._form_model.rendered_content
|
||||
|
||||
@property
|
||||
def selected_action_id(self) -> str | None:
|
||||
return self._form_model.selected_action_id
|
||||
|
||||
@property
|
||||
def submitted_data(self) -> Mapping[str, Any] | None:
|
||||
return self._submitted_data
|
||||
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
return self._form_model.submitted_at is not None
|
||||
|
||||
@property
|
||||
def status(self) -> HumanInputFormStatus:
|
||||
return self._form_model.status
|
||||
|
||||
@property
|
||||
def expiration_time(self) -> datetime:
|
||||
return self._form_model.expiration_time
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class HumanInputFormRecord:
|
||||
form_id: str
|
||||
workflow_run_id: str | None
|
||||
node_id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
form_kind: HumanInputFormKind
|
||||
definition: FormDefinition
|
||||
rendered_content: str
|
||||
created_at: datetime
|
||||
expiration_time: datetime
|
||||
status: HumanInputFormStatus
|
||||
selected_action_id: str | None
|
||||
submitted_data: Mapping[str, Any] | None
|
||||
submitted_at: datetime | None
|
||||
submission_user_id: str | None
|
||||
submission_end_user_id: str | None
|
||||
completed_by_recipient_id: str | None
|
||||
recipient_id: str | None
|
||||
recipient_type: RecipientType | None
|
||||
access_token: str | None
|
||||
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
return self.submitted_at is not None
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
cls, form_model: HumanInputForm, recipient_model: HumanInputFormRecipient | None
|
||||
) -> "HumanInputFormRecord":
|
||||
definition_payload = json.loads(form_model.form_definition)
|
||||
if "expiration_time" not in definition_payload:
|
||||
definition_payload["expiration_time"] = form_model.expiration_time
|
||||
return cls(
|
||||
form_id=form_model.id,
|
||||
workflow_run_id=form_model.workflow_run_id,
|
||||
node_id=form_model.node_id,
|
||||
tenant_id=form_model.tenant_id,
|
||||
app_id=form_model.app_id,
|
||||
form_kind=form_model.form_kind,
|
||||
definition=FormDefinition.model_validate(definition_payload),
|
||||
rendered_content=form_model.rendered_content,
|
||||
created_at=form_model.created_at,
|
||||
expiration_time=form_model.expiration_time,
|
||||
status=form_model.status,
|
||||
selected_action_id=form_model.selected_action_id,
|
||||
submitted_data=json.loads(form_model.submitted_data) if form_model.submitted_data else None,
|
||||
submitted_at=form_model.submitted_at,
|
||||
submission_user_id=form_model.submission_user_id,
|
||||
submission_end_user_id=form_model.submission_end_user_id,
|
||||
completed_by_recipient_id=form_model.completed_by_recipient_id,
|
||||
recipient_id=recipient_model.id if recipient_model else None,
|
||||
recipient_type=recipient_model.recipient_type if recipient_model else None,
|
||||
access_token=recipient_model.access_token if recipient_model else None,
|
||||
)
|
||||
|
||||
|
||||
class _InvalidTimeoutStatusError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class HumanInputFormRepositoryImpl:
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker | Engine,
|
||||
tenant_id: str,
|
||||
):
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
def _delivery_method_to_model(
|
||||
self,
|
||||
session: Session,
|
||||
form_id: str,
|
||||
delivery_method: DeliveryChannelConfig,
|
||||
) -> _DeliveryAndRecipients:
|
||||
delivery_id = str(uuidv7())
|
||||
delivery_model = HumanInputDelivery(
|
||||
id=delivery_id,
|
||||
form_id=form_id,
|
||||
delivery_method_type=delivery_method.type,
|
||||
delivery_config_id=delivery_method.id,
|
||||
channel_payload=delivery_method.model_dump_json(),
|
||||
)
|
||||
recipients: list[HumanInputFormRecipient] = []
|
||||
if isinstance(delivery_method, WebAppDeliveryMethod):
|
||||
recipient_model = HumanInputFormRecipient(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
recipient_payload=StandaloneWebAppRecipientPayload().model_dump_json(),
|
||||
)
|
||||
recipients.append(recipient_model)
|
||||
elif isinstance(delivery_method, EmailDeliveryMethod):
|
||||
email_recipients_config = delivery_method.config.recipients
|
||||
recipients.extend(
|
||||
self._build_email_recipients(
|
||||
session=session,
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipients_config=email_recipients_config,
|
||||
)
|
||||
)
|
||||
|
||||
return _DeliveryAndRecipients(delivery=delivery_model, recipients=recipients)
|
||||
|
||||
def _build_email_recipients(
|
||||
self,
|
||||
session: Session,
|
||||
form_id: str,
|
||||
delivery_id: str,
|
||||
recipients_config: EmailRecipients,
|
||||
) -> list[HumanInputFormRecipient]:
|
||||
member_user_ids = [
|
||||
recipient.user_id for recipient in recipients_config.items if isinstance(recipient, MemberRecipient)
|
||||
]
|
||||
external_emails = [
|
||||
recipient.email for recipient in recipients_config.items if isinstance(recipient, ExternalRecipient)
|
||||
]
|
||||
if recipients_config.whole_workspace:
|
||||
members = self._query_all_workspace_members(session=session)
|
||||
else:
|
||||
members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=member_user_ids)
|
||||
|
||||
return self._create_email_recipients_from_resolved(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
members=members,
|
||||
external_emails=external_emails,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_email_recipients_from_resolved(
|
||||
*,
|
||||
form_id: str,
|
||||
delivery_id: str,
|
||||
members: Sequence[_WorkspaceMemberInfo],
|
||||
external_emails: Sequence[str],
|
||||
) -> list[HumanInputFormRecipient]:
|
||||
recipient_models: list[HumanInputFormRecipient] = []
|
||||
seen_emails: set[str] = set()
|
||||
|
||||
for member in members:
|
||||
if not member.email:
|
||||
continue
|
||||
if member.email in seen_emails:
|
||||
continue
|
||||
seen_emails.add(member.email)
|
||||
payload = EmailMemberRecipientPayload(user_id=member.user_id, email=member.email)
|
||||
recipient_models.append(
|
||||
HumanInputFormRecipient.new(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
|
||||
for email in external_emails:
|
||||
if not email:
|
||||
continue
|
||||
if email in seen_emails:
|
||||
continue
|
||||
seen_emails.add(email)
|
||||
recipient_models.append(
|
||||
HumanInputFormRecipient.new(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
payload=EmailExternalRecipientPayload(email=email),
|
||||
)
|
||||
)
|
||||
|
||||
return recipient_models
|
||||
|
||||
def _query_all_workspace_members(
|
||||
self,
|
||||
session: Session,
|
||||
) -> list[_WorkspaceMemberInfo]:
|
||||
stmt = (
|
||||
select(Account.id, Account.email)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
|
||||
.where(TenantAccountJoin.tenant_id == self._tenant_id)
|
||||
)
|
||||
rows = session.execute(stmt).all()
|
||||
return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows]
|
||||
|
||||
def _query_workspace_members_by_ids(
|
||||
self,
|
||||
session: Session,
|
||||
restrict_to_user_ids: Sequence[str],
|
||||
) -> list[_WorkspaceMemberInfo]:
|
||||
unique_ids = {user_id for user_id in restrict_to_user_ids if user_id}
|
||||
if not unique_ids:
|
||||
return []
|
||||
|
||||
stmt = (
|
||||
select(Account.id, Account.email)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
|
||||
.where(TenantAccountJoin.tenant_id == self._tenant_id)
|
||||
)
|
||||
stmt = stmt.where(Account.id.in_(unique_ids))
|
||||
|
||||
rows = session.execute(stmt).all()
|
||||
return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows]
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
form_config: HumanInputNodeData = params.form_config
|
||||
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
# Generate unique form ID
|
||||
form_id = str(uuidv7())
|
||||
start_time = naive_utc_now()
|
||||
node_expiration = form_config.expiration_time(start_time)
|
||||
form_definition = FormDefinition(
|
||||
form_content=form_config.form_content,
|
||||
inputs=form_config.inputs,
|
||||
user_actions=form_config.user_actions,
|
||||
rendered_content=params.rendered_content,
|
||||
expiration_time=node_expiration,
|
||||
default_values=dict(params.resolved_default_values),
|
||||
display_in_ui=params.display_in_ui,
|
||||
node_title=form_config.title,
|
||||
)
|
||||
form_model = HumanInputForm(
|
||||
id=form_id,
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=params.app_id,
|
||||
workflow_run_id=params.workflow_execution_id,
|
||||
form_kind=params.form_kind,
|
||||
node_id=params.node_id,
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content=params.rendered_content,
|
||||
expiration_time=node_expiration,
|
||||
created_at=start_time,
|
||||
)
|
||||
session.add(form_model)
|
||||
recipient_models: list[HumanInputFormRecipient] = []
|
||||
for delivery in params.delivery_methods:
|
||||
delivery_and_recipients = self._delivery_method_to_model(
|
||||
session=session,
|
||||
form_id=form_id,
|
||||
delivery_method=delivery,
|
||||
)
|
||||
session.add(delivery_and_recipients.delivery)
|
||||
session.add_all(delivery_and_recipients.recipients)
|
||||
recipient_models.extend(delivery_and_recipients.recipients)
|
||||
if params.console_recipient_required and not any(
|
||||
recipient.recipient_type == RecipientType.CONSOLE for recipient in recipient_models
|
||||
):
|
||||
console_delivery_id = str(uuidv7())
|
||||
console_delivery = HumanInputDelivery(
|
||||
id=console_delivery_id,
|
||||
form_id=form_id,
|
||||
delivery_method_type=DeliveryMethodType.WEBAPP,
|
||||
delivery_config_id=None,
|
||||
channel_payload=ConsoleDeliveryPayload().model_dump_json(),
|
||||
)
|
||||
console_recipient = HumanInputFormRecipient(
|
||||
form_id=form_id,
|
||||
delivery_id=console_delivery_id,
|
||||
recipient_type=RecipientType.CONSOLE,
|
||||
recipient_payload=ConsoleRecipientPayload(
|
||||
account_id=params.console_creator_account_id,
|
||||
).model_dump_json(),
|
||||
)
|
||||
session.add(console_delivery)
|
||||
session.add(console_recipient)
|
||||
recipient_models.append(console_recipient)
|
||||
if params.backstage_recipient_required and not any(
|
||||
recipient.recipient_type == RecipientType.BACKSTAGE for recipient in recipient_models
|
||||
):
|
||||
backstage_delivery_id = str(uuidv7())
|
||||
backstage_delivery = HumanInputDelivery(
|
||||
id=backstage_delivery_id,
|
||||
form_id=form_id,
|
||||
delivery_method_type=DeliveryMethodType.WEBAPP,
|
||||
delivery_config_id=None,
|
||||
channel_payload=ConsoleDeliveryPayload().model_dump_json(),
|
||||
)
|
||||
backstage_recipient = HumanInputFormRecipient(
|
||||
form_id=form_id,
|
||||
delivery_id=backstage_delivery_id,
|
||||
recipient_type=RecipientType.BACKSTAGE,
|
||||
recipient_payload=BackstageRecipientPayload(
|
||||
account_id=params.console_creator_account_id,
|
||||
).model_dump_json(),
|
||||
)
|
||||
session.add(backstage_delivery)
|
||||
session.add(backstage_recipient)
|
||||
recipient_models.append(backstage_recipient)
|
||||
session.flush()
|
||||
|
||||
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
|
||||
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
form_query = select(HumanInputForm).where(
|
||||
HumanInputForm.workflow_run_id == workflow_execution_id,
|
||||
HumanInputForm.node_id == node_id,
|
||||
HumanInputForm.tenant_id == self._tenant_id,
|
||||
)
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
form_model: HumanInputForm | None = session.scalars(form_query).first()
|
||||
if form_model is None:
|
||||
return None
|
||||
|
||||
recipient_query = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_model.id)
|
||||
recipient_models = session.scalars(recipient_query).all()
|
||||
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
|
||||
|
||||
|
||||
class HumanInputFormSubmissionRepository:
|
||||
"""Repository for fetching and submitting human input forms."""
|
||||
|
||||
def __init__(self, session_factory: sessionmaker | Engine):
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
|
||||
def get_by_token(self, form_token: str) -> HumanInputFormRecord | None:
|
||||
query = (
|
||||
select(HumanInputFormRecipient)
|
||||
.options(selectinload(HumanInputFormRecipient.form))
|
||||
.where(HumanInputFormRecipient.access_token == form_token)
|
||||
)
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
recipient_model = session.scalars(query).first()
|
||||
if recipient_model is None or recipient_model.form is None:
|
||||
return None
|
||||
return HumanInputFormRecord.from_models(recipient_model.form, recipient_model)
|
||||
|
||||
def get_by_form_id_and_recipient_type(
|
||||
self,
|
||||
form_id: str,
|
||||
recipient_type: RecipientType,
|
||||
) -> HumanInputFormRecord | None:
|
||||
query = (
|
||||
select(HumanInputFormRecipient)
|
||||
.options(selectinload(HumanInputFormRecipient.form))
|
||||
.where(
|
||||
HumanInputFormRecipient.form_id == form_id,
|
||||
HumanInputFormRecipient.recipient_type == recipient_type,
|
||||
)
|
||||
)
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
recipient_model = session.scalars(query).first()
|
||||
if recipient_model is None or recipient_model.form is None:
|
||||
return None
|
||||
return HumanInputFormRecord.from_models(recipient_model.form, recipient_model)
|
||||
|
||||
def mark_submitted(
|
||||
self,
|
||||
*,
|
||||
form_id: str,
|
||||
recipient_id: str | None,
|
||||
selected_action_id: str,
|
||||
form_data: Mapping[str, Any],
|
||||
submission_user_id: str | None,
|
||||
submission_end_user_id: str | None,
|
||||
) -> HumanInputFormRecord:
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
form_model = session.get(HumanInputForm, form_id)
|
||||
if form_model is None:
|
||||
raise FormNotFoundError(f"form not found, id={form_id}")
|
||||
|
||||
recipient_model = session.get(HumanInputFormRecipient, recipient_id) if recipient_id else None
|
||||
|
||||
form_model.selected_action_id = selected_action_id
|
||||
form_model.submitted_data = json.dumps(form_data)
|
||||
form_model.submitted_at = naive_utc_now()
|
||||
form_model.status = HumanInputFormStatus.SUBMITTED
|
||||
form_model.submission_user_id = submission_user_id
|
||||
form_model.submission_end_user_id = submission_end_user_id
|
||||
form_model.completed_by_recipient_id = recipient_id
|
||||
|
||||
session.add(form_model)
|
||||
session.flush()
|
||||
session.refresh(form_model)
|
||||
if recipient_model is not None:
|
||||
session.refresh(recipient_model)
|
||||
|
||||
return HumanInputFormRecord.from_models(form_model, recipient_model)
|
||||
|
||||
def mark_timeout(
|
||||
self,
|
||||
*,
|
||||
form_id: str,
|
||||
timeout_status: HumanInputFormStatus,
|
||||
reason: str | None = None,
|
||||
) -> HumanInputFormRecord:
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
form_model = session.get(HumanInputForm, form_id)
|
||||
if form_model is None:
|
||||
raise FormNotFoundError(f"form not found, id={form_id}")
|
||||
|
||||
if timeout_status not in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}:
|
||||
raise _InvalidTimeoutStatusError(f"invalid timeout status: {timeout_status}")
|
||||
|
||||
# already handled or submitted
|
||||
if form_model.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}:
|
||||
return HumanInputFormRecord.from_models(form_model, None)
|
||||
|
||||
if form_model.submitted_at is not None or form_model.status == HumanInputFormStatus.SUBMITTED:
|
||||
raise FormNotFoundError(f"form already submitted, id={form_id}")
|
||||
|
||||
form_model.status = timeout_status
|
||||
form_model.selected_action_id = None
|
||||
form_model.submitted_data = None
|
||||
form_model.submission_user_id = None
|
||||
form_model.submission_end_user_id = None
|
||||
form_model.completed_by_recipient_id = None
|
||||
# Reason is recorded in status/error downstream; not stored on form.
|
||||
session.add(form_model)
|
||||
session.flush()
|
||||
session.refresh(form_model)
|
||||
|
||||
return HumanInputFormRecord.from_models(form_model, None)
|
||||
@@ -488,6 +488,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecutionModel.triggered_from == triggered_from,
|
||||
WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from core.tools.entities.tool_entities import ToolInvokeMeta
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class ToolProviderNotFoundError(ValueError):
|
||||
@@ -37,6 +38,12 @@ class ToolCredentialPolicyViolationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class WorkflowToolHumanInputNotSupportedError(BaseHTTPException):
|
||||
error_code = "workflow_tool_human_input_not_supported"
|
||||
description = "Workflow with Human Input nodes cannot be published as a workflow tool."
|
||||
code = 400
|
||||
|
||||
|
||||
class ToolEngineInvokeError(Exception):
|
||||
meta: ToolInvokeMeta
|
||||
|
||||
|
||||
@@ -3,6 +3,8 @@ from typing import Any
|
||||
|
||||
from core.app.app_config.entities import VariableEntity
|
||||
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
|
||||
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
|
||||
from core.workflow.enums import NodeType
|
||||
from core.workflow.nodes.base.entities import OutputVariableEntity
|
||||
|
||||
|
||||
@@ -45,6 +47,13 @@ class WorkflowToolConfigurationUtils:
|
||||
|
||||
return [outputs_by_variable[variable] for variable in variable_order]
|
||||
|
||||
@classmethod
|
||||
def ensure_no_human_input_nodes(cls, graph: Mapping[str, Any]) -> None:
|
||||
nodes = graph.get("nodes", [])
|
||||
for node in nodes:
|
||||
if node.get("data", {}).get("type") == NodeType.HUMAN_INPUT:
|
||||
raise WorkflowToolHumanInputNotSupportedError()
|
||||
|
||||
@classmethod
|
||||
def check_is_synced(
|
||||
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]
|
||||
|
||||
@@ -98,6 +98,10 @@ class WorkflowTool(Tool):
|
||||
invoke_from=self.runtime.invoke_from,
|
||||
streaming=False,
|
||||
call_depth=self.workflow_call_depth + 1,
|
||||
# NOTE(QuantumGhost): We explicitly set `pause_state_config` to `None`
|
||||
# because workflow pausing mechanisms (such as HumanInput) are not
|
||||
# supported within WorkflowTool execution context.
|
||||
pause_state_config=None,
|
||||
)
|
||||
assert isinstance(result, dict)
|
||||
data = result.get("data", {})
|
||||
|
||||
@@ -112,7 +112,7 @@ class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
|
||||
|
||||
class RAGPipelineVariable(BaseModel):
|
||||
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
|
||||
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
|
||||
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
|
||||
label: str = Field(description="label")
|
||||
description: str | None = Field(description="description", default="")
|
||||
variable: str = Field(description="variable key", default="")
|
||||
|
||||
@@ -2,10 +2,12 @@ from .agent import AgentNodeStrategyInit
|
||||
from .graph_init_params import GraphInitParams
|
||||
from .workflow_execution import WorkflowExecution
|
||||
from .workflow_node_execution import WorkflowNodeExecution
|
||||
from .workflow_start_reason import WorkflowStartReason
|
||||
|
||||
__all__ = [
|
||||
"AgentNodeStrategyInit",
|
||||
"GraphInitParams",
|
||||
"WorkflowExecution",
|
||||
"WorkflowNodeExecution",
|
||||
"WorkflowStartReason",
|
||||
]
|
||||
|
||||
@@ -5,6 +5,16 @@ from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class GraphInitParams(BaseModel):
|
||||
"""GraphInitParams encapsulates the configurations and contextual information
|
||||
that remain constant throughout a single execution of the graph engine.
|
||||
|
||||
A single execution is defined as follows: as long as the execution has not reached
|
||||
its conclusion, it is considered one execution. For instance, if a workflow is suspended
|
||||
and later resumed, it is still regarded as a single execution, not two.
|
||||
|
||||
For the state diagram of workflow execution, refer to `WorkflowExecutionStatus`.
|
||||
"""
|
||||
|
||||
# init params
|
||||
tenant_id: str = Field(..., description="tenant / workspace id")
|
||||
app_id: str = Field(..., description="app id")
|
||||
|
||||
@@ -1,8 +1,11 @@
|
||||
from collections.abc import Mapping
|
||||
from enum import StrEnum, auto
|
||||
from typing import Annotated, Literal, TypeAlias
|
||||
from typing import Annotated, Any, Literal, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
|
||||
|
||||
class PauseReasonType(StrEnum):
|
||||
HUMAN_INPUT_REQUIRED = auto()
|
||||
@@ -11,10 +14,31 @@ class PauseReasonType(StrEnum):
|
||||
|
||||
class HumanInputRequired(BaseModel):
|
||||
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
|
||||
|
||||
form_id: str
|
||||
# The identifier of the human input node causing the pause.
|
||||
form_content: str
|
||||
inputs: list[FormInput] = Field(default_factory=list)
|
||||
actions: list[UserAction] = Field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
node_id: str
|
||||
node_title: str
|
||||
|
||||
# The `resolved_default_values` stores the resolved values of variable defaults. It's a mapping from
|
||||
# `output_variable_name` to their resolved values.
|
||||
#
|
||||
# For example, The form contains a input with output variable name `name` and placeholder type `VARIABLE`, its
|
||||
# selector is ["start", "name"]. While the HumanInputNode is executed, the correspond value of variable
|
||||
# `start.name` in variable pool is `John`. Thus, the resolved value of the output variable `name` is `John`. The
|
||||
# `resolved_default_values` is `{"name": "John"}`.
|
||||
#
|
||||
# Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`.
|
||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# The `form_token` is the token used to submit the form via UI surfaces. It corresponds to
|
||||
# `HumanInputFormRecipient.access_token`.
|
||||
#
|
||||
# This field is `None` if webapp delivery is not set and not
|
||||
# in orchestrating mode.
|
||||
form_token: str | None = None
|
||||
|
||||
|
||||
class SchedulingPause(BaseModel):
|
||||
|
||||
8
api/core/workflow/entities/workflow_start_reason.py
Normal file
8
api/core/workflow/entities/workflow_start_reason.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from enum import StrEnum
|
||||
|
||||
|
||||
class WorkflowStartReason(StrEnum):
|
||||
"""Reason for workflow start events across graph/queue/SSE layers."""
|
||||
|
||||
INITIAL = "initial" # First start of a workflow run.
|
||||
RESUMPTION = "resumption" # Start triggered after resuming a paused run.
|
||||
15
api/core/workflow/graph_engine/_engine_utils.py
Normal file
15
api/core/workflow/graph_engine/_engine_utils.py
Normal file
@@ -0,0 +1,15 @@
|
||||
import time
|
||||
|
||||
|
||||
def get_timestamp() -> float:
|
||||
"""Retrieve a timestamp as a float point numer representing the number of seconds
|
||||
since the Unix epoch.
|
||||
|
||||
This function is primarily used to measure the execution time of the workflow engine.
|
||||
Since workflow execution may be paused and resumed on a different machine,
|
||||
`time.perf_counter` cannot be used as it is inconsistent across machines.
|
||||
|
||||
To address this, the function uses the wall clock as the time source.
|
||||
However, it assumes that the clocks of all servers are properly synchronized.
|
||||
"""
|
||||
return round(time.time())
|
||||
@@ -2,12 +2,14 @@
|
||||
GraphEngine configuration models.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
class GraphEngineConfig(BaseModel):
|
||||
"""Configuration for GraphEngine worker pool scaling."""
|
||||
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
min_workers: int = 1
|
||||
max_workers: int = 5
|
||||
scale_up_threshold: int = 3
|
||||
|
||||
@@ -192,9 +192,13 @@ class EventHandler:
|
||||
self._event_collector.collect(edge_event)
|
||||
|
||||
# Enqueue ready nodes
|
||||
for node_id in ready_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
if self._graph_execution.is_paused:
|
||||
for node_id in ready_nodes:
|
||||
self._graph_runtime_state.register_deferred_node(node_id)
|
||||
else:
|
||||
for node_id in ready_nodes:
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
# Update execution tracking
|
||||
self._state_manager.finish_execution(event.node_id)
|
||||
|
||||
@@ -14,6 +14,7 @@ from collections.abc import Generator
|
||||
from typing import TYPE_CHECKING, cast, final
|
||||
|
||||
from core.workflow.context import capture_current_context
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import NodeExecutionType
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_events import (
|
||||
@@ -55,6 +56,9 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
_DEFAULT_CONFIG = GraphEngineConfig()
|
||||
|
||||
|
||||
@final
|
||||
class GraphEngine:
|
||||
"""
|
||||
@@ -70,7 +74,7 @@ class GraphEngine:
|
||||
graph: Graph,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
command_channel: CommandChannel,
|
||||
config: GraphEngineConfig,
|
||||
config: GraphEngineConfig = _DEFAULT_CONFIG,
|
||||
) -> None:
|
||||
"""Initialize the graph engine with all subsystems and dependencies."""
|
||||
# stop event
|
||||
@@ -234,7 +238,9 @@ class GraphEngine:
|
||||
self._graph_execution.paused = False
|
||||
self._graph_execution.pause_reasons = []
|
||||
|
||||
start_event = GraphRunStartedEvent()
|
||||
start_event = GraphRunStartedEvent(
|
||||
reason=WorkflowStartReason.RESUMPTION if is_resume else WorkflowStartReason.INITIAL,
|
||||
)
|
||||
self._event_manager.notify_layers(start_event)
|
||||
yield start_event
|
||||
|
||||
@@ -303,15 +309,17 @@ class GraphEngine:
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_graph_start()
|
||||
except Exception as e:
|
||||
logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e)
|
||||
except Exception:
|
||||
logger.exception("Layer %s failed on_graph_start", layer.__class__.__name__)
|
||||
|
||||
def _start_execution(self, *, resume: bool = False) -> None:
|
||||
"""Start execution subsystems."""
|
||||
self._stop_event.clear()
|
||||
paused_nodes: list[str] = []
|
||||
deferred_nodes: list[str] = []
|
||||
if resume:
|
||||
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
|
||||
deferred_nodes = self._graph_runtime_state.consume_deferred_nodes()
|
||||
|
||||
# Start worker pool (it calculates initial workers internally)
|
||||
self._worker_pool.start()
|
||||
@@ -327,7 +335,11 @@ class GraphEngine:
|
||||
self._state_manager.enqueue_node(root_node.id)
|
||||
self._state_manager.start_execution(root_node.id)
|
||||
else:
|
||||
for node_id in paused_nodes:
|
||||
seen_nodes: set[str] = set()
|
||||
for node_id in paused_nodes + deferred_nodes:
|
||||
if node_id in seen_nodes:
|
||||
continue
|
||||
seen_nodes.add(node_id)
|
||||
self._state_manager.enqueue_node(node_id)
|
||||
self._state_manager.start_execution(node_id)
|
||||
|
||||
@@ -345,8 +357,8 @@ class GraphEngine:
|
||||
for layer in self._layers:
|
||||
try:
|
||||
layer.on_graph_end(self._graph_execution.error)
|
||||
except Exception as e:
|
||||
logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e)
|
||||
except Exception:
|
||||
logger.exception("Layer %s failed on_graph_end", layer.__class__.__name__)
|
||||
|
||||
# Public property accessors for attributes that need external access
|
||||
@property
|
||||
|
||||
@@ -224,6 +224,8 @@ class GraphStateManager:
|
||||
Returns:
|
||||
Number of executing nodes
|
||||
"""
|
||||
# This count is a best-effort snapshot and can change concurrently.
|
||||
# Only use it for pause-drain checks where scheduling is already frozen.
|
||||
with self._lock:
|
||||
return len(self._executing_nodes)
|
||||
|
||||
|
||||
@@ -83,12 +83,12 @@ class Dispatcher:
|
||||
"""Main dispatcher loop."""
|
||||
try:
|
||||
self._process_commands()
|
||||
paused = False
|
||||
while not self._stop_event.is_set():
|
||||
if (
|
||||
self._execution_coordinator.aborted
|
||||
or self._execution_coordinator.paused
|
||||
or self._execution_coordinator.execution_complete
|
||||
):
|
||||
if self._execution_coordinator.aborted or self._execution_coordinator.execution_complete:
|
||||
break
|
||||
if self._execution_coordinator.paused:
|
||||
paused = True
|
||||
break
|
||||
|
||||
self._execution_coordinator.check_scaling()
|
||||
@@ -101,13 +101,10 @@ class Dispatcher:
|
||||
time.sleep(0.1)
|
||||
|
||||
self._process_commands()
|
||||
while True:
|
||||
try:
|
||||
event = self._event_queue.get(block=False)
|
||||
self._event_handler.dispatch(event)
|
||||
self._event_queue.task_done()
|
||||
except queue.Empty:
|
||||
break
|
||||
if paused:
|
||||
self._drain_events_until_idle()
|
||||
else:
|
||||
self._drain_event_queue()
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Dispatcher error")
|
||||
@@ -122,3 +119,24 @@ class Dispatcher:
|
||||
def _process_commands(self, event: GraphNodeEventBase | None = None):
|
||||
if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS):
|
||||
self._execution_coordinator.process_commands()
|
||||
|
||||
def _drain_event_queue(self) -> None:
|
||||
while True:
|
||||
try:
|
||||
event = self._event_queue.get(block=False)
|
||||
self._event_handler.dispatch(event)
|
||||
self._event_queue.task_done()
|
||||
except queue.Empty:
|
||||
break
|
||||
|
||||
def _drain_events_until_idle(self) -> None:
|
||||
while not self._stop_event.is_set():
|
||||
try:
|
||||
event = self._event_queue.get(timeout=0.1)
|
||||
self._event_handler.dispatch(event)
|
||||
self._event_queue.task_done()
|
||||
self._process_commands(event)
|
||||
except queue.Empty:
|
||||
if not self._execution_coordinator.has_executing_nodes():
|
||||
break
|
||||
self._drain_event_queue()
|
||||
|
||||
@@ -94,3 +94,11 @@ class ExecutionCoordinator:
|
||||
|
||||
self._worker_pool.stop()
|
||||
self._state_manager.clear_executing()
|
||||
|
||||
def has_executing_nodes(self) -> bool:
|
||||
"""Return True if any nodes are currently marked as executing."""
|
||||
# This check is only safe once execution has already paused.
|
||||
# Before pause, executing state can change concurrently, which makes the result unreliable.
|
||||
if not self._graph_execution.is_paused:
|
||||
raise AssertionError("has_executing_nodes should only be called after execution is paused")
|
||||
return self._state_manager.get_executing_count() > 0
|
||||
|
||||
@@ -38,6 +38,8 @@ from .loop import (
|
||||
from .node import (
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunHumanInputFormFilledEvent,
|
||||
NodeRunHumanInputFormTimeoutEvent,
|
||||
NodeRunPauseRequestedEvent,
|
||||
NodeRunRetrieverResourceEvent,
|
||||
NodeRunRetryEvent,
|
||||
@@ -60,6 +62,8 @@ __all__ = [
|
||||
"NodeRunAgentLogEvent",
|
||||
"NodeRunExceptionEvent",
|
||||
"NodeRunFailedEvent",
|
||||
"NodeRunHumanInputFormFilledEvent",
|
||||
"NodeRunHumanInputFormTimeoutEvent",
|
||||
"NodeRunIterationFailedEvent",
|
||||
"NodeRunIterationNextEvent",
|
||||
"NodeRunIterationStartedEvent",
|
||||
|
||||
@@ -1,11 +1,16 @@
|
||||
from pydantic import Field
|
||||
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.graph_events import BaseGraphEvent
|
||||
|
||||
|
||||
class GraphRunStartedEvent(BaseGraphEvent):
|
||||
pass
|
||||
# Reason is emitted for workflow start events and is always set.
|
||||
reason: WorkflowStartReason = Field(
|
||||
default=WorkflowStartReason.INITIAL,
|
||||
description="reason for workflow start",
|
||||
)
|
||||
|
||||
|
||||
class GraphRunSucceededEvent(BaseGraphEvent):
|
||||
|
||||
0
api/core/workflow/graph_events/human_input.py
Normal file
0
api/core/workflow/graph_events/human_input.py
Normal file
@@ -54,6 +54,22 @@ class NodeRunRetryEvent(NodeRunStartedEvent):
|
||||
retry_index: int = Field(..., description="which retry attempt is about to be performed")
|
||||
|
||||
|
||||
class NodeRunHumanInputFormFilledEvent(GraphNodeEventBase):
|
||||
"""Emitted when a HumanInput form is submitted and before the node finishes."""
|
||||
|
||||
node_title: str = Field(..., description="HumanInput node title")
|
||||
rendered_content: str = Field(..., description="Markdown content rendered with user inputs.")
|
||||
action_id: str = Field(..., description="User action identifier chosen in the form.")
|
||||
action_text: str = Field(..., description="Display text of the chosen action button.")
|
||||
|
||||
|
||||
class NodeRunHumanInputFormTimeoutEvent(GraphNodeEventBase):
|
||||
"""Emitted when a HumanInput form times out."""
|
||||
|
||||
node_title: str = Field(..., description="HumanInput node title")
|
||||
expiration_time: datetime = Field(..., description="Form expiration time")
|
||||
|
||||
|
||||
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
|
||||
reason: PauseReason = Field(..., description="pause reason")
|
||||
|
||||
|
||||
@@ -13,6 +13,8 @@ from .loop import (
|
||||
LoopSucceededEvent,
|
||||
)
|
||||
from .node import (
|
||||
HumanInputFormFilledEvent,
|
||||
HumanInputFormTimeoutEvent,
|
||||
ModelInvokeCompletedEvent,
|
||||
PauseRequestedEvent,
|
||||
RunRetrieverResourceEvent,
|
||||
@@ -23,6 +25,8 @@ from .node import (
|
||||
|
||||
__all__ = [
|
||||
"AgentLogEvent",
|
||||
"HumanInputFormFilledEvent",
|
||||
"HumanInputFormTimeoutEvent",
|
||||
"IterationFailedEvent",
|
||||
"IterationNextEvent",
|
||||
"IterationStartedEvent",
|
||||
|
||||
@@ -47,3 +47,19 @@ class StreamCompletedEvent(NodeEventBase):
|
||||
|
||||
class PauseRequestedEvent(NodeEventBase):
|
||||
reason: PauseReason = Field(..., description="pause reason")
|
||||
|
||||
|
||||
class HumanInputFormFilledEvent(NodeEventBase):
|
||||
"""Event emitted when a human input form is submitted."""
|
||||
|
||||
node_title: str
|
||||
rendered_content: str
|
||||
action_id: str
|
||||
action_text: str
|
||||
|
||||
|
||||
class HumanInputFormTimeoutEvent(NodeEventBase):
|
||||
"""Event emitted when a human input form times out."""
|
||||
|
||||
node_title: str
|
||||
expiration_time: datetime
|
||||
|
||||
@@ -2,7 +2,7 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Union, cast
|
||||
|
||||
from packaging.version import Version
|
||||
from pydantic import ValidationError
|
||||
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.agent.entities import AgentToolEntity
|
||||
from core.agent.plugin_entities import AgentStrategyParameter
|
||||
from core.db.session_factory import session_factory
|
||||
from core.file import File, FileTransferMethod
|
||||
from core.memory.token_buffer_memory import TokenBufferMemory
|
||||
from core.model_manager import ModelInstance, ModelManager
|
||||
@@ -49,6 +50,12 @@ from factories import file_factory
|
||||
from factories.agent_factory import get_plugin_agent_strategy
|
||||
from models import ToolFile
|
||||
from models.model import Conversation
|
||||
from models.tools import (
|
||||
ApiToolProvider,
|
||||
BuiltinToolProvider,
|
||||
MCPToolProvider,
|
||||
WorkflowToolProvider,
|
||||
)
|
||||
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
|
||||
|
||||
from .exc import (
|
||||
@@ -259,7 +266,7 @@ class AgentNode(Node[AgentNodeData]):
|
||||
value = cast(list[dict[str, Any]], value)
|
||||
tool_value = []
|
||||
for tool in value:
|
||||
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
|
||||
provider_type = self._infer_tool_provider_type(tool, self.tenant_id)
|
||||
setting_params = tool.get("settings", {})
|
||||
parameters = tool.get("parameters", {})
|
||||
manual_input_params = [key for key, value in parameters.items() if value is not None]
|
||||
@@ -748,3 +755,34 @@ class AgentNode(Node[AgentNodeData]):
|
||||
llm_usage=llm_usage,
|
||||
)
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _infer_tool_provider_type(tool_config: dict[str, Any], tenant_id: str) -> ToolProviderType:
|
||||
provider_type_str = tool_config.get("type")
|
||||
if provider_type_str:
|
||||
return ToolProviderType(provider_type_str)
|
||||
|
||||
provider_id = tool_config.get("provider_name")
|
||||
if not provider_id:
|
||||
return ToolProviderType.BUILT_IN
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
provider_map: dict[
|
||||
type[Union[WorkflowToolProvider, MCPToolProvider, ApiToolProvider, BuiltinToolProvider]],
|
||||
ToolProviderType,
|
||||
] = {
|
||||
WorkflowToolProvider: ToolProviderType.WORKFLOW,
|
||||
MCPToolProvider: ToolProviderType.MCP,
|
||||
ApiToolProvider: ToolProviderType.API,
|
||||
BuiltinToolProvider: ToolProviderType.BUILT_IN,
|
||||
}
|
||||
|
||||
for provider_model, provider_type in provider_map.items():
|
||||
stmt = select(provider_model).where(
|
||||
provider_model.id == provider_id,
|
||||
provider_model.tenant_id == tenant_id,
|
||||
)
|
||||
if session.scalar(stmt):
|
||||
return provider_type
|
||||
|
||||
raise AgentNodeError(f"Tool provider with ID '{provider_id}' not found.")
|
||||
|
||||
@@ -18,6 +18,8 @@ from core.workflow.graph_events import (
|
||||
GraphNodeEventBase,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunHumanInputFormFilledEvent,
|
||||
NodeRunHumanInputFormTimeoutEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
@@ -34,6 +36,8 @@ from core.workflow.graph_events import (
|
||||
)
|
||||
from core.workflow.node_events import (
|
||||
AgentLogEvent,
|
||||
HumanInputFormFilledEvent,
|
||||
HumanInputFormTimeoutEvent,
|
||||
IterationFailedEvent,
|
||||
IterationNextEvent,
|
||||
IterationStartedEvent,
|
||||
@@ -61,6 +65,15 @@ logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class Node(Generic[NodeDataT]):
|
||||
"""BaseNode serves as the foundational class for all node implementations.
|
||||
|
||||
Nodes are allowed to maintain transient states (e.g., `LLMNode` uses the `_file_output`
|
||||
attribute to track files generated by the LLM). However, these states are not persisted
|
||||
when the workflow is suspended or resumed. If a node needs its state to be preserved
|
||||
across workflow suspension and resumption, it should include the relevant state data
|
||||
in its output.
|
||||
"""
|
||||
|
||||
node_type: ClassVar[NodeType]
|
||||
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
|
||||
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
|
||||
@@ -251,10 +264,33 @@ class Node(Generic[NodeDataT]):
|
||||
return self._node_execution_id
|
||||
|
||||
def ensure_execution_id(self) -> str:
|
||||
if not self._node_execution_id:
|
||||
self._node_execution_id = str(uuid4())
|
||||
if self._node_execution_id:
|
||||
return self._node_execution_id
|
||||
|
||||
resumed_execution_id = self._restore_execution_id_from_runtime_state()
|
||||
if resumed_execution_id:
|
||||
self._node_execution_id = resumed_execution_id
|
||||
return self._node_execution_id
|
||||
|
||||
self._node_execution_id = str(uuid4())
|
||||
return self._node_execution_id
|
||||
|
||||
def _restore_execution_id_from_runtime_state(self) -> str | None:
|
||||
graph_execution = self.graph_runtime_state.graph_execution
|
||||
try:
|
||||
node_executions = graph_execution.node_executions
|
||||
except AttributeError:
|
||||
return None
|
||||
if not isinstance(node_executions, dict):
|
||||
return None
|
||||
node_execution = node_executions.get(self._node_id)
|
||||
if node_execution is None:
|
||||
return None
|
||||
execution_id = node_execution.execution_id
|
||||
if not execution_id:
|
||||
return None
|
||||
return str(execution_id)
|
||||
|
||||
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
|
||||
return cast(NodeDataT, self._node_data_type.model_validate(data))
|
||||
|
||||
@@ -620,6 +656,28 @@ class Node(Generic[NodeDataT]):
|
||||
metadata=event.metadata,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: HumanInputFormFilledEvent):
|
||||
return NodeRunHumanInputFormFilledEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=event.node_title,
|
||||
rendered_content=event.rendered_content,
|
||||
action_id=event.action_id,
|
||||
action_text=event.action_text,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: HumanInputFormTimeoutEvent):
|
||||
return NodeRunHumanInputFormTimeoutEvent(
|
||||
id=self.execution_id,
|
||||
node_id=self._node_id,
|
||||
node_type=self.node_type,
|
||||
node_title=event.node_title,
|
||||
expiration_time=event.expiration_time,
|
||||
)
|
||||
|
||||
@_dispatch.register
|
||||
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
|
||||
return NodeRunLoopStartedEvent(
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
from .human_input_node import HumanInputNode
|
||||
|
||||
__all__ = ["HumanInputNode"]
|
||||
"""
|
||||
Human Input node implementation.
|
||||
"""
|
||||
|
||||
@@ -1,10 +1,350 @@
|
||||
from pydantic import Field
|
||||
"""
|
||||
Human Input node entities.
|
||||
"""
|
||||
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime, timedelta
|
||||
from typing import Annotated, Any, ClassVar, Literal, Self
|
||||
|
||||
from pydantic import BaseModel, Field, field_validator, model_validator
|
||||
|
||||
from core.variables.consts import SELECTORS_LENGTH
|
||||
from core.workflow.nodes.base import BaseNodeData
|
||||
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
|
||||
from core.workflow.runtime import VariablePool
|
||||
|
||||
from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit
|
||||
|
||||
_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P<field_name>[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}")
|
||||
|
||||
|
||||
class _WebAppDeliveryConfig(BaseModel):
|
||||
"""Configuration for webapp delivery method."""
|
||||
|
||||
pass # Empty for webapp delivery
|
||||
|
||||
|
||||
class MemberRecipient(BaseModel):
|
||||
"""Member recipient for email delivery."""
|
||||
|
||||
type: Literal[EmailRecipientType.MEMBER] = EmailRecipientType.MEMBER
|
||||
user_id: str
|
||||
|
||||
|
||||
class ExternalRecipient(BaseModel):
|
||||
"""External recipient for email delivery."""
|
||||
|
||||
type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL
|
||||
email: str
|
||||
|
||||
|
||||
EmailRecipient = Annotated[MemberRecipient | ExternalRecipient, Field(discriminator="type")]
|
||||
|
||||
|
||||
class EmailRecipients(BaseModel):
|
||||
"""Email recipients configuration."""
|
||||
|
||||
# When true, recipients are the union of all workspace members and external items.
|
||||
# Member items are ignored because they are already covered by the workspace scope.
|
||||
# De-duplication is applied by email, with member recipients taking precedence.
|
||||
whole_workspace: bool = False
|
||||
items: list[EmailRecipient] = Field(default_factory=list)
|
||||
|
||||
|
||||
class EmailDeliveryConfig(BaseModel):
|
||||
"""Configuration for email delivery method."""
|
||||
|
||||
URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}"
|
||||
|
||||
recipients: EmailRecipients
|
||||
|
||||
# the subject of email
|
||||
subject: str
|
||||
|
||||
# Body is the content of email.It may contain the speical placeholder `{{#url#}}`, which
|
||||
# represent the url to submit the form.
|
||||
#
|
||||
# It may also reference the output variable of the previous node with the syntax
|
||||
# `{{#<node_id>.<field_name>#}}`.
|
||||
body: str
|
||||
debug_mode: bool = False
|
||||
|
||||
def with_debug_recipient(self, user_id: str) -> "EmailDeliveryConfig":
|
||||
if not user_id:
|
||||
debug_recipients = EmailRecipients(whole_workspace=False, items=[])
|
||||
return self.model_copy(update={"recipients": debug_recipients})
|
||||
debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)])
|
||||
return self.model_copy(update={"recipients": debug_recipients})
|
||||
|
||||
@classmethod
|
||||
def replace_url_placeholder(cls, body: str, url: str | None) -> str:
|
||||
"""Replace the url placeholder with provided value."""
|
||||
return body.replace(cls.URL_PLACEHOLDER, url or "")
|
||||
|
||||
@classmethod
|
||||
def render_body_template(
|
||||
cls,
|
||||
*,
|
||||
body: str,
|
||||
url: str | None,
|
||||
variable_pool: VariablePool | None = None,
|
||||
) -> str:
|
||||
"""Render email body by replacing placeholders with runtime values."""
|
||||
templated_body = cls.replace_url_placeholder(body, url)
|
||||
if variable_pool is None:
|
||||
return templated_body
|
||||
return variable_pool.convert_template(templated_body).text
|
||||
|
||||
|
||||
class _DeliveryMethodBase(BaseModel):
|
||||
"""Base delivery method configuration."""
|
||||
|
||||
enabled: bool = True
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4)
|
||||
|
||||
def extract_variable_selectors(self) -> Sequence[Sequence[str]]:
|
||||
return ()
|
||||
|
||||
|
||||
class WebAppDeliveryMethod(_DeliveryMethodBase):
|
||||
"""Webapp delivery method configuration."""
|
||||
|
||||
type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP
|
||||
# The config field is not used currently.
|
||||
config: _WebAppDeliveryConfig = Field(default_factory=_WebAppDeliveryConfig)
|
||||
|
||||
|
||||
class EmailDeliveryMethod(_DeliveryMethodBase):
|
||||
"""Email delivery method configuration."""
|
||||
|
||||
type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL
|
||||
config: EmailDeliveryConfig
|
||||
|
||||
def extract_variable_selectors(self) -> Sequence[Sequence[str]]:
|
||||
variable_template_parser = VariableTemplateParser(template=self.config.body)
|
||||
selectors: list[Sequence[str]] = []
|
||||
for variable_selector in variable_template_parser.extract_variable_selectors():
|
||||
value_selector = list(variable_selector.value_selector)
|
||||
if len(value_selector) < SELECTORS_LENGTH:
|
||||
continue
|
||||
selectors.append(value_selector[:SELECTORS_LENGTH])
|
||||
return selectors
|
||||
|
||||
|
||||
DeliveryChannelConfig = Annotated[WebAppDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")]
|
||||
|
||||
|
||||
def apply_debug_email_recipient(
|
||||
method: DeliveryChannelConfig,
|
||||
*,
|
||||
enabled: bool,
|
||||
user_id: str,
|
||||
) -> DeliveryChannelConfig:
|
||||
if not enabled:
|
||||
return method
|
||||
if not isinstance(method, EmailDeliveryMethod):
|
||||
return method
|
||||
if not method.config.debug_mode:
|
||||
return method
|
||||
debug_config = method.config.with_debug_recipient(user_id or "")
|
||||
return method.model_copy(update={"config": debug_config})
|
||||
|
||||
|
||||
class FormInputDefault(BaseModel):
|
||||
"""Default configuration for form inputs."""
|
||||
|
||||
# NOTE: Ideally, a discriminated union would be used to model
|
||||
# FormInputDefault. However, the UI requires preserving the previous
|
||||
# value when switching between `VARIABLE` and `CONSTANT` types. This
|
||||
# necessitates retaining all fields, making a discriminated union unsuitable.
|
||||
|
||||
type: PlaceholderType
|
||||
|
||||
# The selector of default variable, used when `type` is `VARIABLE`.
|
||||
selector: Sequence[str] = Field(default_factory=tuple) #
|
||||
|
||||
# The value of the default, used when `type` is `CONSTANT`.
|
||||
# TODO: How should we express JSON values?
|
||||
value: str = ""
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _validate_selector(self) -> Self:
|
||||
if self.type == PlaceholderType.CONSTANT:
|
||||
return self
|
||||
if len(self.selector) < SELECTORS_LENGTH:
|
||||
raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}")
|
||||
return self
|
||||
|
||||
|
||||
class FormInput(BaseModel):
|
||||
"""Form input definition."""
|
||||
|
||||
type: FormInputType
|
||||
output_variable_name: str
|
||||
default: FormInputDefault | None = None
|
||||
|
||||
|
||||
_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
|
||||
|
||||
|
||||
class UserAction(BaseModel):
|
||||
"""User action configuration."""
|
||||
|
||||
# id is the identifier for this action.
|
||||
# It also serves as the identifiers of output handle.
|
||||
#
|
||||
# The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.)
|
||||
id: str = Field(max_length=20)
|
||||
title: str = Field(max_length=20)
|
||||
button_style: ButtonStyle = ButtonStyle.DEFAULT
|
||||
|
||||
@field_validator("id")
|
||||
@classmethod
|
||||
def _validate_id(cls, value: str) -> str:
|
||||
if not _IDENTIFIER_PATTERN.match(value):
|
||||
raise ValueError(
|
||||
f"'{value}' is not a valid identifier. It must start with a letter or underscore, "
|
||||
f"and contain only letters, numbers, or underscores."
|
||||
)
|
||||
return value
|
||||
|
||||
|
||||
class HumanInputNodeData(BaseNodeData):
|
||||
"""Configuration schema for the HumanInput node."""
|
||||
"""Human Input node data."""
|
||||
|
||||
required_variables: list[str] = Field(default_factory=list)
|
||||
pause_reason: str | None = Field(default=None)
|
||||
delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list)
|
||||
form_content: str = ""
|
||||
inputs: list[FormInput] = Field(default_factory=list)
|
||||
user_actions: list[UserAction] = Field(default_factory=list)
|
||||
timeout: int = 36
|
||||
timeout_unit: TimeoutUnit = TimeoutUnit.HOUR
|
||||
|
||||
@field_validator("inputs")
|
||||
@classmethod
|
||||
def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]:
|
||||
seen_names: set[str] = set()
|
||||
for form_input in inputs:
|
||||
name = form_input.output_variable_name
|
||||
if name in seen_names:
|
||||
raise ValueError(f"duplicated output_variable_name '{name}' in inputs")
|
||||
seen_names.add(name)
|
||||
return inputs
|
||||
|
||||
@field_validator("user_actions")
|
||||
@classmethod
|
||||
def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]:
|
||||
seen_ids: set[str] = set()
|
||||
for action in user_actions:
|
||||
action_id = action.id
|
||||
if action_id in seen_ids:
|
||||
raise ValueError(f"duplicated user action id '{action_id}'")
|
||||
seen_ids.add(action_id)
|
||||
return user_actions
|
||||
|
||||
def is_webapp_enabled(self) -> bool:
|
||||
for dm in self.delivery_methods:
|
||||
if not dm.enabled:
|
||||
continue
|
||||
if dm.type == DeliveryMethodType.WEBAPP:
|
||||
return True
|
||||
return False
|
||||
|
||||
def expiration_time(self, start_time: datetime) -> datetime:
|
||||
if self.timeout_unit == TimeoutUnit.HOUR:
|
||||
return start_time + timedelta(hours=self.timeout)
|
||||
elif self.timeout_unit == TimeoutUnit.DAY:
|
||||
return start_time + timedelta(days=self.timeout)
|
||||
else:
|
||||
raise AssertionError("unknown timeout unit.")
|
||||
|
||||
def outputs_field_names(self) -> Sequence[str]:
|
||||
field_names = []
|
||||
for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content):
|
||||
field_names.append(match.group("field_name"))
|
||||
return field_names
|
||||
|
||||
def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]:
|
||||
variable_mappings: dict[str, Sequence[str]] = {}
|
||||
|
||||
def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None:
|
||||
for selector in selectors:
|
||||
if len(selector) < SELECTORS_LENGTH:
|
||||
continue
|
||||
qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#"
|
||||
variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH])
|
||||
|
||||
form_template_parser = VariableTemplateParser(template=self.form_content)
|
||||
_add_variable_selectors(
|
||||
[selector.value_selector for selector in form_template_parser.extract_variable_selectors()]
|
||||
)
|
||||
for delivery_method in self.delivery_methods:
|
||||
if not delivery_method.enabled:
|
||||
continue
|
||||
_add_variable_selectors(delivery_method.extract_variable_selectors())
|
||||
|
||||
for input in self.inputs:
|
||||
default_value = input.default
|
||||
if default_value is None:
|
||||
continue
|
||||
if default_value.type == PlaceholderType.CONSTANT:
|
||||
continue
|
||||
default_value_key = ".".join(default_value.selector)
|
||||
qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#"
|
||||
variable_mappings[qualified_variable_mapping_key] = default_value.selector
|
||||
|
||||
return variable_mappings
|
||||
|
||||
def find_action_text(self, action_id: str) -> str:
|
||||
"""
|
||||
Resolve action display text by id.
|
||||
"""
|
||||
for action in self.user_actions:
|
||||
if action.id == action_id:
|
||||
return action.title
|
||||
return action_id
|
||||
|
||||
|
||||
class FormDefinition(BaseModel):
|
||||
form_content: str
|
||||
inputs: list[FormInput] = Field(default_factory=list)
|
||||
user_actions: list[UserAction] = Field(default_factory=list)
|
||||
rendered_content: str
|
||||
expiration_time: datetime
|
||||
|
||||
# this is used to store the resolved default values
|
||||
default_values: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# node_title records the title of the HumanInput node.
|
||||
node_title: str | None = None
|
||||
|
||||
# display_in_ui controls whether the form should be displayed in UI surfaces.
|
||||
display_in_ui: bool | None = None
|
||||
|
||||
|
||||
class HumanInputSubmissionValidationError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
def validate_human_input_submission(
|
||||
*,
|
||||
inputs: Sequence[FormInput],
|
||||
user_actions: Sequence[UserAction],
|
||||
selected_action_id: str,
|
||||
form_data: Mapping[str, Any],
|
||||
) -> None:
|
||||
available_actions = {action.id for action in user_actions}
|
||||
if selected_action_id not in available_actions:
|
||||
raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}")
|
||||
|
||||
provided_inputs = set(form_data.keys())
|
||||
missing_inputs = [
|
||||
form_input.output_variable_name
|
||||
for form_input in inputs
|
||||
if form_input.output_variable_name not in provided_inputs
|
||||
]
|
||||
|
||||
if missing_inputs:
|
||||
missing_list = ", ".join(missing_inputs)
|
||||
raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}")
|
||||
|
||||
72
api/core/workflow/nodes/human_input/enums.py
Normal file
72
api/core/workflow/nodes/human_input/enums.py
Normal file
@@ -0,0 +1,72 @@
|
||||
import enum
|
||||
|
||||
|
||||
class HumanInputFormStatus(enum.StrEnum):
|
||||
"""Status of a human input form."""
|
||||
|
||||
# Awaiting submission from any recipient. Forms stay in this state until
|
||||
# submitted or a timeout rule applies.
|
||||
WAITING = enum.auto()
|
||||
# Global timeout reached. The workflow run is stopped and will not resume.
|
||||
# This is distinct from node-level timeout.
|
||||
EXPIRED = enum.auto()
|
||||
# Submitted by a recipient; form data is available and execution resumes
|
||||
# along the selected action edge.
|
||||
SUBMITTED = enum.auto()
|
||||
# Node-level timeout reached. The human input node should emit a timeout
|
||||
# event and the workflow should resume along the timeout edge.
|
||||
TIMEOUT = enum.auto()
|
||||
|
||||
|
||||
class HumanInputFormKind(enum.StrEnum):
|
||||
"""Kind of a human input form."""
|
||||
|
||||
RUNTIME = enum.auto() # Form created during workflow execution.
|
||||
DELIVERY_TEST = enum.auto() # Form created for delivery tests.
|
||||
|
||||
|
||||
class DeliveryMethodType(enum.StrEnum):
|
||||
"""Delivery method types for human input forms."""
|
||||
|
||||
# WEBAPP controls whether the form is delivered to the web app. It not only controls
|
||||
# the standalone web app, but also controls the installed apps in the console.
|
||||
WEBAPP = enum.auto()
|
||||
|
||||
EMAIL = enum.auto()
|
||||
|
||||
|
||||
class ButtonStyle(enum.StrEnum):
|
||||
"""Button styles for user actions."""
|
||||
|
||||
PRIMARY = enum.auto()
|
||||
DEFAULT = enum.auto()
|
||||
ACCENT = enum.auto()
|
||||
GHOST = enum.auto()
|
||||
|
||||
|
||||
class TimeoutUnit(enum.StrEnum):
|
||||
"""Timeout unit for form expiration."""
|
||||
|
||||
HOUR = enum.auto()
|
||||
DAY = enum.auto()
|
||||
|
||||
|
||||
class FormInputType(enum.StrEnum):
|
||||
"""Form input types."""
|
||||
|
||||
TEXT_INPUT = enum.auto()
|
||||
PARAGRAPH = enum.auto()
|
||||
|
||||
|
||||
class PlaceholderType(enum.StrEnum):
|
||||
"""Default value types for form inputs."""
|
||||
|
||||
VARIABLE = enum.auto()
|
||||
CONSTANT = enum.auto()
|
||||
|
||||
|
||||
class EmailRecipientType(enum.StrEnum):
|
||||
"""Email recipient types."""
|
||||
|
||||
MEMBER = enum.auto()
|
||||
EXTERNAL = enum.auto()
|
||||
@@ -1,12 +1,42 @@
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
|
||||
from core.workflow.node_events import (
|
||||
HumanInputFormFilledEvent,
|
||||
HumanInputFormTimeoutEvent,
|
||||
NodeRunResult,
|
||||
PauseRequestedEvent,
|
||||
)
|
||||
from core.workflow.node_events.base import NodeEventBase
|
||||
from core.workflow.node_events.node import StreamCompletedEvent
|
||||
from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRepository,
|
||||
)
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
|
||||
from .entities import HumanInputNodeData
|
||||
from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient
|
||||
from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.graph_init_params import GraphInitParams
|
||||
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
|
||||
|
||||
|
||||
_SELECTED_BRANCH_KEY = "selected_branch"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class HumanInputNode(Node[HumanInputNodeData]):
|
||||
@@ -17,7 +47,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
"edge_source_handle",
|
||||
"edgeSourceHandle",
|
||||
"source_handle",
|
||||
"selected_branch",
|
||||
_SELECTED_BRANCH_KEY,
|
||||
"selectedBranch",
|
||||
"branch",
|
||||
"branch_id",
|
||||
@@ -25,43 +55,37 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
"handle",
|
||||
)
|
||||
|
||||
_node_data: HumanInputNodeData
|
||||
_form_repository: HumanInputFormRepository
|
||||
_OUTPUT_FIELD_ACTION_ID = "__action_id"
|
||||
_OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content"
|
||||
_TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
id: str,
|
||||
config: Mapping[str, Any],
|
||||
graph_init_params: "GraphInitParams",
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
form_repository: HumanInputFormRepository | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
config=config,
|
||||
graph_init_params=graph_init_params,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
if form_repository is None:
|
||||
form_repository = HumanInputFormRepositoryImpl(
|
||||
session_factory=db.engine,
|
||||
tenant_id=self.tenant_id,
|
||||
)
|
||||
self._form_repository = form_repository
|
||||
|
||||
@classmethod
|
||||
def version(cls) -> str:
|
||||
return "1"
|
||||
|
||||
def _run(self): # type: ignore[override]
|
||||
if self._is_completion_ready():
|
||||
branch_handle = self._resolve_branch_selection()
|
||||
return NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={},
|
||||
edge_source_handle=branch_handle or "source",
|
||||
)
|
||||
|
||||
return self._pause_generator()
|
||||
|
||||
def _pause_generator(self):
|
||||
# TODO(QuantumGhost): yield a real form id.
|
||||
yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id))
|
||||
|
||||
def _is_completion_ready(self) -> bool:
|
||||
"""Determine whether all required inputs are satisfied."""
|
||||
|
||||
if not self.node_data.required_variables:
|
||||
return False
|
||||
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
|
||||
for selector_str in self.node_data.required_variables:
|
||||
parts = selector_str.split(".")
|
||||
if len(parts) != 2:
|
||||
return False
|
||||
segment = variable_pool.get(parts)
|
||||
if segment is None:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def _resolve_branch_selection(self) -> str | None:
|
||||
"""Determine the branch handle selected by human input if available."""
|
||||
|
||||
@@ -108,3 +132,224 @@ class HumanInputNode(Node[HumanInputNodeData]):
|
||||
return candidate
|
||||
|
||||
return None
|
||||
|
||||
@property
|
||||
def _workflow_execution_id(self) -> str:
|
||||
workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
|
||||
assert workflow_exec_id is not None
|
||||
return workflow_exec_id
|
||||
|
||||
def _form_to_pause_event(self, form_entity: HumanInputFormEntity):
|
||||
required_event = self._human_input_required_event(form_entity)
|
||||
pause_requested_event = PauseRequestedEvent(reason=required_event)
|
||||
return pause_requested_event
|
||||
|
||||
def resolve_default_values(self) -> Mapping[str, Any]:
|
||||
variable_pool = self.graph_runtime_state.variable_pool
|
||||
resolved_defaults: dict[str, Any] = {}
|
||||
for input in self._node_data.inputs:
|
||||
if (default_value := input.default) is None:
|
||||
continue
|
||||
if default_value.type == PlaceholderType.CONSTANT:
|
||||
continue
|
||||
resolved_value = variable_pool.get(default_value.selector)
|
||||
if resolved_value is None:
|
||||
# TODO: How should we handle this?
|
||||
continue
|
||||
resolved_defaults[input.output_variable_name] = (
|
||||
WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(resolved_value.value)
|
||||
)
|
||||
|
||||
return resolved_defaults
|
||||
|
||||
def _should_require_console_recipient(self) -> bool:
|
||||
if self.invoke_from == InvokeFrom.DEBUGGER:
|
||||
return True
|
||||
if self.invoke_from == InvokeFrom.EXPLORE:
|
||||
return self._node_data.is_webapp_enabled()
|
||||
return False
|
||||
|
||||
def _display_in_ui(self) -> bool:
|
||||
if self.invoke_from == InvokeFrom.DEBUGGER:
|
||||
return True
|
||||
return self._node_data.is_webapp_enabled()
|
||||
|
||||
def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]:
|
||||
enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled]
|
||||
if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
|
||||
enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP]
|
||||
return [
|
||||
apply_debug_email_recipient(
|
||||
method,
|
||||
enabled=self.invoke_from == InvokeFrom.DEBUGGER,
|
||||
user_id=self.user_id or "",
|
||||
)
|
||||
for method in enabled_methods
|
||||
]
|
||||
|
||||
def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired:
|
||||
node_data = self._node_data
|
||||
resolved_default_values = self.resolve_default_values()
|
||||
display_in_ui = self._display_in_ui()
|
||||
form_token = form_entity.web_app_token
|
||||
if display_in_ui and form_token is None:
|
||||
raise AssertionError("Form token should be available for UI execution.")
|
||||
return HumanInputRequired(
|
||||
form_id=form_entity.id,
|
||||
form_content=form_entity.rendered_content,
|
||||
inputs=node_data.inputs,
|
||||
actions=node_data.user_actions,
|
||||
display_in_ui=display_in_ui,
|
||||
node_id=self.id,
|
||||
node_title=node_data.title,
|
||||
form_token=form_token,
|
||||
resolved_default_values=resolved_default_values,
|
||||
)
|
||||
|
||||
def _run(self) -> Generator[NodeEventBase, None, None]:
|
||||
"""
|
||||
Execute the human input node.
|
||||
|
||||
This method will:
|
||||
1. Generate a unique form ID
|
||||
2. Create form content with variable substitution
|
||||
3. Create form in database
|
||||
4. Send form via configured delivery methods
|
||||
5. Suspend workflow execution
|
||||
6. Wait for form submission to resume
|
||||
"""
|
||||
repo = self._form_repository
|
||||
form = repo.get_form(self._workflow_execution_id, self.id)
|
||||
if form is None:
|
||||
display_in_ui = self._display_in_ui()
|
||||
params = FormCreateParams(
|
||||
app_id=self.app_id,
|
||||
workflow_execution_id=self._workflow_execution_id,
|
||||
node_id=self.id,
|
||||
form_config=self._node_data,
|
||||
rendered_content=self.render_form_content_before_submission(),
|
||||
delivery_methods=self._effective_delivery_methods(),
|
||||
display_in_ui=display_in_ui,
|
||||
resolved_default_values=self.resolve_default_values(),
|
||||
console_recipient_required=self._should_require_console_recipient(),
|
||||
console_creator_account_id=(
|
||||
self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None
|
||||
),
|
||||
backstage_recipient_required=True,
|
||||
)
|
||||
form_entity = self._form_repository.create_form(params)
|
||||
# Create human input required event
|
||||
|
||||
logger.info(
|
||||
"Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s",
|
||||
self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id,
|
||||
self.id,
|
||||
form_entity.id,
|
||||
)
|
||||
yield self._form_to_pause_event(form_entity)
|
||||
return
|
||||
|
||||
if (
|
||||
form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}
|
||||
or form.expiration_time <= naive_utc_now()
|
||||
):
|
||||
yield HumanInputFormTimeoutEvent(
|
||||
node_title=self._node_data.title,
|
||||
expiration_time=form.expiration_time,
|
||||
)
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs={self._OUTPUT_FIELD_ACTION_ID: ""},
|
||||
edge_source_handle=self._TIMEOUT_HANDLE,
|
||||
)
|
||||
)
|
||||
return
|
||||
|
||||
if not form.submitted:
|
||||
yield self._form_to_pause_event(form)
|
||||
return
|
||||
|
||||
selected_action_id = form.selected_action_id
|
||||
if selected_action_id is None:
|
||||
raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}")
|
||||
submitted_data = form.submitted_data or {}
|
||||
outputs: dict[str, Any] = dict(submitted_data)
|
||||
outputs[self._OUTPUT_FIELD_ACTION_ID] = selected_action_id
|
||||
rendered_content = self.render_form_content_with_outputs(
|
||||
form.rendered_content,
|
||||
outputs,
|
||||
self._node_data.outputs_field_names(),
|
||||
)
|
||||
outputs[self._OUTPUT_FIELD_RENDERED_CONTENT] = rendered_content
|
||||
|
||||
action_text = self._node_data.find_action_text(selected_action_id)
|
||||
|
||||
yield HumanInputFormFilledEvent(
|
||||
node_title=self._node_data.title,
|
||||
rendered_content=rendered_content,
|
||||
action_id=selected_action_id,
|
||||
action_text=action_text,
|
||||
)
|
||||
|
||||
yield StreamCompletedEvent(
|
||||
node_run_result=NodeRunResult(
|
||||
status=WorkflowNodeExecutionStatus.SUCCEEDED,
|
||||
outputs=outputs,
|
||||
edge_source_handle=selected_action_id,
|
||||
)
|
||||
)
|
||||
|
||||
def render_form_content_before_submission(self) -> str:
|
||||
"""
|
||||
Process form content by substituting variables.
|
||||
|
||||
This method should:
|
||||
1. Parse the form_content markdown
|
||||
2. Substitute {{#node_name.var_name#}} with actual values
|
||||
3. Keep {{#$output.field_name#}} placeholders for form inputs
|
||||
"""
|
||||
rendered_form_content = self.graph_runtime_state.variable_pool.convert_template(
|
||||
self._node_data.form_content,
|
||||
)
|
||||
return rendered_form_content.markdown
|
||||
|
||||
@staticmethod
|
||||
def render_form_content_with_outputs(
|
||||
form_content: str,
|
||||
outputs: Mapping[str, Any],
|
||||
field_names: Sequence[str],
|
||||
) -> str:
|
||||
"""
|
||||
Replace {{#$output.xxx#}} placeholders with submitted values.
|
||||
"""
|
||||
rendered_content = form_content
|
||||
for field_name in field_names:
|
||||
placeholder = "{{#$output." + field_name + "#}}"
|
||||
value = outputs.get(field_name)
|
||||
if value is None:
|
||||
replacement = ""
|
||||
elif isinstance(value, (dict, list)):
|
||||
replacement = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
replacement = str(value)
|
||||
rendered_content = rendered_content.replace(placeholder, replacement)
|
||||
return rendered_content
|
||||
|
||||
@classmethod
|
||||
def _extract_variable_selector_to_variable_mapping(
|
||||
cls,
|
||||
*,
|
||||
graph_config: Mapping[str, Any],
|
||||
node_id: str,
|
||||
node_data: Mapping[str, Any],
|
||||
) -> Mapping[str, Sequence[str]]:
|
||||
"""
|
||||
Extract variable selectors referenced in form content and input default values.
|
||||
|
||||
This method should parse:
|
||||
1. Variables referenced in form_content ({{#node_name.var_name#}})
|
||||
2. Variables referenced in input default values
|
||||
"""
|
||||
validated_node_data = HumanInputNodeData.model_validate(node_data)
|
||||
return validated_node_data.extract_variable_selector_to_variable_mapping(node_id)
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from configs import dify_config
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
|
||||
from core.workflow.node_events import NodeRunResult
|
||||
from core.workflow.nodes.base.node import Node
|
||||
@@ -16,12 +15,13 @@ if TYPE_CHECKING:
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
|
||||
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000
|
||||
|
||||
|
||||
class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
node_type = NodeType.TEMPLATE_TRANSFORM
|
||||
_template_renderer: Jinja2TemplateRenderer
|
||||
_max_output_length: int
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -31,6 +31,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
graph_runtime_state: "GraphRuntimeState",
|
||||
*,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
max_output_length: int | None = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
id=id,
|
||||
@@ -40,6 +41,10 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
|
||||
if max_output_length is not None and max_output_length <= 0:
|
||||
raise ValueError("max_output_length must be a positive integer")
|
||||
self._max_output_length = max_output_length or DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH
|
||||
|
||||
@classmethod
|
||||
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
|
||||
"""
|
||||
@@ -69,11 +74,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
|
||||
except TemplateRenderError as e:
|
||||
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
|
||||
|
||||
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
|
||||
if len(rendered) > self._max_output_length:
|
||||
return NodeRunResult(
|
||||
inputs=variables,
|
||||
status=WorkflowNodeExecutionStatus.FAILED,
|
||||
error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters",
|
||||
error=f"Output length exceeds {self._max_output_length} characters",
|
||||
)
|
||||
|
||||
return NodeRunResult(
|
||||
|
||||
152
api/core/workflow/repositories/human_input_form_repository.py
Normal file
152
api/core/workflow/repositories/human_input_form_repository.py
Normal file
@@ -0,0 +1,152 @@
|
||||
import abc
|
||||
import dataclasses
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any, Protocol
|
||||
|
||||
from core.workflow.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData
|
||||
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
|
||||
|
||||
|
||||
class HumanInputError(Exception):
|
||||
pass
|
||||
|
||||
|
||||
class FormNotFoundError(HumanInputError):
|
||||
pass
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class FormCreateParams:
|
||||
# app_id is the identifier for the app that the form belongs to.
|
||||
# It is a string with uuid format.
|
||||
app_id: str
|
||||
# None when creating a delivery test form; set for runtime forms.
|
||||
workflow_execution_id: str | None
|
||||
|
||||
# node_id is the identifier for a specific
|
||||
# node in the graph.
|
||||
#
|
||||
# TODO: for node inside loop / iteration, this would
|
||||
# cause problems, as a single node may be executed multiple times.
|
||||
node_id: str
|
||||
|
||||
form_config: HumanInputNodeData
|
||||
rendered_content: str
|
||||
# Delivery methods already filtered by runtime context (invoke_from).
|
||||
delivery_methods: Sequence[DeliveryChannelConfig]
|
||||
# UI display flag computed by runtime context.
|
||||
display_in_ui: bool
|
||||
|
||||
# resolved_default_values saves the values for defaults with
|
||||
# type = VARIABLE.
|
||||
#
|
||||
# For type = CONSTANT, the value is not stored inside `resolved_default_values`
|
||||
resolved_default_values: Mapping[str, Any]
|
||||
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
|
||||
|
||||
# Force creating a console-only recipient for submission in Console.
|
||||
console_recipient_required: bool = False
|
||||
console_creator_account_id: str | None = None
|
||||
# Force creating a backstage recipient for submission in Console.
|
||||
backstage_recipient_required: bool = False
|
||||
|
||||
|
||||
class HumanInputFormEntity(abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def id(self) -> str:
|
||||
"""id returns the identifer of the form."""
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def web_app_token(self) -> str | None:
|
||||
"""web_app_token returns the token for submission inside webapp.
|
||||
|
||||
For console/debug execution, this may point to the console submission token
|
||||
if the form is configured to require console delivery.
|
||||
"""
|
||||
|
||||
# TODO: what if the users are allowed to add multiple
|
||||
# webapp delivery?
|
||||
pass
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def recipients(self) -> list["HumanInputFormRecipientEntity"]: ...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def rendered_content(self) -> str:
|
||||
"""Rendered markdown content associated with the form."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def selected_action_id(self) -> str | None:
|
||||
"""Identifier of the selected user action if the form has been submitted."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def submitted_data(self) -> Mapping[str, Any] | None:
|
||||
"""Submitted form data if available."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def submitted(self) -> bool:
|
||||
"""Whether the form has been submitted."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def status(self) -> HumanInputFormStatus:
|
||||
"""Current status of the form."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def expiration_time(self) -> datetime:
|
||||
"""When the form expires."""
|
||||
...
|
||||
|
||||
|
||||
class HumanInputFormRecipientEntity(abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def id(self) -> str:
|
||||
"""id returns the identifer of this recipient."""
|
||||
...
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def token(self) -> str:
|
||||
"""token returns a random string used to submit form"""
|
||||
...
|
||||
|
||||
|
||||
class HumanInputFormRepository(Protocol):
|
||||
"""
|
||||
Repository interface for HumanInputForm.
|
||||
|
||||
This interface defines the contract for accessing and manipulating
|
||||
HumanInputForm data, regardless of the underlying storage mechanism.
|
||||
|
||||
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
|
||||
and other implementation details should be handled at the implementation level, not in
|
||||
the core interface. This keeps the core domain model clean and independent of specific
|
||||
application domains or deployment scenarios.
|
||||
"""
|
||||
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
"""Get the form created for a given human input node in a workflow execution. Returns
|
||||
`None` if the form has not been created yet."""
|
||||
...
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
"""
|
||||
Create a human input form from form definition.
|
||||
"""
|
||||
...
|
||||
@@ -6,15 +6,18 @@ import threading
|
||||
from collections.abc import Mapping, Sequence
|
||||
from copy import deepcopy
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, ClassVar, Protocol
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Protocol
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic.json import pydantic_encoder
|
||||
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
|
||||
from core.workflow.runtime.variable_pool import VariablePool
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
|
||||
|
||||
class ReadyQueueProtocol(Protocol):
|
||||
"""Structural interface required from ready queue implementations."""
|
||||
@@ -61,7 +64,7 @@ class GraphExecutionProtocol(Protocol):
|
||||
aborted: bool
|
||||
error: Exception | None
|
||||
exceptions_count: int
|
||||
pause_reasons: list[PauseReason]
|
||||
pause_reasons: Sequence[PauseReason]
|
||||
|
||||
def start(self) -> None:
|
||||
"""Transition execution into the running state."""
|
||||
@@ -133,6 +136,13 @@ class GraphProtocol(Protocol):
|
||||
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
|
||||
|
||||
|
||||
class _GraphStateSnapshot(BaseModel):
|
||||
"""Serializable graph state snapshot for node/edge states."""
|
||||
|
||||
nodes: dict[str, NodeState] = Field(default_factory=dict)
|
||||
edges: dict[str, NodeState] = Field(default_factory=dict)
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _GraphRuntimeStateSnapshot:
|
||||
"""Immutable view of a serialized runtime state snapshot."""
|
||||
@@ -148,10 +158,20 @@ class _GraphRuntimeStateSnapshot:
|
||||
graph_execution_dump: str | None
|
||||
response_coordinator_dump: str | None
|
||||
paused_nodes: tuple[str, ...]
|
||||
deferred_nodes: tuple[str, ...]
|
||||
graph_node_states: dict[str, NodeState]
|
||||
graph_edge_states: dict[str, NodeState]
|
||||
|
||||
|
||||
class GraphRuntimeState:
|
||||
"""Mutable runtime state shared across graph execution components."""
|
||||
"""Mutable runtime state shared across graph execution components.
|
||||
|
||||
`GraphRuntimeState` encapsulates the runtime state of workflow execution,
|
||||
including scheduling details, variable values, and timing information.
|
||||
|
||||
Values that are initialized prior to workflow execution and remain constant
|
||||
throughout the execution should be part of `GraphInitParams` instead.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -189,6 +209,16 @@ class GraphRuntimeState:
|
||||
self._pending_response_coordinator_dump: str | None = None
|
||||
self._pending_graph_execution_workflow_id: str | None = None
|
||||
self._paused_nodes: set[str] = set()
|
||||
self._deferred_nodes: set[str] = set()
|
||||
|
||||
# Node and edges states needed to be restored into
|
||||
# graph object.
|
||||
#
|
||||
# These two fields are non-None only when resuming from a snapshot.
|
||||
# Once the graph is attached, these two fields will be set to None.
|
||||
self._pending_graph_node_states: dict[str, NodeState] | None = None
|
||||
self._pending_graph_edge_states: dict[str, NodeState] | None = None
|
||||
|
||||
self.stop_event: threading.Event = threading.Event()
|
||||
|
||||
if graph is not None:
|
||||
@@ -210,6 +240,7 @@ class GraphRuntimeState:
|
||||
if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None:
|
||||
self._response_coordinator.loads(self._pending_response_coordinator_dump)
|
||||
self._pending_response_coordinator_dump = None
|
||||
self._apply_pending_graph_state()
|
||||
|
||||
def configure(self, *, graph: GraphProtocol | None = None) -> None:
|
||||
"""Ensure core collaborators are initialized with the provided context."""
|
||||
@@ -331,8 +362,13 @@ class GraphRuntimeState:
|
||||
"ready_queue": self.ready_queue.dumps(),
|
||||
"graph_execution": self.graph_execution.dumps(),
|
||||
"paused_nodes": list(self._paused_nodes),
|
||||
"deferred_nodes": list(self._deferred_nodes),
|
||||
}
|
||||
|
||||
graph_state = self._snapshot_graph_state()
|
||||
if graph_state is not None:
|
||||
snapshot["graph_state"] = graph_state
|
||||
|
||||
if self._response_coordinator is not None and self._graph is not None:
|
||||
snapshot["response_coordinator"] = self._response_coordinator.dumps()
|
||||
|
||||
@@ -366,6 +402,11 @@ class GraphRuntimeState:
|
||||
|
||||
self._paused_nodes.add(node_id)
|
||||
|
||||
def get_paused_nodes(self) -> list[str]:
|
||||
"""Retrieve the list of paused nodes without mutating internal state."""
|
||||
|
||||
return list(self._paused_nodes)
|
||||
|
||||
def consume_paused_nodes(self) -> list[str]:
|
||||
"""Retrieve and clear the list of paused nodes awaiting resume."""
|
||||
|
||||
@@ -373,6 +414,23 @@ class GraphRuntimeState:
|
||||
self._paused_nodes.clear()
|
||||
return nodes
|
||||
|
||||
def register_deferred_node(self, node_id: str) -> None:
|
||||
"""Record a node that became ready during pause and should resume later."""
|
||||
|
||||
self._deferred_nodes.add(node_id)
|
||||
|
||||
def get_deferred_nodes(self) -> list[str]:
|
||||
"""Retrieve deferred nodes without mutating internal state."""
|
||||
|
||||
return list(self._deferred_nodes)
|
||||
|
||||
def consume_deferred_nodes(self) -> list[str]:
|
||||
"""Retrieve and clear deferred nodes awaiting resume."""
|
||||
|
||||
nodes = list(self._deferred_nodes)
|
||||
self._deferred_nodes.clear()
|
||||
return nodes
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Builders
|
||||
# ------------------------------------------------------------------
|
||||
@@ -434,6 +492,10 @@ class GraphRuntimeState:
|
||||
graph_execution_payload = payload.get("graph_execution")
|
||||
response_payload = payload.get("response_coordinator")
|
||||
paused_nodes_payload = payload.get("paused_nodes", [])
|
||||
deferred_nodes_payload = payload.get("deferred_nodes", [])
|
||||
graph_state_payload = payload.get("graph_state", {}) or {}
|
||||
graph_node_states = _coerce_graph_state_map(graph_state_payload, "nodes")
|
||||
graph_edge_states = _coerce_graph_state_map(graph_state_payload, "edges")
|
||||
|
||||
return _GraphRuntimeStateSnapshot(
|
||||
start_at=start_at,
|
||||
@@ -447,6 +509,9 @@ class GraphRuntimeState:
|
||||
graph_execution_dump=graph_execution_payload,
|
||||
response_coordinator_dump=response_payload,
|
||||
paused_nodes=tuple(map(str, paused_nodes_payload)),
|
||||
deferred_nodes=tuple(map(str, deferred_nodes_payload)),
|
||||
graph_node_states=graph_node_states,
|
||||
graph_edge_states=graph_edge_states,
|
||||
)
|
||||
|
||||
def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None:
|
||||
@@ -462,6 +527,10 @@ class GraphRuntimeState:
|
||||
self._restore_graph_execution(snapshot.graph_execution_dump)
|
||||
self._restore_response_coordinator(snapshot.response_coordinator_dump)
|
||||
self._paused_nodes = set(snapshot.paused_nodes)
|
||||
self._deferred_nodes = set(snapshot.deferred_nodes)
|
||||
self._pending_graph_node_states = snapshot.graph_node_states or None
|
||||
self._pending_graph_edge_states = snapshot.graph_edge_states or None
|
||||
self._apply_pending_graph_state()
|
||||
|
||||
def _restore_ready_queue(self, payload: str | None) -> None:
|
||||
if payload is not None:
|
||||
@@ -498,3 +567,68 @@ class GraphRuntimeState:
|
||||
|
||||
self._pending_response_coordinator_dump = payload
|
||||
self._response_coordinator = None
|
||||
|
||||
def _snapshot_graph_state(self) -> _GraphStateSnapshot:
|
||||
graph = self._graph
|
||||
if graph is None:
|
||||
if self._pending_graph_node_states is None and self._pending_graph_edge_states is None:
|
||||
return _GraphStateSnapshot()
|
||||
return _GraphStateSnapshot(
|
||||
nodes=self._pending_graph_node_states or {},
|
||||
edges=self._pending_graph_edge_states or {},
|
||||
)
|
||||
|
||||
nodes = graph.nodes
|
||||
edges = graph.edges
|
||||
if not isinstance(nodes, Mapping) or not isinstance(edges, Mapping):
|
||||
return _GraphStateSnapshot()
|
||||
|
||||
node_states = {}
|
||||
for node_id, node in nodes.items():
|
||||
if not isinstance(node_id, str):
|
||||
continue
|
||||
node_states[node_id] = node.state
|
||||
|
||||
edge_states = {}
|
||||
for edge_id, edge in edges.items():
|
||||
if not isinstance(edge_id, str):
|
||||
continue
|
||||
edge_states[edge_id] = edge.state
|
||||
|
||||
return _GraphStateSnapshot(nodes=node_states, edges=edge_states)
|
||||
|
||||
def _apply_pending_graph_state(self) -> None:
|
||||
if self._graph is None:
|
||||
return
|
||||
if self._pending_graph_node_states:
|
||||
for node_id, state in self._pending_graph_node_states.items():
|
||||
node = self._graph.nodes.get(node_id)
|
||||
if node is None:
|
||||
continue
|
||||
node.state = state
|
||||
if self._pending_graph_edge_states:
|
||||
for edge_id, state in self._pending_graph_edge_states.items():
|
||||
edge = self._graph.edges.get(edge_id)
|
||||
if edge is None:
|
||||
continue
|
||||
edge.state = state
|
||||
|
||||
self._pending_graph_node_states = None
|
||||
self._pending_graph_edge_states = None
|
||||
|
||||
|
||||
def _coerce_graph_state_map(payload: Any, key: str) -> dict[str, NodeState]:
|
||||
if not isinstance(payload, Mapping):
|
||||
return {}
|
||||
raw_map = payload.get(key, {})
|
||||
if not isinstance(raw_map, Mapping):
|
||||
return {}
|
||||
result: dict[str, NodeState] = {}
|
||||
for node_id, raw_state in raw_map.items():
|
||||
if not isinstance(node_id, str):
|
||||
continue
|
||||
try:
|
||||
result[node_id] = NodeState(str(raw_state))
|
||||
except ValueError:
|
||||
continue
|
||||
return result
|
||||
|
||||
@@ -15,12 +15,14 @@ class WorkflowRuntimeTypeConverter:
|
||||
def to_json_encodable(self, value: None) -> None: ...
|
||||
|
||||
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
|
||||
result = self._to_json_encodable_recursive(value)
|
||||
"""Convert runtime values to JSON-serializable structures."""
|
||||
|
||||
result = self.value_to_json_encodable_recursive(value)
|
||||
if isinstance(result, Mapping) or result is None:
|
||||
return result
|
||||
return {}
|
||||
|
||||
def _to_json_encodable_recursive(self, value: Any):
|
||||
def value_to_json_encodable_recursive(self, value: Any):
|
||||
if value is None:
|
||||
return value
|
||||
if isinstance(value, (bool, int, str, float)):
|
||||
@@ -29,7 +31,7 @@ class WorkflowRuntimeTypeConverter:
|
||||
# Convert Decimal to float for JSON serialization
|
||||
return float(value)
|
||||
if isinstance(value, Segment):
|
||||
return self._to_json_encodable_recursive(value.value)
|
||||
return self.value_to_json_encodable_recursive(value.value)
|
||||
if isinstance(value, File):
|
||||
return value.to_dict()
|
||||
if isinstance(value, BaseModel):
|
||||
@@ -37,11 +39,11 @@ class WorkflowRuntimeTypeConverter:
|
||||
if isinstance(value, dict):
|
||||
res = {}
|
||||
for k, v in value.items():
|
||||
res[k] = self._to_json_encodable_recursive(v)
|
||||
res[k] = self.value_to_json_encodable_recursive(v)
|
||||
return res
|
||||
if isinstance(value, list):
|
||||
res_list = []
|
||||
for item in value:
|
||||
res_list.append(self._to_json_encodable_recursive(item))
|
||||
res_list.append(self.value_to_json_encodable_recursive(item))
|
||||
return res_list
|
||||
return value
|
||||
|
||||
@@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
|
||||
if [[ -z "${CELERY_QUEUES}" ]]; then
|
||||
if [[ "${EDITION}" == "CLOUD" ]]; then
|
||||
# Cloud edition: separate queues for dataset and trigger tasks
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
else
|
||||
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
|
||||
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
|
||||
fi
|
||||
else
|
||||
DEFAULT_QUEUES="${CELERY_QUEUES}"
|
||||
@@ -102,7 +102,7 @@ elif [[ "${MODE}" == "job" ]]; then
|
||||
fi
|
||||
|
||||
echo "Running Flask job command: flask $*"
|
||||
|
||||
|
||||
# Temporarily disable exit on error to capture exit code
|
||||
set +e
|
||||
flask "$@"
|
||||
|
||||
@@ -151,6 +151,12 @@ def init_app(app: DifyApp) -> Celery:
|
||||
"task": "schedule.queue_monitor_task.queue_monitor_task",
|
||||
"schedule": timedelta(minutes=dify_config.QUEUE_MONITOR_INTERVAL or 30),
|
||||
}
|
||||
if dify_config.ENABLE_HUMAN_INPUT_TIMEOUT_TASK:
|
||||
imports.append("tasks.human_input_timeout_tasks")
|
||||
beat_schedule["human_input_form_timeout"] = {
|
||||
"task": "human_input_form_timeout.check_and_resume",
|
||||
"schedule": timedelta(minutes=dify_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL),
|
||||
}
|
||||
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED:
|
||||
imports.append("schedule.check_upgradable_plugin_task")
|
||||
imports.append("tasks.process_tenant_plugin_autoupgrade_check_task")
|
||||
|
||||
@@ -8,12 +8,16 @@ from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union
|
||||
import redis
|
||||
from redis import RedisError
|
||||
from redis.cache import CacheConfig
|
||||
from redis.client import PubSub
|
||||
from redis.cluster import ClusterNode, RedisCluster
|
||||
from redis.connection import Connection, SSLConnection
|
||||
from redis.sentinel import Sentinel
|
||||
|
||||
from configs import dify_config
|
||||
from dify_app import DifyApp
|
||||
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
|
||||
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
|
||||
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from redis.lock import Lock
|
||||
@@ -106,6 +110,7 @@ class RedisClientWrapper:
|
||||
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ...
|
||||
def zcard(self, name: str | bytes) -> Any: ...
|
||||
def getdel(self, name: str | bytes) -> Any: ...
|
||||
def pubsub(self) -> PubSub: ...
|
||||
|
||||
def __getattr__(self, item: str) -> Any:
|
||||
if self._client is None:
|
||||
@@ -114,6 +119,7 @@ class RedisClientWrapper:
|
||||
|
||||
|
||||
redis_client: RedisClientWrapper = RedisClientWrapper()
|
||||
pubsub_redis_client: RedisClientWrapper = RedisClientWrapper()
|
||||
|
||||
|
||||
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
|
||||
@@ -226,6 +232,12 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
|
||||
return client
|
||||
|
||||
|
||||
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> Union[redis.Redis, RedisCluster]:
|
||||
if use_clusters:
|
||||
return RedisCluster.from_url(pubsub_url)
|
||||
return redis.Redis.from_url(pubsub_url)
|
||||
|
||||
|
||||
def init_app(app: DifyApp):
|
||||
"""Initialize Redis client and attach it to the app."""
|
||||
global redis_client
|
||||
@@ -244,6 +256,24 @@ def init_app(app: DifyApp):
|
||||
redis_client.initialize(client)
|
||||
app.extensions["redis"] = redis_client
|
||||
|
||||
pubsub_client = client
|
||||
if dify_config.normalized_pubsub_redis_url:
|
||||
pubsub_client = _create_pubsub_client(
|
||||
dify_config.normalized_pubsub_redis_url, dify_config.PUBSUB_REDIS_USE_CLUSTERS
|
||||
)
|
||||
pubsub_redis_client.initialize(pubsub_client)
|
||||
|
||||
|
||||
def get_pubsub_redis_client() -> RedisClientWrapper:
|
||||
return pubsub_redis_client
|
||||
|
||||
|
||||
def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
|
||||
redis_conn = get_pubsub_redis_client()
|
||||
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
|
||||
return ShardedRedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
|
||||
return RedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
|
||||
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R")
|
||||
|
||||
@@ -13,6 +13,7 @@ from typing import Any
|
||||
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.workflow.enums import WorkflowNodeExecutionStatus
|
||||
from extensions.logstore.aliyun_logstore import AliyunLogStore
|
||||
from extensions.logstore.repositories import safe_float, safe_int
|
||||
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
|
||||
@@ -207,8 +208,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
||||
reverse=True,
|
||||
)
|
||||
|
||||
if deduplicated_results:
|
||||
return _dict_to_workflow_node_execution_model(deduplicated_results[0])
|
||||
for row in deduplicated_results:
|
||||
model = _dict_to_workflow_node_execution_model(row)
|
||||
if model.status != WorkflowNodeExecutionStatus.PAUSED:
|
||||
return model
|
||||
|
||||
return None
|
||||
|
||||
@@ -309,6 +312,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
|
||||
if model and model.id: # Ensure model is valid
|
||||
models.append(model)
|
||||
|
||||
models = [model for model in models if model.status != WorkflowNodeExecutionStatus.PAUSED]
|
||||
|
||||
# Sort by index DESC for trace visualization
|
||||
models.sort(key=lambda x: x.index, reverse=True)
|
||||
|
||||
|
||||
@@ -192,6 +192,7 @@ class StatusCount(ResponseModel):
|
||||
success: int
|
||||
failed: int
|
||||
partial_success: int
|
||||
paused: int
|
||||
|
||||
|
||||
class ModelConfig(ResponseModel):
|
||||
|
||||
@@ -6,6 +6,7 @@ from uuid import uuid4
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field, field_validator
|
||||
|
||||
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
|
||||
from core.file import File
|
||||
from fields.conversation_fields import AgentThought, JSONValue, MessageFile
|
||||
|
||||
@@ -61,6 +62,7 @@ class MessageListItem(ResponseModel):
|
||||
message_files: list[MessageFile]
|
||||
status: str
|
||||
error: str | None = None
|
||||
extra_contents: list[ExecutionExtraContentDomainModel]
|
||||
|
||||
@field_validator("inputs", mode="before")
|
||||
@classmethod
|
||||
|
||||
@@ -7,6 +7,7 @@ from typing import Self
|
||||
|
||||
from libs.broadcast_channel.channel import Subscription
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
from redis import Redis, RedisCluster
|
||||
from redis.client import PubSub
|
||||
|
||||
_logger = logging.getLogger(__name__)
|
||||
@@ -22,10 +23,12 @@ class RedisSubscriptionBase(Subscription):
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: Redis | RedisCluster,
|
||||
pubsub: PubSub,
|
||||
topic: str,
|
||||
):
|
||||
# The _pubsub is None only if the subscription is closed.
|
||||
self._client = client
|
||||
self._pubsub: PubSub | None = pubsub
|
||||
self._topic = topic
|
||||
self._closed = threading.Event()
|
||||
@@ -162,7 +165,7 @@ class RedisSubscriptionBase(Subscription):
|
||||
self._start_if_needed()
|
||||
return iter(self._message_iterator())
|
||||
|
||||
def receive(self, timeout: float | None = None) -> bytes | None:
|
||||
def receive(self, timeout: float | None = 0.1) -> bytes | None:
|
||||
"""Receive the next message from the subscription."""
|
||||
if self._closed.is_set():
|
||||
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")
|
||||
|
||||
@@ -42,6 +42,7 @@ class Topic:
|
||||
|
||||
def subscribe(self) -> Subscription:
|
||||
return _RedisSubscription(
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
)
|
||||
@@ -63,7 +64,7 @@ class _RedisSubscription(RedisSubscriptionBase):
|
||||
|
||||
def _get_message(self) -> dict | None:
|
||||
assert self._pubsub is not None
|
||||
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
|
||||
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1)
|
||||
|
||||
def _get_message_type(self) -> str:
|
||||
return "message"
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
|
||||
from redis import Redis
|
||||
from redis import Redis, RedisCluster
|
||||
|
||||
from ._subscription import RedisSubscriptionBase
|
||||
|
||||
@@ -16,7 +16,7 @@ class ShardedRedisBroadcastChannel:
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
redis_client: Redis,
|
||||
redis_client: Redis | RedisCluster,
|
||||
):
|
||||
self._client = redis_client
|
||||
|
||||
@@ -25,7 +25,7 @@ class ShardedRedisBroadcastChannel:
|
||||
|
||||
|
||||
class ShardedTopic:
|
||||
def __init__(self, redis_client: Redis, topic: str):
|
||||
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
|
||||
self._client = redis_client
|
||||
self._topic = topic
|
||||
|
||||
@@ -40,6 +40,7 @@ class ShardedTopic:
|
||||
|
||||
def subscribe(self) -> Subscription:
|
||||
return _RedisShardedSubscription(
|
||||
client=self._client,
|
||||
pubsub=self._client.pubsub(),
|
||||
topic=self._topic,
|
||||
)
|
||||
@@ -61,7 +62,26 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
|
||||
|
||||
def _get_message(self) -> dict | None:
|
||||
assert self._pubsub is not None
|
||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=0.1) # type: ignore[attr-defined]
|
||||
# NOTE(QuantumGhost): this is an issue in
|
||||
# upstream code. If Sharded PubSub is used with Cluster, the
|
||||
# `ClusterPubSub.get_sharded_message` will return `None` regardless of
|
||||
# message['type'].
|
||||
#
|
||||
# Since we have already filtered at the caller's site, we can safely set
|
||||
# `ignore_subscribe_messages=False`.
|
||||
if isinstance(self._client, RedisCluster):
|
||||
# NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message`
|
||||
# would use busy-looping to wait for incoming message, consuming excessive CPU quota.
|
||||
#
|
||||
# Here we specify the `target_node` to mitigate this problem.
|
||||
node = self._client.get_node_from_key(self._topic)
|
||||
return self._pubsub.get_sharded_message(
|
||||
ignore_subscribe_messages=False,
|
||||
timeout=1,
|
||||
target_node=node,
|
||||
)
|
||||
else:
|
||||
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined]
|
||||
|
||||
def _get_message_type(self) -> str:
|
||||
return "smessage"
|
||||
|
||||
49
api/libs/email_template_renderer.py
Normal file
49
api/libs/email_template_renderer.py
Normal file
@@ -0,0 +1,49 @@
|
||||
"""
|
||||
Email template rendering helpers with configurable safety modes.
|
||||
"""
|
||||
|
||||
import time
|
||||
from collections.abc import Mapping
|
||||
from typing import Any
|
||||
|
||||
from flask import render_template_string
|
||||
from jinja2.runtime import Context
|
||||
from jinja2.sandbox import ImmutableSandboxedEnvironment
|
||||
|
||||
from configs import dify_config
|
||||
from configs.feature import TemplateMode
|
||||
|
||||
|
||||
class SandboxedEnvironment(ImmutableSandboxedEnvironment):
|
||||
"""Sandboxed environment with execution timeout."""
|
||||
|
||||
def __init__(self, timeout: int, *args: Any, **kwargs: Any):
|
||||
self._deadline = time.time() + timeout if timeout else None
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def call(self, context: Context, obj: Any, *args: Any, **kwargs: Any) -> Any:
|
||||
if self._deadline is not None and time.time() > self._deadline:
|
||||
raise TimeoutError("Template rendering timeout")
|
||||
return super().call(context, obj, *args, **kwargs)
|
||||
|
||||
|
||||
def render_email_template(template: str, substitutions: Mapping[str, str]) -> str:
|
||||
"""
|
||||
Render email template content according to the configured template mode.
|
||||
|
||||
In unsafe mode, Jinja expressions are evaluated directly.
|
||||
In sandbox mode, a sandboxed environment with timeout is used.
|
||||
In disabled mode, the template is returned without rendering.
|
||||
"""
|
||||
mode = dify_config.MAIL_TEMPLATING_MODE
|
||||
timeout = dify_config.MAIL_TEMPLATING_TIMEOUT
|
||||
|
||||
if mode == TemplateMode.UNSAFE:
|
||||
return render_template_string(template, **substitutions)
|
||||
if mode == TemplateMode.SANDBOX:
|
||||
env = SandboxedEnvironment(timeout=timeout)
|
||||
tmpl = env.from_string(template)
|
||||
return tmpl.render(substitutions)
|
||||
if mode == TemplateMode.DISABLED:
|
||||
return template
|
||||
raise ValueError(f"Unsupported mail templating mode: {mode}")
|
||||
@@ -1,12 +1,15 @@
|
||||
import contextvars
|
||||
from collections.abc import Iterator
|
||||
from contextlib import contextmanager
|
||||
from typing import TypeVar
|
||||
from typing import TYPE_CHECKING, TypeVar
|
||||
|
||||
from flask import Flask, g
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models import Account, EndUser
|
||||
|
||||
|
||||
@contextmanager
|
||||
def preserve_flask_contexts(
|
||||
@@ -64,3 +67,7 @@ def preserve_flask_contexts(
|
||||
finally:
|
||||
# Any cleanup can be added here if needed
|
||||
pass
|
||||
|
||||
|
||||
def set_login_user(user: "Account | EndUser"):
|
||||
g._login_user = user
|
||||
|
||||
@@ -7,10 +7,10 @@ import struct
|
||||
import subprocess
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from datetime import datetime
|
||||
from hashlib import sha256
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
|
||||
from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast
|
||||
from uuid import UUID
|
||||
from zoneinfo import available_timezones
|
||||
|
||||
@@ -126,6 +126,13 @@ class TimestampField(fields.Raw):
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
class OptionalTimestampField(fields.Raw):
|
||||
def format(self, value) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def email(email):
|
||||
# Define a regex pattern for email addresses
|
||||
pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$"
|
||||
@@ -237,6 +244,26 @@ def convert_datetime_to_date(field, target_timezone: str = ":tz"):
|
||||
|
||||
|
||||
def generate_string(n):
|
||||
"""
|
||||
Generates a cryptographically secure random string of the specified length.
|
||||
|
||||
This function uses a cryptographically secure pseudorandom number generator (CSPRNG)
|
||||
to create a string composed of ASCII letters (both uppercase and lowercase) and digits.
|
||||
|
||||
Each character in the generated string provides approximately 5.95 bits of entropy
|
||||
(log2(62)). To ensure a minimum of 128 bits of entropy for security purposes, the
|
||||
length of the string (`n`) should be at least 22 characters.
|
||||
|
||||
Args:
|
||||
n (int): The length of the random string to generate. For secure usage,
|
||||
`n` should be 22 or greater.
|
||||
|
||||
Returns:
|
||||
str: A random string of length `n` composed of ASCII letters and digits.
|
||||
|
||||
Note:
|
||||
This function is suitable for generating credentials or other secure tokens.
|
||||
"""
|
||||
letters_digits = string.ascii_letters + string.digits
|
||||
result = ""
|
||||
for _ in range(n):
|
||||
@@ -405,11 +432,35 @@ class TokenManager:
|
||||
return f"{token_type}:account:{account_id}"
|
||||
|
||||
|
||||
class _RateLimiterRedisClient(Protocol):
|
||||
def zadd(self, name: str | bytes, mapping: dict[str | bytes | int | float, float | int | str | bytes]) -> int: ...
|
||||
|
||||
def zremrangebyscore(self, name: str | bytes, min: str | float, max: str | float) -> int: ...
|
||||
|
||||
def zcard(self, name: str | bytes) -> int: ...
|
||||
|
||||
def expire(self, name: str | bytes, time: int) -> bool: ...
|
||||
|
||||
|
||||
def _default_rate_limit_member_factory() -> str:
|
||||
current_time = int(time.time())
|
||||
return f"{current_time}:{secrets.token_urlsafe(nbytes=8)}"
|
||||
|
||||
|
||||
class RateLimiter:
|
||||
def __init__(self, prefix: str, max_attempts: int, time_window: int):
|
||||
def __init__(
|
||||
self,
|
||||
prefix: str,
|
||||
max_attempts: int,
|
||||
time_window: int,
|
||||
member_factory: Callable[[], str] = _default_rate_limit_member_factory,
|
||||
redis_client: _RateLimiterRedisClient = redis_client,
|
||||
):
|
||||
self.prefix = prefix
|
||||
self.max_attempts = max_attempts
|
||||
self.time_window = time_window
|
||||
self._member_factory = member_factory
|
||||
self._redis_client = redis_client
|
||||
|
||||
def _get_key(self, email: str) -> str:
|
||||
return f"{self.prefix}:{email}"
|
||||
@@ -419,8 +470,8 @@ class RateLimiter:
|
||||
current_time = int(time.time())
|
||||
window_start_time = current_time - self.time_window
|
||||
|
||||
redis_client.zremrangebyscore(key, "-inf", window_start_time)
|
||||
attempts = redis_client.zcard(key)
|
||||
self._redis_client.zremrangebyscore(key, "-inf", window_start_time)
|
||||
attempts = self._redis_client.zcard(key)
|
||||
|
||||
if attempts and int(attempts) >= self.max_attempts:
|
||||
return True
|
||||
@@ -428,7 +479,8 @@ class RateLimiter:
|
||||
|
||||
def increment_rate_limit(self, email: str):
|
||||
key = self._get_key(email)
|
||||
member = self._member_factory()
|
||||
current_time = int(time.time())
|
||||
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.expire(key, self.time_window * 2)
|
||||
self._redis_client.zadd(key, {member: current_time})
|
||||
self._redis_client.expire(key, self.time_window * 2)
|
||||
|
||||
@@ -10,6 +10,10 @@ import models as models
|
||||
import sqlalchemy as sa
|
||||
from sqlalchemy.dialects import postgresql
|
||||
|
||||
|
||||
def _is_pg(conn):
|
||||
return conn.dialect.name == "postgresql"
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '7df29de0f6be'
|
||||
down_revision = '03ea244985ce'
|
||||
@@ -19,16 +23,31 @@ depends_on = None
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('tenant_credit_pools',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False),
|
||||
sa.Column('quota_limit', sa.BigInteger(), nullable=False),
|
||||
sa.Column('quota_used', sa.BigInteger(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey')
|
||||
)
|
||||
conn = op.get_bind()
|
||||
|
||||
if _is_pg(conn):
|
||||
op.create_table('tenant_credit_pools',
|
||||
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False),
|
||||
sa.Column('quota_limit', sa.BigInteger(), nullable=False),
|
||||
sa.Column('quota_used', sa.BigInteger(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey')
|
||||
)
|
||||
else:
|
||||
# For MySQL and other databases, UUID should be generated at application level
|
||||
op.create_table('tenant_credit_pools',
|
||||
sa.Column('id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
|
||||
sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False),
|
||||
sa.Column('quota_limit', sa.BigInteger(), nullable=False),
|
||||
sa.Column('quota_used', sa.BigInteger(), nullable=False),
|
||||
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey')
|
||||
)
|
||||
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
|
||||
batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False)
|
||||
batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False)
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user