Compare commits

..

39 Commits

Author SHA1 Message Date
QuantumGhost
3d0ff9463f Merge branch 'fix/redis-pubsub-perf' into feat/hitl 2026-02-06 14:42:39 +08:00
QuantumGhost
b893d2df82 docs(api): add a short note about the target_node argument 2026-02-06 14:42:04 +08:00
QuantumGhost
79b6117d80 fixup! fix(api): fix performance issue in ShardedRedisBroadcastChannel 2026-02-06 14:35:19 +08:00
WTW0313
d2ef434dec Merge branch 'main' into feat/hitl 2026-02-06 13:58:24 +08:00
longbingljw
d9530f7bb7 fix: make flask upgrade-db fail on error (#32024) 2026-02-06 12:01:31 +08:00
wangxiaolei
b24e6edada fix: fix agent node tool type is not right (#32008)
Infer real tool type via querying relevant database tables.

The root cause for incorrect `type` field is still not clear.
2026-02-06 11:24:39 +08:00
Ryan
59a9cbbf78 chore: remove .codex/skills directory (#32022)
Co-authored-by: Longwei Liu <longweiliu@LongweideMacBook-Air.local>
2026-02-06 10:46:50 +08:00
99
45164ce33e refactor: strip external imports in workflow template transform (#32017) 2026-02-06 10:37:26 +08:00
99
095b3ee234 chore: Remove redundant double space in variable type description (core/variables/variables.py) (#32002)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
2026-02-05 21:44:31 +08:00
QuantumGhost
cb970e54da perf(api): Optimize the response time of AppListApi endpoint (#31999) 2026-02-05 19:05:09 +08:00
Stream
e04f2a0786 feat: use static manifest for pre-caching all plugin manifests before checking updates (#31942)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: Junyan Qin <rockchinq@gmail.com>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-05 18:58:17 +08:00
Stephen Zhou
7202a24bcf chore: migrate to eslint-better-tailwind (#31969) 2026-02-05 18:36:08 +09:00
wangxiaolei
be8f265e43 fix: fix uuid_generate_v4 only used in postgresql (#31304) 2026-02-05 17:32:33 +08:00
QuantumGhost
aaf83c2b4c chore(api): fix linting issue 2026-02-05 16:15:32 +08:00
lif
9e54f086dc fix(web): add rewrite rule to fix Serwist precaching 404 errors (#31770)
Signed-off-by: majiayu000 <1835304752@qq.com>
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
2026-02-05 15:42:18 +08:00
QuantumGhost
d898bcff90 feat(api): adjust timeout for get_message to 1s 2026-02-05 15:22:09 +08:00
twwu
b4cf146c85 Merge branch 'main' into feat/hitl 2026-02-05 14:56:02 +08:00
Joel
8c31b69c8e chore: sticky the applist header in explore page (#31967)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
2026-02-05 14:44:51 +08:00
wangxiaolei
b886b3f6c8 fix: fix miss use db.session (#31971) 2026-02-05 14:42:34 +08:00
Stephen Zhou
ef0d18bb61 test: fix test (#31975) 2026-02-05 14:31:21 +08:00
QuantumGhost
f21782a9a3 fix(api): fix performance issue in ShardedRedisBroadcastChannel 2026-02-05 13:28:39 +08:00
JzoNg
e4455987e7 fix: do not stop when workflow paused event recieved 2026-02-05 11:16:14 +08:00
twwu
b2ceb41dd6 Merge branch 'main' into feat/hitl 2026-02-05 11:13:40 +08:00
Xiyuan Chen
c56ad8e323 feat: account delete cleanup (#31519)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-04 17:59:41 -08:00
yyh
365f749ed5 fix: remove staleTime/gcTime overrides from trigger query hooks and use orpc contract (#31863)
Some checks failed
autofix.ci / autofix (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Main CI Pipeline / Check Changed Files (push) Has been cancelled
Main CI Pipeline / API Tests (push) Has been cancelled
Main CI Pipeline / Web Tests (push) Has been cancelled
Main CI Pipeline / Style Check (push) Has been cancelled
Main CI Pipeline / VDB Tests (push) Has been cancelled
Main CI Pipeline / DB Migration Test (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
2026-02-04 19:33:32 +08:00
wangxiaolei
f686197589 feat: use latest hash to sync draft (#31924) 2026-02-04 19:32:36 +08:00
Coding On Star
f584be9cf0 chore: update CODEOWNERS to specify test file patterns for base components (#31941)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-02-04 19:29:57 +08:00
QuantumGhost
3bd228ddb7 chore: bump version in docker-compose and package manager to 1.12.1 (#31947) 2026-02-04 19:29:28 +08:00
wangxiaolei
0dfa59b1db fix: fix delete_draft_variables_batch cycle forever (#31934)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-04 19:10:27 +08:00
Coding On Star
1e344f773b refactor(web): extract complex components into modular structure with comprehensive tests (#31729)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 18:35:31 +08:00
-LAN-
bba2040a05 chore: assign code owners for test directories (#31940) 2026-02-04 18:22:14 +08:00
Coding On Star
ad3be1e4d0 fix: include locale in appList query key for localization support inuseExploreAppList (#31921)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
2026-02-04 18:12:30 +08:00
Coding On Star
297dd832aa refactor(datasets): extract hooks and components with comprehensive tests (#31707)
Co-authored-by: CodingOnStar <hanxujiang@dify.com>
Co-authored-by: Claude Opus 4.5 <noreply@anthropic.com>
2026-02-04 18:12:17 +08:00
zxhlyh
cc5705cb71 fix: auto summary env (#31930) 2026-02-04 17:47:38 +08:00
wangxiaolei
74b027c41a fix: fix mcp output schema is union type frontend crash (#31779)
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
2026-02-04 17:33:41 +08:00
Stephen Zhou
5f69470ebf test: try fix test, clear test log in CI (#31912) 2026-02-04 17:05:15 +08:00
wangxiaolei
ec7ccd800c fix: fix mcp server status is not right (#31826)
Co-authored-by: Stephen Zhou <38493346+hyoban@users.noreply.github.com>
2026-02-04 16:55:12 +08:00
QuantumGhost
f614153f30 chore(api): fix circular import 2026-02-02 16:52:43 +08:00
QuantumGhost
8ca020e179 Revert "revert: revert human input relevant code (#31766)"
This reverts commit 90fe9abab7.
2026-02-01 16:21:14 +08:00
600 changed files with 49997 additions and 5350 deletions

View File

@@ -1 +0,0 @@
../../.agents/skills/component-refactoring

View File

@@ -1 +0,0 @@
../../.agents/skills/frontend-code-review

View File

@@ -1 +0,0 @@
../../.agents/skills/frontend-testing

View File

@@ -1 +0,0 @@
../../.agents/skills/orpc-contract-first

7
.github/CODEOWNERS vendored
View File

@@ -24,6 +24,10 @@
/api/services/tools/mcp_tools_manage_service.py @Nov1c444
/api/controllers/mcp/ @Nov1c444
/api/controllers/console/app/mcp_server.py @Nov1c444
# Backend - Tests
/api/tests/ @laipz8200 @QuantumGhost
/api/tests/**/*mcp* @Nov1c444
# Backend - Workflow - Engine (Core graph execution engine)
@@ -234,6 +238,9 @@
# Frontend - Base Components
/web/app/components/base/ @iamjoel @zxhlyh
# Frontend - Base Components Tests
/web/app/components/base/**/*.spec.tsx @hyoban @CodingOnStar
# Frontend - Utils and Hooks
/web/utils/classnames.ts @iamjoel @zxhlyh
/web/utils/time.ts @iamjoel @zxhlyh

View File

@@ -79,29 +79,6 @@ jobs:
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
find . -name "*.py.bak" -type f -delete
- name: Install pnpm
uses: pnpm/action-setup@v4
with:
package_json_file: web/package.json
run_install: false
- name: Setup Node.js
uses: actions/setup-node@v6
with:
node-version: 24
cache: pnpm
cache-dependency-path: ./web/pnpm-lock.yaml
- name: Install web dependencies
run: |
cd web
pnpm install --frozen-lockfile
- name: ESLint autofix
run: |
cd web
pnpm lint:fix || true
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
- name: mdformat
run: |

View File

@@ -8,6 +8,7 @@ on:
- "build/**"
- "release/e-*"
- "hotfix/**"
- "feat/hitl-backend"
tags:
- "*"

View File

@@ -39,7 +39,7 @@ jobs:
run: pnpm install --frozen-lockfile
- name: Run tests
run: pnpm test:coverage
run: pnpm test:ci
- name: Coverage Summary
if: always()

View File

@@ -717,3 +717,28 @@ SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000
# Redis URL used for PubSub between API and
# celery worker
# defaults to url constructed from `REDIS_*`
# configurations
PUBSUB_REDIS_URL=
# Pub/sub channel type for streaming events.
# valid options are:
#
# - pubsub: for normal Pub/Sub
# - sharded: for sharded Pub/Sub
#
# It's highly recommended to use sharded Pub/Sub AND redis cluster
# for large deployments.
PUBSUB_REDIS_CHANNEL_TYPE=pubsub
# Whether to use Redis cluster mode while running
# PubSub.
# It's highly recommended to enable this for large deployments.
PUBSUB_REDIS_USE_CLUSTERS=false
# Whether to Enable human input timeout check task
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true
# Human input timeout check interval in minutes
HUMAN_INPUT_TIMEOUT_TASK_INTERVAL=1

View File

@@ -36,6 +36,8 @@ ignore_imports =
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
core.workflow.nodes.loop.loop_node -> core.workflow.graph
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
# TODO(QuantumGhost): fix the import violation later
core.workflow.entities.pause_reason -> core.workflow.nodes.human_input.entities
[importlinter:contract:workflow-infrastructure-dependencies]
name = Workflow Infrastructure Dependencies
@@ -58,6 +60,8 @@ ignore_imports =
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.graph_engine.manager -> extensions.ext_redis
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
# TODO(QuantumGhost): use DI to avoid depending on global DB.
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
[importlinter:contract:workflow-external-imports]
name = Workflow External Imports
@@ -102,6 +106,8 @@ forbidden_modules =
core.trigger
core.variables
ignore_imports =
core.workflow.nodes.agent.agent_node -> core.db.session_factory
core.workflow.nodes.agent.agent_node -> models.tools
core.workflow.nodes.loop.loop_node -> core.app.workflow.node_factory
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
core.workflow.workflow_entry -> core.app.workflow.layers.observability
@@ -136,7 +142,6 @@ ignore_imports =
core.workflow.nodes.llm.llm_utils -> models.provider
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
core.workflow.nodes.llm.node -> core.tools.signature
core.workflow.nodes.template_transform.template_transform_node -> configs
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
core.workflow.nodes.tool.tool_node -> core.tools.tool_manager
@@ -145,6 +150,7 @@ ignore_imports =
core.workflow.nodes.agent.agent_node -> core.agent.entities
core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities
core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities
@@ -248,6 +254,7 @@ ignore_imports =
core.workflow.nodes.document_extractor.node -> core.variables.segments
core.workflow.nodes.http_request.executor -> core.variables.segments
core.workflow.nodes.http_request.node -> core.variables.segments
core.workflow.nodes.human_input.entities -> core.variables.consts
core.workflow.nodes.iteration.iteration_node -> core.variables
core.workflow.nodes.iteration.iteration_node -> core.variables.segments
core.workflow.nodes.iteration.iteration_node -> core.variables.variables
@@ -294,6 +301,8 @@ ignore_imports =
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
core.workflow.nodes.llm.node -> extensions.ext_database
core.workflow.nodes.tool.tool_node -> extensions.ext_database
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
core.workflow.nodes.human_input.human_input_node -> core.repositories.human_input_repository
core.workflow.workflow_entry -> extensions.otel.runtime
core.workflow.nodes.agent.agent_node -> models
core.workflow.nodes.base.node -> models.enums

View File

@@ -739,8 +739,10 @@ def upgrade_db():
click.echo(click.style("Database migration successful!", fg="green"))
except Exception:
except Exception as e:
logger.exception("Failed to execute database migration")
click.echo(click.style(f"Database migration failed: {e}", fg="red"))
raise SystemExit(1)
finally:
lock.release()
else:

View File

@@ -1,3 +1,4 @@
from datetime import timedelta
from enum import StrEnum
from typing import Literal
@@ -48,6 +49,16 @@ class SecurityConfig(BaseSettings):
default=5,
)
WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS: PositiveInt = Field(
description="Maximum number of web form submissions allowed per IP within the rate limit window",
default=30,
)
WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS: PositiveInt = Field(
description="Time window in seconds for web form submission rate limiting",
default=60,
)
LOGIN_DISABLED: bool = Field(
description="Whether to disable login checks",
default=False,
@@ -82,6 +93,12 @@ class AppExecutionConfig(BaseSettings):
default=0,
)
HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS: PositiveInt = Field(
description="Maximum seconds a workflow run can stay paused waiting for human input before global timeout.",
default=int(timedelta(days=7).total_seconds()),
ge=1,
)
class CodeExecutionSandboxConfig(BaseSettings):
"""
@@ -1134,6 +1151,14 @@ class CeleryScheduleTasksConfig(BaseSettings):
description="Enable queue monitor task",
default=False,
)
ENABLE_HUMAN_INPUT_TIMEOUT_TASK: bool = Field(
description="Enable human input timeout check task",
default=True,
)
HUMAN_INPUT_TIMEOUT_TASK_INTERVAL: PositiveInt = Field(
description="Human input timeout check interval in minutes",
default=1,
)
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: bool = Field(
description="Enable check upgradable plugin task",
default=True,

View File

@@ -6,6 +6,7 @@ from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, Pos
from pydantic_settings import BaseSettings
from .cache.redis_config import RedisConfig
from .cache.redis_pubsub_config import RedisPubSubConfig
from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
from .storage.amazon_s3_storage_config import S3StorageConfig
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
@@ -317,6 +318,7 @@ class MiddlewareConfig(
CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig
KeywordStoreConfig,
RedisConfig,
RedisPubSubConfig,
# configs of storage and storage providers
StorageConfig,
AliyunOSSStorageConfig,

View File

@@ -0,0 +1,96 @@
from typing import Literal, Protocol
from urllib.parse import quote_plus, urlunparse
from pydantic import Field
from pydantic_settings import BaseSettings
class RedisConfigDefaults(Protocol):
REDIS_HOST: str
REDIS_PORT: int
REDIS_USERNAME: str | None
REDIS_PASSWORD: str | None
REDIS_DB: int
REDIS_USE_SSL: bool
REDIS_USE_SENTINEL: bool | None
REDIS_USE_CLUSTERS: bool
class RedisConfigDefaultsMixin:
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
return self
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
"""
Configuration settings for Redis pub/sub streaming.
"""
PUBSUB_REDIS_URL: str | None = Field(
alias="PUBSUB_REDIS_URL",
description=(
"Redis connection URL for pub/sub streaming events between API "
"and celery worker, defaults to url constructed from "
"`REDIS_*` configurations"
),
default=None,
)
PUBSUB_REDIS_USE_CLUSTERS: bool = Field(
description=(
"Enable Redis Cluster mode for pub/sub streaming. It's highly "
"recommended to enable this for large deployments."
),
default=False,
)
PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded"] = Field(
description=(
"Pub/sub channel type for streaming events. "
"Valid options are:\n"
"\n"
" - pubsub: for normal Pub/Sub\n"
" - sharded: for sharded Pub/Sub\n"
"\n"
"It's highly recommended to use sharded Pub/Sub AND redis cluster "
"for large deployments."
),
default="pubsub",
)
def _build_default_pubsub_url(self) -> str:
defaults = self._redis_defaults()
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
scheme = "rediss" if defaults.REDIS_USE_SSL else "redis"
username = defaults.REDIS_USERNAME or None
password = defaults.REDIS_PASSWORD or None
userinfo = ""
if username:
userinfo = quote_plus(username)
if password:
password_part = quote_plus(password)
userinfo = f"{userinfo}:{password_part}" if userinfo else f":{password_part}"
if userinfo:
userinfo = f"{userinfo}@"
host = defaults.REDIS_HOST
port = defaults.REDIS_PORT
db = defaults.REDIS_DB
netloc = f"{userinfo}{host}:{port}"
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
@property
def normalized_pubsub_redis_url(self) -> str:
pubsub_redis_url = self.PUBSUB_REDIS_URL
if pubsub_redis_url:
cleaned = pubsub_redis_url.strip()
pubsub_redis_url = cleaned or None
if pubsub_redis_url:
return pubsub_redis_url
return self._build_default_pubsub_url()

View File

@@ -37,6 +37,7 @@ from . import (
apikey,
extension,
feature,
human_input_form,
init_validate,
ping,
setup,
@@ -171,6 +172,7 @@ __all__ = [
"forgot_password",
"generator",
"hit_testing",
"human_input_form",
"init_validate",
"installed_app",
"load_balancing_config",

View File

@@ -1,3 +1,4 @@
import logging
import uuid
from datetime import datetime
from typing import Any, Literal, TypeAlias
@@ -54,6 +55,8 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co
register_enum_models(console_ns, IconType)
_logger = logging.getLogger(__name__)
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
@@ -499,6 +502,7 @@ class AppListApi(Resource):
select(Workflow).where(
Workflow.version == Workflow.VERSION_DRAFT,
Workflow.app_id.in_(workflow_capable_app_ids),
Workflow.tenant_id == current_tenant_id,
)
)
.scalars()
@@ -510,12 +514,14 @@ class AppListApi(Resource):
NodeType.TRIGGER_PLUGIN,
}
for workflow in draft_workflows:
node_id = None
try:
for _, node_data in workflow.walk_nodes():
for node_id, node_data in workflow.walk_nodes():
if node_data.get("type") in trigger_node_types:
draft_trigger_app_ids.add(str(workflow.app_id))
break
except Exception:
_logger.exception("error while walking nodes, workflow_id=%s, node_id=%s", workflow.id, node_id)
continue
for app in app_pagination.items:

View File

@@ -89,6 +89,7 @@ status_count_model = console_ns.model(
"success": fields.Integer,
"failed": fields.Integer,
"partial_success": fields.Integer,
"paused": fields.Integer,
},
)

View File

@@ -33,7 +33,7 @@ from libs.login import current_account_with_tenant, login_required
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
from services.message_service import MessageService
from services.message_service import MessageService, attach_message_extra_contents
logger = logging.getLogger(__name__)
@@ -207,6 +207,7 @@ message_detail_model = console_ns.model(
"created_at": TimestampField,
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
"message_files": fields.List(fields.Nested(message_file_model)),
"extra_contents": fields.List(fields.Raw),
"metadata": fields.Raw(attribute="message_metadata_dict"),
"status": fields.String,
"error": fields.String,
@@ -299,6 +300,7 @@ class ChatMessageListApi(Resource):
has_more = False
history_messages = list(reversed(history_messages))
attach_message_extra_contents(history_messages)
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
@@ -481,4 +483,5 @@ class MessageApi(Resource):
if not message:
raise NotFound("Message Not Exists.")
attach_message_extra_contents([message])
return message

View File

@@ -507,6 +507,179 @@ class WorkflowDraftRunLoopNodeApi(Resource):
raise InternalServerError()
class HumanInputFormPreviewPayload(BaseModel):
inputs: dict[str, Any] = Field(
default_factory=dict,
description="Values used to fill missing upstream variables referenced in form_content",
)
class HumanInputFormSubmitPayload(BaseModel):
form_inputs: dict[str, Any] = Field(..., description="Values the user provides for the form's own fields")
inputs: dict[str, Any] = Field(
...,
description="Values used to fill missing upstream variables referenced in form_content",
)
action: str = Field(..., description="Selected action ID")
class HumanInputDeliveryTestPayload(BaseModel):
delivery_method_id: str = Field(..., description="Delivery method ID")
inputs: dict[str, Any] = Field(
default_factory=dict,
description="Values used to fill missing upstream variables referenced in form_content",
)
reg(HumanInputFormPreviewPayload)
reg(HumanInputFormSubmitPayload)
reg(HumanInputDeliveryTestPayload)
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/human-input/nodes/<string:node_id>/form/preview")
class AdvancedChatDraftHumanInputFormPreviewApi(Resource):
@console_ns.doc("get_advanced_chat_draft_human_input_form")
@console_ns.doc(description="Get human input form preview for advanced chat workflow")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Preview human input form content and placeholders
"""
current_user, _ = current_account_with_tenant()
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
inputs = args.inputs
workflow_service = WorkflowService()
preview = workflow_service.get_human_input_form_preview(
app_model=app_model,
account=current_user,
node_id=node_id,
inputs=inputs,
)
return jsonable_encoder(preview)
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/human-input/nodes/<string:node_id>/form/run")
class AdvancedChatDraftHumanInputFormRunApi(Resource):
@console_ns.doc("submit_advanced_chat_draft_human_input_form")
@console_ns.doc(description="Submit human input form preview for advanced chat workflow")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Submit human input form preview
"""
current_user, _ = current_account_with_tenant()
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
workflow_service = WorkflowService()
result = workflow_service.submit_human_input_form_preview(
app_model=app_model,
account=current_user,
node_id=node_id,
form_inputs=args.form_inputs,
inputs=args.inputs,
action=args.action,
)
return jsonable_encoder(result)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/form/preview")
class WorkflowDraftHumanInputFormPreviewApi(Resource):
@console_ns.doc("get_workflow_draft_human_input_form")
@console_ns.doc(description="Get human input form preview for workflow")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Preview human input form content and placeholders
"""
current_user, _ = current_account_with_tenant()
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
inputs = args.inputs
workflow_service = WorkflowService()
preview = workflow_service.get_human_input_form_preview(
app_model=app_model,
account=current_user,
node_id=node_id,
inputs=inputs,
)
return jsonable_encoder(preview)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/form/run")
class WorkflowDraftHumanInputFormRunApi(Resource):
@console_ns.doc("submit_workflow_draft_human_input_form")
@console_ns.doc(description="Submit human input form preview for workflow")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Submit human input form preview
"""
current_user, _ = current_account_with_tenant()
workflow_service = WorkflowService()
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
result = workflow_service.submit_human_input_form_preview(
app_model=app_model,
account=current_user,
node_id=node_id,
form_inputs=args.form_inputs,
inputs=args.inputs,
action=args.action,
)
return jsonable_encoder(result)
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/delivery-test")
class WorkflowDraftHumanInputDeliveryTestApi(Resource):
@console_ns.doc("test_workflow_draft_human_input_delivery")
@console_ns.doc(description="Test human input delivery for workflow")
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
@console_ns.expect(console_ns.models[HumanInputDeliveryTestPayload.__name__])
@setup_required
@login_required
@account_initialization_required
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
@edit_permission_required
def post(self, app_model: App, node_id: str):
"""
Test human input delivery
"""
current_user, _ = current_account_with_tenant()
workflow_service = WorkflowService()
args = HumanInputDeliveryTestPayload.model_validate(console_ns.payload or {})
workflow_service.test_human_input_delivery(
app_model=app_model,
account=current_user,
node_id=node_id,
delivery_method_id=args.delivery_method_id,
inputs=args.inputs,
)
return jsonable_encoder({})
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/run")
class DraftWorkflowRunApi(Resource):
@console_ns.doc("run_draft_workflow")

View File

@@ -5,10 +5,15 @@ from flask import request
from flask_restx import Resource, fields, marshal_with
from pydantic import BaseModel, Field, field_validator
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from configs import dify_config
from controllers.console import console_ns
from controllers.console.app.wraps import get_app_model
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import NotFoundError
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import WorkflowExecutionStatus
from extensions.ext_database import db
from fields.end_user_fields import simple_end_user_fields
from fields.member_fields import simple_account_fields
@@ -27,9 +32,21 @@ from libs.custom_inputs import time_duration
from libs.helper import uuid_value
from libs.login import current_user, login_required
from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
from services.workflow_run_service import WorkflowRunService
def _build_backstage_input_url(form_token: str | None) -> str | None:
if not form_token:
return None
base_url = dify_config.APP_WEB_URL
if not base_url:
return None
return f"{base_url.rstrip('/')}/form/{form_token}"
# Workflow run status choices for filtering
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600
@@ -440,3 +457,63 @@ class WorkflowRunNodeExecutionListApi(Resource):
)
return {"data": node_executions}
@console_ns.route("/workflow/<string:workflow_run_id>/pause-details")
class ConsoleWorkflowPauseDetailsApi(Resource):
"""Console API for getting workflow pause details."""
@account_initialization_required
@login_required
def get(self, workflow_run_id: str):
"""
Get workflow pause details.
GET /console/api/workflow/<workflow_run_id>/pause-details
Returns information about why and where the workflow is paused.
"""
# Query WorkflowRun to determine if workflow is suspended
session_maker = sessionmaker(bind=db.engine)
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_maker)
workflow_run = db.session.get(WorkflowRun, workflow_run_id)
if not workflow_run:
raise NotFoundError("Workflow run not found")
# Check if workflow is suspended
is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED
if not is_paused:
return {
"paused_at": None,
"paused_nodes": [],
}, 200
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
pause_reasons = pause_entity.get_pause_reasons() if pause_entity else []
# Build response
paused_at = pause_entity.paused_at if pause_entity else None
paused_nodes = []
response = {
"paused_at": paused_at.isoformat() + "Z" if paused_at else None,
"paused_nodes": paused_nodes,
}
for reason in pause_reasons:
if isinstance(reason, HumanInputRequired):
paused_nodes.append(
{
"node_id": reason.node_id,
"node_title": reason.node_title,
"pause_type": {
"type": "human_input",
"form_id": reason.form_id,
"backstage_input_url": _build_backstage_input_url(reason.form_token),
},
}
)
else:
raise AssertionError("unimplemented.")
return response, 200

View File

@@ -0,0 +1,217 @@
"""
Console/Studio Human Input Form APIs.
"""
import json
import logging
from collections.abc import Generator
from flask import Response, jsonify, request
from flask_restx import Resource, reqparse
from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from controllers.console import console_ns
from controllers.console.wraps import account_initialization_required, setup_required
from controllers.web.error import InvalidArgumentError, NotFoundError
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from extensions.ext_database import db
from libs.login import current_account_with_tenant, login_required
from models import App
from models.enums import CreatorUserRole
from models.human_input import RecipientType
from models.model import AppMode
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from services.human_input_service import Form, HumanInputService
from services.workflow_event_snapshot_service import build_workflow_event_stream
logger = logging.getLogger(__name__)
def _jsonify_form_definition(form: Form) -> Response:
payload = form.get_definition().model_dump()
payload["expiration_time"] = int(form.expiration_time.timestamp())
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
@console_ns.route("/form/human_input/<string:form_token>")
class ConsoleHumanInputFormApi(Resource):
"""Console API for getting human input form definition."""
@staticmethod
def _ensure_console_access(form: Form):
_, current_tenant_id = current_account_with_tenant()
if form.tenant_id != current_tenant_id:
raise NotFoundError("App not found")
@setup_required
@login_required
@account_initialization_required
def get(self, form_token: str):
"""
Get human input form definition by form token.
GET /console/api/form/human_input/<form_token>
"""
service = HumanInputService(db.engine)
form = service.get_form_definition_by_token_for_console(form_token)
if form is None:
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form)
return _jsonify_form_definition(form)
@account_initialization_required
@login_required
def post(self, form_token: str):
"""
Submit human input form by form token.
POST /console/api/form/human_input/<form_token>
Request body:
{
"inputs": {
"content": "User input content"
},
"action": "Approve"
}
"""
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("action", type=str, required=True, location="json")
args = parser.parse_args()
current_user, _ = current_account_with_tenant()
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFoundError(f"form not found, token={form_token}")
self._ensure_console_access(form)
recipient_type = form.recipient_type
if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
raise NotFoundError(f"form not found, token={form_token}")
# The type checker is not smart enought to validate the following invariant.
# So we need to assert it manually.
assert recipient_type is not None, "recipient_type cannot be None here."
service.submit_form_by_token(
recipient_type=recipient_type,
form_token=form_token,
selected_action_id=args["action"],
form_data=args["inputs"],
submission_user_id=current_user.id,
)
return jsonify({})
@console_ns.route("/workflow/<string:workflow_run_id>/events")
class ConsoleWorkflowEventsApi(Resource):
"""Console API for getting workflow execution events after resume."""
@account_initialization_required
@login_required
def get(self, workflow_run_id: str):
"""
Get workflow execution events stream after resume.
GET /console/api/workflow/<workflow_run_id>/events
Returns Server-Sent Events stream.
"""
user, tenant_id = current_account_with_tenant()
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
tenant_id=tenant_id,
run_id=workflow_run_id,
)
if workflow_run is None:
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
if workflow_run.created_by_role != CreatorUserRole.ACCOUNT:
raise NotFoundError(f"WorkflowRun not created by account, id={workflow_run_id}")
if workflow_run.created_by != user.id:
raise NotFoundError(f"WorkflowRun not created by the current account, id={workflow_run_id}")
with Session(expire_on_commit=False, bind=db.engine) as session:
app = _retrieve_app_for_workflow_run(session, workflow_run)
if workflow_run.finished_at is not None:
# TODO(QuantumGhost): should we modify the handling for finished workflow run here?
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
task_id=workflow_run.id,
workflow_run=workflow_run,
creator_user=user,
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
def _generate_finished_events() -> Generator[str, None, None]:
yield f"data: {json.dumps(payload)}\n\n"
event_generator = _generate_finished_events
else:
msg_generator = MessageGenerator()
if app.mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app.mode == AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
else:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
def _generate_stream_events():
if include_state_snapshot:
return generator.convert_to_event_stream(
build_workflow_event_stream(
app_mode=AppMode(app.mode),
workflow_run=workflow_run,
tenant_id=workflow_run.tenant_id,
app_id=workflow_run.app_id,
session_maker=session_maker,
)
)
return generator.convert_to_event_stream(
msg_generator.retrieve_events(AppMode(app.mode), workflow_run.id),
)
event_generator = _generate_stream_events
return Response(
event_generator(),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun):
query = select(App).where(
App.id == workflow_run.app_id,
App.tenant_id == workflow_run.tenant_id,
)
app = session.scalars(query).first()
if app is None:
raise AssertionError(
f"App not found for WorkflowRun, workflow_run_id={workflow_run.id}, "
f"app_id={workflow_run.app_id}, tenant_id={workflow_run.tenant_id}"
)
return app

View File

@@ -33,8 +33,9 @@ from core.workflow.graph_engine.manager import GraphEngineManager
from extensions.ext_database import db
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
from libs import helper
from libs.helper import TimestampField
from libs.helper import OptionalTimestampField, TimestampField
from models.model import App, AppMode, EndUser
from models.workflow import WorkflowRun
from repositories.factory import DifyAPIRepositoryFactory
from services.app_generate_service import AppGenerateService
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
@@ -63,17 +64,32 @@ class WorkflowLogQuery(BaseModel):
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
class WorkflowRunStatusField(fields.Raw):
def output(self, key, obj: WorkflowRun, **kwargs):
return obj.status.value
class WorkflowRunOutputsField(fields.Raw):
def output(self, key, obj: WorkflowRun, **kwargs):
if obj.status == WorkflowExecutionStatus.PAUSED:
return {}
outputs = obj.outputs_dict
return outputs or {}
workflow_run_fields = {
"id": fields.String,
"workflow_id": fields.String,
"status": fields.String,
"status": WorkflowRunStatusField,
"inputs": fields.Raw,
"outputs": fields.Raw,
"outputs": WorkflowRunOutputsField,
"error": fields.String,
"total_steps": fields.Integer,
"total_tokens": fields.Integer,
"created_at": TimestampField,
"finished_at": TimestampField,
"finished_at": OptionalTimestampField,
"elapsed_time": fields.Float,
}

View File

@@ -23,6 +23,7 @@ from . import (
feature,
files,
forgot_password,
human_input_form,
login,
message,
passport,
@@ -30,6 +31,7 @@ from . import (
saved_message,
site,
workflow,
workflow_events,
)
api.add_namespace(web_ns)
@@ -44,6 +46,7 @@ __all__ = [
"feature",
"files",
"forgot_password",
"human_input_form",
"login",
"message",
"passport",
@@ -52,4 +55,5 @@ __all__ = [
"site",
"web_ns",
"workflow",
"workflow_events",
]

View File

@@ -117,6 +117,12 @@ class InvokeRateLimitError(BaseHTTPException):
code = 429
class WebFormRateLimitExceededError(BaseHTTPException):
error_code = "web_form_rate_limit_exceeded"
description = "Too many form requests. Please try again later."
code = 429
class NotFoundError(BaseHTTPException):
error_code = "not_found"
code = 404

View File

@@ -0,0 +1,164 @@
"""
Web App Human Input Form APIs.
"""
import json
import logging
from datetime import datetime
from flask import Response, request
from flask_restx import Resource, reqparse
from werkzeug.exceptions import Forbidden
from configs import dify_config
from controllers.web import web_ns
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
from controllers.web.site import serialize_app_site_payload
from extensions.ext_database import db
from libs.helper import RateLimiter, extract_remote_ip
from models.account import TenantStatus
from models.model import App, Site
from services.human_input_service import Form, FormNotFoundError, HumanInputService
logger = logging.getLogger(__name__)
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
prefix="web_form_submit_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
)
_FORM_ACCESS_RATE_LIMITER = RateLimiter(
prefix="web_form_access_rate_limit",
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
)
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
result: dict[str, str] = {}
for key, value in values.items():
if value is None:
result[key] = ""
elif isinstance(value, (dict, list)):
result[key] = json.dumps(value, ensure_ascii=False)
else:
result[key] = str(value)
return result
def _to_timestamp(value: datetime) -> int:
return int(value.timestamp())
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
"""Return the form payload (optionally with site) as a JSON response."""
definition_payload = form.get_definition().model_dump()
payload = {
"form_content": definition_payload["rendered_content"],
"inputs": definition_payload["inputs"],
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
"user_actions": definition_payload["user_actions"],
"expiration_time": _to_timestamp(form.expiration_time),
}
if site_payload is not None:
payload["site"] = site_payload
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
# TODO(QuantumGhost): disable authorization for web app
# form api temporarily
@web_ns.route("/form/human_input/<string:form_token>")
# class HumanInputFormApi(WebApiResource):
class HumanInputFormApi(Resource):
"""API for getting and submitting human input forms via the web app."""
# def get(self, _app_model: App, _end_user: EndUser, form_token: str):
def get(self, form_token: str):
"""
Get human input form definition by token.
GET /api/form/human_input/<form_token>
"""
ip_address = extract_remote_ip(request)
if _FORM_ACCESS_RATE_LIMITER.is_rate_limited(ip_address):
raise WebFormRateLimitExceededError()
_FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address)
service = HumanInputService(db.engine)
# TODO(QuantumGhost): forbid submision for form tokens
# that are only for console.
form = service.get_form_by_token(form_token)
if form is None:
raise NotFoundError("Form not found")
service.ensure_form_active(form)
app_model, site = _get_app_site_from_form(form)
return _jsonify_form_definition(form, site_payload=serialize_app_site_payload(app_model, site, None))
# def post(self, _app_model: App, _end_user: EndUser, form_token: str):
def post(self, form_token: str):
"""
Submit human input form by token.
POST /api/form/human_input/<form_token>
Request body:
{
"inputs": {
"content": "User input content"
},
"action": "Approve"
}
"""
parser = reqparse.RequestParser()
parser.add_argument("inputs", type=dict, required=True, location="json")
parser.add_argument("action", type=str, required=True, location="json")
args = parser.parse_args()
ip_address = extract_remote_ip(request)
if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address):
raise WebFormRateLimitExceededError()
_FORM_SUBMIT_RATE_LIMITER.increment_rate_limit(ip_address)
service = HumanInputService(db.engine)
form = service.get_form_by_token(form_token)
if form is None:
raise NotFoundError("Form not found")
if (recipient_type := form.recipient_type) is None:
logger.warning("Recipient type is None for form, form_id=%", form.id)
raise AssertionError("Recipient type is None")
try:
service.submit_form_by_token(
recipient_type=recipient_type,
form_token=form_token,
selected_action_id=args["action"],
form_data=args["inputs"],
submission_end_user_id=None,
# submission_end_user_id=_end_user.id,
)
except FormNotFoundError:
raise NotFoundError("Form not found")
return {}, 200
def _get_app_site_from_form(form: Form) -> tuple[App, Site]:
"""Resolve App/Site for the form's app and validate tenant status."""
app_model = db.session.query(App).where(App.id == form.app_id).first()
if app_model is None or app_model.tenant_id != form.tenant_id:
raise NotFoundError("Form not found")
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
if site is None:
raise Forbidden()
if app_model.tenant and app_model.tenant.status == TenantStatus.ARCHIVE:
raise Forbidden()
return app_model, site

View File

@@ -1,4 +1,6 @@
from flask_restx import fields, marshal_with
from typing import cast
from flask_restx import fields, marshal, marshal_with
from werkzeug.exceptions import Forbidden
from configs import dify_config
@@ -7,7 +9,7 @@ from controllers.web.wraps import WebApiResource
from extensions.ext_database import db
from libs.helper import AppIconUrlField
from models.account import TenantStatus
from models.model import Site
from models.model import App, Site
from services.feature_service import FeatureService
@@ -108,3 +110,14 @@ class AppSiteInfo:
"remove_webapp_brand": remove_webapp_brand,
"replace_webapp_logo": replace_webapp_logo,
}
def serialize_site(site: Site) -> dict:
"""Serialize Site model using the same schema as AppSiteApi."""
return cast(dict, marshal(site, AppSiteApi.site_fields))
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict:
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields))

View File

@@ -0,0 +1,112 @@
"""
Web App Workflow Resume APIs.
"""
import json
from collections.abc import Generator
from flask import Response, request
from sqlalchemy.orm import sessionmaker
from controllers.web import api
from controllers.web.error import InvalidArgumentError, NotFoundError
from controllers.web.wraps import WebApiResource
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
from core.app.apps.message_generator import MessageGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.model import App, AppMode, EndUser
from repositories.factory import DifyAPIRepositoryFactory
from services.workflow_event_snapshot_service import build_workflow_event_stream
class WorkflowEventsApi(WebApiResource):
"""API for getting workflow execution events after resume."""
def get(self, app_model: App, end_user: EndUser, task_id: str):
"""
Get workflow execution events stream after resume.
GET /api/workflow/<task_id>/events
Returns Server-Sent Events stream.
"""
workflow_run_id = task_id
session_maker = sessionmaker(db.engine)
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
tenant_id=app_model.tenant_id,
run_id=workflow_run_id,
)
if workflow_run is None:
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
if workflow_run.app_id != app_model.id:
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
if workflow_run.created_by_role != CreatorUserRole.END_USER:
raise NotFoundError(f"WorkflowRun not created by end user, id={workflow_run_id}")
if workflow_run.created_by != end_user.id:
raise NotFoundError(f"WorkflowRun not created by the current end user, id={workflow_run_id}")
if workflow_run.finished_at is not None:
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
task_id=workflow_run.id,
workflow_run=workflow_run,
creator_user=end_user,
)
payload = response.model_dump(mode="json")
payload["event"] = response.event.value
def _generate_finished_events() -> Generator[str, None, None]:
yield f"data: {json.dumps(payload)}\n\n"
event_generator = _generate_finished_events
else:
app_mode = AppMode.value_of(app_model.mode)
msg_generator = MessageGenerator()
generator: BaseAppGenerator
if app_mode == AppMode.ADVANCED_CHAT:
generator = AdvancedChatAppGenerator()
elif app_mode == AppMode.WORKFLOW:
generator = WorkflowAppGenerator()
else:
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
def _generate_stream_events():
if include_state_snapshot:
return generator.convert_to_event_stream(
build_workflow_event_stream(
app_mode=app_mode,
workflow_run=workflow_run,
tenant_id=app_model.tenant_id,
app_id=app_model.id,
session_maker=session_maker,
)
)
return generator.convert_to_event_stream(
msg_generator.retrieve_events(app_mode, workflow_run.id),
)
event_generator = _generate_stream_events
return Response(
event_generator(),
mimetype="text/event-stream",
headers={
"Cache-Control": "no-cache",
"Connection": "keep-alive",
},
)
# Register the APIs
api.add_resource(WorkflowEventsApi, "/workflow/<string:task_id>/events")

View File

@@ -4,8 +4,8 @@ import contextvars
import logging
import threading
import uuid
from collections.abc import Generator, Mapping
from typing import TYPE_CHECKING, Any, Literal, Union, overload
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
from flask import Flask, current_app
from pydantic import ValidationError
@@ -29,21 +29,25 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.repositories.draft_variable_repository import (
DraftVariableSaverFactory,
)
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.base import Base
from models.enums import WorkflowRunTriggeredFrom
from services.conversation_service import ConversationService
from services.workflow_draft_variable_service import (
@@ -65,7 +69,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
user: Union[Account, EndUser],
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
streaming: Literal[False],
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any]: ...
@overload
@@ -74,9 +80,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
streaming: Literal[True],
pause_state_config: PauseStateLayerConfig | None = None,
) -> Generator[Mapping | str, None, None]: ...
@overload
@@ -85,9 +93,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
streaming: bool,
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...
def generate(
@@ -95,9 +105,11 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
args: Mapping,
args: Mapping[str, Any],
invoke_from: InvokeFrom,
workflow_run_id: str,
streaming: bool = True,
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
"""
Generate App response.
@@ -161,7 +173,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
# always enable retriever resource in debugger mode
app_config.additional_features.show_retrieve_source = True # type: ignore
workflow_run_id = str(uuid.uuid4())
# init application generate entity
application_generate_entity = AdvancedChatAppGenerateEntity(
task_id=str(uuid.uuid4()),
@@ -179,7 +190,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
invoke_from=invoke_from,
extras=extras,
trace_manager=trace_manager,
workflow_run_id=workflow_run_id,
workflow_run_id=str(workflow_run_id),
)
contexts.plugin_tool_providers.set({})
contexts.plugin_tool_providers_lock.set(threading.Lock())
@@ -216,6 +227,38 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation,
stream=streaming,
pause_state_config=pause_state_config,
)
def resume(
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
conversation: Conversation,
message: Message,
application_generate_entity: AdvancedChatAppGenerateEntity,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_runtime_state: GraphRuntimeState,
pause_state_config: PauseStateLayerConfig | None = None,
):
"""
Resume a paused advanced chat execution.
"""
return self._generate(
workflow=workflow,
user=user,
invoke_from=application_generate_entity.invoke_from,
application_generate_entity=application_generate_entity,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
conversation=conversation,
message=message,
stream=application_generate_entity.stream,
pause_state_config=pause_state_config,
graph_runtime_state=graph_runtime_state,
)
def single_iteration_generate(
@@ -396,8 +439,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
conversation: Conversation | None = None,
message: Message | None = None,
stream: bool = True,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
pause_state_config: PauseStateLayerConfig | None = None,
graph_runtime_state: GraphRuntimeState | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
"""
Generate App response.
@@ -411,12 +458,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
:param conversation: conversation
:param stream: is stream
"""
is_first_conversation = False
if not conversation:
is_first_conversation = True
is_first_conversation = conversation is None
# init generate records
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
if conversation is not None and message is not None:
pass
else:
conversation, message = self._init_generate_records(application_generate_entity, conversation)
if is_first_conversation:
# update conversation features
@@ -439,6 +486,16 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
message_id=message.id,
)
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
if pause_state_config is not None:
graph_layers.append(
PauseStatePersistenceLayer(
session_factory=pause_state_config.session_factory,
generate_entity=application_generate_entity,
state_owner_user_id=pause_state_config.state_owner_user_id,
)
)
# new thread with request context and contextvars
context = contextvars.copy_context()
@@ -454,14 +511,25 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
"variable_loader": variable_loader,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": tuple(graph_layers),
"graph_runtime_state": graph_runtime_state,
},
)
worker_thread.start()
# release database connection, because the following new thread operations may take a long time
db.session.refresh(workflow)
db.session.refresh(message)
with Session(bind=db.engine, expire_on_commit=False) as session:
workflow = _refresh_model(session, workflow)
message = _refresh_model(session, message)
# workflow_ = session.get(Workflow, workflow.id)
# assert workflow_ is not None
# workflow = workflow_
# message_ = session.get(Message, message.id)
# assert message_ is not None
# message = message_
# db.session.refresh(workflow)
# db.session.refresh(message)
# db.session.refresh(user)
db.session.close()
@@ -490,6 +558,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
variable_loader: VariableLoader,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
):
"""
Generate worker in a new thread.
@@ -547,6 +617,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
app=app,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
graph_engine_layers=graph_engine_layers,
graph_runtime_state=graph_runtime_state,
)
try:
@@ -614,3 +686,13 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
else:
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
raise e
_T = TypeVar("_T", bound=Base)
def _refresh_model(session, model: _T) -> _T:
with Session(bind=db.engine, expire_on_commit=False) as session:
detach_model = session.get(type(model), model.id)
assert detach_model is not None
return detach_model

View File

@@ -66,6 +66,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
):
super().__init__(
queue_manager=queue_manager,
@@ -82,6 +83,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
self._app = app
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
self._resume_graph_runtime_state = graph_runtime_state
@trace_span(WorkflowAppRunnerHandler)
def run(self):
@@ -110,7 +112,21 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
invoke_from = InvokeFrom.DEBUGGER
user_from = self._resolve_user_from(invoke_from)
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
resume_state = self._resume_graph_runtime_state
if resume_state is not None:
graph_runtime_state = resume_state
variable_pool = graph_runtime_state.variable_pool
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
invoke_from=invoke_from,
user_from=user_from,
)
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
# Handle single iteration or single loop run
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,

View File

@@ -24,6 +24,8 @@ from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueAnnotationReplyEvent,
QueueErrorEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@@ -42,6 +44,7 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
@@ -63,6 +66,8 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
@@ -71,7 +76,8 @@ from core.workflow.system_variable import SystemVariable
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, MessageStatus
from models.execution_extra_content import HumanInputContent
from models.workflow import Workflow
logger = logging.getLogger(__name__)
@@ -128,6 +134,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
self._task_state = WorkflowTaskState()
self._seed_task_state_from_message(message)
self._message_cycle_manager = MessageCycleManager(
application_generate_entity=application_generate_entity, task_state=self._task_state
)
@@ -135,6 +142,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._application_generate_entity = application_generate_entity
self._workflow_id = workflow.id
self._workflow_features_dict = workflow.features_dict
self._workflow_tenant_id = workflow.tenant_id
self._conversation_id = conversation.id
self._conversation_mode = conversation.mode
self._message_id = message.id
@@ -144,8 +152,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
self._workflow_run_id: str = ""
self._draft_var_saver_factory = draft_var_saver_factory
self._graph_runtime_state: GraphRuntimeState | None = None
self._message_saved_on_pause = False
self._seed_graph_runtime_state_from_queue_manager()
def _seed_task_state_from_message(self, message: Message) -> None:
if message.status == MessageStatus.PAUSED and message.answer:
self._task_state.answer = message.answer
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
"""
Process generate task pipeline.
@@ -308,6 +321,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
task_id=self._application_generate_entity.task_id,
workflow_run_id=run_id,
workflow_id=self._workflow_id,
reason=event.reason,
)
yield workflow_start_resp
@@ -525,6 +539,35 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
yield workflow_finish_resp
def _handle_workflow_paused_event(
self,
event: QueueWorkflowPausedEvent,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow paused events."""
validated_state = self._ensure_graph_runtime_initialized()
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
graph_runtime_state=validated_state,
)
for reason in event.reasons:
if isinstance(reason, HumanInputRequired):
self._persist_human_input_extra_content(form_id=reason.form_id, node_id=reason.node_id)
yield from responses
resolved_state: GraphRuntimeState | None = None
try:
resolved_state = self._ensure_graph_runtime_initialized()
except ValueError:
resolved_state = None
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
message = self._get_message(session=session)
if message is not None:
message.status = MessageStatus.PAUSED
self._message_saved_on_pause = True
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
def _handle_workflow_failed_event(
@@ -614,9 +657,10 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
)
# Save message
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
# Save message unless it has already been persisted on pause.
if not self._message_saved_on_pause:
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
yield self._message_end_to_stream_response()
@@ -642,6 +686,65 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
"""Handle message replace events."""
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text, reason=event.reason)
def _handle_human_input_form_filled_event(
self, event: QueueHumanInputFormFilledEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle human input form filled events."""
self._persist_human_input_extra_content(node_id=event.node_id)
yield self._workflow_response_converter.human_input_form_filled_to_stream_response(
event=event, task_id=self._application_generate_entity.task_id
)
def _handle_human_input_form_timeout_event(
self, event: QueueHumanInputFormTimeoutEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle human input form timeout events."""
yield self._workflow_response_converter.human_input_form_timeout_to_stream_response(
event=event, task_id=self._application_generate_entity.task_id
)
def _persist_human_input_extra_content(self, *, node_id: str | None = None, form_id: str | None = None) -> None:
if not self._workflow_run_id or not self._message_id:
return
if form_id is None:
if node_id is None:
return
form_id = self._load_human_input_form_id(node_id=node_id)
if form_id is None:
logger.warning(
"HumanInput form not found for workflow run %s node %s",
self._workflow_run_id,
node_id,
)
return
with self._database_session() as session:
exists_stmt = select(HumanInputContent).where(
HumanInputContent.workflow_run_id == self._workflow_run_id,
HumanInputContent.message_id == self._message_id,
HumanInputContent.form_id == form_id,
)
if session.scalar(exists_stmt) is not None:
return
content = HumanInputContent(
workflow_run_id=self._workflow_run_id,
message_id=self._message_id,
form_id=form_id,
)
session.add(content)
def _load_human_input_form_id(self, *, node_id: str) -> str | None:
form_repository = HumanInputFormRepositoryImpl(
session_factory=db.engine,
tenant_id=self._workflow_tenant_id,
)
form = form_repository.get_form(self._workflow_run_id, node_id)
if form is None:
return None
return form.id
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
"""Handle agent log events."""
yield self._workflow_response_converter.handle_agent_log(
@@ -659,6 +762,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
QueueWorkflowFailedEvent: self._handle_workflow_failed_event,
# Node events
QueueNodeRetryEvent: self._handle_node_retry_event,
@@ -680,6 +784,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
QueueMessageReplaceEvent: self._handle_message_replace_event,
QueueAdvancedChatMessageEndEvent: self._handle_advanced_chat_message_end_event,
QueueAgentLogEvent: self._handle_agent_log_event,
QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event,
QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event,
}
def _dispatch_event(
@@ -747,6 +853,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
case QueueWorkflowFailedEvent():
yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager)
break
case QueueWorkflowPausedEvent():
yield from self._handle_workflow_paused_event(event)
break
case QueueStopEvent():
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
@@ -772,6 +881,11 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
message = self._get_message(session=session)
if message is None:
return
if message.status == MessageStatus.PAUSED:
message.status = MessageStatus.NORMAL
# If there are assistant files, remove markdown image links from answer
answer_text = self._task_state.answer

View File

@@ -5,9 +5,14 @@ from dataclasses import dataclass
from datetime import datetime
from typing import Any, NewType, Union
from sqlalchemy import select
from sqlalchemy.orm import Session
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.queue_entities import (
QueueAgentLogEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@@ -19,9 +24,13 @@ from core.app.entities.queue_entities import (
QueueNodeRetryEvent,
QueueNodeStartedEvent,
QueueNodeSucceededEvent,
QueueWorkflowPausedEvent,
)
from core.app.entities.task_entities import (
AgentLogStreamResponse,
HumanInputFormFilledResponse,
HumanInputFormTimeoutResponse,
HumanInputRequiredResponse,
IterationNodeCompletedStreamResponse,
IterationNodeNextStreamResponse,
IterationNodeStartStreamResponse,
@@ -31,7 +40,9 @@ from core.app.entities.task_entities import (
NodeFinishStreamResponse,
NodeRetryStreamResponse,
NodeStartStreamResponse,
StreamResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
WorkflowStartStreamResponse,
)
from core.file import FILE_MODEL_IDENTITY, File
@@ -40,6 +51,8 @@ from core.tools.entities.tool_entities import ToolProviderType
from core.tools.tool_manager import ToolManager
from core.trigger.trigger_manager import TriggerManager
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import (
NodeType,
SystemVariableKey,
@@ -51,8 +64,11 @@ from core.workflow.runtime import GraphRuntimeState
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, EndUser
from models.human_input import HumanInputForm
from models.workflow import WorkflowRun
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
NodeExecutionId = NewType("NodeExecutionId", str)
@@ -191,6 +207,7 @@ class WorkflowResponseConverter:
task_id: str,
workflow_run_id: str,
workflow_id: str,
reason: WorkflowStartReason,
) -> WorkflowStartStreamResponse:
run_id = self._ensure_workflow_run_id(workflow_run_id)
started_at = naive_utc_now()
@@ -204,6 +221,7 @@ class WorkflowResponseConverter:
workflow_id=workflow_id,
inputs=self._workflow_inputs,
created_at=int(started_at.timestamp()),
reason=reason,
),
)
@@ -264,6 +282,160 @@ class WorkflowResponseConverter:
),
)
def workflow_pause_to_stream_response(
self,
*,
event: QueueWorkflowPausedEvent,
task_id: str,
graph_runtime_state: GraphRuntimeState,
) -> list[StreamResponse]:
run_id = self._ensure_workflow_run_id()
started_at = self._workflow_started_at
if started_at is None:
raise ValueError(
"workflow_pause_to_stream_response called before workflow_start_to_stream_response",
)
paused_at = naive_utc_now()
elapsed_time = (paused_at - started_at).total_seconds()
encoded_outputs = self._encode_outputs(event.outputs) or {}
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
encoded_outputs = {}
pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons]
human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)]
expiration_times_by_form_id: dict[str, datetime] = {}
if human_input_form_ids:
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where(
HumanInputForm.id.in_(human_input_form_ids)
)
with Session(bind=db.engine) as session:
for form_id, expiration_time in session.execute(stmt):
expiration_times_by_form_id[str(form_id)] = expiration_time
responses: list[StreamResponse] = []
for reason in event.reasons:
if isinstance(reason, HumanInputRequired):
expiration_time = expiration_times_by_form_id.get(reason.form_id)
if expiration_time is None:
raise ValueError(f"HumanInputForm not found for pause reason, form_id={reason.form_id}")
responses.append(
HumanInputRequiredResponse(
task_id=task_id,
workflow_run_id=run_id,
data=HumanInputRequiredResponse.Data(
form_id=reason.form_id,
node_id=reason.node_id,
node_title=reason.node_title,
form_content=reason.form_content,
inputs=reason.inputs,
actions=reason.actions,
display_in_ui=reason.display_in_ui,
form_token=reason.form_token,
resolved_default_values=reason.resolved_default_values,
expiration_time=int(expiration_time.timestamp()),
),
)
)
responses.append(
WorkflowPauseStreamResponse(
task_id=task_id,
workflow_run_id=run_id,
data=WorkflowPauseStreamResponse.Data(
workflow_run_id=run_id,
paused_nodes=list(event.paused_nodes),
outputs=encoded_outputs,
reasons=pause_reasons,
status=WorkflowExecutionStatus.PAUSED.value,
created_at=int(started_at.timestamp()),
elapsed_time=elapsed_time,
total_tokens=graph_runtime_state.total_tokens,
total_steps=graph_runtime_state.node_run_steps,
),
)
)
return responses
def human_input_form_filled_to_stream_response(
self, *, event: QueueHumanInputFormFilledEvent, task_id: str
) -> HumanInputFormFilledResponse:
run_id = self._ensure_workflow_run_id()
return HumanInputFormFilledResponse(
task_id=task_id,
workflow_run_id=run_id,
data=HumanInputFormFilledResponse.Data(
node_id=event.node_id,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
),
)
def human_input_form_timeout_to_stream_response(
self, *, event: QueueHumanInputFormTimeoutEvent, task_id: str
) -> HumanInputFormTimeoutResponse:
run_id = self._ensure_workflow_run_id()
return HumanInputFormTimeoutResponse(
task_id=task_id,
workflow_run_id=run_id,
data=HumanInputFormTimeoutResponse.Data(
node_id=event.node_id,
node_title=event.node_title,
expiration_time=int(event.expiration_time.timestamp()),
),
)
@classmethod
def workflow_run_result_to_finish_response(
cls,
*,
task_id: str,
workflow_run: WorkflowRun,
creator_user: Account | EndUser,
) -> WorkflowFinishStreamResponse:
run_id = workflow_run.id
elapsed_time = workflow_run.elapsed_time
encoded_outputs = workflow_run.outputs_dict
finished_at = workflow_run.finished_at
assert finished_at is not None
created_by: Mapping[str, object]
user = creator_user
if isinstance(user, Account):
created_by = {
"id": user.id,
"name": user.name,
"email": user.email,
}
else:
created_by = {
"id": user.id,
"user": user.session_id,
}
return WorkflowFinishStreamResponse(
task_id=task_id,
workflow_run_id=run_id,
data=WorkflowFinishStreamResponse.Data(
id=run_id,
workflow_id=workflow_run.workflow_id,
status=workflow_run.status.value,
outputs=encoded_outputs,
error=workflow_run.error,
elapsed_time=elapsed_time,
total_tokens=workflow_run.total_tokens,
total_steps=workflow_run.total_steps,
created_by=created_by,
created_at=int(workflow_run.created_at.timestamp()),
finished_at=int(finished_at.timestamp()),
files=cls.fetch_files_from_node_outputs(encoded_outputs),
exceptions_count=workflow_run.exceptions_count,
),
)
def workflow_node_start_to_stream_response(
self,
*,
@@ -592,7 +764,8 @@ class WorkflowResponseConverter:
),
)
def fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
@classmethod
def fetch_files_from_node_outputs(cls, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
"""
Fetch files from node outputs
:param outputs_dict: node outputs dict
@@ -601,7 +774,7 @@ class WorkflowResponseConverter:
if not outputs_dict:
return []
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
files = [cls._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
# Remove None
files = [file for file in files if file]
# Flatten list

View File

@@ -1,6 +1,6 @@
import json
import logging
from collections.abc import Generator
from collections.abc import Callable, Generator, Mapping
from typing import Union, cast
from sqlalchemy import select
@@ -10,12 +10,14 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod
from core.app.apps.base_app_generator import BaseAppGenerator
from core.app.apps.base_app_queue_manager import AppQueueManager
from core.app.apps.exc import GenerateTaskStoppedError
from core.app.apps.streaming_utils import stream_topic_events
from core.app.entities.app_invoke_entities import (
AdvancedChatAppGenerateEntity,
AgentChatAppGenerateEntity,
AppGenerateEntity,
ChatAppGenerateEntity,
CompletionAppGenerateEntity,
ConversationAppGenerateEntity,
InvokeFrom,
)
from core.app.entities.task_entities import (
@@ -27,6 +29,8 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from extensions.ext_database import db
from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
from libs.datetime_utils import naive_utc_now
from models import Account
from models.enums import CreatorUserRole
@@ -156,6 +160,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
query = application_generate_entity.query or "New conversation"
conversation_name = (query[:20] + "") if len(query) > 20 else query
created_new_conversation = conversation is None
try:
if not conversation:
conversation = Conversation(
@@ -232,6 +237,10 @@ class MessageBasedAppGenerator(BaseAppGenerator):
db.session.add_all(message_files)
db.session.commit()
if isinstance(application_generate_entity, ConversationAppGenerateEntity):
application_generate_entity.conversation_id = conversation.id
application_generate_entity.is_new_conversation = created_new_conversation
return conversation, message
except Exception:
db.session.rollback()
@@ -284,3 +293,29 @@ class MessageBasedAppGenerator(BaseAppGenerator):
raise MessageNotExistsError("Message not exists")
return message
@staticmethod
def _make_channel_key(app_mode: AppMode, workflow_run_id: str):
return f"channel:{app_mode}:{workflow_run_id}"
@classmethod
def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic:
key = cls._make_channel_key(app_mode, workflow_run_id)
channel = get_pubsub_broadcast_channel()
topic = channel.topic(key)
return topic
@classmethod
def retrieve_events(
cls,
app_mode: AppMode,
workflow_run_id: str,
idle_timeout=300,
on_subscribe: Callable[[], None] | None = None,
) -> Generator[Mapping | str, None, None]:
topic = cls.get_response_topic(app_mode, workflow_run_id)
return stream_topic_events(
topic=topic,
idle_timeout=idle_timeout,
on_subscribe=on_subscribe,
)

View File

@@ -0,0 +1,36 @@
from collections.abc import Callable, Generator, Mapping
from core.app.apps.streaming_utils import stream_topic_events
from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
from models.model import AppMode
class MessageGenerator:
@staticmethod
def _make_channel_key(app_mode: AppMode, workflow_run_id: str):
return f"channel:{app_mode}:{str(workflow_run_id)}"
@classmethod
def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic:
key = cls._make_channel_key(app_mode, workflow_run_id)
channel = get_pubsub_broadcast_channel()
topic = channel.topic(key)
return topic
@classmethod
def retrieve_events(
cls,
app_mode: AppMode,
workflow_run_id: str,
idle_timeout=300,
ping_interval: float = 10.0,
on_subscribe: Callable[[], None] | None = None,
) -> Generator[Mapping | str, None, None]:
topic = cls.get_response_topic(app_mode, workflow_run_id)
return stream_topic_events(
topic=topic,
idle_timeout=idle_timeout,
ping_interval=ping_interval,
on_subscribe=on_subscribe,
)

View File

@@ -0,0 +1,70 @@
from __future__ import annotations
import json
import time
from collections.abc import Callable, Generator, Iterable, Mapping
from typing import Any
from core.app.entities.task_entities import StreamEvent
from libs.broadcast_channel.channel import Topic
from libs.broadcast_channel.exc import SubscriptionClosedError
def stream_topic_events(
*,
topic: Topic,
idle_timeout: float,
ping_interval: float | None = None,
on_subscribe: Callable[[], None] | None = None,
terminal_events: Iterable[str | StreamEvent] | None = None,
) -> Generator[Mapping[str, Any] | str, None, None]:
# send a PING event immediately to prevent the connection staying in pending state for a long time.
#
# This simplify the debugging process as the DevTools in Chrome does not
# provide complete curl command for pending connections.
yield StreamEvent.PING.value
terminal_values = _normalize_terminal_events(terminal_events)
last_msg_time = time.time()
last_ping_time = last_msg_time
with topic.subscribe() as sub:
# on_subscribe fires only after the Redis subscription is active.
# This is used to gate task start and reduce pub/sub race for the first event.
if on_subscribe is not None:
on_subscribe()
while True:
try:
msg = sub.receive(timeout=0.1)
except SubscriptionClosedError:
return
if msg is None:
current_time = time.time()
if current_time - last_msg_time > idle_timeout:
return
if ping_interval is not None and current_time - last_ping_time >= ping_interval:
yield StreamEvent.PING.value
last_ping_time = current_time
continue
last_msg_time = time.time()
last_ping_time = last_msg_time
event = json.loads(msg)
yield event
if not isinstance(event, dict):
continue
event_type = event.get("event")
if event_type in terminal_values:
return
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
if not terminal_events:
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
values: set[str] = set()
for item in terminal_events:
if isinstance(item, StreamEvent):
values.add(item.value)
else:
values.add(str(item))
return values

View File

@@ -25,6 +25,7 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
from core.db.session_factory import session_factory
from core.helper.trace_id_helper import extract_external_trace_id_from_args
from core.model_runtime.errors.invoke import InvokeAuthorizationError
@@ -34,12 +35,15 @@ from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
from core.workflow.runtime import GraphRuntimeState
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
from extensions.ext_database import db
from factories import file_factory
from libs.flask_utils import preserve_flask_contexts
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
from models.account import Account
from models.enums import WorkflowRunTriggeredFrom
from models.model import App, EndUser
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
if TYPE_CHECKING:
@@ -66,9 +70,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[True],
call_depth: int,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Generator[Mapping[str, Any] | str, None, None]: ...
@overload
@@ -82,9 +88,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: Literal[False],
call_depth: int,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Mapping[str, Any]: ...
@overload
@@ -98,9 +106,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool,
call_depth: int,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
def generate(
@@ -113,9 +123,11 @@ class WorkflowAppGenerator(BaseAppGenerator):
invoke_from: InvokeFrom,
streaming: bool = True,
call_depth: int = 0,
workflow_run_id: str | uuid.UUID | None = None,
triggered_from: WorkflowRunTriggeredFrom | None = None,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
files: Sequence[Mapping[str, Any]] = args.get("files") or []
@@ -150,7 +162,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
extras = {
**extract_external_trace_id_from_args(args),
}
workflow_run_id = str(uuid.uuid4())
workflow_run_id = str(workflow_run_id or uuid.uuid4())
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
# trigger shouldn't prepare user inputs
if self._should_prepare_user_inputs(args):
@@ -216,13 +228,40 @@ class WorkflowAppGenerator(BaseAppGenerator):
streaming=streaming,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
pause_state_config=pause_state_config,
)
def resume(self, *, workflow_run_id: str) -> None:
def resume(
self,
*,
app_model: App,
workflow: Workflow,
user: Union[Account, EndUser],
application_generate_entity: WorkflowAppGenerateEntity,
graph_runtime_state: GraphRuntimeState,
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
pause_state_config: PauseStateLayerConfig | None = None,
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
@TBD
Resume a paused workflow execution using the persisted runtime state.
"""
pass
return self._generate(
app_model=app_model,
workflow=workflow,
user=user,
application_generate_entity=application_generate_entity,
invoke_from=application_generate_entity.invoke_from,
workflow_execution_repository=workflow_execution_repository,
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=application_generate_entity.stream,
variable_loader=variable_loader,
graph_engine_layers=graph_engine_layers,
graph_runtime_state=graph_runtime_state,
pause_state_config=pause_state_config,
)
def _generate(
self,
@@ -238,6 +277,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
pause_state_config: PauseStateLayerConfig | None = None,
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
"""
Generate App response.
@@ -251,6 +292,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
:param workflow_node_execution_repository: repository for workflow node execution
:param streaming: is stream
"""
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
# init queue manager
queue_manager = WorkflowAppQueueManager(
task_id=application_generate_entity.task_id,
@@ -259,6 +302,15 @@ class WorkflowAppGenerator(BaseAppGenerator):
app_mode=app_model.mode,
)
if pause_state_config is not None:
graph_layers.append(
PauseStatePersistenceLayer(
session_factory=pause_state_config.session_factory,
generate_entity=application_generate_entity,
state_owner_user_id=pause_state_config.state_owner_user_id,
)
)
# new thread with request context and contextvars
context = contextvars.copy_context()
@@ -276,7 +328,8 @@ class WorkflowAppGenerator(BaseAppGenerator):
"root_node_id": root_node_id,
"workflow_execution_repository": workflow_execution_repository,
"workflow_node_execution_repository": workflow_node_execution_repository,
"graph_engine_layers": graph_engine_layers,
"graph_engine_layers": tuple(graph_layers),
"graph_runtime_state": graph_runtime_state,
},
)
@@ -378,6 +431,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
pause_state_config=None,
)
def single_loop_generate(
@@ -459,6 +513,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
streaming=streaming,
variable_loader=var_loader,
pause_state_config=None,
)
def _generate_worker(
@@ -472,6 +527,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
root_node_id: str | None = None,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
) -> None:
"""
Generate worker in a new thread.
@@ -517,6 +573,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
workflow_node_execution_repository=workflow_node_execution_repository,
root_node_id=root_node_id,
graph_engine_layers=graph_engine_layers,
graph_runtime_state=graph_runtime_state,
)
try:

View File

@@ -42,6 +42,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow_execution_repository: WorkflowExecutionRepository,
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
graph_engine_layers: Sequence[GraphEngineLayer] = (),
graph_runtime_state: GraphRuntimeState | None = None,
):
super().__init__(
queue_manager=queue_manager,
@@ -55,6 +56,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
self._root_node_id = root_node_id
self._workflow_execution_repository = workflow_execution_repository
self._workflow_node_execution_repository = workflow_node_execution_repository
self._resume_graph_runtime_state = graph_runtime_state
@trace_span(WorkflowAppRunnerHandler)
def run(self):
@@ -63,23 +65,28 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
"""
app_config = self.application_generate_entity.app_config
app_config = cast(WorkflowAppConfig, app_config)
system_inputs = SystemVariable(
files=self.application_generate_entity.files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
timestamp=int(naive_utc_now().timestamp()),
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
invoke_from = self.application_generate_entity.invoke_from
# if only single iteration or single loop run is requested
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
invoke_from = InvokeFrom.DEBUGGER
user_from = self._resolve_user_from(invoke_from)
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
resume_state = self._resume_graph_runtime_state
if resume_state is not None:
graph_runtime_state = resume_state
variable_pool = graph_runtime_state.variable_pool
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,
workflow_id=self._workflow.id,
tenant_id=self._workflow.tenant_id,
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
root_node_id=self._root_node_id,
)
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
workflow=self._workflow,
single_iteration_run=self.application_generate_entity.single_iteration_run,
@@ -89,7 +96,14 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
inputs = self.application_generate_entity.inputs
# Create a variable pool.
system_inputs = SystemVariable(
files=self.application_generate_entity.files,
user_id=self._sys_user_id,
app_id=app_config.app_id,
timestamp=int(naive_utc_now().timestamp()),
workflow_id=app_config.workflow_id,
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
)
variable_pool = VariablePool(
system_variables=system_inputs,
user_inputs=inputs,
@@ -98,8 +112,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
)
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
# init graph
graph = self._init_graph(
graph_config=self._workflow.graph_dict,
graph_runtime_state=graph_runtime_state,

View File

@@ -0,0 +1,7 @@
from libs.exception import BaseHTTPException
class WorkflowPausedInBlockingModeError(BaseHTTPException):
error_code = "workflow_paused_in_blocking_mode"
description = "Workflow execution paused for human input; blocking response mode is not supported."
code = 400

View File

@@ -16,6 +16,8 @@ from core.app.entities.queue_entities import (
MessageQueueMessage,
QueueAgentLogEvent,
QueueErrorEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@@ -32,6 +34,7 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
WorkflowQueueMessage,
@@ -46,11 +49,13 @@ from core.app.entities.task_entities import (
WorkflowAppBlockingResponse,
WorkflowAppStreamResponse,
WorkflowFinishStreamResponse,
WorkflowPauseStreamResponse,
WorkflowStartStreamResponse,
)
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
from core.workflow.runtime import GraphRuntimeState
@@ -132,6 +137,25 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
for stream_response in generator:
if isinstance(stream_response, ErrorStreamResponse):
raise stream_response.err
elif isinstance(stream_response, WorkflowPauseStreamResponse):
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
workflow_run_id=stream_response.data.workflow_run_id,
data=WorkflowAppBlockingResponse.Data(
id=stream_response.data.workflow_run_id,
workflow_id=self._workflow.id,
status=stream_response.data.status,
outputs=stream_response.data.outputs or {},
error=None,
elapsed_time=stream_response.data.elapsed_time,
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=stream_response.data.created_at,
finished_at=None,
),
)
return response
elif isinstance(stream_response, WorkflowFinishStreamResponse):
response = WorkflowAppBlockingResponse(
task_id=self._application_generate_entity.task_id,
@@ -146,7 +170,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
total_tokens=stream_response.data.total_tokens,
total_steps=stream_response.data.total_steps,
created_at=int(stream_response.data.created_at),
finished_at=int(stream_response.data.finished_at),
finished_at=int(stream_response.data.finished_at) if stream_response.data.finished_at else None,
),
)
@@ -259,13 +283,15 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
run_id = self._extract_workflow_run_id(runtime_state)
self._workflow_execution_id = run_id
with self._database_session() as session:
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
if event.reason == WorkflowStartReason.INITIAL:
with self._database_session() as session:
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
task_id=self._application_generate_entity.task_id,
workflow_run_id=run_id,
workflow_id=self._workflow.id,
reason=event.reason,
)
yield start_resp
@@ -440,6 +466,21 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
yield workflow_finish_resp
def _handle_workflow_paused_event(
self,
event: QueueWorkflowPausedEvent,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle workflow paused events."""
self._ensure_workflow_initialized()
validated_state = self._ensure_graph_runtime_initialized()
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
event=event,
task_id=self._application_generate_entity.task_id,
graph_runtime_state=validated_state,
)
yield from responses
def _handle_workflow_failed_and_stop_events(
self,
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
@@ -495,6 +536,22 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
task_id=self._application_generate_entity.task_id, event=event
)
def _handle_human_input_form_filled_event(
self, event: QueueHumanInputFormFilledEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle human input form filled events."""
yield self._workflow_response_converter.human_input_form_filled_to_stream_response(
event=event, task_id=self._application_generate_entity.task_id
)
def _handle_human_input_form_timeout_event(
self, event: QueueHumanInputFormTimeoutEvent, **kwargs
) -> Generator[StreamResponse, None, None]:
"""Handle human input form timeout events."""
yield self._workflow_response_converter.human_input_form_timeout_to_stream_response(
event=event, task_id=self._application_generate_entity.task_id
)
def _get_event_handlers(self) -> dict[type, Callable]:
"""Get mapping of event types to their handlers using fluent pattern."""
return {
@@ -506,6 +563,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
# Node events
QueueNodeRetryEvent: self._handle_node_retry_event,
QueueNodeStartedEvent: self._handle_node_started_event,
@@ -520,6 +578,8 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
QueueLoopCompletedEvent: self._handle_loop_completed_event,
# Agent events
QueueAgentLogEvent: self._handle_agent_log_event,
QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event,
QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event,
}
def _dispatch_event(
@@ -602,6 +662,9 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
case QueueWorkflowFailedEvent():
yield from self._handle_workflow_failed_and_stop_events(event)
break
case QueueWorkflowPausedEvent():
yield from self._handle_workflow_paused_event(event)
break
case QueueStopEvent():
yield from self._handle_workflow_failed_and_stop_events(event)

View File

@@ -1,3 +1,4 @@
import logging
import time
from collections.abc import Mapping, Sequence
from typing import Any, cast
@@ -7,6 +8,8 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.entities.queue_entities import (
AppQueueEvent,
QueueAgentLogEvent,
QueueHumanInputFormFilledEvent,
QueueHumanInputFormTimeoutEvent,
QueueIterationCompletedEvent,
QueueIterationNextEvent,
QueueIterationStartEvent,
@@ -22,22 +25,27 @@ from core.app.entities.queue_entities import (
QueueTextChunkEvent,
QueueWorkflowFailedEvent,
QueueWorkflowPartialSuccessEvent,
QueueWorkflowPausedEvent,
QueueWorkflowStartedEvent,
QueueWorkflowSucceededEvent,
)
from core.app.workflow.node_factory import DifyNodeFactory
from core.workflow.entities import GraphInitParams
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.graph import Graph
from core.workflow.graph_engine.layers.base import GraphEngineLayer
from core.workflow.graph_events import (
GraphEngineEvent,
GraphRunFailedEvent,
GraphRunPartialSucceededEvent,
GraphRunPausedEvent,
GraphRunStartedEvent,
GraphRunSucceededEvent,
NodeRunAgentLogEvent,
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunHumanInputFormFilledEvent,
NodeRunHumanInputFormTimeoutEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
@@ -61,6 +69,9 @@ from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader,
from core.workflow.workflow_entry import WorkflowEntry
from models.enums import UserFrom
from models.workflow import Workflow
from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
logger = logging.getLogger(__name__)
class WorkflowBasedAppRunner:
@@ -327,7 +338,7 @@ class WorkflowBasedAppRunner:
:param event: event
"""
if isinstance(event, GraphRunStartedEvent):
self._publish_event(QueueWorkflowStartedEvent())
self._publish_event(QueueWorkflowStartedEvent(reason=event.reason))
elif isinstance(event, GraphRunSucceededEvent):
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
elif isinstance(event, GraphRunPartialSucceededEvent):
@@ -338,6 +349,38 @@ class WorkflowBasedAppRunner:
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
elif isinstance(event, GraphRunAbortedEvent):
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
elif isinstance(event, GraphRunPausedEvent):
runtime_state = workflow_entry.graph_engine.graph_runtime_state
paused_nodes = runtime_state.get_paused_nodes()
self._enqueue_human_input_notifications(event.reasons)
self._publish_event(
QueueWorkflowPausedEvent(
reasons=event.reasons,
outputs=event.outputs,
paused_nodes=paused_nodes,
)
)
elif isinstance(event, NodeRunHumanInputFormFilledEvent):
self._publish_event(
QueueHumanInputFormFilledEvent(
node_execution_id=event.id,
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
)
)
elif isinstance(event, NodeRunHumanInputFormTimeoutEvent):
self._publish_event(
QueueHumanInputFormTimeoutEvent(
node_id=event.node_id,
node_type=event.node_type,
node_title=event.node_title,
expiration_time=event.expiration_time,
)
)
elif isinstance(event, NodeRunRetryEvent):
node_run_result = event.node_run_result
inputs = node_run_result.inputs
@@ -544,5 +587,19 @@ class WorkflowBasedAppRunner:
)
)
def _enqueue_human_input_notifications(self, reasons: Sequence[object]) -> None:
for reason in reasons:
if not isinstance(reason, HumanInputRequired):
continue
if not reason.form_id:
continue
try:
dispatch_human_input_email_task.apply_async(
kwargs={"form_id": reason.form_id, "node_title": reason.node_title},
queue="mail",
)
except Exception: # pragma: no cover - defensive logging
logger.exception("Failed to enqueue human input email task for form %s", reason.form_id)
def _publish_event(self, event: AppQueueEvent):
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)

View File

@@ -132,7 +132,7 @@ class AppGenerateEntity(BaseModel):
extras: dict[str, Any] = Field(default_factory=dict)
# tracing instance
trace_manager: Optional["TraceQueueManager"] = None
trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False)
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
@@ -156,6 +156,7 @@ class ConversationAppGenerateEntity(AppGenerateEntity):
"""
conversation_id: str | None = None
is_new_conversation: bool = False
parent_message_id: str | None = Field(
default=None,
description=(

View File

@@ -8,6 +8,8 @@ from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.nodes import NodeType
@@ -46,6 +48,9 @@ class QueueEvent(StrEnum):
PING = "ping"
STOP = "stop"
RETRY = "retry"
PAUSE = "pause"
HUMAN_INPUT_FORM_FILLED = "human_input_form_filled"
HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout"
class AppQueueEvent(BaseModel):
@@ -261,6 +266,8 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
"""QueueWorkflowStartedEvent entity."""
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
# Always present; mirrors GraphRunStartedEvent.reason for downstream consumers.
reason: WorkflowStartReason = WorkflowStartReason.INITIAL
class QueueWorkflowSucceededEvent(AppQueueEvent):
@@ -484,6 +491,35 @@ class QueueStopEvent(AppQueueEvent):
return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.")
class QueueHumanInputFormFilledEvent(AppQueueEvent):
"""
QueueHumanInputFormFilledEvent entity
"""
event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_FILLED
node_execution_id: str
node_id: str
node_type: NodeType
node_title: str
rendered_content: str
action_id: str
action_text: str
class QueueHumanInputFormTimeoutEvent(AppQueueEvent):
"""
QueueHumanInputFormTimeoutEvent entity
"""
event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_TIMEOUT
node_id: str
node_type: NodeType
node_title: str
expiration_time: datetime
class QueueMessage(BaseModel):
"""
QueueMessage abstract entity
@@ -509,3 +545,14 @@ class WorkflowQueueMessage(QueueMessage):
"""
pass
class QueueWorkflowPausedEvent(AppQueueEvent):
"""
QueueWorkflowPausedEvent entity
"""
event: QueueEvent = QueueEvent.PAUSE
reasons: Sequence[PauseReason] = Field(default_factory=list)
outputs: Mapping[str, object] = Field(default_factory=dict)
paused_nodes: Sequence[str] = Field(default_factory=list)

View File

@@ -7,7 +7,9 @@ from pydantic import BaseModel, ConfigDict, Field
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
from core.workflow.entities import AgentNodeStrategyInit
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
from core.workflow.nodes.human_input.entities import FormInput, UserAction
class AnnotationReplyAccount(BaseModel):
@@ -69,6 +71,7 @@ class StreamEvent(StrEnum):
AGENT_THOUGHT = "agent_thought"
AGENT_MESSAGE = "agent_message"
WORKFLOW_STARTED = "workflow_started"
WORKFLOW_PAUSED = "workflow_paused"
WORKFLOW_FINISHED = "workflow_finished"
NODE_STARTED = "node_started"
NODE_FINISHED = "node_finished"
@@ -82,6 +85,9 @@ class StreamEvent(StrEnum):
TEXT_CHUNK = "text_chunk"
TEXT_REPLACE = "text_replace"
AGENT_LOG = "agent_log"
HUMAN_INPUT_REQUIRED = "human_input_required"
HUMAN_INPUT_FORM_FILLED = "human_input_form_filled"
HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout"
class StreamResponse(BaseModel):
@@ -205,6 +211,8 @@ class WorkflowStartStreamResponse(StreamResponse):
workflow_id: str
inputs: Mapping[str, Any]
created_at: int
# Always present; mirrors QueueWorkflowStartedEvent.reason for SSE clients.
reason: WorkflowStartReason = WorkflowStartReason.INITIAL
event: StreamEvent = StreamEvent.WORKFLOW_STARTED
workflow_run_id: str
@@ -231,7 +239,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
total_steps: int
created_by: Mapping[str, object] = Field(default_factory=dict)
created_at: int
finished_at: int
finished_at: int | None
exceptions_count: int | None = 0
files: Sequence[Mapping[str, Any]] | None = []
@@ -240,6 +248,85 @@ class WorkflowFinishStreamResponse(StreamResponse):
data: Data
class WorkflowPauseStreamResponse(StreamResponse):
"""
WorkflowPauseStreamResponse entity
"""
class Data(BaseModel):
"""
Data entity
"""
workflow_run_id: str
paused_nodes: Sequence[str] = Field(default_factory=list)
outputs: Mapping[str, Any] = Field(default_factory=dict)
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
status: str
created_at: int
elapsed_time: float
total_tokens: int
total_steps: int
event: StreamEvent = StreamEvent.WORKFLOW_PAUSED
workflow_run_id: str
data: Data
class HumanInputRequiredResponse(StreamResponse):
class Data(BaseModel):
"""
Data entity
"""
form_id: str
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
expiration_time: int = Field(..., description="Unix timestamp in seconds")
event: StreamEvent = StreamEvent.HUMAN_INPUT_REQUIRED
workflow_run_id: str
data: Data
class HumanInputFormFilledResponse(StreamResponse):
class Data(BaseModel):
"""
Data entity
"""
node_id: str
node_title: str
rendered_content: str
action_id: str
action_text: str
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_FILLED
workflow_run_id: str
data: Data
class HumanInputFormTimeoutResponse(StreamResponse):
class Data(BaseModel):
"""
Data entity
"""
node_id: str
node_title: str
expiration_time: int
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_TIMEOUT
workflow_run_id: str
data: Data
class NodeStartStreamResponse(StreamResponse):
"""
NodeStartStreamResponse entity
@@ -726,7 +813,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
total_tokens: int
total_steps: int
created_at: int
finished_at: int
finished_at: int | None
workflow_run_id: str
data: Data

View File

@@ -1,3 +1,4 @@
import contextlib
import logging
import time
import uuid
@@ -103,6 +104,14 @@ class RateLimit:
)
@contextlib.contextmanager
def rate_limit_context(rate_limit: RateLimit, request_id: str | None):
request_id = rate_limit.enter(request_id)
yield
if request_id is not None:
rate_limit.exit(request_id)
class RateLimitGenerator:
def __init__(self, rate_limit: RateLimit, generator: Generator[str, None, None], request_id: str):
self.rate_limit = rate_limit

View File

@@ -1,3 +1,4 @@
from dataclasses import dataclass
from typing import Annotated, Literal, Self, TypeAlias
from pydantic import BaseModel, Field
@@ -52,6 +53,14 @@ class WorkflowResumptionContext(BaseModel):
return self.generate_entity.entity
@dataclass(frozen=True)
class PauseStateLayerConfig:
"""Configuration container for instantiating pause persistence layers."""
session_factory: Engine | sessionmaker[Session]
state_owner_user_id: str
class PauseStatePersistenceLayer(GraphEngineLayer):
def __init__(
self,

View File

@@ -82,10 +82,11 @@ class MessageCycleManager:
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
return None
is_first_message = self._application_generate_entity.conversation_id is None
is_first_message = self._application_generate_entity.is_new_conversation
extras = self._application_generate_entity.extras
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
thread: Thread | None = None
if auto_generate_conversation_name and is_first_message:
# start generate thread
# time.sleep not block other logic
@@ -101,9 +102,10 @@ class MessageCycleManager:
thread.daemon = True
thread.start()
return thread
if is_first_message:
self._application_generate_entity.is_new_conversation = False
return None
return thread
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
with flask_app.app_context():

View File

@@ -47,6 +47,7 @@ class DifyNodeFactory(NodeFactory):
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
code_limits: CodeNodeLimits | None = None,
template_renderer: Jinja2TemplateRenderer | None = None,
template_transform_max_output_length: int | None = None,
http_request_http_client: HttpClientProtocol | None = None,
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
http_request_file_manager: FileManagerProtocol | None = None,
@@ -68,6 +69,9 @@ class DifyNodeFactory(NodeFactory):
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
self._template_transform_max_output_length = (
template_transform_max_output_length or dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
)
self._http_request_http_client = http_request_http_client or ssrf_proxy
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
self._http_request_file_manager = http_request_file_manager or file_manager
@@ -122,6 +126,7 @@ class DifyNodeFactory(NodeFactory):
graph_init_params=self.graph_init_params,
graph_runtime_state=self.graph_runtime_state,
template_renderer=self._template_renderer,
max_output_length=self._template_transform_max_output_length,
)
if node_type == NodeType.HTTP_REQUEST:

View File

@@ -0,0 +1,54 @@
from __future__ import annotations
from collections.abc import Mapping, Sequence
from typing import Any, TypeAlias
from pydantic import BaseModel, ConfigDict, Field
from core.workflow.nodes.human_input.entities import FormInput, UserAction
from models.execution_extra_content import ExecutionContentType
class HumanInputFormDefinition(BaseModel):
model_config = ConfigDict(frozen=True)
form_id: str
node_id: str
node_title: str
form_content: str
inputs: Sequence[FormInput] = Field(default_factory=list)
actions: Sequence[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
form_token: str | None = None
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
expiration_time: int
class HumanInputFormSubmissionData(BaseModel):
model_config = ConfigDict(frozen=True)
node_id: str
node_title: str
rendered_content: str
action_id: str
action_text: str
class HumanInputContent(BaseModel):
model_config = ConfigDict(frozen=True)
workflow_run_id: str
submitted: bool
form_definition: HumanInputFormDefinition | None = None
form_submission_data: HumanInputFormSubmissionData | None = None
type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT)
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent
__all__ = [
"ExecutionExtraContentDomainModel",
"HumanInputContent",
"HumanInputFormDefinition",
"HumanInputFormSubmissionData",
]

View File

@@ -28,8 +28,8 @@ from core.model_runtime.entities.provider_entities import (
)
from core.model_runtime.model_providers.__base.ai_model import AIModel
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.engine import db
from models.provider import (
LoadBalancingModelConfig,
Provider,

View File

@@ -6,7 +6,8 @@ from yarl import URL
from configs import dify_config
from core.helper.download import download_with_size_limit
from core.plugin.entities.marketplace import MarketplacePluginDeclaration
from core.plugin.entities.marketplace import MarketplacePluginDeclaration, MarketplacePluginSnapshot
from extensions.ext_redis import redis_client
marketplace_api_url = URL(str(dify_config.MARKETPLACE_API_URL))
logger = logging.getLogger(__name__)
@@ -43,28 +44,37 @@ def batch_fetch_plugin_by_ids(plugin_ids: list[str]) -> list[dict]:
return data.get("data", {}).get("plugins", [])
def batch_fetch_plugin_manifests_ignore_deserialization_error(
plugin_ids: list[str],
) -> Sequence[MarketplacePluginDeclaration]:
if len(plugin_ids) == 0:
return []
url = str(marketplace_api_url / "api/v1/plugins/batch")
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
response.raise_for_status()
result: list[MarketplacePluginDeclaration] = []
for plugin in response.json()["data"]["plugins"]:
try:
result.append(MarketplacePluginDeclaration.model_validate(plugin))
except Exception:
logger.exception(
"Failed to deserialize marketplace plugin manifest for %s", plugin.get("plugin_id", "unknown")
)
return result
def record_install_plugin_event(plugin_unique_identifier: str):
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
response = httpx.post(url, json={"unique_identifier": plugin_unique_identifier})
response.raise_for_status()
def fetch_global_plugin_manifest(cache_key_prefix: str, cache_ttl: int) -> None:
"""
Fetch all plugin manifests from marketplace and cache them in Redis.
This should be called once per check cycle to populate the instance-level cache.
Args:
cache_key_prefix: Redis key prefix for caching plugin manifests
cache_ttl: Cache TTL in seconds
Raises:
httpx.HTTPError: If the HTTP request fails
Exception: If any other error occurs during fetching or caching
"""
url = str(marketplace_api_url / "api/v1/dist/plugins/manifest.json")
response = httpx.get(url, headers={"X-Dify-Version": dify_config.project.version}, timeout=30)
response.raise_for_status()
raw_json = response.json()
plugins_data = raw_json.get("plugins", [])
# Parse and cache all plugin snapshots
for plugin_data in plugins_data:
plugin_snapshot = MarketplacePluginSnapshot.model_validate(plugin_data)
redis_client.setex(
name=f"{cache_key_prefix}{plugin_snapshot.plugin_id}",
time=cache_ttl,
value=plugin_snapshot.model_dump_json(),
)

View File

@@ -15,10 +15,7 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import (
OPS_FILE_PATH,
TracingProviderEnum,
)
from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
GenerateNameTraceInfo,
@@ -31,8 +28,8 @@ from core.ops.entities.trace_entity import (
WorkflowTraceInfo,
)
from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.engine import db
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.workflow import WorkflowAppLog
from tasks.ops_trace_task import process_trace_tasks
@@ -469,6 +466,8 @@ class TraceTask:
@classmethod
def _get_workflow_run_repo(cls):
from repositories.factory import DifyAPIRepositoryFactory
if cls._workflow_run_repo is None:
with cls._repo_lock:
if cls._workflow_run_repo is None:

View File

@@ -5,7 +5,7 @@ from urllib.parse import urlparse
from sqlalchemy import select
from extensions.ext_database import db
from models.engine import db
from models.model import Message

View File

@@ -1,3 +1,4 @@
import uuid
from collections.abc import Generator, Mapping
from typing import Union
@@ -11,6 +12,7 @@ from core.app.apps.chat.app_generator import ChatAppGenerator
from core.app.apps.completion.app_generator import CompletionAppGenerator
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
from core.app.entities.app_invoke_entities import InvokeFrom
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
from extensions.ext_database import db
from models import Account
@@ -101,6 +103,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
if not workflow:
raise ValueError("unexpected app type")
pause_config = PauseStateLayerConfig(
session_factory=db.engine,
state_owner_user_id=workflow.created_by,
)
return AdvancedChatAppGenerator().generate(
app_model=app,
workflow=workflow,
@@ -112,7 +119,9 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
"conversation_id": conversation_id,
},
invoke_from=InvokeFrom.SERVICE_API,
workflow_run_id=str(uuid.uuid4()),
streaming=stream,
pause_state_config=pause_config,
)
elif app.mode == AppMode.AGENT_CHAT:
return AgentChatAppGenerator().generate(
@@ -159,6 +168,11 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
if not workflow:
raise ValueError("unexpected app type")
pause_config = PauseStateLayerConfig(
session_factory=db.engine,
state_owner_user_id=workflow.created_by,
)
return WorkflowAppGenerator().generate(
app_model=app,
workflow=workflow,
@@ -167,6 +181,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
invoke_from=InvokeFrom.SERVICE_API,
streaming=stream,
call_depth=1,
pause_state_config=pause_config,
)
@classmethod

View File

@@ -1,4 +1,4 @@
from pydantic import BaseModel, Field, model_validator
from pydantic import BaseModel, Field, computed_field, model_validator
from core.model_runtime.entities.provider_entities import ProviderEntity
from core.plugin.entities.endpoint import EndpointProviderDeclaration
@@ -48,3 +48,15 @@ class MarketplacePluginDeclaration(BaseModel):
if "tool" in data and not data["tool"]:
del data["tool"]
return data
class MarketplacePluginSnapshot(BaseModel):
org: str
name: str
latest_version: str
latest_package_identifier: str
latest_package_url: str
@computed_field
def plugin_id(self) -> str:
return f"{self.org}/{self.name}"

View File

@@ -1,19 +1,18 @@
"""
Repository implementations for data access.
"""Repository implementations for data access."""
This package contains concrete implementations of the repository interfaces
defined in the core.workflow.repository package.
"""
from __future__ import annotations
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
from .factory import DifyCoreRepositoryFactory, RepositoryImportError
from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
__all__ = [
"CeleryWorkflowExecutionRepository",
"CeleryWorkflowNodeExecutionRepository",
"DifyCoreRepositoryFactory",
"RepositoryImportError",
"SQLAlchemyWorkflowExecutionRepository",
"SQLAlchemyWorkflowNodeExecutionRepository",
]

View File

@@ -0,0 +1,553 @@
import dataclasses
import json
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any
from sqlalchemy import Engine, select
from sqlalchemy.orm import Session, selectinload, sessionmaker
from core.workflow.nodes.human_input.entities import (
DeliveryChannelConfig,
EmailDeliveryMethod,
EmailRecipients,
ExternalRecipient,
FormDefinition,
HumanInputNodeData,
MemberRecipient,
WebAppDeliveryMethod,
)
from core.workflow.nodes.human_input.enums import (
DeliveryMethodType,
HumanInputFormKind,
HumanInputFormStatus,
)
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
FormNotFoundError,
HumanInputFormEntity,
HumanInputFormRecipientEntity,
)
from libs.datetime_utils import naive_utc_now
from libs.uuid_utils import uuidv7
from models.account import Account, TenantAccountJoin
from models.human_input import (
BackstageRecipientPayload,
ConsoleDeliveryPayload,
ConsoleRecipientPayload,
EmailExternalRecipientPayload,
EmailMemberRecipientPayload,
HumanInputDelivery,
HumanInputForm,
HumanInputFormRecipient,
RecipientType,
StandaloneWebAppRecipientPayload,
)
@dataclasses.dataclass(frozen=True)
class _DeliveryAndRecipients:
delivery: HumanInputDelivery
recipients: Sequence[HumanInputFormRecipient]
@dataclasses.dataclass(frozen=True)
class _WorkspaceMemberInfo:
user_id: str
email: str
class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity):
def __init__(self, recipient_model: HumanInputFormRecipient):
self._recipient_model = recipient_model
@property
def id(self) -> str:
return self._recipient_model.id
@property
def token(self) -> str:
if self._recipient_model.access_token is None:
raise AssertionError(f"access_token should not be None for recipient {self._recipient_model.id}")
return self._recipient_model.access_token
class _HumanInputFormEntityImpl(HumanInputFormEntity):
def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]):
self._form_model = form_model
self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models]
self._web_app_recipient = next(
(
recipient
for recipient in recipient_models
if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP
),
None,
)
self._console_recipient = next(
(recipient for recipient in recipient_models if recipient.recipient_type == RecipientType.CONSOLE),
None,
)
self._submitted_data: Mapping[str, Any] | None = (
json.loads(form_model.submitted_data) if form_model.submitted_data is not None else None
)
@property
def id(self) -> str:
return self._form_model.id
@property
def web_app_token(self):
if self._console_recipient is not None:
return self._console_recipient.access_token
if self._web_app_recipient is None:
return None
return self._web_app_recipient.access_token
@property
def recipients(self) -> list[HumanInputFormRecipientEntity]:
return list(self._recipients)
@property
def rendered_content(self) -> str:
return self._form_model.rendered_content
@property
def selected_action_id(self) -> str | None:
return self._form_model.selected_action_id
@property
def submitted_data(self) -> Mapping[str, Any] | None:
return self._submitted_data
@property
def submitted(self) -> bool:
return self._form_model.submitted_at is not None
@property
def status(self) -> HumanInputFormStatus:
return self._form_model.status
@property
def expiration_time(self) -> datetime:
return self._form_model.expiration_time
@dataclasses.dataclass(frozen=True)
class HumanInputFormRecord:
form_id: str
workflow_run_id: str | None
node_id: str
tenant_id: str
app_id: str
form_kind: HumanInputFormKind
definition: FormDefinition
rendered_content: str
created_at: datetime
expiration_time: datetime
status: HumanInputFormStatus
selected_action_id: str | None
submitted_data: Mapping[str, Any] | None
submitted_at: datetime | None
submission_user_id: str | None
submission_end_user_id: str | None
completed_by_recipient_id: str | None
recipient_id: str | None
recipient_type: RecipientType | None
access_token: str | None
@property
def submitted(self) -> bool:
return self.submitted_at is not None
@classmethod
def from_models(
cls, form_model: HumanInputForm, recipient_model: HumanInputFormRecipient | None
) -> "HumanInputFormRecord":
definition_payload = json.loads(form_model.form_definition)
if "expiration_time" not in definition_payload:
definition_payload["expiration_time"] = form_model.expiration_time
return cls(
form_id=form_model.id,
workflow_run_id=form_model.workflow_run_id,
node_id=form_model.node_id,
tenant_id=form_model.tenant_id,
app_id=form_model.app_id,
form_kind=form_model.form_kind,
definition=FormDefinition.model_validate(definition_payload),
rendered_content=form_model.rendered_content,
created_at=form_model.created_at,
expiration_time=form_model.expiration_time,
status=form_model.status,
selected_action_id=form_model.selected_action_id,
submitted_data=json.loads(form_model.submitted_data) if form_model.submitted_data else None,
submitted_at=form_model.submitted_at,
submission_user_id=form_model.submission_user_id,
submission_end_user_id=form_model.submission_end_user_id,
completed_by_recipient_id=form_model.completed_by_recipient_id,
recipient_id=recipient_model.id if recipient_model else None,
recipient_type=recipient_model.recipient_type if recipient_model else None,
access_token=recipient_model.access_token if recipient_model else None,
)
class _InvalidTimeoutStatusError(ValueError):
pass
class HumanInputFormRepositoryImpl:
def __init__(
self,
session_factory: sessionmaker | Engine,
tenant_id: str,
):
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
self._tenant_id = tenant_id
def _delivery_method_to_model(
self,
session: Session,
form_id: str,
delivery_method: DeliveryChannelConfig,
) -> _DeliveryAndRecipients:
delivery_id = str(uuidv7())
delivery_model = HumanInputDelivery(
id=delivery_id,
form_id=form_id,
delivery_method_type=delivery_method.type,
delivery_config_id=delivery_method.id,
channel_payload=delivery_method.model_dump_json(),
)
recipients: list[HumanInputFormRecipient] = []
if isinstance(delivery_method, WebAppDeliveryMethod):
recipient_model = HumanInputFormRecipient(
form_id=form_id,
delivery_id=delivery_id,
recipient_type=RecipientType.STANDALONE_WEB_APP,
recipient_payload=StandaloneWebAppRecipientPayload().model_dump_json(),
)
recipients.append(recipient_model)
elif isinstance(delivery_method, EmailDeliveryMethod):
email_recipients_config = delivery_method.config.recipients
recipients.extend(
self._build_email_recipients(
session=session,
form_id=form_id,
delivery_id=delivery_id,
recipients_config=email_recipients_config,
)
)
return _DeliveryAndRecipients(delivery=delivery_model, recipients=recipients)
def _build_email_recipients(
self,
session: Session,
form_id: str,
delivery_id: str,
recipients_config: EmailRecipients,
) -> list[HumanInputFormRecipient]:
member_user_ids = [
recipient.user_id for recipient in recipients_config.items if isinstance(recipient, MemberRecipient)
]
external_emails = [
recipient.email for recipient in recipients_config.items if isinstance(recipient, ExternalRecipient)
]
if recipients_config.whole_workspace:
members = self._query_all_workspace_members(session=session)
else:
members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=member_user_ids)
return self._create_email_recipients_from_resolved(
form_id=form_id,
delivery_id=delivery_id,
members=members,
external_emails=external_emails,
)
@staticmethod
def _create_email_recipients_from_resolved(
*,
form_id: str,
delivery_id: str,
members: Sequence[_WorkspaceMemberInfo],
external_emails: Sequence[str],
) -> list[HumanInputFormRecipient]:
recipient_models: list[HumanInputFormRecipient] = []
seen_emails: set[str] = set()
for member in members:
if not member.email:
continue
if member.email in seen_emails:
continue
seen_emails.add(member.email)
payload = EmailMemberRecipientPayload(user_id=member.user_id, email=member.email)
recipient_models.append(
HumanInputFormRecipient.new(
form_id=form_id,
delivery_id=delivery_id,
payload=payload,
)
)
for email in external_emails:
if not email:
continue
if email in seen_emails:
continue
seen_emails.add(email)
recipient_models.append(
HumanInputFormRecipient.new(
form_id=form_id,
delivery_id=delivery_id,
payload=EmailExternalRecipientPayload(email=email),
)
)
return recipient_models
def _query_all_workspace_members(
self,
session: Session,
) -> list[_WorkspaceMemberInfo]:
stmt = (
select(Account.id, Account.email)
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
.where(TenantAccountJoin.tenant_id == self._tenant_id)
)
rows = session.execute(stmt).all()
return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows]
def _query_workspace_members_by_ids(
self,
session: Session,
restrict_to_user_ids: Sequence[str],
) -> list[_WorkspaceMemberInfo]:
unique_ids = {user_id for user_id in restrict_to_user_ids if user_id}
if not unique_ids:
return []
stmt = (
select(Account.id, Account.email)
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
.where(TenantAccountJoin.tenant_id == self._tenant_id)
)
stmt = stmt.where(Account.id.in_(unique_ids))
rows = session.execute(stmt).all()
return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows]
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
form_config: HumanInputNodeData = params.form_config
with self._session_factory(expire_on_commit=False) as session, session.begin():
# Generate unique form ID
form_id = str(uuidv7())
start_time = naive_utc_now()
node_expiration = form_config.expiration_time(start_time)
form_definition = FormDefinition(
form_content=form_config.form_content,
inputs=form_config.inputs,
user_actions=form_config.user_actions,
rendered_content=params.rendered_content,
expiration_time=node_expiration,
default_values=dict(params.resolved_default_values),
display_in_ui=params.display_in_ui,
node_title=form_config.title,
)
form_model = HumanInputForm(
id=form_id,
tenant_id=self._tenant_id,
app_id=params.app_id,
workflow_run_id=params.workflow_execution_id,
form_kind=params.form_kind,
node_id=params.node_id,
form_definition=form_definition.model_dump_json(),
rendered_content=params.rendered_content,
expiration_time=node_expiration,
created_at=start_time,
)
session.add(form_model)
recipient_models: list[HumanInputFormRecipient] = []
for delivery in params.delivery_methods:
delivery_and_recipients = self._delivery_method_to_model(
session=session,
form_id=form_id,
delivery_method=delivery,
)
session.add(delivery_and_recipients.delivery)
session.add_all(delivery_and_recipients.recipients)
recipient_models.extend(delivery_and_recipients.recipients)
if params.console_recipient_required and not any(
recipient.recipient_type == RecipientType.CONSOLE for recipient in recipient_models
):
console_delivery_id = str(uuidv7())
console_delivery = HumanInputDelivery(
id=console_delivery_id,
form_id=form_id,
delivery_method_type=DeliveryMethodType.WEBAPP,
delivery_config_id=None,
channel_payload=ConsoleDeliveryPayload().model_dump_json(),
)
console_recipient = HumanInputFormRecipient(
form_id=form_id,
delivery_id=console_delivery_id,
recipient_type=RecipientType.CONSOLE,
recipient_payload=ConsoleRecipientPayload(
account_id=params.console_creator_account_id,
).model_dump_json(),
)
session.add(console_delivery)
session.add(console_recipient)
recipient_models.append(console_recipient)
if params.backstage_recipient_required and not any(
recipient.recipient_type == RecipientType.BACKSTAGE for recipient in recipient_models
):
backstage_delivery_id = str(uuidv7())
backstage_delivery = HumanInputDelivery(
id=backstage_delivery_id,
form_id=form_id,
delivery_method_type=DeliveryMethodType.WEBAPP,
delivery_config_id=None,
channel_payload=ConsoleDeliveryPayload().model_dump_json(),
)
backstage_recipient = HumanInputFormRecipient(
form_id=form_id,
delivery_id=backstage_delivery_id,
recipient_type=RecipientType.BACKSTAGE,
recipient_payload=BackstageRecipientPayload(
account_id=params.console_creator_account_id,
).model_dump_json(),
)
session.add(backstage_delivery)
session.add(backstage_recipient)
recipient_models.append(backstage_recipient)
session.flush()
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
form_query = select(HumanInputForm).where(
HumanInputForm.workflow_run_id == workflow_execution_id,
HumanInputForm.node_id == node_id,
HumanInputForm.tenant_id == self._tenant_id,
)
with self._session_factory(expire_on_commit=False) as session:
form_model: HumanInputForm | None = session.scalars(form_query).first()
if form_model is None:
return None
recipient_query = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_model.id)
recipient_models = session.scalars(recipient_query).all()
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
class HumanInputFormSubmissionRepository:
"""Repository for fetching and submitting human input forms."""
def __init__(self, session_factory: sessionmaker | Engine):
if isinstance(session_factory, Engine):
session_factory = sessionmaker(bind=session_factory)
self._session_factory = session_factory
def get_by_token(self, form_token: str) -> HumanInputFormRecord | None:
query = (
select(HumanInputFormRecipient)
.options(selectinload(HumanInputFormRecipient.form))
.where(HumanInputFormRecipient.access_token == form_token)
)
with self._session_factory(expire_on_commit=False) as session:
recipient_model = session.scalars(query).first()
if recipient_model is None or recipient_model.form is None:
return None
return HumanInputFormRecord.from_models(recipient_model.form, recipient_model)
def get_by_form_id_and_recipient_type(
self,
form_id: str,
recipient_type: RecipientType,
) -> HumanInputFormRecord | None:
query = (
select(HumanInputFormRecipient)
.options(selectinload(HumanInputFormRecipient.form))
.where(
HumanInputFormRecipient.form_id == form_id,
HumanInputFormRecipient.recipient_type == recipient_type,
)
)
with self._session_factory(expire_on_commit=False) as session:
recipient_model = session.scalars(query).first()
if recipient_model is None or recipient_model.form is None:
return None
return HumanInputFormRecord.from_models(recipient_model.form, recipient_model)
def mark_submitted(
self,
*,
form_id: str,
recipient_id: str | None,
selected_action_id: str,
form_data: Mapping[str, Any],
submission_user_id: str | None,
submission_end_user_id: str | None,
) -> HumanInputFormRecord:
with self._session_factory(expire_on_commit=False) as session, session.begin():
form_model = session.get(HumanInputForm, form_id)
if form_model is None:
raise FormNotFoundError(f"form not found, id={form_id}")
recipient_model = session.get(HumanInputFormRecipient, recipient_id) if recipient_id else None
form_model.selected_action_id = selected_action_id
form_model.submitted_data = json.dumps(form_data)
form_model.submitted_at = naive_utc_now()
form_model.status = HumanInputFormStatus.SUBMITTED
form_model.submission_user_id = submission_user_id
form_model.submission_end_user_id = submission_end_user_id
form_model.completed_by_recipient_id = recipient_id
session.add(form_model)
session.flush()
session.refresh(form_model)
if recipient_model is not None:
session.refresh(recipient_model)
return HumanInputFormRecord.from_models(form_model, recipient_model)
def mark_timeout(
self,
*,
form_id: str,
timeout_status: HumanInputFormStatus,
reason: str | None = None,
) -> HumanInputFormRecord:
with self._session_factory(expire_on_commit=False) as session, session.begin():
form_model = session.get(HumanInputForm, form_id)
if form_model is None:
raise FormNotFoundError(f"form not found, id={form_id}")
if timeout_status not in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}:
raise _InvalidTimeoutStatusError(f"invalid timeout status: {timeout_status}")
# already handled or submitted
if form_model.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}:
return HumanInputFormRecord.from_models(form_model, None)
if form_model.submitted_at is not None or form_model.status == HumanInputFormStatus.SUBMITTED:
raise FormNotFoundError(f"form already submitted, id={form_id}")
form_model.status = timeout_status
form_model.selected_action_id = None
form_model.submitted_data = None
form_model.submission_user_id = None
form_model.submission_end_user_id = None
form_model.completed_by_recipient_id = None
# Reason is recorded in status/error downstream; not stored on form.
session.add(form_model)
session.flush()
session.refresh(form_model)
return HumanInputFormRecord.from_models(form_model, None)

View File

@@ -488,6 +488,7 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
WorkflowNodeExecutionModel.triggered_from == triggered_from,
WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED,
)
if self._app_id:

View File

@@ -1,4 +1,5 @@
from core.tools.entities.tool_entities import ToolInvokeMeta
from libs.exception import BaseHTTPException
class ToolProviderNotFoundError(ValueError):
@@ -37,6 +38,12 @@ class ToolCredentialPolicyViolationError(ValueError):
pass
class WorkflowToolHumanInputNotSupportedError(BaseHTTPException):
error_code = "workflow_tool_human_input_not_supported"
description = "Workflow with Human Input nodes cannot be published as a workflow tool."
code = 400
class ToolEngineInvokeError(Exception):
meta: ToolInvokeMeta

View File

@@ -3,6 +3,8 @@ from typing import Any
from core.app.app_config.entities import VariableEntity
from core.tools.entities.tool_entities import WorkflowToolParameterConfiguration
from core.tools.errors import WorkflowToolHumanInputNotSupportedError
from core.workflow.enums import NodeType
from core.workflow.nodes.base.entities import OutputVariableEntity
@@ -45,6 +47,13 @@ class WorkflowToolConfigurationUtils:
return [outputs_by_variable[variable] for variable in variable_order]
@classmethod
def ensure_no_human_input_nodes(cls, graph: Mapping[str, Any]) -> None:
nodes = graph.get("nodes", [])
for node in nodes:
if node.get("data", {}).get("type") == NodeType.HUMAN_INPUT:
raise WorkflowToolHumanInputNotSupportedError()
@classmethod
def check_is_synced(
cls, variables: list[VariableEntity], tool_configurations: list[WorkflowToolParameterConfiguration]

View File

@@ -98,6 +98,10 @@ class WorkflowTool(Tool):
invoke_from=self.runtime.invoke_from,
streaming=False,
call_depth=self.workflow_call_depth + 1,
# NOTE(QuantumGhost): We explicitly set `pause_state_config` to `None`
# because workflow pausing mechanisms (such as HumanInput) are not
# supported within WorkflowTool execution context.
pause_state_config=None,
)
assert isinstance(result, dict)
data = result.get("data", {})

View File

@@ -112,7 +112,7 @@ class ArrayBooleanVariable(ArrayBooleanSegment, ArrayVariable):
class RAGPipelineVariable(BaseModel):
belong_to_node_id: str = Field(description="belong to which node id, shared means public")
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
type: str = Field(description="variable type, text-input, paragraph, select, number, file, file-list")
label: str = Field(description="label")
description: str | None = Field(description="description", default="")
variable: str = Field(description="variable key", default="")

View File

@@ -2,10 +2,12 @@ from .agent import AgentNodeStrategyInit
from .graph_init_params import GraphInitParams
from .workflow_execution import WorkflowExecution
from .workflow_node_execution import WorkflowNodeExecution
from .workflow_start_reason import WorkflowStartReason
__all__ = [
"AgentNodeStrategyInit",
"GraphInitParams",
"WorkflowExecution",
"WorkflowNodeExecution",
"WorkflowStartReason",
]

View File

@@ -5,6 +5,16 @@ from pydantic import BaseModel, Field
class GraphInitParams(BaseModel):
"""GraphInitParams encapsulates the configurations and contextual information
that remain constant throughout a single execution of the graph engine.
A single execution is defined as follows: as long as the execution has not reached
its conclusion, it is considered one execution. For instance, if a workflow is suspended
and later resumed, it is still regarded as a single execution, not two.
For the state diagram of workflow execution, refer to `WorkflowExecutionStatus`.
"""
# init params
tenant_id: str = Field(..., description="tenant / workspace id")
app_id: str = Field(..., description="app id")

View File

@@ -1,8 +1,11 @@
from collections.abc import Mapping
from enum import StrEnum, auto
from typing import Annotated, Literal, TypeAlias
from typing import Annotated, Any, Literal, TypeAlias
from pydantic import BaseModel, Field
from core.workflow.nodes.human_input.entities import FormInput, UserAction
class PauseReasonType(StrEnum):
HUMAN_INPUT_REQUIRED = auto()
@@ -11,10 +14,31 @@ class PauseReasonType(StrEnum):
class HumanInputRequired(BaseModel):
TYPE: Literal[PauseReasonType.HUMAN_INPUT_REQUIRED] = PauseReasonType.HUMAN_INPUT_REQUIRED
form_id: str
# The identifier of the human input node causing the pause.
form_content: str
inputs: list[FormInput] = Field(default_factory=list)
actions: list[UserAction] = Field(default_factory=list)
display_in_ui: bool = False
node_id: str
node_title: str
# The `resolved_default_values` stores the resolved values of variable defaults. It's a mapping from
# `output_variable_name` to their resolved values.
#
# For example, The form contains a input with output variable name `name` and placeholder type `VARIABLE`, its
# selector is ["start", "name"]. While the HumanInputNode is executed, the correspond value of variable
# `start.name` in variable pool is `John`. Thus, the resolved value of the output variable `name` is `John`. The
# `resolved_default_values` is `{"name": "John"}`.
#
# Only form inputs with default value type `VARIABLE` will be resolved and stored in `resolved_default_values`.
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
# The `form_token` is the token used to submit the form via UI surfaces. It corresponds to
# `HumanInputFormRecipient.access_token`.
#
# This field is `None` if webapp delivery is not set and not
# in orchestrating mode.
form_token: str | None = None
class SchedulingPause(BaseModel):

View File

@@ -0,0 +1,8 @@
from enum import StrEnum
class WorkflowStartReason(StrEnum):
"""Reason for workflow start events across graph/queue/SSE layers."""
INITIAL = "initial" # First start of a workflow run.
RESUMPTION = "resumption" # Start triggered after resuming a paused run.

View File

@@ -0,0 +1,15 @@
import time
def get_timestamp() -> float:
"""Retrieve a timestamp as a float point numer representing the number of seconds
since the Unix epoch.
This function is primarily used to measure the execution time of the workflow engine.
Since workflow execution may be paused and resumed on a different machine,
`time.perf_counter` cannot be used as it is inconsistent across machines.
To address this, the function uses the wall clock as the time source.
However, it assumes that the clocks of all servers are properly synchronized.
"""
return round(time.time())

View File

@@ -2,12 +2,14 @@
GraphEngine configuration models.
"""
from pydantic import BaseModel
from pydantic import BaseModel, ConfigDict
class GraphEngineConfig(BaseModel):
"""Configuration for GraphEngine worker pool scaling."""
model_config = ConfigDict(frozen=True)
min_workers: int = 1
max_workers: int = 5
scale_up_threshold: int = 3

View File

@@ -192,9 +192,13 @@ class EventHandler:
self._event_collector.collect(edge_event)
# Enqueue ready nodes
for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
if self._graph_execution.is_paused:
for node_id in ready_nodes:
self._graph_runtime_state.register_deferred_node(node_id)
else:
for node_id in ready_nodes:
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
# Update execution tracking
self._state_manager.finish_execution(event.node_id)

View File

@@ -14,6 +14,7 @@ from collections.abc import Generator
from typing import TYPE_CHECKING, cast, final
from core.workflow.context import capture_current_context
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.enums import NodeExecutionType
from core.workflow.graph import Graph
from core.workflow.graph_events import (
@@ -55,6 +56,9 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
_DEFAULT_CONFIG = GraphEngineConfig()
@final
class GraphEngine:
"""
@@ -70,7 +74,7 @@ class GraphEngine:
graph: Graph,
graph_runtime_state: GraphRuntimeState,
command_channel: CommandChannel,
config: GraphEngineConfig,
config: GraphEngineConfig = _DEFAULT_CONFIG,
) -> None:
"""Initialize the graph engine with all subsystems and dependencies."""
# stop event
@@ -234,7 +238,9 @@ class GraphEngine:
self._graph_execution.paused = False
self._graph_execution.pause_reasons = []
start_event = GraphRunStartedEvent()
start_event = GraphRunStartedEvent(
reason=WorkflowStartReason.RESUMPTION if is_resume else WorkflowStartReason.INITIAL,
)
self._event_manager.notify_layers(start_event)
yield start_event
@@ -303,15 +309,17 @@ class GraphEngine:
for layer in self._layers:
try:
layer.on_graph_start()
except Exception as e:
logger.warning("Layer %s failed on_graph_start: %s", layer.__class__.__name__, e)
except Exception:
logger.exception("Layer %s failed on_graph_start", layer.__class__.__name__)
def _start_execution(self, *, resume: bool = False) -> None:
"""Start execution subsystems."""
self._stop_event.clear()
paused_nodes: list[str] = []
deferred_nodes: list[str] = []
if resume:
paused_nodes = self._graph_runtime_state.consume_paused_nodes()
deferred_nodes = self._graph_runtime_state.consume_deferred_nodes()
# Start worker pool (it calculates initial workers internally)
self._worker_pool.start()
@@ -327,7 +335,11 @@ class GraphEngine:
self._state_manager.enqueue_node(root_node.id)
self._state_manager.start_execution(root_node.id)
else:
for node_id in paused_nodes:
seen_nodes: set[str] = set()
for node_id in paused_nodes + deferred_nodes:
if node_id in seen_nodes:
continue
seen_nodes.add(node_id)
self._state_manager.enqueue_node(node_id)
self._state_manager.start_execution(node_id)
@@ -345,8 +357,8 @@ class GraphEngine:
for layer in self._layers:
try:
layer.on_graph_end(self._graph_execution.error)
except Exception as e:
logger.warning("Layer %s failed on_graph_end: %s", layer.__class__.__name__, e)
except Exception:
logger.exception("Layer %s failed on_graph_end", layer.__class__.__name__)
# Public property accessors for attributes that need external access
@property

View File

@@ -224,6 +224,8 @@ class GraphStateManager:
Returns:
Number of executing nodes
"""
# This count is a best-effort snapshot and can change concurrently.
# Only use it for pause-drain checks where scheduling is already frozen.
with self._lock:
return len(self._executing_nodes)

View File

@@ -83,12 +83,12 @@ class Dispatcher:
"""Main dispatcher loop."""
try:
self._process_commands()
paused = False
while not self._stop_event.is_set():
if (
self._execution_coordinator.aborted
or self._execution_coordinator.paused
or self._execution_coordinator.execution_complete
):
if self._execution_coordinator.aborted or self._execution_coordinator.execution_complete:
break
if self._execution_coordinator.paused:
paused = True
break
self._execution_coordinator.check_scaling()
@@ -101,13 +101,10 @@ class Dispatcher:
time.sleep(0.1)
self._process_commands()
while True:
try:
event = self._event_queue.get(block=False)
self._event_handler.dispatch(event)
self._event_queue.task_done()
except queue.Empty:
break
if paused:
self._drain_events_until_idle()
else:
self._drain_event_queue()
except Exception as e:
logger.exception("Dispatcher error")
@@ -122,3 +119,24 @@ class Dispatcher:
def _process_commands(self, event: GraphNodeEventBase | None = None):
if event is None or isinstance(event, self._COMMAND_TRIGGER_EVENTS):
self._execution_coordinator.process_commands()
def _drain_event_queue(self) -> None:
while True:
try:
event = self._event_queue.get(block=False)
self._event_handler.dispatch(event)
self._event_queue.task_done()
except queue.Empty:
break
def _drain_events_until_idle(self) -> None:
while not self._stop_event.is_set():
try:
event = self._event_queue.get(timeout=0.1)
self._event_handler.dispatch(event)
self._event_queue.task_done()
self._process_commands(event)
except queue.Empty:
if not self._execution_coordinator.has_executing_nodes():
break
self._drain_event_queue()

View File

@@ -94,3 +94,11 @@ class ExecutionCoordinator:
self._worker_pool.stop()
self._state_manager.clear_executing()
def has_executing_nodes(self) -> bool:
"""Return True if any nodes are currently marked as executing."""
# This check is only safe once execution has already paused.
# Before pause, executing state can change concurrently, which makes the result unreliable.
if not self._graph_execution.is_paused:
raise AssertionError("has_executing_nodes should only be called after execution is paused")
return self._state_manager.get_executing_count() > 0

View File

@@ -38,6 +38,8 @@ from .loop import (
from .node import (
NodeRunExceptionEvent,
NodeRunFailedEvent,
NodeRunHumanInputFormFilledEvent,
NodeRunHumanInputFormTimeoutEvent,
NodeRunPauseRequestedEvent,
NodeRunRetrieverResourceEvent,
NodeRunRetryEvent,
@@ -60,6 +62,8 @@ __all__ = [
"NodeRunAgentLogEvent",
"NodeRunExceptionEvent",
"NodeRunFailedEvent",
"NodeRunHumanInputFormFilledEvent",
"NodeRunHumanInputFormTimeoutEvent",
"NodeRunIterationFailedEvent",
"NodeRunIterationNextEvent",
"NodeRunIterationStartedEvent",

View File

@@ -1,11 +1,16 @@
from pydantic import Field
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
from core.workflow.graph_events import BaseGraphEvent
class GraphRunStartedEvent(BaseGraphEvent):
pass
# Reason is emitted for workflow start events and is always set.
reason: WorkflowStartReason = Field(
default=WorkflowStartReason.INITIAL,
description="reason for workflow start",
)
class GraphRunSucceededEvent(BaseGraphEvent):

View File

@@ -54,6 +54,22 @@ class NodeRunRetryEvent(NodeRunStartedEvent):
retry_index: int = Field(..., description="which retry attempt is about to be performed")
class NodeRunHumanInputFormFilledEvent(GraphNodeEventBase):
"""Emitted when a HumanInput form is submitted and before the node finishes."""
node_title: str = Field(..., description="HumanInput node title")
rendered_content: str = Field(..., description="Markdown content rendered with user inputs.")
action_id: str = Field(..., description="User action identifier chosen in the form.")
action_text: str = Field(..., description="Display text of the chosen action button.")
class NodeRunHumanInputFormTimeoutEvent(GraphNodeEventBase):
"""Emitted when a HumanInput form times out."""
node_title: str = Field(..., description="HumanInput node title")
expiration_time: datetime = Field(..., description="Form expiration time")
class NodeRunPauseRequestedEvent(GraphNodeEventBase):
reason: PauseReason = Field(..., description="pause reason")

View File

@@ -13,6 +13,8 @@ from .loop import (
LoopSucceededEvent,
)
from .node import (
HumanInputFormFilledEvent,
HumanInputFormTimeoutEvent,
ModelInvokeCompletedEvent,
PauseRequestedEvent,
RunRetrieverResourceEvent,
@@ -23,6 +25,8 @@ from .node import (
__all__ = [
"AgentLogEvent",
"HumanInputFormFilledEvent",
"HumanInputFormTimeoutEvent",
"IterationFailedEvent",
"IterationNextEvent",
"IterationStartedEvent",

View File

@@ -47,3 +47,19 @@ class StreamCompletedEvent(NodeEventBase):
class PauseRequestedEvent(NodeEventBase):
reason: PauseReason = Field(..., description="pause reason")
class HumanInputFormFilledEvent(NodeEventBase):
"""Event emitted when a human input form is submitted."""
node_title: str
rendered_content: str
action_id: str
action_text: str
class HumanInputFormTimeoutEvent(NodeEventBase):
"""Event emitted when a human input form times out."""
node_title: str
expiration_time: datetime

View File

@@ -2,7 +2,7 @@ from __future__ import annotations
import json
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Union, cast
from packaging.version import Version
from pydantic import ValidationError
@@ -11,6 +11,7 @@ from sqlalchemy.orm import Session
from core.agent.entities import AgentToolEntity
from core.agent.plugin_entities import AgentStrategyParameter
from core.db.session_factory import session_factory
from core.file import File, FileTransferMethod
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelInstance, ModelManager
@@ -49,6 +50,12 @@ from factories import file_factory
from factories.agent_factory import get_plugin_agent_strategy
from models import ToolFile
from models.model import Conversation
from models.tools import (
ApiToolProvider,
BuiltinToolProvider,
MCPToolProvider,
WorkflowToolProvider,
)
from services.tools.builtin_tools_manage_service import BuiltinToolManageService
from .exc import (
@@ -259,7 +266,7 @@ class AgentNode(Node[AgentNodeData]):
value = cast(list[dict[str, Any]], value)
tool_value = []
for tool in value:
provider_type = ToolProviderType(tool.get("type", ToolProviderType.BUILT_IN))
provider_type = self._infer_tool_provider_type(tool, self.tenant_id)
setting_params = tool.get("settings", {})
parameters = tool.get("parameters", {})
manual_input_params = [key for key, value in parameters.items() if value is not None]
@@ -748,3 +755,34 @@ class AgentNode(Node[AgentNodeData]):
llm_usage=llm_usage,
)
)
@staticmethod
def _infer_tool_provider_type(tool_config: dict[str, Any], tenant_id: str) -> ToolProviderType:
provider_type_str = tool_config.get("type")
if provider_type_str:
return ToolProviderType(provider_type_str)
provider_id = tool_config.get("provider_name")
if not provider_id:
return ToolProviderType.BUILT_IN
with session_factory.create_session() as session:
provider_map: dict[
type[Union[WorkflowToolProvider, MCPToolProvider, ApiToolProvider, BuiltinToolProvider]],
ToolProviderType,
] = {
WorkflowToolProvider: ToolProviderType.WORKFLOW,
MCPToolProvider: ToolProviderType.MCP,
ApiToolProvider: ToolProviderType.API,
BuiltinToolProvider: ToolProviderType.BUILT_IN,
}
for provider_model, provider_type in provider_map.items():
stmt = select(provider_model).where(
provider_model.id == provider_id,
provider_model.tenant_id == tenant_id,
)
if session.scalar(stmt):
return provider_type
raise AgentNodeError(f"Tool provider with ID '{provider_id}' not found.")

View File

@@ -18,6 +18,8 @@ from core.workflow.graph_events import (
GraphNodeEventBase,
NodeRunAgentLogEvent,
NodeRunFailedEvent,
NodeRunHumanInputFormFilledEvent,
NodeRunHumanInputFormTimeoutEvent,
NodeRunIterationFailedEvent,
NodeRunIterationNextEvent,
NodeRunIterationStartedEvent,
@@ -34,6 +36,8 @@ from core.workflow.graph_events import (
)
from core.workflow.node_events import (
AgentLogEvent,
HumanInputFormFilledEvent,
HumanInputFormTimeoutEvent,
IterationFailedEvent,
IterationNextEvent,
IterationStartedEvent,
@@ -61,6 +65,15 @@ logger = logging.getLogger(__name__)
class Node(Generic[NodeDataT]):
"""BaseNode serves as the foundational class for all node implementations.
Nodes are allowed to maintain transient states (e.g., `LLMNode` uses the `_file_output`
attribute to track files generated by the LLM). However, these states are not persisted
when the workflow is suspended or resumed. If a node needs its state to be preserved
across workflow suspension and resumption, it should include the relevant state data
in its output.
"""
node_type: ClassVar[NodeType]
execution_type: NodeExecutionType = NodeExecutionType.EXECUTABLE
_node_data_type: ClassVar[type[BaseNodeData]] = BaseNodeData
@@ -251,10 +264,33 @@ class Node(Generic[NodeDataT]):
return self._node_execution_id
def ensure_execution_id(self) -> str:
if not self._node_execution_id:
self._node_execution_id = str(uuid4())
if self._node_execution_id:
return self._node_execution_id
resumed_execution_id = self._restore_execution_id_from_runtime_state()
if resumed_execution_id:
self._node_execution_id = resumed_execution_id
return self._node_execution_id
self._node_execution_id = str(uuid4())
return self._node_execution_id
def _restore_execution_id_from_runtime_state(self) -> str | None:
graph_execution = self.graph_runtime_state.graph_execution
try:
node_executions = graph_execution.node_executions
except AttributeError:
return None
if not isinstance(node_executions, dict):
return None
node_execution = node_executions.get(self._node_id)
if node_execution is None:
return None
execution_id = node_execution.execution_id
if not execution_id:
return None
return str(execution_id)
def _hydrate_node_data(self, data: Mapping[str, Any]) -> NodeDataT:
return cast(NodeDataT, self._node_data_type.model_validate(data))
@@ -620,6 +656,28 @@ class Node(Generic[NodeDataT]):
metadata=event.metadata,
)
@_dispatch.register
def _(self, event: HumanInputFormFilledEvent):
return NodeRunHumanInputFormFilledEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=event.node_title,
rendered_content=event.rendered_content,
action_id=event.action_id,
action_text=event.action_text,
)
@_dispatch.register
def _(self, event: HumanInputFormTimeoutEvent):
return NodeRunHumanInputFormTimeoutEvent(
id=self.execution_id,
node_id=self._node_id,
node_type=self.node_type,
node_title=event.node_title,
expiration_time=event.expiration_time,
)
@_dispatch.register
def _(self, event: LoopStartedEvent) -> NodeRunLoopStartedEvent:
return NodeRunLoopStartedEvent(

View File

@@ -1,3 +1,3 @@
from .human_input_node import HumanInputNode
__all__ = ["HumanInputNode"]
"""
Human Input node implementation.
"""

View File

@@ -1,10 +1,350 @@
from pydantic import Field
"""
Human Input node entities.
"""
import re
import uuid
from collections.abc import Mapping, Sequence
from datetime import datetime, timedelta
from typing import Annotated, Any, ClassVar, Literal, Self
from pydantic import BaseModel, Field, field_validator, model_validator
from core.variables.consts import SELECTORS_LENGTH
from core.workflow.nodes.base import BaseNodeData
from core.workflow.nodes.base.variable_template_parser import VariableTemplateParser
from core.workflow.runtime import VariablePool
from .enums import ButtonStyle, DeliveryMethodType, EmailRecipientType, FormInputType, PlaceholderType, TimeoutUnit
_OUTPUT_VARIABLE_PATTERN = re.compile(r"\{\{#\$output\.(?P<field_name>[a-zA-Z_][a-zA-Z0-9_]{0,29})#\}\}")
class _WebAppDeliveryConfig(BaseModel):
"""Configuration for webapp delivery method."""
pass # Empty for webapp delivery
class MemberRecipient(BaseModel):
"""Member recipient for email delivery."""
type: Literal[EmailRecipientType.MEMBER] = EmailRecipientType.MEMBER
user_id: str
class ExternalRecipient(BaseModel):
"""External recipient for email delivery."""
type: Literal[EmailRecipientType.EXTERNAL] = EmailRecipientType.EXTERNAL
email: str
EmailRecipient = Annotated[MemberRecipient | ExternalRecipient, Field(discriminator="type")]
class EmailRecipients(BaseModel):
"""Email recipients configuration."""
# When true, recipients are the union of all workspace members and external items.
# Member items are ignored because they are already covered by the workspace scope.
# De-duplication is applied by email, with member recipients taking precedence.
whole_workspace: bool = False
items: list[EmailRecipient] = Field(default_factory=list)
class EmailDeliveryConfig(BaseModel):
"""Configuration for email delivery method."""
URL_PLACEHOLDER: ClassVar[str] = "{{#url#}}"
recipients: EmailRecipients
# the subject of email
subject: str
# Body is the content of email.It may contain the speical placeholder `{{#url#}}`, which
# represent the url to submit the form.
#
# It may also reference the output variable of the previous node with the syntax
# `{{#<node_id>.<field_name>#}}`.
body: str
debug_mode: bool = False
def with_debug_recipient(self, user_id: str) -> "EmailDeliveryConfig":
if not user_id:
debug_recipients = EmailRecipients(whole_workspace=False, items=[])
return self.model_copy(update={"recipients": debug_recipients})
debug_recipients = EmailRecipients(whole_workspace=False, items=[MemberRecipient(user_id=user_id)])
return self.model_copy(update={"recipients": debug_recipients})
@classmethod
def replace_url_placeholder(cls, body: str, url: str | None) -> str:
"""Replace the url placeholder with provided value."""
return body.replace(cls.URL_PLACEHOLDER, url or "")
@classmethod
def render_body_template(
cls,
*,
body: str,
url: str | None,
variable_pool: VariablePool | None = None,
) -> str:
"""Render email body by replacing placeholders with runtime values."""
templated_body = cls.replace_url_placeholder(body, url)
if variable_pool is None:
return templated_body
return variable_pool.convert_template(templated_body).text
class _DeliveryMethodBase(BaseModel):
"""Base delivery method configuration."""
enabled: bool = True
id: uuid.UUID = Field(default_factory=uuid.uuid4)
def extract_variable_selectors(self) -> Sequence[Sequence[str]]:
return ()
class WebAppDeliveryMethod(_DeliveryMethodBase):
"""Webapp delivery method configuration."""
type: Literal[DeliveryMethodType.WEBAPP] = DeliveryMethodType.WEBAPP
# The config field is not used currently.
config: _WebAppDeliveryConfig = Field(default_factory=_WebAppDeliveryConfig)
class EmailDeliveryMethod(_DeliveryMethodBase):
"""Email delivery method configuration."""
type: Literal[DeliveryMethodType.EMAIL] = DeliveryMethodType.EMAIL
config: EmailDeliveryConfig
def extract_variable_selectors(self) -> Sequence[Sequence[str]]:
variable_template_parser = VariableTemplateParser(template=self.config.body)
selectors: list[Sequence[str]] = []
for variable_selector in variable_template_parser.extract_variable_selectors():
value_selector = list(variable_selector.value_selector)
if len(value_selector) < SELECTORS_LENGTH:
continue
selectors.append(value_selector[:SELECTORS_LENGTH])
return selectors
DeliveryChannelConfig = Annotated[WebAppDeliveryMethod | EmailDeliveryMethod, Field(discriminator="type")]
def apply_debug_email_recipient(
method: DeliveryChannelConfig,
*,
enabled: bool,
user_id: str,
) -> DeliveryChannelConfig:
if not enabled:
return method
if not isinstance(method, EmailDeliveryMethod):
return method
if not method.config.debug_mode:
return method
debug_config = method.config.with_debug_recipient(user_id or "")
return method.model_copy(update={"config": debug_config})
class FormInputDefault(BaseModel):
"""Default configuration for form inputs."""
# NOTE: Ideally, a discriminated union would be used to model
# FormInputDefault. However, the UI requires preserving the previous
# value when switching between `VARIABLE` and `CONSTANT` types. This
# necessitates retaining all fields, making a discriminated union unsuitable.
type: PlaceholderType
# The selector of default variable, used when `type` is `VARIABLE`.
selector: Sequence[str] = Field(default_factory=tuple) #
# The value of the default, used when `type` is `CONSTANT`.
# TODO: How should we express JSON values?
value: str = ""
@model_validator(mode="after")
def _validate_selector(self) -> Self:
if self.type == PlaceholderType.CONSTANT:
return self
if len(self.selector) < SELECTORS_LENGTH:
raise ValueError(f"the length of selector should be at least {SELECTORS_LENGTH}, selector={self.selector}")
return self
class FormInput(BaseModel):
"""Form input definition."""
type: FormInputType
output_variable_name: str
default: FormInputDefault | None = None
_IDENTIFIER_PATTERN = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
class UserAction(BaseModel):
"""User action configuration."""
# id is the identifier for this action.
# It also serves as the identifiers of output handle.
#
# The id must be a valid identifier (satisfy the _IDENTIFIER_PATTERN above.)
id: str = Field(max_length=20)
title: str = Field(max_length=20)
button_style: ButtonStyle = ButtonStyle.DEFAULT
@field_validator("id")
@classmethod
def _validate_id(cls, value: str) -> str:
if not _IDENTIFIER_PATTERN.match(value):
raise ValueError(
f"'{value}' is not a valid identifier. It must start with a letter or underscore, "
f"and contain only letters, numbers, or underscores."
)
return value
class HumanInputNodeData(BaseNodeData):
"""Configuration schema for the HumanInput node."""
"""Human Input node data."""
required_variables: list[str] = Field(default_factory=list)
pause_reason: str | None = Field(default=None)
delivery_methods: list[DeliveryChannelConfig] = Field(default_factory=list)
form_content: str = ""
inputs: list[FormInput] = Field(default_factory=list)
user_actions: list[UserAction] = Field(default_factory=list)
timeout: int = 36
timeout_unit: TimeoutUnit = TimeoutUnit.HOUR
@field_validator("inputs")
@classmethod
def _validate_inputs(cls, inputs: list[FormInput]) -> list[FormInput]:
seen_names: set[str] = set()
for form_input in inputs:
name = form_input.output_variable_name
if name in seen_names:
raise ValueError(f"duplicated output_variable_name '{name}' in inputs")
seen_names.add(name)
return inputs
@field_validator("user_actions")
@classmethod
def _validate_user_actions(cls, user_actions: list[UserAction]) -> list[UserAction]:
seen_ids: set[str] = set()
for action in user_actions:
action_id = action.id
if action_id in seen_ids:
raise ValueError(f"duplicated user action id '{action_id}'")
seen_ids.add(action_id)
return user_actions
def is_webapp_enabled(self) -> bool:
for dm in self.delivery_methods:
if not dm.enabled:
continue
if dm.type == DeliveryMethodType.WEBAPP:
return True
return False
def expiration_time(self, start_time: datetime) -> datetime:
if self.timeout_unit == TimeoutUnit.HOUR:
return start_time + timedelta(hours=self.timeout)
elif self.timeout_unit == TimeoutUnit.DAY:
return start_time + timedelta(days=self.timeout)
else:
raise AssertionError("unknown timeout unit.")
def outputs_field_names(self) -> Sequence[str]:
field_names = []
for match in _OUTPUT_VARIABLE_PATTERN.finditer(self.form_content):
field_names.append(match.group("field_name"))
return field_names
def extract_variable_selector_to_variable_mapping(self, node_id: str) -> Mapping[str, Sequence[str]]:
variable_mappings: dict[str, Sequence[str]] = {}
def _add_variable_selectors(selectors: Sequence[Sequence[str]]) -> None:
for selector in selectors:
if len(selector) < SELECTORS_LENGTH:
continue
qualified_variable_mapping_key = f"{node_id}.#{'.'.join(selector[:SELECTORS_LENGTH])}#"
variable_mappings[qualified_variable_mapping_key] = list(selector[:SELECTORS_LENGTH])
form_template_parser = VariableTemplateParser(template=self.form_content)
_add_variable_selectors(
[selector.value_selector for selector in form_template_parser.extract_variable_selectors()]
)
for delivery_method in self.delivery_methods:
if not delivery_method.enabled:
continue
_add_variable_selectors(delivery_method.extract_variable_selectors())
for input in self.inputs:
default_value = input.default
if default_value is None:
continue
if default_value.type == PlaceholderType.CONSTANT:
continue
default_value_key = ".".join(default_value.selector)
qualified_variable_mapping_key = f"{node_id}.#{default_value_key}#"
variable_mappings[qualified_variable_mapping_key] = default_value.selector
return variable_mappings
def find_action_text(self, action_id: str) -> str:
"""
Resolve action display text by id.
"""
for action in self.user_actions:
if action.id == action_id:
return action.title
return action_id
class FormDefinition(BaseModel):
form_content: str
inputs: list[FormInput] = Field(default_factory=list)
user_actions: list[UserAction] = Field(default_factory=list)
rendered_content: str
expiration_time: datetime
# this is used to store the resolved default values
default_values: dict[str, Any] = Field(default_factory=dict)
# node_title records the title of the HumanInput node.
node_title: str | None = None
# display_in_ui controls whether the form should be displayed in UI surfaces.
display_in_ui: bool | None = None
class HumanInputSubmissionValidationError(ValueError):
pass
def validate_human_input_submission(
*,
inputs: Sequence[FormInput],
user_actions: Sequence[UserAction],
selected_action_id: str,
form_data: Mapping[str, Any],
) -> None:
available_actions = {action.id for action in user_actions}
if selected_action_id not in available_actions:
raise HumanInputSubmissionValidationError(f"Invalid action: {selected_action_id}")
provided_inputs = set(form_data.keys())
missing_inputs = [
form_input.output_variable_name
for form_input in inputs
if form_input.output_variable_name not in provided_inputs
]
if missing_inputs:
missing_list = ", ".join(missing_inputs)
raise HumanInputSubmissionValidationError(f"Missing required inputs: {missing_list}")

View File

@@ -0,0 +1,72 @@
import enum
class HumanInputFormStatus(enum.StrEnum):
"""Status of a human input form."""
# Awaiting submission from any recipient. Forms stay in this state until
# submitted or a timeout rule applies.
WAITING = enum.auto()
# Global timeout reached. The workflow run is stopped and will not resume.
# This is distinct from node-level timeout.
EXPIRED = enum.auto()
# Submitted by a recipient; form data is available and execution resumes
# along the selected action edge.
SUBMITTED = enum.auto()
# Node-level timeout reached. The human input node should emit a timeout
# event and the workflow should resume along the timeout edge.
TIMEOUT = enum.auto()
class HumanInputFormKind(enum.StrEnum):
"""Kind of a human input form."""
RUNTIME = enum.auto() # Form created during workflow execution.
DELIVERY_TEST = enum.auto() # Form created for delivery tests.
class DeliveryMethodType(enum.StrEnum):
"""Delivery method types for human input forms."""
# WEBAPP controls whether the form is delivered to the web app. It not only controls
# the standalone web app, but also controls the installed apps in the console.
WEBAPP = enum.auto()
EMAIL = enum.auto()
class ButtonStyle(enum.StrEnum):
"""Button styles for user actions."""
PRIMARY = enum.auto()
DEFAULT = enum.auto()
ACCENT = enum.auto()
GHOST = enum.auto()
class TimeoutUnit(enum.StrEnum):
"""Timeout unit for form expiration."""
HOUR = enum.auto()
DAY = enum.auto()
class FormInputType(enum.StrEnum):
"""Form input types."""
TEXT_INPUT = enum.auto()
PARAGRAPH = enum.auto()
class PlaceholderType(enum.StrEnum):
"""Default value types for form inputs."""
VARIABLE = enum.auto()
CONSTANT = enum.auto()
class EmailRecipientType(enum.StrEnum):
"""Email recipient types."""
MEMBER = enum.auto()
EXTERNAL = enum.auto()

View File

@@ -1,12 +1,42 @@
from collections.abc import Mapping
from typing import Any
import json
import logging
from collections.abc import Generator, Mapping, Sequence
from typing import TYPE_CHECKING, Any
from core.app.entities.app_invoke_entities import InvokeFrom
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
from core.workflow.entities.pause_reason import HumanInputRequired
from core.workflow.enums import NodeExecutionType, NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult, PauseRequestedEvent
from core.workflow.node_events import (
HumanInputFormFilledEvent,
HumanInputFormTimeoutEvent,
NodeRunResult,
PauseRequestedEvent,
)
from core.workflow.node_events.base import NodeEventBase
from core.workflow.node_events.node import StreamCompletedEvent
from core.workflow.nodes.base.node import Node
from core.workflow.repositories.human_input_form_repository import (
FormCreateParams,
HumanInputFormEntity,
HumanInputFormRepository,
)
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from .entities import HumanInputNodeData
from .entities import DeliveryChannelConfig, HumanInputNodeData, apply_debug_email_recipient
from .enums import DeliveryMethodType, HumanInputFormStatus, PlaceholderType
if TYPE_CHECKING:
from core.workflow.entities.graph_init_params import GraphInitParams
from core.workflow.runtime.graph_runtime_state import GraphRuntimeState
_SELECTED_BRANCH_KEY = "selected_branch"
logger = logging.getLogger(__name__)
class HumanInputNode(Node[HumanInputNodeData]):
@@ -17,7 +47,7 @@ class HumanInputNode(Node[HumanInputNodeData]):
"edge_source_handle",
"edgeSourceHandle",
"source_handle",
"selected_branch",
_SELECTED_BRANCH_KEY,
"selectedBranch",
"branch",
"branch_id",
@@ -25,43 +55,37 @@ class HumanInputNode(Node[HumanInputNodeData]):
"handle",
)
_node_data: HumanInputNodeData
_form_repository: HumanInputFormRepository
_OUTPUT_FIELD_ACTION_ID = "__action_id"
_OUTPUT_FIELD_RENDERED_CONTENT = "__rendered_content"
_TIMEOUT_HANDLE = _TIMEOUT_ACTION_ID = "__timeout"
def __init__(
self,
id: str,
config: Mapping[str, Any],
graph_init_params: "GraphInitParams",
graph_runtime_state: "GraphRuntimeState",
form_repository: HumanInputFormRepository | None = None,
) -> None:
super().__init__(
id=id,
config=config,
graph_init_params=graph_init_params,
graph_runtime_state=graph_runtime_state,
)
if form_repository is None:
form_repository = HumanInputFormRepositoryImpl(
session_factory=db.engine,
tenant_id=self.tenant_id,
)
self._form_repository = form_repository
@classmethod
def version(cls) -> str:
return "1"
def _run(self): # type: ignore[override]
if self._is_completion_ready():
branch_handle = self._resolve_branch_selection()
return NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={},
edge_source_handle=branch_handle or "source",
)
return self._pause_generator()
def _pause_generator(self):
# TODO(QuantumGhost): yield a real form id.
yield PauseRequestedEvent(reason=HumanInputRequired(form_id="test_form_id", node_id=self.id))
def _is_completion_ready(self) -> bool:
"""Determine whether all required inputs are satisfied."""
if not self.node_data.required_variables:
return False
variable_pool = self.graph_runtime_state.variable_pool
for selector_str in self.node_data.required_variables:
parts = selector_str.split(".")
if len(parts) != 2:
return False
segment = variable_pool.get(parts)
if segment is None:
return False
return True
def _resolve_branch_selection(self) -> str | None:
"""Determine the branch handle selected by human input if available."""
@@ -108,3 +132,224 @@ class HumanInputNode(Node[HumanInputNodeData]):
return candidate
return None
@property
def _workflow_execution_id(self) -> str:
workflow_exec_id = self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id
assert workflow_exec_id is not None
return workflow_exec_id
def _form_to_pause_event(self, form_entity: HumanInputFormEntity):
required_event = self._human_input_required_event(form_entity)
pause_requested_event = PauseRequestedEvent(reason=required_event)
return pause_requested_event
def resolve_default_values(self) -> Mapping[str, Any]:
variable_pool = self.graph_runtime_state.variable_pool
resolved_defaults: dict[str, Any] = {}
for input in self._node_data.inputs:
if (default_value := input.default) is None:
continue
if default_value.type == PlaceholderType.CONSTANT:
continue
resolved_value = variable_pool.get(default_value.selector)
if resolved_value is None:
# TODO: How should we handle this?
continue
resolved_defaults[input.output_variable_name] = (
WorkflowRuntimeTypeConverter().value_to_json_encodable_recursive(resolved_value.value)
)
return resolved_defaults
def _should_require_console_recipient(self) -> bool:
if self.invoke_from == InvokeFrom.DEBUGGER:
return True
if self.invoke_from == InvokeFrom.EXPLORE:
return self._node_data.is_webapp_enabled()
return False
def _display_in_ui(self) -> bool:
if self.invoke_from == InvokeFrom.DEBUGGER:
return True
return self._node_data.is_webapp_enabled()
def _effective_delivery_methods(self) -> Sequence[DeliveryChannelConfig]:
enabled_methods = [method for method in self._node_data.delivery_methods if method.enabled]
if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE}:
enabled_methods = [method for method in enabled_methods if method.type != DeliveryMethodType.WEBAPP]
return [
apply_debug_email_recipient(
method,
enabled=self.invoke_from == InvokeFrom.DEBUGGER,
user_id=self.user_id or "",
)
for method in enabled_methods
]
def _human_input_required_event(self, form_entity: HumanInputFormEntity) -> HumanInputRequired:
node_data = self._node_data
resolved_default_values = self.resolve_default_values()
display_in_ui = self._display_in_ui()
form_token = form_entity.web_app_token
if display_in_ui and form_token is None:
raise AssertionError("Form token should be available for UI execution.")
return HumanInputRequired(
form_id=form_entity.id,
form_content=form_entity.rendered_content,
inputs=node_data.inputs,
actions=node_data.user_actions,
display_in_ui=display_in_ui,
node_id=self.id,
node_title=node_data.title,
form_token=form_token,
resolved_default_values=resolved_default_values,
)
def _run(self) -> Generator[NodeEventBase, None, None]:
"""
Execute the human input node.
This method will:
1. Generate a unique form ID
2. Create form content with variable substitution
3. Create form in database
4. Send form via configured delivery methods
5. Suspend workflow execution
6. Wait for form submission to resume
"""
repo = self._form_repository
form = repo.get_form(self._workflow_execution_id, self.id)
if form is None:
display_in_ui = self._display_in_ui()
params = FormCreateParams(
app_id=self.app_id,
workflow_execution_id=self._workflow_execution_id,
node_id=self.id,
form_config=self._node_data,
rendered_content=self.render_form_content_before_submission(),
delivery_methods=self._effective_delivery_methods(),
display_in_ui=display_in_ui,
resolved_default_values=self.resolve_default_values(),
console_recipient_required=self._should_require_console_recipient(),
console_creator_account_id=(
self.user_id if self.invoke_from in {InvokeFrom.DEBUGGER, InvokeFrom.EXPLORE} else None
),
backstage_recipient_required=True,
)
form_entity = self._form_repository.create_form(params)
# Create human input required event
logger.info(
"Human Input node suspended workflow for form. workflow_run_id=%s, node_id=%s, form_id=%s",
self.graph_runtime_state.variable_pool.system_variables.workflow_execution_id,
self.id,
form_entity.id,
)
yield self._form_to_pause_event(form_entity)
return
if (
form.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}
or form.expiration_time <= naive_utc_now()
):
yield HumanInputFormTimeoutEvent(
node_title=self._node_data.title,
expiration_time=form.expiration_time,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs={self._OUTPUT_FIELD_ACTION_ID: ""},
edge_source_handle=self._TIMEOUT_HANDLE,
)
)
return
if not form.submitted:
yield self._form_to_pause_event(form)
return
selected_action_id = form.selected_action_id
if selected_action_id is None:
raise AssertionError(f"selected_action_id should not be None when form submitted, form_id={form.id}")
submitted_data = form.submitted_data or {}
outputs: dict[str, Any] = dict(submitted_data)
outputs[self._OUTPUT_FIELD_ACTION_ID] = selected_action_id
rendered_content = self.render_form_content_with_outputs(
form.rendered_content,
outputs,
self._node_data.outputs_field_names(),
)
outputs[self._OUTPUT_FIELD_RENDERED_CONTENT] = rendered_content
action_text = self._node_data.find_action_text(selected_action_id)
yield HumanInputFormFilledEvent(
node_title=self._node_data.title,
rendered_content=rendered_content,
action_id=selected_action_id,
action_text=action_text,
)
yield StreamCompletedEvent(
node_run_result=NodeRunResult(
status=WorkflowNodeExecutionStatus.SUCCEEDED,
outputs=outputs,
edge_source_handle=selected_action_id,
)
)
def render_form_content_before_submission(self) -> str:
"""
Process form content by substituting variables.
This method should:
1. Parse the form_content markdown
2. Substitute {{#node_name.var_name#}} with actual values
3. Keep {{#$output.field_name#}} placeholders for form inputs
"""
rendered_form_content = self.graph_runtime_state.variable_pool.convert_template(
self._node_data.form_content,
)
return rendered_form_content.markdown
@staticmethod
def render_form_content_with_outputs(
form_content: str,
outputs: Mapping[str, Any],
field_names: Sequence[str],
) -> str:
"""
Replace {{#$output.xxx#}} placeholders with submitted values.
"""
rendered_content = form_content
for field_name in field_names:
placeholder = "{{#$output." + field_name + "#}}"
value = outputs.get(field_name)
if value is None:
replacement = ""
elif isinstance(value, (dict, list)):
replacement = json.dumps(value, ensure_ascii=False)
else:
replacement = str(value)
rendered_content = rendered_content.replace(placeholder, replacement)
return rendered_content
@classmethod
def _extract_variable_selector_to_variable_mapping(
cls,
*,
graph_config: Mapping[str, Any],
node_id: str,
node_data: Mapping[str, Any],
) -> Mapping[str, Sequence[str]]:
"""
Extract variable selectors referenced in form content and input default values.
This method should parse:
1. Variables referenced in form_content ({{#node_name.var_name#}})
2. Variables referenced in input default values
"""
validated_node_data = HumanInputNodeData.model_validate(node_data)
return validated_node_data.extract_variable_selector_to_variable_mapping(node_id)

View File

@@ -1,7 +1,6 @@
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any
from configs import dify_config
from core.workflow.enums import NodeType, WorkflowNodeExecutionStatus
from core.workflow.node_events import NodeRunResult
from core.workflow.nodes.base.node import Node
@@ -16,12 +15,13 @@ if TYPE_CHECKING:
from core.workflow.entities import GraphInitParams
from core.workflow.runtime import GraphRuntimeState
MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH = dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH = 400_000
class TemplateTransformNode(Node[TemplateTransformNodeData]):
node_type = NodeType.TEMPLATE_TRANSFORM
_template_renderer: Jinja2TemplateRenderer
_max_output_length: int
def __init__(
self,
@@ -31,6 +31,7 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
graph_runtime_state: "GraphRuntimeState",
*,
template_renderer: Jinja2TemplateRenderer | None = None,
max_output_length: int | None = None,
) -> None:
super().__init__(
id=id,
@@ -40,6 +41,10 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
)
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
if max_output_length is not None and max_output_length <= 0:
raise ValueError("max_output_length must be a positive integer")
self._max_output_length = max_output_length or DEFAULT_TEMPLATE_TRANSFORM_MAX_OUTPUT_LENGTH
@classmethod
def get_default_config(cls, filters: Mapping[str, object] | None = None) -> Mapping[str, object]:
"""
@@ -69,11 +74,11 @@ class TemplateTransformNode(Node[TemplateTransformNodeData]):
except TemplateRenderError as e:
return NodeRunResult(inputs=variables, status=WorkflowNodeExecutionStatus.FAILED, error=str(e))
if len(rendered) > MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH:
if len(rendered) > self._max_output_length:
return NodeRunResult(
inputs=variables,
status=WorkflowNodeExecutionStatus.FAILED,
error=f"Output length exceeds {MAX_TEMPLATE_TRANSFORM_OUTPUT_LENGTH} characters",
error=f"Output length exceeds {self._max_output_length} characters",
)
return NodeRunResult(

View File

@@ -0,0 +1,152 @@
import abc
import dataclasses
from collections.abc import Mapping, Sequence
from datetime import datetime
from typing import Any, Protocol
from core.workflow.nodes.human_input.entities import DeliveryChannelConfig, HumanInputNodeData
from core.workflow.nodes.human_input.enums import HumanInputFormKind, HumanInputFormStatus
class HumanInputError(Exception):
pass
class FormNotFoundError(HumanInputError):
pass
@dataclasses.dataclass
class FormCreateParams:
# app_id is the identifier for the app that the form belongs to.
# It is a string with uuid format.
app_id: str
# None when creating a delivery test form; set for runtime forms.
workflow_execution_id: str | None
# node_id is the identifier for a specific
# node in the graph.
#
# TODO: for node inside loop / iteration, this would
# cause problems, as a single node may be executed multiple times.
node_id: str
form_config: HumanInputNodeData
rendered_content: str
# Delivery methods already filtered by runtime context (invoke_from).
delivery_methods: Sequence[DeliveryChannelConfig]
# UI display flag computed by runtime context.
display_in_ui: bool
# resolved_default_values saves the values for defaults with
# type = VARIABLE.
#
# For type = CONSTANT, the value is not stored inside `resolved_default_values`
resolved_default_values: Mapping[str, Any]
form_kind: HumanInputFormKind = HumanInputFormKind.RUNTIME
# Force creating a console-only recipient for submission in Console.
console_recipient_required: bool = False
console_creator_account_id: str | None = None
# Force creating a backstage recipient for submission in Console.
backstage_recipient_required: bool = False
class HumanInputFormEntity(abc.ABC):
@property
@abc.abstractmethod
def id(self) -> str:
"""id returns the identifer of the form."""
pass
@property
@abc.abstractmethod
def web_app_token(self) -> str | None:
"""web_app_token returns the token for submission inside webapp.
For console/debug execution, this may point to the console submission token
if the form is configured to require console delivery.
"""
# TODO: what if the users are allowed to add multiple
# webapp delivery?
pass
@property
@abc.abstractmethod
def recipients(self) -> list["HumanInputFormRecipientEntity"]: ...
@property
@abc.abstractmethod
def rendered_content(self) -> str:
"""Rendered markdown content associated with the form."""
...
@property
@abc.abstractmethod
def selected_action_id(self) -> str | None:
"""Identifier of the selected user action if the form has been submitted."""
...
@property
@abc.abstractmethod
def submitted_data(self) -> Mapping[str, Any] | None:
"""Submitted form data if available."""
...
@property
@abc.abstractmethod
def submitted(self) -> bool:
"""Whether the form has been submitted."""
...
@property
@abc.abstractmethod
def status(self) -> HumanInputFormStatus:
"""Current status of the form."""
...
@property
@abc.abstractmethod
def expiration_time(self) -> datetime:
"""When the form expires."""
...
class HumanInputFormRecipientEntity(abc.ABC):
@property
@abc.abstractmethod
def id(self) -> str:
"""id returns the identifer of this recipient."""
...
@property
@abc.abstractmethod
def token(self) -> str:
"""token returns a random string used to submit form"""
...
class HumanInputFormRepository(Protocol):
"""
Repository interface for HumanInputForm.
This interface defines the contract for accessing and manipulating
HumanInputForm data, regardless of the underlying storage mechanism.
Note: Domain-specific concepts like multi-tenancy (tenant_id), application context (app_id),
and other implementation details should be handled at the implementation level, not in
the core interface. This keeps the core domain model clean and independent of specific
application domains or deployment scenarios.
"""
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
"""Get the form created for a given human input node in a workflow execution. Returns
`None` if the form has not been created yet."""
...
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
"""
Create a human input form from form definition.
"""
...

View File

@@ -6,15 +6,18 @@ import threading
from collections.abc import Mapping, Sequence
from copy import deepcopy
from dataclasses import dataclass
from typing import Any, ClassVar, Protocol
from typing import TYPE_CHECKING, Any, ClassVar, Protocol
from pydantic import BaseModel, Field
from pydantic.json import pydantic_encoder
from core.model_runtime.entities.llm_entities import LLMUsage
from core.workflow.entities.pause_reason import PauseReason
from core.workflow.enums import NodeExecutionType, NodeState, NodeType
from core.workflow.runtime.variable_pool import VariablePool
if TYPE_CHECKING:
from core.workflow.entities.pause_reason import PauseReason
class ReadyQueueProtocol(Protocol):
"""Structural interface required from ready queue implementations."""
@@ -61,7 +64,7 @@ class GraphExecutionProtocol(Protocol):
aborted: bool
error: Exception | None
exceptions_count: int
pause_reasons: list[PauseReason]
pause_reasons: Sequence[PauseReason]
def start(self) -> None:
"""Transition execution into the running state."""
@@ -133,6 +136,13 @@ class GraphProtocol(Protocol):
def get_outgoing_edges(self, node_id: str) -> Sequence[EdgeProtocol]: ...
class _GraphStateSnapshot(BaseModel):
"""Serializable graph state snapshot for node/edge states."""
nodes: dict[str, NodeState] = Field(default_factory=dict)
edges: dict[str, NodeState] = Field(default_factory=dict)
@dataclass(slots=True)
class _GraphRuntimeStateSnapshot:
"""Immutable view of a serialized runtime state snapshot."""
@@ -148,10 +158,20 @@ class _GraphRuntimeStateSnapshot:
graph_execution_dump: str | None
response_coordinator_dump: str | None
paused_nodes: tuple[str, ...]
deferred_nodes: tuple[str, ...]
graph_node_states: dict[str, NodeState]
graph_edge_states: dict[str, NodeState]
class GraphRuntimeState:
"""Mutable runtime state shared across graph execution components."""
"""Mutable runtime state shared across graph execution components.
`GraphRuntimeState` encapsulates the runtime state of workflow execution,
including scheduling details, variable values, and timing information.
Values that are initialized prior to workflow execution and remain constant
throughout the execution should be part of `GraphInitParams` instead.
"""
def __init__(
self,
@@ -189,6 +209,16 @@ class GraphRuntimeState:
self._pending_response_coordinator_dump: str | None = None
self._pending_graph_execution_workflow_id: str | None = None
self._paused_nodes: set[str] = set()
self._deferred_nodes: set[str] = set()
# Node and edges states needed to be restored into
# graph object.
#
# These two fields are non-None only when resuming from a snapshot.
# Once the graph is attached, these two fields will be set to None.
self._pending_graph_node_states: dict[str, NodeState] | None = None
self._pending_graph_edge_states: dict[str, NodeState] | None = None
self.stop_event: threading.Event = threading.Event()
if graph is not None:
@@ -210,6 +240,7 @@ class GraphRuntimeState:
if self._pending_response_coordinator_dump is not None and self._response_coordinator is not None:
self._response_coordinator.loads(self._pending_response_coordinator_dump)
self._pending_response_coordinator_dump = None
self._apply_pending_graph_state()
def configure(self, *, graph: GraphProtocol | None = None) -> None:
"""Ensure core collaborators are initialized with the provided context."""
@@ -331,8 +362,13 @@ class GraphRuntimeState:
"ready_queue": self.ready_queue.dumps(),
"graph_execution": self.graph_execution.dumps(),
"paused_nodes": list(self._paused_nodes),
"deferred_nodes": list(self._deferred_nodes),
}
graph_state = self._snapshot_graph_state()
if graph_state is not None:
snapshot["graph_state"] = graph_state
if self._response_coordinator is not None and self._graph is not None:
snapshot["response_coordinator"] = self._response_coordinator.dumps()
@@ -366,6 +402,11 @@ class GraphRuntimeState:
self._paused_nodes.add(node_id)
def get_paused_nodes(self) -> list[str]:
"""Retrieve the list of paused nodes without mutating internal state."""
return list(self._paused_nodes)
def consume_paused_nodes(self) -> list[str]:
"""Retrieve and clear the list of paused nodes awaiting resume."""
@@ -373,6 +414,23 @@ class GraphRuntimeState:
self._paused_nodes.clear()
return nodes
def register_deferred_node(self, node_id: str) -> None:
"""Record a node that became ready during pause and should resume later."""
self._deferred_nodes.add(node_id)
def get_deferred_nodes(self) -> list[str]:
"""Retrieve deferred nodes without mutating internal state."""
return list(self._deferred_nodes)
def consume_deferred_nodes(self) -> list[str]:
"""Retrieve and clear deferred nodes awaiting resume."""
nodes = list(self._deferred_nodes)
self._deferred_nodes.clear()
return nodes
# ------------------------------------------------------------------
# Builders
# ------------------------------------------------------------------
@@ -434,6 +492,10 @@ class GraphRuntimeState:
graph_execution_payload = payload.get("graph_execution")
response_payload = payload.get("response_coordinator")
paused_nodes_payload = payload.get("paused_nodes", [])
deferred_nodes_payload = payload.get("deferred_nodes", [])
graph_state_payload = payload.get("graph_state", {}) or {}
graph_node_states = _coerce_graph_state_map(graph_state_payload, "nodes")
graph_edge_states = _coerce_graph_state_map(graph_state_payload, "edges")
return _GraphRuntimeStateSnapshot(
start_at=start_at,
@@ -447,6 +509,9 @@ class GraphRuntimeState:
graph_execution_dump=graph_execution_payload,
response_coordinator_dump=response_payload,
paused_nodes=tuple(map(str, paused_nodes_payload)),
deferred_nodes=tuple(map(str, deferred_nodes_payload)),
graph_node_states=graph_node_states,
graph_edge_states=graph_edge_states,
)
def _apply_snapshot(self, snapshot: _GraphRuntimeStateSnapshot) -> None:
@@ -462,6 +527,10 @@ class GraphRuntimeState:
self._restore_graph_execution(snapshot.graph_execution_dump)
self._restore_response_coordinator(snapshot.response_coordinator_dump)
self._paused_nodes = set(snapshot.paused_nodes)
self._deferred_nodes = set(snapshot.deferred_nodes)
self._pending_graph_node_states = snapshot.graph_node_states or None
self._pending_graph_edge_states = snapshot.graph_edge_states or None
self._apply_pending_graph_state()
def _restore_ready_queue(self, payload: str | None) -> None:
if payload is not None:
@@ -498,3 +567,68 @@ class GraphRuntimeState:
self._pending_response_coordinator_dump = payload
self._response_coordinator = None
def _snapshot_graph_state(self) -> _GraphStateSnapshot:
graph = self._graph
if graph is None:
if self._pending_graph_node_states is None and self._pending_graph_edge_states is None:
return _GraphStateSnapshot()
return _GraphStateSnapshot(
nodes=self._pending_graph_node_states or {},
edges=self._pending_graph_edge_states or {},
)
nodes = graph.nodes
edges = graph.edges
if not isinstance(nodes, Mapping) or not isinstance(edges, Mapping):
return _GraphStateSnapshot()
node_states = {}
for node_id, node in nodes.items():
if not isinstance(node_id, str):
continue
node_states[node_id] = node.state
edge_states = {}
for edge_id, edge in edges.items():
if not isinstance(edge_id, str):
continue
edge_states[edge_id] = edge.state
return _GraphStateSnapshot(nodes=node_states, edges=edge_states)
def _apply_pending_graph_state(self) -> None:
if self._graph is None:
return
if self._pending_graph_node_states:
for node_id, state in self._pending_graph_node_states.items():
node = self._graph.nodes.get(node_id)
if node is None:
continue
node.state = state
if self._pending_graph_edge_states:
for edge_id, state in self._pending_graph_edge_states.items():
edge = self._graph.edges.get(edge_id)
if edge is None:
continue
edge.state = state
self._pending_graph_node_states = None
self._pending_graph_edge_states = None
def _coerce_graph_state_map(payload: Any, key: str) -> dict[str, NodeState]:
if not isinstance(payload, Mapping):
return {}
raw_map = payload.get(key, {})
if not isinstance(raw_map, Mapping):
return {}
result: dict[str, NodeState] = {}
for node_id, raw_state in raw_map.items():
if not isinstance(node_id, str):
continue
try:
result[node_id] = NodeState(str(raw_state))
except ValueError:
continue
return result

View File

@@ -15,12 +15,14 @@ class WorkflowRuntimeTypeConverter:
def to_json_encodable(self, value: None) -> None: ...
def to_json_encodable(self, value: Mapping[str, Any] | None) -> Mapping[str, Any] | None:
result = self._to_json_encodable_recursive(value)
"""Convert runtime values to JSON-serializable structures."""
result = self.value_to_json_encodable_recursive(value)
if isinstance(result, Mapping) or result is None:
return result
return {}
def _to_json_encodable_recursive(self, value: Any):
def value_to_json_encodable_recursive(self, value: Any):
if value is None:
return value
if isinstance(value, (bool, int, str, float)):
@@ -29,7 +31,7 @@ class WorkflowRuntimeTypeConverter:
# Convert Decimal to float for JSON serialization
return float(value)
if isinstance(value, Segment):
return self._to_json_encodable_recursive(value.value)
return self.value_to_json_encodable_recursive(value.value)
if isinstance(value, File):
return value.to_dict()
if isinstance(value, BaseModel):
@@ -37,11 +39,11 @@ class WorkflowRuntimeTypeConverter:
if isinstance(value, dict):
res = {}
for k, v in value.items():
res[k] = self._to_json_encodable_recursive(v)
res[k] = self.value_to_json_encodable_recursive(v)
return res
if isinstance(value, list):
res_list = []
for item in value:
res_list.append(self._to_json_encodable_recursive(item))
res_list.append(self.value_to_json_encodable_recursive(item))
return res_list
return value

View File

@@ -35,10 +35,10 @@ if [[ "${MODE}" == "worker" ]]; then
if [[ -z "${CELERY_QUEUES}" ]]; then
if [[ "${EDITION}" == "CLOUD" ]]; then
# Cloud edition: separate queues for dataset and trigger tasks
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow_professional,workflow_team,workflow_sandbox,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
else
# Community edition (SELF_HOSTED): dataset, pipeline and workflow have separate queues
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention"
DEFAULT_QUEUES="dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution"
fi
else
DEFAULT_QUEUES="${CELERY_QUEUES}"
@@ -102,7 +102,7 @@ elif [[ "${MODE}" == "job" ]]; then
fi
echo "Running Flask job command: flask $*"
# Temporarily disable exit on error to capture exit code
set +e
flask "$@"

View File

@@ -151,6 +151,12 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.queue_monitor_task.queue_monitor_task",
"schedule": timedelta(minutes=dify_config.QUEUE_MONITOR_INTERVAL or 30),
}
if dify_config.ENABLE_HUMAN_INPUT_TIMEOUT_TASK:
imports.append("tasks.human_input_timeout_tasks")
beat_schedule["human_input_form_timeout"] = {
"task": "human_input_form_timeout.check_and_resume",
"schedule": timedelta(minutes=dify_config.HUMAN_INPUT_TIMEOUT_TASK_INTERVAL),
}
if dify_config.ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK and dify_config.MARKETPLACE_ENABLED:
imports.append("schedule.check_upgradable_plugin_task")
imports.append("tasks.process_tenant_plugin_autoupgrade_check_task")

View File

@@ -8,12 +8,16 @@ from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar, Union
import redis
from redis import RedisError
from redis.cache import CacheConfig
from redis.client import PubSub
from redis.cluster import ClusterNode, RedisCluster
from redis.connection import Connection, SSLConnection
from redis.sentinel import Sentinel
from configs import dify_config
from dify_app import DifyApp
from libs.broadcast_channel.channel import BroadcastChannel as BroadcastChannelProtocol
from libs.broadcast_channel.redis.channel import BroadcastChannel as RedisBroadcastChannel
from libs.broadcast_channel.redis.sharded_channel import ShardedRedisBroadcastChannel
if TYPE_CHECKING:
from redis.lock import Lock
@@ -106,6 +110,7 @@ class RedisClientWrapper:
def zremrangebyscore(self, name: str | bytes, min: float | str, max: float | str) -> Any: ...
def zcard(self, name: str | bytes) -> Any: ...
def getdel(self, name: str | bytes) -> Any: ...
def pubsub(self) -> PubSub: ...
def __getattr__(self, item: str) -> Any:
if self._client is None:
@@ -114,6 +119,7 @@ class RedisClientWrapper:
redis_client: RedisClientWrapper = RedisClientWrapper()
pubsub_redis_client: RedisClientWrapper = RedisClientWrapper()
def _get_ssl_configuration() -> tuple[type[Union[Connection, SSLConnection]], dict[str, Any]]:
@@ -226,6 +232,12 @@ def _create_standalone_client(redis_params: dict[str, Any]) -> Union[redis.Redis
return client
def _create_pubsub_client(pubsub_url: str, use_clusters: bool) -> Union[redis.Redis, RedisCluster]:
if use_clusters:
return RedisCluster.from_url(pubsub_url)
return redis.Redis.from_url(pubsub_url)
def init_app(app: DifyApp):
"""Initialize Redis client and attach it to the app."""
global redis_client
@@ -244,6 +256,24 @@ def init_app(app: DifyApp):
redis_client.initialize(client)
app.extensions["redis"] = redis_client
pubsub_client = client
if dify_config.normalized_pubsub_redis_url:
pubsub_client = _create_pubsub_client(
dify_config.normalized_pubsub_redis_url, dify_config.PUBSUB_REDIS_USE_CLUSTERS
)
pubsub_redis_client.initialize(pubsub_client)
def get_pubsub_redis_client() -> RedisClientWrapper:
return pubsub_redis_client
def get_pubsub_broadcast_channel() -> BroadcastChannelProtocol:
redis_conn = get_pubsub_redis_client()
if dify_config.PUBSUB_REDIS_CHANNEL_TYPE == "sharded":
return ShardedRedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
return RedisBroadcastChannel(redis_conn) # pyright: ignore[reportArgumentType]
P = ParamSpec("P")
R = TypeVar("R")

View File

@@ -13,6 +13,7 @@ from typing import Any
from sqlalchemy.orm import sessionmaker
from core.workflow.enums import WorkflowNodeExecutionStatus
from extensions.logstore.aliyun_logstore import AliyunLogStore
from extensions.logstore.repositories import safe_float, safe_int
from extensions.logstore.sql_escape import escape_identifier, escape_logstore_query_value
@@ -207,8 +208,10 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
reverse=True,
)
if deduplicated_results:
return _dict_to_workflow_node_execution_model(deduplicated_results[0])
for row in deduplicated_results:
model = _dict_to_workflow_node_execution_model(row)
if model.status != WorkflowNodeExecutionStatus.PAUSED:
return model
return None
@@ -309,6 +312,8 @@ class LogstoreAPIWorkflowNodeExecutionRepository(DifyAPIWorkflowNodeExecutionRep
if model and model.id: # Ensure model is valid
models.append(model)
models = [model for model in models if model.status != WorkflowNodeExecutionStatus.PAUSED]
# Sort by index DESC for trace visualization
models.sort(key=lambda x: x.index, reverse=True)

View File

@@ -192,6 +192,7 @@ class StatusCount(ResponseModel):
success: int
failed: int
partial_success: int
paused: int
class ModelConfig(ResponseModel):

View File

@@ -6,6 +6,7 @@ from uuid import uuid4
from pydantic import BaseModel, ConfigDict, Field, field_validator
from core.entities.execution_extra_content import ExecutionExtraContentDomainModel
from core.file import File
from fields.conversation_fields import AgentThought, JSONValue, MessageFile
@@ -61,6 +62,7 @@ class MessageListItem(ResponseModel):
message_files: list[MessageFile]
status: str
error: str | None = None
extra_contents: list[ExecutionExtraContentDomainModel]
@field_validator("inputs", mode="before")
@classmethod

View File

@@ -7,6 +7,7 @@ from typing import Self
from libs.broadcast_channel.channel import Subscription
from libs.broadcast_channel.exc import SubscriptionClosedError
from redis import Redis, RedisCluster
from redis.client import PubSub
_logger = logging.getLogger(__name__)
@@ -22,10 +23,12 @@ class RedisSubscriptionBase(Subscription):
def __init__(
self,
client: Redis | RedisCluster,
pubsub: PubSub,
topic: str,
):
# The _pubsub is None only if the subscription is closed.
self._client = client
self._pubsub: PubSub | None = pubsub
self._topic = topic
self._closed = threading.Event()
@@ -162,7 +165,7 @@ class RedisSubscriptionBase(Subscription):
self._start_if_needed()
return iter(self._message_iterator())
def receive(self, timeout: float | None = None) -> bytes | None:
def receive(self, timeout: float | None = 0.1) -> bytes | None:
"""Receive the next message from the subscription."""
if self._closed.is_set():
raise SubscriptionClosedError(f"The Redis {self._get_subscription_type()} subscription is closed")

View File

@@ -42,6 +42,7 @@ class Topic:
def subscribe(self) -> Subscription:
return _RedisSubscription(
client=self._client,
pubsub=self._client.pubsub(),
topic=self._topic,
)
@@ -63,7 +64,7 @@ class _RedisSubscription(RedisSubscriptionBase):
def _get_message(self) -> dict | None:
assert self._pubsub is not None
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=0.1)
return self._pubsub.get_message(ignore_subscribe_messages=True, timeout=1)
def _get_message_type(self) -> str:
return "message"

View File

@@ -1,7 +1,7 @@
from __future__ import annotations
from libs.broadcast_channel.channel import Producer, Subscriber, Subscription
from redis import Redis
from redis import Redis, RedisCluster
from ._subscription import RedisSubscriptionBase
@@ -16,7 +16,7 @@ class ShardedRedisBroadcastChannel:
def __init__(
self,
redis_client: Redis,
redis_client: Redis | RedisCluster,
):
self._client = redis_client
@@ -25,7 +25,7 @@ class ShardedRedisBroadcastChannel:
class ShardedTopic:
def __init__(self, redis_client: Redis, topic: str):
def __init__(self, redis_client: Redis | RedisCluster, topic: str):
self._client = redis_client
self._topic = topic
@@ -40,6 +40,7 @@ class ShardedTopic:
def subscribe(self) -> Subscription:
return _RedisShardedSubscription(
client=self._client,
pubsub=self._client.pubsub(),
topic=self._topic,
)
@@ -61,7 +62,26 @@ class _RedisShardedSubscription(RedisSubscriptionBase):
def _get_message(self) -> dict | None:
assert self._pubsub is not None
return self._pubsub.get_sharded_message(ignore_subscribe_messages=True, timeout=0.1) # type: ignore[attr-defined]
# NOTE(QuantumGhost): this is an issue in
# upstream code. If Sharded PubSub is used with Cluster, the
# `ClusterPubSub.get_sharded_message` will return `None` regardless of
# message['type'].
#
# Since we have already filtered at the caller's site, we can safely set
# `ignore_subscribe_messages=False`.
if isinstance(self._client, RedisCluster):
# NOTE(QuantumGhost): due to an issue in upstream code, calling `get_sharded_message`
# would use busy-looping to wait for incoming message, consuming excessive CPU quota.
#
# Here we specify the `target_node` to mitigate this problem.
node = self._client.get_node_from_key(self._topic)
return self._pubsub.get_sharded_message(
ignore_subscribe_messages=False,
timeout=1,
target_node=node,
)
else:
return self._pubsub.get_sharded_message(ignore_subscribe_messages=False, timeout=1) # type: ignore[attr-defined]
def _get_message_type(self) -> str:
return "smessage"

View File

@@ -0,0 +1,49 @@
"""
Email template rendering helpers with configurable safety modes.
"""
import time
from collections.abc import Mapping
from typing import Any
from flask import render_template_string
from jinja2.runtime import Context
from jinja2.sandbox import ImmutableSandboxedEnvironment
from configs import dify_config
from configs.feature import TemplateMode
class SandboxedEnvironment(ImmutableSandboxedEnvironment):
"""Sandboxed environment with execution timeout."""
def __init__(self, timeout: int, *args: Any, **kwargs: Any):
self._deadline = time.time() + timeout if timeout else None
super().__init__(*args, **kwargs)
def call(self, context: Context, obj: Any, *args: Any, **kwargs: Any) -> Any:
if self._deadline is not None and time.time() > self._deadline:
raise TimeoutError("Template rendering timeout")
return super().call(context, obj, *args, **kwargs)
def render_email_template(template: str, substitutions: Mapping[str, str]) -> str:
"""
Render email template content according to the configured template mode.
In unsafe mode, Jinja expressions are evaluated directly.
In sandbox mode, a sandboxed environment with timeout is used.
In disabled mode, the template is returned without rendering.
"""
mode = dify_config.MAIL_TEMPLATING_MODE
timeout = dify_config.MAIL_TEMPLATING_TIMEOUT
if mode == TemplateMode.UNSAFE:
return render_template_string(template, **substitutions)
if mode == TemplateMode.SANDBOX:
env = SandboxedEnvironment(timeout=timeout)
tmpl = env.from_string(template)
return tmpl.render(substitutions)
if mode == TemplateMode.DISABLED:
return template
raise ValueError(f"Unsupported mail templating mode: {mode}")

View File

@@ -1,12 +1,15 @@
import contextvars
from collections.abc import Iterator
from contextlib import contextmanager
from typing import TypeVar
from typing import TYPE_CHECKING, TypeVar
from flask import Flask, g
T = TypeVar("T")
if TYPE_CHECKING:
from models import Account, EndUser
@contextmanager
def preserve_flask_contexts(
@@ -64,3 +67,7 @@ def preserve_flask_contexts(
finally:
# Any cleanup can be added here if needed
pass
def set_login_user(user: "Account | EndUser"):
g._login_user = user

View File

@@ -7,10 +7,10 @@ import struct
import subprocess
import time
import uuid
from collections.abc import Generator, Mapping
from collections.abc import Callable, Generator, Mapping
from datetime import datetime
from hashlib import sha256
from typing import TYPE_CHECKING, Annotated, Any, Optional, Union, cast
from typing import TYPE_CHECKING, Annotated, Any, Optional, Protocol, Union, cast
from uuid import UUID
from zoneinfo import available_timezones
@@ -126,6 +126,13 @@ class TimestampField(fields.Raw):
return int(value.timestamp())
class OptionalTimestampField(fields.Raw):
def format(self, value) -> int | None:
if value is None:
return None
return int(value.timestamp())
def email(email):
# Define a regex pattern for email addresses
pattern = r"^[\w\.!#$%&'*+\-/=?^_`{|}~]+@([\w-]+\.)+[\w-]{2,}$"
@@ -237,6 +244,26 @@ def convert_datetime_to_date(field, target_timezone: str = ":tz"):
def generate_string(n):
"""
Generates a cryptographically secure random string of the specified length.
This function uses a cryptographically secure pseudorandom number generator (CSPRNG)
to create a string composed of ASCII letters (both uppercase and lowercase) and digits.
Each character in the generated string provides approximately 5.95 bits of entropy
(log2(62)). To ensure a minimum of 128 bits of entropy for security purposes, the
length of the string (`n`) should be at least 22 characters.
Args:
n (int): The length of the random string to generate. For secure usage,
`n` should be 22 or greater.
Returns:
str: A random string of length `n` composed of ASCII letters and digits.
Note:
This function is suitable for generating credentials or other secure tokens.
"""
letters_digits = string.ascii_letters + string.digits
result = ""
for _ in range(n):
@@ -405,11 +432,35 @@ class TokenManager:
return f"{token_type}:account:{account_id}"
class _RateLimiterRedisClient(Protocol):
def zadd(self, name: str | bytes, mapping: dict[str | bytes | int | float, float | int | str | bytes]) -> int: ...
def zremrangebyscore(self, name: str | bytes, min: str | float, max: str | float) -> int: ...
def zcard(self, name: str | bytes) -> int: ...
def expire(self, name: str | bytes, time: int) -> bool: ...
def _default_rate_limit_member_factory() -> str:
current_time = int(time.time())
return f"{current_time}:{secrets.token_urlsafe(nbytes=8)}"
class RateLimiter:
def __init__(self, prefix: str, max_attempts: int, time_window: int):
def __init__(
self,
prefix: str,
max_attempts: int,
time_window: int,
member_factory: Callable[[], str] = _default_rate_limit_member_factory,
redis_client: _RateLimiterRedisClient = redis_client,
):
self.prefix = prefix
self.max_attempts = max_attempts
self.time_window = time_window
self._member_factory = member_factory
self._redis_client = redis_client
def _get_key(self, email: str) -> str:
return f"{self.prefix}:{email}"
@@ -419,8 +470,8 @@ class RateLimiter:
current_time = int(time.time())
window_start_time = current_time - self.time_window
redis_client.zremrangebyscore(key, "-inf", window_start_time)
attempts = redis_client.zcard(key)
self._redis_client.zremrangebyscore(key, "-inf", window_start_time)
attempts = self._redis_client.zcard(key)
if attempts and int(attempts) >= self.max_attempts:
return True
@@ -428,7 +479,8 @@ class RateLimiter:
def increment_rate_limit(self, email: str):
key = self._get_key(email)
member = self._member_factory()
current_time = int(time.time())
redis_client.zadd(key, {current_time: current_time})
redis_client.expire(key, self.time_window * 2)
self._redis_client.zadd(key, {member: current_time})
self._redis_client.expire(key, self.time_window * 2)

View File

@@ -10,6 +10,10 @@ import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
def _is_pg(conn):
return conn.dialect.name == "postgresql"
# revision identifiers, used by Alembic.
revision = '7df29de0f6be'
down_revision = '03ea244985ce'
@@ -19,16 +23,31 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tenant_credit_pools',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False),
sa.Column('quota_limit', sa.BigInteger(), nullable=False),
sa.Column('quota_used', sa.BigInteger(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey')
)
conn = op.get_bind()
if _is_pg(conn):
op.create_table('tenant_credit_pools',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False),
sa.Column('quota_limit', sa.BigInteger(), nullable=False),
sa.Column('quota_used', sa.BigInteger(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey')
)
else:
# For MySQL and other databases, UUID should be generated at application level
op.create_table('tenant_credit_pools',
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False),
sa.Column('quota_limit', sa.BigInteger(), nullable=False),
sa.Column('quota_used', sa.BigInteger(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.Column('updated_at', sa.DateTime(), server_default=sa.func.current_timestamp(), nullable=False),
sa.PrimaryKeyConstraint('id', name='tenant_credit_pool_pkey')
)
with op.batch_alter_table('tenant_credit_pools', schema=None) as batch_op:
batch_op.create_index('tenant_credit_pool_pool_type_idx', ['pool_type'], unique=False)
batch_op.create_index('tenant_credit_pool_tenant_id_idx', ['tenant_id'], unique=False)

Some files were not shown because too many files have changed in this diff Show More