Compare commits

..

4 Commits

Author SHA1 Message Date
Yanli 盐粒
5574802631 fix: pass webhook request metadata explicitly 2026-03-20 04:51:37 +08:00
Yanli 盐粒
b8cedefd7d fix: normalize signed call depth for root webhook URLs 2026-03-20 04:27:22 +08:00
Yanli 盐粒
4ecba5858b fix: propagate workflow call depth through HTTP recursion 2026-03-20 04:03:00 +08:00
tmimmanuel
5b9cb55c45 refactor: use EnumText for MessageFeedback and MessageFile columns (#33738) 2026-03-20 01:13:26 +09:00
45 changed files with 997 additions and 281 deletions

View File

@@ -30,6 +30,7 @@ from fields.raws import FilesContainedField
from libs.helper import TimestampField, uuid_value
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from libs.login import current_account_with_tenant, login_required
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
@@ -335,7 +336,7 @@ class MessageFeedbackApi(Resource):
if not args.rating and feedback:
db.session.delete(feedback)
elif args.rating and feedback:
feedback.rating = args.rating
feedback.rating = FeedbackRating(args.rating)
feedback.content = args.content
elif not args.rating and not feedback:
raise ValueError("rating cannot be None when feedback not exists")
@@ -347,9 +348,9 @@ class MessageFeedbackApi(Resource):
app_id=app_model.id,
conversation_id=message.conversation_id,
message_id=message.id,
rating=rating_value,
rating=FeedbackRating(rating_value),
content=args.content,
from_source="admin",
from_source=FeedbackFromSource.ADMIN,
from_account_id=current_user.id,
)
db.session.add(feedback)

View File

@@ -27,6 +27,7 @@ from fields.message_fields import MessageInfiniteScrollPagination, MessageListIt
from libs import helper
from libs.helper import UUIDStrOrEmpty
from libs.login import current_account_with_tenant
from models.enums import FeedbackRating
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
@@ -116,7 +117,7 @@ class MessageFeedbackApi(InstalledAppResource):
app_model=app_model,
message_id=message_id,
user=current_user,
rating=payload.rating,
rating=FeedbackRating(payload.rating) if payload.rating else None,
content=payload.content,
)
except MessageNotExistsError:

View File

@@ -7,7 +7,6 @@ from sqlalchemy import select
from werkzeug.exceptions import Unauthorized
import services
from configs import dify_config
from controllers.common.errors import (
FilenameNotExistsError,
FileTooLargeError,
@@ -30,7 +29,6 @@ from libs.helper import TimestampField
from libs.login import current_account_with_tenant, login_required
from models.account import Tenant, TenantStatus
from services.account_service import TenantService
from services.billing_service import BillingService
from services.enterprise.enterprise_service import EnterpriseService
from services.feature_service import FeatureService
from services.file_service import FileService
@@ -110,27 +108,9 @@ class TenantListApi(Resource):
current_user, current_tenant_id = current_account_with_tenant()
tenants = TenantService.get_join_tenants(current_user)
tenant_dicts = []
is_enterprise_only = dify_config.ENTERPRISE_ENABLED and not dify_config.BILLING_ENABLED
is_saas = dify_config.EDITION == "CLOUD" and dify_config.BILLING_ENABLED
tenant_plans: dict[str, dict] = {}
use_legacy_feature_path = not is_enterprise_only and not is_saas
if is_saas:
tenant_ids = [tenant.id for tenant in tenants]
if tenant_ids:
try:
tenant_plans = BillingService.get_plan_bulk(tenant_ids)
except Exception:
logger.exception("failed to fetch workspace plans in bulk, falling back to legacy feature path")
use_legacy_feature_path = True
for tenant in tenants:
plan = CloudPlan.SANDBOX
if is_saas and not use_legacy_feature_path:
plan = tenant_plans.get(tenant.id, {}).get("plan", CloudPlan.SANDBOX)
elif not is_enterprise_only:
features = FeatureService.get_features(tenant.id)
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
features = FeatureService.get_features(tenant.id)
# Create a dictionary with tenant attributes
tenant_dict = {
@@ -138,7 +118,7 @@ class TenantListApi(Resource):
"name": tenant.name,
"status": tenant.status,
"created_at": tenant.created_at,
"plan": plan,
"plan": features.billing.subscription.plan if features.billing.enabled else CloudPlan.SANDBOX,
"current": tenant.id == current_tenant_id if current_tenant_id else False,
}

View File

@@ -15,6 +15,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
from fields.conversation_fields import ResultResponse
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
from libs.helper import UUIDStrOrEmpty
from models.enums import FeedbackRating
from models.model import App, AppMode, EndUser
from services.errors.message import (
FirstMessageNotExistsError,
@@ -116,7 +117,7 @@ class MessageFeedbackApi(Resource):
app_model=app_model,
message_id=message_id,
user=end_user,
rating=payload.rating,
rating=FeedbackRating(payload.rating) if payload.rating else None,
content=payload.content,
)
except MessageNotExistsError:

View File

@@ -52,8 +52,19 @@ def handle_webhook(webhook_id: str):
if error:
return jsonify({"error": "Bad Request", "message": error}), 400
trigger_call_depth = WebhookService.extract_workflow_call_depth(
dict(request.headers),
request_method=request.method,
request_path=request.path,
)
# Process webhook call (send to Celery)
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
WebhookService.trigger_workflow_execution(
webhook_trigger,
webhook_data,
workflow,
call_depth=trigger_call_depth,
)
# Return configured response
response_data, status_code = WebhookService.generate_webhook_response(node_config)

View File

@@ -25,6 +25,7 @@ from fields.conversation_fields import ResultResponse
from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
from libs import helper
from libs.helper import uuid_value
from models.enums import FeedbackRating
from models.model import AppMode
from services.app_generate_service import AppGenerateService
from services.errors.app import MoreLikeThisDisabledError
@@ -157,7 +158,7 @@ class MessageFeedbackApi(WebApiResource):
app_model=app_model,
message_id=message_id,
user=end_user,
rating=payload.rating,
rating=FeedbackRating(payload.rating) if payload.rating else None,
content=payload.content,
)
except MessageNotExistsError:

View File

@@ -76,7 +76,7 @@ from dify_graph.system_variable import SystemVariable
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models import Account, Conversation, EndUser, Message, MessageFile
from models.enums import CreatorUserRole, MessageStatus
from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus
from models.execution_extra_content import HumanInputContent
from models.workflow import Workflow
@@ -939,7 +939,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
type=file["type"],
transfer_method=file["transfer_method"],
url=file["remote_url"],
belongs_to="assistant",
belongs_to=MessageFileBelongsTo.ASSISTANT,
upload_file_id=file["related_id"],
created_by_role=CreatorUserRole.ACCOUNT
if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}

View File

@@ -40,7 +40,7 @@ from dify_graph.model_runtime.entities.message_entities import (
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, MessageFileBelongsTo
from models.model import App, AppMode, Message, MessageAnnotation, MessageFile
if TYPE_CHECKING:
@@ -419,7 +419,7 @@ class AppRunner:
message_id=message_id,
type=FileType.IMAGE,
transfer_method=FileTransferMethod.TOOL_FILE,
belongs_to="assistant",
belongs_to=MessageFileBelongsTo.ASSISTANT,
url=f"/files/tools/{tool_file.id}",
upload_file_id=tool_file.id,
created_by_role=(

View File

@@ -33,7 +33,7 @@ from extensions.ext_redis import get_pubsub_broadcast_channel
from libs.broadcast_channel.channel import Topic
from libs.datetime_utils import naive_utc_now
from models import Account
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, MessageFileBelongsTo
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
from services.errors.app_model_config import AppModelConfigBrokenError
from services.errors.conversation import ConversationNotExistsError
@@ -225,7 +225,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
message_id=message.id,
type=file.type,
transfer_method=file.transfer_method,
belongs_to="user",
belongs_to=MessageFileBelongsTo.USER,
url=file.remote_url,
upload_file_id=file.related_id,
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),

View File

@@ -84,6 +84,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
call_depth=self.application_generate_entity.call_depth,
root_node_id=self._root_node_id,
)
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
@@ -91,6 +92,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
workflow=self._workflow,
single_iteration_run=self.application_generate_entity.single_iteration_run,
single_loop_run=self.application_generate_entity.single_loop_run,
call_depth=self.application_generate_entity.call_depth,
)
else:
inputs = self.application_generate_entity.inputs
@@ -120,6 +122,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
user_id=self.application_generate_entity.user_id,
user_from=user_from,
invoke_from=invoke_from,
call_depth=self.application_generate_entity.call_depth,
root_node_id=self._root_node_id,
)

View File

@@ -102,6 +102,7 @@ class WorkflowBasedAppRunner:
graph_runtime_state: GraphRuntimeState,
user_from: UserFrom,
invoke_from: InvokeFrom,
call_depth: int = 0,
workflow_id: str = "",
tenant_id: str = "",
user_id: str = "",
@@ -130,7 +131,7 @@ class WorkflowBasedAppRunner:
user_from=user_from,
invoke_from=invoke_from,
),
call_depth=0,
call_depth=call_depth,
)
# Use the provided graph_runtime_state for consistent state management
@@ -156,6 +157,7 @@ class WorkflowBasedAppRunner:
workflow: Workflow,
single_iteration_run: Any | None = None,
single_loop_run: Any | None = None,
call_depth: int = 0,
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
"""
Prepare graph, variable pool, and runtime state for single node execution
@@ -189,6 +191,7 @@ class WorkflowBasedAppRunner:
node_id=single_iteration_run.node_id,
user_inputs=dict(single_iteration_run.inputs),
graph_runtime_state=graph_runtime_state,
call_depth=call_depth,
node_type_filter_key="iteration_id",
node_type_label="iteration",
)
@@ -198,6 +201,7 @@ class WorkflowBasedAppRunner:
node_id=single_loop_run.node_id,
user_inputs=dict(single_loop_run.inputs),
graph_runtime_state=graph_runtime_state,
call_depth=call_depth,
node_type_filter_key="loop_id",
node_type_label="loop",
)
@@ -214,6 +218,7 @@ class WorkflowBasedAppRunner:
node_id: str,
user_inputs: dict[str, Any],
graph_runtime_state: GraphRuntimeState,
call_depth: int,
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
) -> tuple[Graph, VariablePool]:
@@ -283,7 +288,7 @@ class WorkflowBasedAppRunner:
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
),
call_depth=0,
call_depth=call_depth,
)
node_factory = DifyNodeFactory(

View File

@@ -34,6 +34,7 @@ from core.llm_generator.llm_generator import LLMGenerator
from core.tools.signature import sign_tool_file
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.enums import MessageFileBelongsTo
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
from services.annotation_service import AppAnnotationService
@@ -233,7 +234,7 @@ class MessageCycleManager:
task_id=self._application_generate_entity.task_id,
id=message_file.id,
type=message_file.type,
belongs_to=message_file.belongs_to or "user",
belongs_to=message_file.belongs_to or MessageFileBelongsTo.USER,
url=url,
)

View File

@@ -34,7 +34,7 @@ from core.tools.workflow_as_tool.tool import WorkflowTool
from dify_graph.file import FileType
from dify_graph.file.models import FileTransferMethod
from extensions.ext_database import db
from models.enums import CreatorUserRole
from models.enums import CreatorUserRole, MessageFileBelongsTo
from models.model import Message, MessageFile
logger = logging.getLogger(__name__)
@@ -352,7 +352,7 @@ class ToolEngine:
message_id=agent_message.id,
type=file_type,
transfer_method=FileTransferMethod.TOOL_FILE,
belongs_to="assistant",
belongs_to=MessageFileBelongsTo.ASSISTANT,
url=message.url,
upload_file_id=tool_file_id,
created_by_role=(

View File

@@ -282,6 +282,7 @@ class DifyNodeFactory(NodeFactory):
max_text_size=dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE,
ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
secret_key=dify_config.SECRET_KEY,
)
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id)

View File

@@ -0,0 +1,27 @@
"""Helpers for workflow recursion depth propagation.
The HTTP request node emits a reserved depth header pair on outbound requests,
and ``services.trigger.webhook_service`` validates that pair when a webhook is
received. The signature binds the propagated depth to the concrete HTTP method
and request path so a depth value captured for one endpoint cannot be replayed
verbatim against another path.
"""
import hashlib
import hmac
def build_workflow_call_depth_signature(*, secret_key: str, method: str, path: str, depth: str) -> str:
"""Build the stable HMAC payload for workflow call-depth propagation.
Args:
secret_key: Shared signing key used by both sender and receiver.
method: Outbound or inbound HTTP method.
path: Request path that the signature is bound to.
depth: Workflow call depth value serialized as a string.
Returns:
Hex-encoded HMAC-SHA256 digest for the method/path/depth tuple.
"""
payload = f"{method.upper()}:{path}:{depth}"
return hmac.new(secret_key.encode("utf-8"), payload.encode("utf-8"), hashlib.sha256).hexdigest()

View File

@@ -2,3 +2,8 @@ SYSTEM_VARIABLE_NODE_ID = "sys"
ENVIRONMENT_VARIABLE_NODE_ID = "env"
CONVERSATION_VARIABLE_NODE_ID = "conversation"
RAG_PIPELINE_VARIABLE_NODE_ID = "rag"
# Reserved for internal workflow-to-workflow HTTP calls. External callers should
# not rely on or set this header.
WORKFLOW_CALL_DEPTH_HEADER = "X-Dify-Workflow-Call-Depth"
WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER = "X-Dify-Workflow-Call-Depth-Signature"

View File

@@ -12,6 +12,7 @@ def build_http_request_config(
max_text_size: int = 1 * 1024 * 1024,
ssl_verify: bool = True,
ssrf_default_max_retries: int = 3,
secret_key: str = "",
) -> HttpRequestNodeConfig:
return HttpRequestNodeConfig(
max_connect_timeout=max_connect_timeout,
@@ -21,6 +22,7 @@ def build_http_request_config(
max_text_size=max_text_size,
ssl_verify=ssl_verify,
ssrf_default_max_retries=ssrf_default_max_retries,
secret_key=secret_key,
)

View File

@@ -76,6 +76,7 @@ class HttpRequestNodeConfig:
max_text_size: int
ssl_verify: bool
ssrf_default_max_retries: int
secret_key: str = ""
def default_timeout(self) -> "HttpRequestNodeTimeout":
return HttpRequestNodeTimeout(

View File

@@ -1,3 +1,11 @@
"""HTTP request execution helpers for workflow nodes.
Besides normal request assembly, this executor is responsible for propagating
workflow recursion depth across outbound HTTP calls. The reserved call-depth
headers are always regenerated from the current node context so user-supplied
values cannot override or poison the propagation contract.
"""
import base64
import json
import secrets
@@ -10,6 +18,8 @@ from urllib.parse import urlencode, urlparse
import httpx
from json_repair import repair_json
from dify_graph.call_depth import build_workflow_call_depth_signature
from dify_graph.constants import WORKFLOW_CALL_DEPTH_HEADER, WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER
from dify_graph.file.enums import FileTransferMethod
from dify_graph.runtime import VariablePool
from dify_graph.variables.segments import ArrayFileSegment, FileSegment
@@ -41,6 +51,8 @@ BODY_TYPE_TO_CONTENT_TYPE = {
class Executor:
"""Prepare, execute, and log a workflow HTTP request node invocation."""
method: Literal[
"get",
"head",
@@ -77,6 +89,7 @@ class Executor:
timeout: HttpRequestNodeTimeout,
variable_pool: VariablePool,
http_request_config: HttpRequestNodeConfig,
workflow_call_depth: int = 0,
max_retries: int | None = None,
ssl_verify: bool | None = None,
http_client: HttpClientProtocol,
@@ -120,6 +133,7 @@ class Executor:
# init template
self.variable_pool = variable_pool
self.node_data = node_data
self.workflow_call_depth = workflow_call_depth
self._initialize()
def _initialize(self):
@@ -272,8 +286,33 @@ class Executor:
self.data = form_data
def _assembling_headers(self) -> dict[str, Any]:
"""Assemble outbound headers for the request.
Reserved workflow call-depth headers are removed case-insensitively
before the canonical pair is re-added from ``workflow_call_depth``.
The signature path mirrors Flask request matching, so URLs without an
explicit path are normalized to ``/`` before signing. This keeps
propagation deterministic even if a workflow author manually configured
colliding headers on the node.
"""
authorization = deepcopy(self.auth)
headers = deepcopy(self.headers) or {}
reserved_header_names = {
WORKFLOW_CALL_DEPTH_HEADER.lower(),
WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER.lower(),
}
headers = {k: v for k, v in headers.items() if k.lower() not in reserved_header_names}
parsed_url = urlparse(self.url)
signed_path = parsed_url.path or "/"
next_call_depth = str(self.workflow_call_depth + 1)
headers[WORKFLOW_CALL_DEPTH_HEADER] = next_call_depth
headers[WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER] = build_workflow_call_depth_signature(
secret_key=self._http_request_config.secret_key,
method=self.method,
path=signed_path,
depth=next_call_depth,
)
if self.auth.type == "api-key":
if self.auth.config is None:
raise AuthorizationConfigError("self.authorization config is required")
@@ -388,6 +427,12 @@ class Executor:
return self._validate_and_parse_response(response)
def to_log(self):
"""Render the request in raw HTTP form for node logs.
Internal workflow call-depth headers and authentication headers are
masked so operational logs remain useful without exposing replayable or
credential-bearing values.
"""
url_parts = urlparse(self.url)
path = url_parts.path or "/"
@@ -410,6 +455,12 @@ class Executor:
if body.type == "form-data":
headers["Content-Type"] = f"multipart/form-data; boundary={boundary}"
for k, v in headers.items():
if k.lower() in {
WORKFLOW_CALL_DEPTH_HEADER.lower(),
WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER.lower(),
}:
raw += f"{k}: [internal]\r\n"
continue
if self.auth.type == "api-key":
authorization_header = "Authorization"
if self.auth.config and self.auth.config.header:

View File

@@ -101,6 +101,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
timeout=self._get_request_timeout(self.node_data),
variable_pool=self.graph_runtime_state.variable_pool,
http_request_config=self._http_request_config,
workflow_call_depth=self.workflow_call_depth,
ssl_verify=self.node_data.ssl_verify,
http_client=self._http_client,
file_manager=self._file_manager,

View File

@@ -158,6 +158,13 @@ class FeedbackFromSource(StrEnum):
ADMIN = "admin"
class FeedbackRating(StrEnum):
"""MessageFeedback rating"""
LIKE = "like"
DISLIKE = "dislike"
class InvokeFrom(StrEnum):
"""How a conversation/message was invoked"""

View File

@@ -36,7 +36,10 @@ from .enums import (
BannerStatus,
ConversationStatus,
CreatorUserRole,
FeedbackFromSource,
FeedbackRating,
MessageChainType,
MessageFileBelongsTo,
MessageStatus,
)
from .provider_ids import GenericProviderID
@@ -1165,7 +1168,7 @@ class Conversation(Base):
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "like",
MessageFeedback.rating == FeedbackRating.LIKE,
)
)
or 0
@@ -1176,7 +1179,7 @@ class Conversation(Base):
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "user",
MessageFeedback.rating == "dislike",
MessageFeedback.rating == FeedbackRating.DISLIKE,
)
)
or 0
@@ -1191,7 +1194,7 @@ class Conversation(Base):
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "like",
MessageFeedback.rating == FeedbackRating.LIKE,
)
)
or 0
@@ -1202,7 +1205,7 @@ class Conversation(Base):
select(func.count(MessageFeedback.id)).where(
MessageFeedback.conversation_id == self.id,
MessageFeedback.from_source == "admin",
MessageFeedback.rating == "dislike",
MessageFeedback.rating == FeedbackRating.DISLIKE,
)
)
or 0
@@ -1725,8 +1728,8 @@ class MessageFeedback(TypeBase):
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
rating: Mapped[str] = mapped_column(String(255), nullable=False)
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
rating: Mapped[FeedbackRating] = mapped_column(EnumText(FeedbackRating, length=255), nullable=False)
from_source: Mapped[FeedbackFromSource] = mapped_column(EnumText(FeedbackFromSource, length=255), nullable=False)
content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
from_account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
@@ -1779,7 +1782,9 @@ class MessageFile(TypeBase):
)
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None)
belongs_to: Mapped[MessageFileBelongsTo | None] = mapped_column(
EnumText(MessageFileBelongsTo, length=255), nullable=True, default=None
)
url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
created_at: Mapped[datetime] = mapped_column(

View File

@@ -7,6 +7,7 @@ from flask import Response
from sqlalchemy import or_
from extensions.ext_database import db
from models.enums import FeedbackRating
from models.model import Account, App, Conversation, Message, MessageFeedback
@@ -100,7 +101,7 @@ class FeedbackService:
"ai_response": message.answer[:500] + "..."
if len(message.answer) > 500
else message.answer, # Truncate long responses
"feedback_rating": "👍" if feedback.rating == "like" else "👎",
"feedback_rating": "👍" if feedback.rating == FeedbackRating.LIKE else "👎",
"feedback_rating_raw": feedback.rating,
"feedback_comment": feedback.content or "",
"feedback_source": feedback.from_source,

View File

@@ -16,6 +16,7 @@ from dify_graph.model_runtime.entities.model_entities import ModelType
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
from repositories.sqlalchemy_execution_extra_content_repository import (
@@ -172,7 +173,7 @@ class MessageService:
app_model: App,
message_id: str,
user: Union[Account, EndUser] | None,
rating: str | None,
rating: FeedbackRating | None,
content: str | None,
):
if not user:
@@ -197,7 +198,7 @@ class MessageService:
message_id=message.id,
rating=rating,
content=content,
from_source=("user" if isinstance(user, EndUser) else "admin"),
from_source=(FeedbackFromSource.USER if isinstance(user, EndUser) else FeedbackFromSource.ADMIN),
from_end_user_id=(user.id if isinstance(user, EndUser) else None),
from_account_id=(user.id if isinstance(user, Account) else None),
)

View File

@@ -402,6 +402,7 @@ class RagPipelineService:
max_text_size=dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE,
ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
secret_key=dify_config.SECRET_KEY,
)
}
default_config = node_class.get_default_config(filters=filters)
@@ -435,6 +436,7 @@ class RagPipelineService:
max_text_size=dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE,
ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
secret_key=dify_config.SECRET_KEY,
)
default_config = node_class.get_default_config(filters=final_filters or None)
if not default_config:

View File

@@ -1,3 +1,4 @@
import hmac
import json
import logging
import mimetypes
@@ -23,6 +24,8 @@ from core.workflow.nodes.trigger_webhook.entities import (
WebhookData,
WebhookParameter,
)
from dify_graph.call_depth import build_workflow_call_depth_signature
from dify_graph.constants import WORKFLOW_CALL_DEPTH_HEADER, WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER
from dify_graph.entities.graph_config import NodeConfigDict
from dify_graph.file.models import FileTransferMethod
from dify_graph.variables.types import ArrayValidation, SegmentType
@@ -57,10 +60,55 @@ class WebhookService:
@staticmethod
def _sanitize_key(key: str) -> str:
"""Normalize external keys (headers/params) to workflow-safe variables."""
if not isinstance(key, str):
return key
return key.replace("-", "_")
@classmethod
def extract_workflow_call_depth(
cls,
headers: Mapping[str, Any],
*,
request_method: str,
request_path: str,
) -> int:
"""Extract the reserved workflow recursion depth header.
The depth header is only trusted when accompanied by a valid HMAC
signature for the current request method/path/depth tuple supplied by the
caller while a request context is available. Header lookup is normalized
case-insensitively so mixed-case spellings still round-trip after headers
are materialized into a plain mapping. Invalid, missing, unsigned, or
negative values are treated as external requests and therefore fall back
to depth 0.
"""
normalized_headers = {str(key).lower(): value for key, value in headers.items()}
raw_value = normalized_headers.get(WORKFLOW_CALL_DEPTH_HEADER.lower())
if raw_value is None:
return 0
raw_signature = normalized_headers.get(WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER.lower())
if raw_signature is None:
return 0
normalized_value = str(raw_value).strip()
# The receiver recomputes the signature from the current request context
# instead of trusting the sender's path or method directly.
expected_signature = build_workflow_call_depth_signature(
secret_key=dify_config.SECRET_KEY,
method=request_method,
path=request_path,
depth=normalized_value,
)
if not hmac.compare_digest(str(raw_signature).strip(), expected_signature):
return 0
try:
call_depth = int(normalized_value)
except (TypeError, ValueError):
return 0
return max(call_depth, 0)
@classmethod
def get_webhook_trigger_and_workflow(
cls, webhook_id: str, is_debug: bool = False
@@ -744,7 +792,12 @@ class WebhookService:
@classmethod
def trigger_workflow_execution(
cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: dict[str, Any], workflow: Workflow
cls,
webhook_trigger: WorkflowWebhookTrigger,
webhook_data: dict[str, Any],
workflow: Workflow,
*,
call_depth: int = 0,
) -> None:
"""Trigger workflow execution via AsyncWorkflowService.
@@ -752,6 +805,8 @@ class WebhookService:
webhook_trigger: The webhook trigger object
webhook_data: Processed webhook data for workflow inputs
workflow: The workflow to execute
call_depth: Validated recursion depth derived earlier from the
incoming request metadata
Raises:
ValueError: If tenant owner is not found
@@ -770,6 +825,7 @@ class WebhookService:
root_node_id=webhook_trigger.node_id, # Start from the webhook node
inputs=workflow_inputs,
tenant_id=webhook_trigger.tenant_id,
call_depth=call_depth,
)
end_user = EndUserService.get_or_create_end_user_by_type(

View File

@@ -26,7 +26,14 @@ class TriggerMetadata(BaseModel):
class TriggerData(BaseModel):
"""Base trigger data model for async workflow execution"""
"""Base trigger data model for async workflow execution.
`call_depth` tracks only the current workflow-to-workflow HTTP recursion
depth. It starts at 0 for external triggers and increments once per nested
webhook-triggered workflow call. For webhook triggers, the value is derived
from the reserved depth headers after `WebhookService.extract_workflow_call_depth`
validates the signature against the inbound request context.
"""
app_id: str
tenant_id: str
@@ -34,6 +41,7 @@ class TriggerData(BaseModel):
root_node_id: str
inputs: Mapping[str, Any]
files: Sequence[Mapping[str, Any]] = Field(default_factory=list)
call_depth: int = 0
trigger_type: AppTriggerType
trigger_from: WorkflowRunTriggeredFrom
trigger_metadata: TriggerMetadata | None = None

View File

@@ -638,6 +638,7 @@ class WorkflowService:
max_text_size=dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE,
ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
secret_key=dify_config.SECRET_KEY,
)
}
default_config = node_class.get_default_config(filters=filters)
@@ -673,6 +674,7 @@ class WorkflowService:
max_text_size=dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE,
ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
secret_key=dify_config.SECRET_KEY,
)
default_config = node_class.get_default_config(filters=resolved_filters or None)
if not default_config:

View File

@@ -164,7 +164,7 @@ def _execute_workflow_common(
args=args,
invoke_from=InvokeFrom.SERVICE_API,
streaming=False,
call_depth=0,
call_depth=trigger_data.call_depth,
triggered_from=trigger_data.trigger_from,
root_node_id=trigger_data.root_node_id,
graph_engine_layers=[

View File

@@ -14,6 +14,7 @@ from controllers.console.app import wraps
from libs.datetime_utils import naive_utc_now
from models import App, Tenant
from models.account import Account, TenantAccountJoin, TenantAccountRole
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import AppMode, MessageFeedback
from services.feedback_service import FeedbackService
@@ -77,8 +78,8 @@ class TestFeedbackExportApi:
app_id=app_id,
conversation_id=conversation_id,
message_id=message_id,
rating="like",
from_source="user",
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.USER,
content=None,
from_end_user_id=str(uuid.uuid4()),
from_account_id=None,
@@ -90,8 +91,8 @@ class TestFeedbackExportApi:
app_id=app_id,
conversation_id=conversation_id,
message_id=message_id,
rating="dislike",
from_source="admin",
rating=FeedbackRating.DISLIKE,
from_source=FeedbackFromSource.ADMIN,
content="The response was not helpful",
from_end_user_id=None,
from_account_id=str(uuid.uuid4()),
@@ -277,8 +278,8 @@ class TestFeedbackExportApi:
# Verify service was called with correct parameters
mock_export_feedbacks.assert_called_once_with(
app_id=mock_app_model.id,
from_source="user",
rating="dislike",
from_source=FeedbackFromSource.USER,
rating=FeedbackRating.DISLIKE,
has_comment=True,
start_date="2024-01-01",
end_date="2024-12-31",

View File

@@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
from core.plugin.impl.exc import PluginDaemonClientSideError
from models import Account
from models.enums import MessageFileBelongsTo
from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought
from services.account_service import AccountService, TenantService
from services.agent_service import AgentService
@@ -852,7 +853,7 @@ class TestAgentService:
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
url="http://example.com/file1.jpg",
belongs_to="user",
belongs_to=MessageFileBelongsTo.USER,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=message.from_account_id,
)
@@ -861,7 +862,7 @@ class TestAgentService:
type=FileType.IMAGE,
transfer_method=FileTransferMethod.REMOTE_URL,
url="http://example.com/file2.png",
belongs_to="user",
belongs_to=MessageFileBelongsTo.USER,
created_by_role=CreatorUserRole.ACCOUNT,
created_by=message.from_account_id,
)

View File

@@ -8,6 +8,7 @@ from unittest import mock
import pytest
from extensions.ext_database import db
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import App, Conversation, Message
from services.feedback_service import FeedbackService
@@ -47,8 +48,8 @@ class TestFeedbackService:
app_id=app_id,
conversation_id="test-conversation-id",
message_id="test-message-id",
rating="like",
from_source="user",
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.USER,
content="Great answer!",
from_end_user_id="user-123",
from_account_id=None,
@@ -61,8 +62,8 @@ class TestFeedbackService:
app_id=app_id,
conversation_id="test-conversation-id",
message_id="test-message-id",
rating="dislike",
from_source="admin",
rating=FeedbackRating.DISLIKE,
from_source=FeedbackFromSource.ADMIN,
content="Could be more detailed",
from_end_user_id=None,
from_account_id="admin-456",
@@ -179,8 +180,8 @@ class TestFeedbackService:
# Test with filters
result = FeedbackService.export_feedbacks(
app_id=sample_data["app"].id,
from_source="admin",
rating="dislike",
from_source=FeedbackFromSource.ADMIN,
rating=FeedbackRating.DISLIKE,
has_comment=True,
start_date="2024-01-01",
end_date="2024-12-31",
@@ -293,8 +294,8 @@ class TestFeedbackService:
app_id=sample_data["app"].id,
conversation_id="test-conversation-id",
message_id="test-message-id",
rating="dislike",
from_source="user",
rating=FeedbackRating.DISLIKE,
from_source=FeedbackFromSource.USER,
content="回答不够详细,需要更多信息",
from_end_user_id="user-123",
from_account_id=None,

View File

@@ -7,6 +7,7 @@ import pytest
from sqlalchemy.orm import Session
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import (
App,
AppAnnotationHitHistory,
@@ -172,8 +173,8 @@ class TestAppMessageExportServiceIntegration:
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="user",
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.USER,
content="first",
from_end_user_id=conversation.from_end_user_id,
)
@@ -181,8 +182,8 @@ class TestAppMessageExportServiceIntegration:
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="dislike",
from_source="user",
rating=FeedbackRating.DISLIKE,
from_source=FeedbackFromSource.USER,
content="second",
from_end_user_id=conversation.from_end_user_id,
)
@@ -190,8 +191,8 @@ class TestAppMessageExportServiceIntegration:
app_id=app.id,
conversation_id=conversation.id,
message_id=first_message.id,
rating="like",
from_source="admin",
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.ADMIN,
content="should-be-filtered",
from_account_id=str(uuid.uuid4()),
)

View File

@@ -4,6 +4,7 @@ import pytest
from faker import Faker
from sqlalchemy.orm import Session
from models.enums import FeedbackRating
from models.model import MessageFeedback
from services.app_service import AppService
from services.errors.message import (
@@ -405,7 +406,7 @@ class TestMessageService:
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Create feedback
rating = "like"
rating = FeedbackRating.LIKE
content = fake.text(max_nb_chars=100)
feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating=rating, content=content
@@ -435,7 +436,11 @@ class TestMessageService:
# Test creating feedback with no user
with pytest.raises(ValueError, match="user cannot be None"):
MessageService.create_feedback(
app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100)
app_model=app,
message_id=message.id,
user=None,
rating=FeedbackRating.LIKE,
content=fake.text(max_nb_chars=100),
)
def test_create_feedback_update_existing(
@@ -452,14 +457,14 @@ class TestMessageService:
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
# Create initial feedback
initial_rating = "like"
initial_rating = FeedbackRating.LIKE
initial_content = fake.text(max_nb_chars=100)
feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating=initial_rating, content=initial_content
)
# Update feedback
updated_rating = "dislike"
updated_rating = FeedbackRating.DISLIKE
updated_content = fake.text(max_nb_chars=100)
updated_feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating=updated_rating, content=updated_content
@@ -487,7 +492,11 @@ class TestMessageService:
# Create initial feedback
feedback = MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating="like", content=fake.text(max_nb_chars=100)
app_model=app,
message_id=message.id,
user=account,
rating=FeedbackRating.LIKE,
content=fake.text(max_nb_chars=100),
)
# Delete feedback by setting rating to None
@@ -538,7 +547,7 @@ class TestMessageService:
app_model=app,
message_id=message.id,
user=account,
rating="like" if i % 2 == 0 else "dislike",
rating=FeedbackRating.LIKE if i % 2 == 0 else FeedbackRating.DISLIKE,
content=f"Feedback {i}: {fake.text(max_nb_chars=50)}",
)
feedbacks.append(feedback)
@@ -568,7 +577,11 @@ class TestMessageService:
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
MessageService.create_feedback(
app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}"
app_model=app,
message_id=message.id,
user=account,
rating=FeedbackRating.LIKE,
content=f"Feedback {i}",
)
# Get feedbacks with pagination

View File

@@ -11,7 +11,7 @@ from sqlalchemy.orm import Session
from enums.cloud_plan import CloudPlan
from extensions.ext_redis import redis_client
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.enums import DataSourceType, MessageChainType
from models.enums import DataSourceType, FeedbackFromSource, FeedbackRating, MessageChainType, MessageFileBelongsTo
from models.model import (
App,
AppAnnotationHitHistory,
@@ -166,7 +166,7 @@ class TestMessagesCleanServiceIntegration:
name="Test conversation",
inputs={},
status="normal",
from_source="api",
from_source=FeedbackFromSource.USER,
from_end_user_id=str(uuid.uuid4()),
)
db_session_with_containers.add(conversation)
@@ -196,7 +196,7 @@ class TestMessagesCleanServiceIntegration:
answer_unit_price=Decimal("0.002"),
total_price=Decimal("0.003"),
currency="USD",
from_source="api",
from_source=FeedbackFromSource.USER,
from_account_id=conversation.from_end_user_id,
created_at=created_at,
)
@@ -216,8 +216,8 @@ class TestMessagesCleanServiceIntegration:
app_id=message.app_id,
conversation_id=message.conversation_id,
message_id=message.id,
rating="like",
from_source="api",
rating=FeedbackRating.LIKE,
from_source=FeedbackFromSource.USER,
from_end_user_id=str(uuid.uuid4()),
)
db_session_with_containers.add(feedback)
@@ -249,7 +249,7 @@ class TestMessagesCleanServiceIntegration:
type="image",
transfer_method="local_file",
url="http://example.com/test.jpg",
belongs_to="user",
belongs_to=MessageFileBelongsTo.USER,
created_by_role="end_user",
created_by=str(uuid.uuid4()),
)

View File

@@ -36,98 +36,7 @@ def unwrap(func):
class TestTenantListApi:
def test_get_success_saas_path(self, app):
api = TenantListApi()
method = unwrap(api.get)
tenant1 = MagicMock(
id="t1",
name="Tenant 1",
status="active",
created_at=datetime.utcnow(),
)
tenant2 = MagicMock(
id="t2",
name="Tenant 2",
status="active",
created_at=datetime.utcnow(),
)
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant1, tenant2],
),
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True),
patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"),
patch(
"controllers.console.workspace.workspace.BillingService.get_plan_bulk",
return_value={
"t1": {"plan": CloudPlan.TEAM, "expiration_date": 0},
"t2": {"plan": CloudPlan.PROFESSIONAL, "expiration_date": 0},
},
) as get_plan_bulk_mock,
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
):
result, status = method(api)
assert status == 200
assert len(result["workspaces"]) == 2
assert result["workspaces"][0]["current"] is True
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
assert result["workspaces"][1]["plan"] == CloudPlan.PROFESSIONAL
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
get_features_mock.assert_not_called()
def test_get_saas_path_falls_back_to_sandbox_for_missing_tenant(self, app):
api = TenantListApi()
method = unwrap(api.get)
tenant1 = MagicMock(
id="t1",
name="Tenant 1",
status="active",
created_at=datetime.utcnow(),
)
tenant2 = MagicMock(
id="t2",
name="Tenant 2",
status="active",
created_at=datetime.utcnow(),
)
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant1, tenant2],
),
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True),
patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"),
patch(
"controllers.console.workspace.workspace.BillingService.get_plan_bulk",
return_value={"t1": {"plan": CloudPlan.TEAM, "expiration_date": 0}},
) as get_plan_bulk_mock,
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
):
result, status = method(api)
assert status == 200
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
assert result["workspaces"][1]["plan"] == CloudPlan.SANDBOX
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
get_features_mock.assert_not_called()
def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app):
def test_get_success(self, app):
api = TenantListApi()
method = unwrap(api.get)
@@ -145,41 +54,27 @@ class TestTenantListApi:
)
features = MagicMock()
features.billing.enabled = False
features.billing.subscription.plan = CloudPlan.TEAM
features.billing.enabled = True
features.billing.subscription.plan = CloudPlan.SANDBOX
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2")
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant1, tenant2],
),
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True),
patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"),
patch(
"controllers.console.workspace.workspace.BillingService.get_plan_bulk",
side_effect=RuntimeError("billing down"),
) as get_plan_bulk_mock,
patch(
"controllers.console.workspace.workspace.FeatureService.get_features",
return_value=features,
) as get_features_mock,
patch("controllers.console.workspace.workspace.logger.exception") as logger_exception_mock,
patch("controllers.console.workspace.workspace.FeatureService.get_features", return_value=features),
):
result, status = method(api)
assert status == 200
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
assert result["workspaces"][1]["plan"] == CloudPlan.TEAM
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
assert get_features_mock.call_count == 2
logger_exception_mock.assert_called_once()
assert len(result["workspaces"]) == 2
assert result["workspaces"][0]["current"] is True
def test_get_billing_disabled_community_path(self, app):
def test_get_billing_disabled(self, app):
api = TenantListApi()
method = unwrap(api.get)
@@ -203,83 +98,15 @@ class TestTenantListApi:
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant],
),
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
patch(
"controllers.console.workspace.workspace.FeatureService.get_features",
return_value=features,
) as get_features_mock,
),
):
result, status = method(api)
assert status == 200
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
get_features_mock.assert_called_once_with("t1")
def test_get_enterprise_only_skips_feature_service(self, app):
api = TenantListApi()
method = unwrap(api.get)
tenant1 = MagicMock(
id="t1",
name="Tenant 1",
status="active",
created_at=datetime.utcnow(),
)
tenant2 = MagicMock(
id="t2",
name="Tenant 2",
status="active",
created_at=datetime.utcnow(),
)
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2")
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[tenant1, tenant2],
),
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", True),
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
):
result, status = method(api)
assert status == 200
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
assert result["workspaces"][1]["plan"] == CloudPlan.SANDBOX
assert result["workspaces"][0]["current"] is False
assert result["workspaces"][1]["current"] is True
get_features_mock.assert_not_called()
def test_get_enterprise_only_with_empty_tenants(self, app):
api = TenantListApi()
method = unwrap(api.get)
with (
app.test_request_context("/workspaces"),
patch(
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), None)
),
patch(
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
return_value=[],
),
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", True),
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False),
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
):
result, status = method(api)
assert status == 200
assert result["workspaces"] == []
get_features_mock.assert_not_called()
class TestWorkspaceListApi:

View File

@@ -31,6 +31,7 @@ from controllers.service_api.app.message import (
MessageListQuery,
MessageSuggestedApi,
)
from models.enums import FeedbackRating
from models.model import App, AppMode, EndUser
from services.errors.conversation import ConversationNotExistsError
from services.errors.message import (
@@ -310,7 +311,7 @@ class TestMessageService:
app_model=Mock(spec=App),
message_id=str(uuid.uuid4()),
user=Mock(spec=EndUser),
rating="like",
rating=FeedbackRating.LIKE,
content="Great response!",
)
@@ -326,7 +327,7 @@ class TestMessageService:
app_model=Mock(spec=App),
message_id="invalid_message_id",
user=Mock(spec=EndUser),
rating="like",
rating=FeedbackRating.LIKE,
content=None,
)

View File

@@ -11,6 +11,7 @@ import controllers.trigger.webhook as module
def mock_request():
module.request = types.SimpleNamespace(
method="POST",
path="/triggers/webhook/wh-1",
headers={"x-test": "1"},
args={"a": "b"},
)
@@ -56,14 +57,17 @@ class TestHandleWebhook:
@patch.object(module.WebhookService, "extract_and_validate_webhook_data")
@patch.object(module.WebhookService, "trigger_workflow_execution")
@patch.object(module.WebhookService, "generate_webhook_response")
@patch.object(module.WebhookService, "extract_workflow_call_depth", return_value=4)
def test_success(
self,
mock_extract_depth,
mock_generate,
mock_trigger,
mock_extract,
mock_get,
):
mock_get.return_value = (DummyWebhookTrigger(), "workflow", "node_config")
webhook_trigger = DummyWebhookTrigger()
mock_get.return_value = (webhook_trigger, "workflow", "node_config")
mock_extract.return_value = {"input": "x"}
mock_generate.return_value = ({"ok": True}, 200)
@@ -71,7 +75,12 @@ class TestHandleWebhook:
assert status == 200
assert response["ok"] is True
mock_trigger.assert_called_once()
mock_extract_depth.assert_called_once_with(
{"x-test": "1"},
request_method="POST",
request_path=module.request.path,
)
mock_trigger.assert_called_once_with(webhook_trigger, {"input": "x"}, "workflow", call_depth=4)
@patch.object(module.WebhookService, "get_webhook_trigger_and_workflow")
@patch.object(module.WebhookService, "extract_and_validate_webhook_data", side_effect=ValueError("bad"))

View File

@@ -124,6 +124,7 @@ class TestWorkflowBasedAppRunner:
node_id="node-1",
user_inputs={},
graph_runtime_state=graph_runtime_state,
call_depth=3,
node_type_filter_key="iteration_id",
node_type_label="iteration",
)
@@ -131,6 +132,35 @@ class TestWorkflowBasedAppRunner:
assert graph is not None
assert variable_pool is graph_runtime_state.variable_pool
def test_init_graph_passes_call_depth_into_node_factory(self, monkeypatch):
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
runtime_state = GraphRuntimeState(
variable_pool=VariablePool(system_variables=SystemVariable.default()),
start_at=0.0,
)
captured: dict[str, int] = {}
class _FakeNodeFactory:
def __init__(self, *, graph_init_params, graph_runtime_state):
captured["call_depth"] = graph_init_params.call_depth
monkeypatch.setattr("core.app.apps.workflow_app_runner.DifyNodeFactory", _FakeNodeFactory)
monkeypatch.setattr(
"core.app.apps.workflow_app_runner.Graph.init",
lambda **kwargs: SimpleNamespace(),
)
graph = runner._init_graph(
graph_config={"nodes": [{"id": "start", "data": {"type": "start", "version": "1"}}], "edges": []},
graph_runtime_state=runtime_state,
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=4,
)
assert graph is not None
assert captured["call_depth"] == 4
def test_handle_graph_run_events_and_pause_notifications(self, monkeypatch):
published: list[object] = []

View File

@@ -100,6 +100,7 @@ def test_run_uses_single_node_execution_branch(
workflow=workflow,
single_iteration_run=single_iteration_run,
single_loop_run=single_loop_run,
call_depth=0,
)
init_graph.assert_not_called()
@@ -156,6 +157,7 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
node_id="loop-node",
user_inputs={},
graph_runtime_state=graph_runtime_state,
call_depth=0,
node_type_filter_key="loop_id",
node_type_label="loop",
)

View File

@@ -2,6 +2,8 @@ import pytest
from configs import dify_config
from core.helper.ssrf_proxy import ssrf_proxy
from dify_graph.call_depth import build_workflow_call_depth_signature
from dify_graph.constants import WORKFLOW_CALL_DEPTH_HEADER, WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER
from dify_graph.file.file_manager import file_manager
from dify_graph.nodes.http_request import (
BodyData,
@@ -24,7 +26,9 @@ HTTP_REQUEST_CONFIG = HttpRequestNodeConfig(
max_text_size=dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE,
ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
secret_key=dify_config.SECRET_KEY,
)
TEST_SECRET_KEY = "test-secret-key"
def test_executor_with_json_body_and_number_variable():
@@ -661,3 +665,320 @@ def test_executor_with_json_body_preserves_numbers_and_strings():
assert executor.json["count"] == 42
assert executor.json["id"] == "abc-123"
def test_executor_propagates_workflow_call_depth_header():
variable_pool = VariablePool(
system_variables=SystemVariable.default(),
user_inputs={},
)
node_data = HttpRequestNodeData(
title="Depth propagation",
method="get",
url="http://localhost:5001/triggers/webhook/test-webhook",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="X-Test: value",
params="",
)
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
http_request_config=HttpRequestNodeConfig(
max_connect_timeout=HTTP_REQUEST_CONFIG.max_connect_timeout,
max_read_timeout=HTTP_REQUEST_CONFIG.max_read_timeout,
max_write_timeout=HTTP_REQUEST_CONFIG.max_write_timeout,
max_binary_size=HTTP_REQUEST_CONFIG.max_binary_size,
max_text_size=HTTP_REQUEST_CONFIG.max_text_size,
ssl_verify=HTTP_REQUEST_CONFIG.ssl_verify,
ssrf_default_max_retries=HTTP_REQUEST_CONFIG.ssrf_default_max_retries,
secret_key=TEST_SECRET_KEY,
),
variable_pool=variable_pool,
workflow_call_depth=2,
http_client=ssrf_proxy,
file_manager=file_manager,
)
headers = executor._assembling_headers()
assert headers["X-Test"] == "value"
assert headers[WORKFLOW_CALL_DEPTH_HEADER] == "3"
assert headers[WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER] == build_workflow_call_depth_signature(
secret_key=TEST_SECRET_KEY,
method="get",
path="/triggers/webhook/test-webhook",
depth="3",
)
def test_executor_replaces_lowercase_reserved_headers_for_internal_webhook_target():
variable_pool = VariablePool(
system_variables=SystemVariable.default(),
user_inputs={},
)
node_data = HttpRequestNodeData(
title="Reserved header replacement",
method="get",
url="http://localhost:5001/triggers/webhook/test-webhook",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers=(
"x-dify-workflow-call-depth: user-value\n"
"x-dify-workflow-call-depth-signature: user-signature\n"
"X-Test: value"
),
params="",
)
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
http_request_config=HttpRequestNodeConfig(
max_connect_timeout=HTTP_REQUEST_CONFIG.max_connect_timeout,
max_read_timeout=HTTP_REQUEST_CONFIG.max_read_timeout,
max_write_timeout=HTTP_REQUEST_CONFIG.max_write_timeout,
max_binary_size=HTTP_REQUEST_CONFIG.max_binary_size,
max_text_size=HTTP_REQUEST_CONFIG.max_text_size,
ssl_verify=HTTP_REQUEST_CONFIG.ssl_verify,
ssrf_default_max_retries=HTTP_REQUEST_CONFIG.ssrf_default_max_retries,
secret_key=TEST_SECRET_KEY,
),
variable_pool=variable_pool,
workflow_call_depth=2,
http_client=ssrf_proxy,
file_manager=file_manager,
)
headers = executor._assembling_headers()
assert headers["X-Test"] == "value"
assert headers[WORKFLOW_CALL_DEPTH_HEADER] == "3"
assert headers[WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER] == build_workflow_call_depth_signature(
secret_key=TEST_SECRET_KEY,
method="get",
path="/triggers/webhook/test-webhook",
depth="3",
)
assert "x-dify-workflow-call-depth" not in headers
assert "x-dify-workflow-call-depth-signature" not in headers
def test_executor_propagates_workflow_call_depth_with_empty_secret():
variable_pool = VariablePool(
system_variables=SystemVariable.default(),
user_inputs={},
)
node_data = HttpRequestNodeData(
title="Depth propagation with empty secret",
method="get",
url="http://localhost:5001/triggers/webhook/test-webhook",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="X-Test: value",
params="",
)
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
http_request_config=HttpRequestNodeConfig(
max_connect_timeout=HTTP_REQUEST_CONFIG.max_connect_timeout,
max_read_timeout=HTTP_REQUEST_CONFIG.max_read_timeout,
max_write_timeout=HTTP_REQUEST_CONFIG.max_write_timeout,
max_binary_size=HTTP_REQUEST_CONFIG.max_binary_size,
max_text_size=HTTP_REQUEST_CONFIG.max_text_size,
ssl_verify=HTTP_REQUEST_CONFIG.ssl_verify,
ssrf_default_max_retries=HTTP_REQUEST_CONFIG.ssrf_default_max_retries,
secret_key="",
),
variable_pool=variable_pool,
workflow_call_depth=2,
http_client=ssrf_proxy,
file_manager=file_manager,
)
headers = executor._assembling_headers()
assert headers[WORKFLOW_CALL_DEPTH_HEADER] == "3"
assert headers[WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER] == build_workflow_call_depth_signature(
secret_key="",
method="get",
path="/triggers/webhook/test-webhook",
depth="3",
)
def test_executor_log_masks_internal_depth_headers():
variable_pool = VariablePool(
system_variables=SystemVariable.default(),
user_inputs={},
)
node_data = HttpRequestNodeData(
title="Depth propagation",
method="get",
url="http://localhost:5001/triggers/webhook/test-webhook",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="X-Test: value",
params="",
)
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
http_request_config=HttpRequestNodeConfig(
max_connect_timeout=HTTP_REQUEST_CONFIG.max_connect_timeout,
max_read_timeout=HTTP_REQUEST_CONFIG.max_read_timeout,
max_write_timeout=HTTP_REQUEST_CONFIG.max_write_timeout,
max_binary_size=HTTP_REQUEST_CONFIG.max_binary_size,
max_text_size=HTTP_REQUEST_CONFIG.max_text_size,
ssl_verify=HTTP_REQUEST_CONFIG.ssl_verify,
ssrf_default_max_retries=HTTP_REQUEST_CONFIG.ssrf_default_max_retries,
secret_key=TEST_SECRET_KEY,
),
variable_pool=variable_pool,
workflow_call_depth=2,
http_client=ssrf_proxy,
file_manager=file_manager,
)
raw_log = executor.to_log()
assert f"{WORKFLOW_CALL_DEPTH_HEADER}: [internal]" in raw_log
assert f"{WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER}: [internal]" in raw_log
assert "X-Dify-Workflow-Call-Depth: 3" not in raw_log
def test_executor_log_masks_reserved_headers_regardless_of_case():
variable_pool = VariablePool(
system_variables=SystemVariable.default(),
user_inputs={},
)
node_data = HttpRequestNodeData(
title="Reserved header replacement",
method="get",
url="http://localhost:5001/triggers/webhook/test-webhook",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers=(
"x-dify-workflow-call-depth: user-value\n"
"x-dify-workflow-call-depth-signature: user-signature\n"
"X-Test: value"
),
params="",
)
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
http_request_config=HttpRequestNodeConfig(
max_connect_timeout=HTTP_REQUEST_CONFIG.max_connect_timeout,
max_read_timeout=HTTP_REQUEST_CONFIG.max_read_timeout,
max_write_timeout=HTTP_REQUEST_CONFIG.max_write_timeout,
max_binary_size=HTTP_REQUEST_CONFIG.max_binary_size,
max_text_size=HTTP_REQUEST_CONFIG.max_text_size,
ssl_verify=HTTP_REQUEST_CONFIG.ssl_verify,
ssrf_default_max_retries=HTTP_REQUEST_CONFIG.ssrf_default_max_retries,
secret_key=TEST_SECRET_KEY,
),
variable_pool=variable_pool,
workflow_call_depth=2,
http_client=ssrf_proxy,
file_manager=file_manager,
)
raw_log = executor.to_log()
assert "x-dify-workflow-call-depth: [internal]" not in raw_log
assert "x-dify-workflow-call-depth-signature: [internal]" not in raw_log
assert f"{WORKFLOW_CALL_DEPTH_HEADER}: [internal]" in raw_log
assert f"{WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER}: [internal]" in raw_log
assert "user-signature" not in raw_log
assert "user-value" not in raw_log
def test_executor_propagates_workflow_call_depth_to_arbitrary_target_with_secret():
variable_pool = VariablePool(
system_variables=SystemVariable.default(),
user_inputs={},
)
node_data = HttpRequestNodeData(
title="External target",
method="get",
url="https://api.example.com/data",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="X-Test: value",
params="",
)
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
http_request_config=HttpRequestNodeConfig(
max_connect_timeout=HTTP_REQUEST_CONFIG.max_connect_timeout,
max_read_timeout=HTTP_REQUEST_CONFIG.max_read_timeout,
max_write_timeout=HTTP_REQUEST_CONFIG.max_write_timeout,
max_binary_size=HTTP_REQUEST_CONFIG.max_binary_size,
max_text_size=HTTP_REQUEST_CONFIG.max_text_size,
ssl_verify=HTTP_REQUEST_CONFIG.ssl_verify,
ssrf_default_max_retries=HTTP_REQUEST_CONFIG.ssrf_default_max_retries,
secret_key=TEST_SECRET_KEY,
),
variable_pool=variable_pool,
workflow_call_depth=2,
http_client=ssrf_proxy,
file_manager=file_manager,
)
headers = executor._assembling_headers()
assert headers["X-Test"] == "value"
assert headers[WORKFLOW_CALL_DEPTH_HEADER] == "3"
assert headers[WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER] == build_workflow_call_depth_signature(
secret_key=TEST_SECRET_KEY,
method="get",
path="/data",
depth="3",
)
def test_executor_normalizes_empty_url_path_when_signing_workflow_call_depth():
variable_pool = VariablePool(
system_variables=SystemVariable.default(),
user_inputs={},
)
node_data = HttpRequestNodeData(
title="External target without explicit path",
method="get",
url="https://api.example.com",
authorization=HttpRequestNodeAuthorization(type="no-auth"),
headers="X-Test: value",
params="",
)
executor = Executor(
node_data=node_data,
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
http_request_config=HttpRequestNodeConfig(
max_connect_timeout=HTTP_REQUEST_CONFIG.max_connect_timeout,
max_read_timeout=HTTP_REQUEST_CONFIG.max_read_timeout,
max_write_timeout=HTTP_REQUEST_CONFIG.max_write_timeout,
max_binary_size=HTTP_REQUEST_CONFIG.max_binary_size,
max_text_size=HTTP_REQUEST_CONFIG.max_text_size,
ssl_verify=HTTP_REQUEST_CONFIG.ssl_verify,
ssrf_default_max_retries=HTTP_REQUEST_CONFIG.ssrf_default_max_retries,
secret_key=TEST_SECRET_KEY,
),
variable_pool=variable_pool,
workflow_call_depth=2,
http_client=ssrf_proxy,
file_manager=file_manager,
)
headers = executor._assembling_headers()
assert headers["X-Test"] == "value"
assert headers[WORKFLOW_CALL_DEPTH_HEADER] == "3"
assert headers[WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER] == build_workflow_call_depth_signature(
secret_key=TEST_SECRET_KEY,
method="get",
path="/",
depth="3",
)

View File

@@ -2,6 +2,8 @@
from unittest.mock import MagicMock, patch
import pytest
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
from core.workflow.workflow_entry import WorkflowEntry
from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel
@@ -134,3 +136,27 @@ class TestWorkflowEntryRedisChannel:
assert len(events) == 2
assert events[0] == mock_event1
assert events[1] == mock_event2
def test_workflow_entry_rejects_depth_over_limit(self):
mock_graph = MagicMock()
mock_variable_pool = MagicMock(spec=VariablePool)
mock_graph_runtime_state = MagicMock(spec=GraphRuntimeState)
mock_graph_runtime_state.variable_pool = mock_variable_pool
with (
patch("core.workflow.workflow_entry.dify_config.WORKFLOW_CALL_MAX_DEPTH", 1),
pytest.raises(ValueError, match="Max workflow call depth 1 reached."),
):
WorkflowEntry(
tenant_id="test-tenant",
app_id="test-app",
workflow_id="test-workflow",
graph_config={"nodes": [], "edges": []},
graph=mock_graph,
user_id="test-user",
user_from=UserFrom.ACCOUNT,
invoke_from=InvokeFrom.DEBUGGER,
call_depth=2,
variable_pool=mock_variable_pool,
graph_runtime_state=mock_graph_runtime_state,
)

View File

@@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch
import pytest
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models.enums import FeedbackFromSource, FeedbackRating
from models.model import App, AppMode, EndUser, Message
from services.errors.message import (
FirstMessageNotExistsError,
@@ -820,14 +821,14 @@ class TestMessageServiceFeedback:
app_model=app,
message_id="msg-123",
user=user,
rating="like",
rating=FeedbackRating.LIKE,
content="Good answer",
)
# Assert
assert result.rating == "like"
assert result.rating == FeedbackRating.LIKE
assert result.content == "Good answer"
assert result.from_source == "user"
assert result.from_source == FeedbackFromSource.USER
mock_db.session.add.assert_called_once()
mock_db.session.commit.assert_called_once()
@@ -852,13 +853,13 @@ class TestMessageServiceFeedback:
app_model=app,
message_id="msg-123",
user=user,
rating="dislike",
rating=FeedbackRating.DISLIKE,
content="Bad answer",
)
# Assert
assert result == feedback
assert feedback.rating == "dislike"
assert feedback.rating == FeedbackRating.DISLIKE
assert feedback.content == "Bad answer"
mock_db.session.commit.assert_called_once()

View File

@@ -5,8 +5,13 @@ import pytest
from flask import Flask
from werkzeug.datastructures import FileStorage
from configs import dify_config
from dify_graph.call_depth import build_workflow_call_depth_signature
from dify_graph.constants import WORKFLOW_CALL_DEPTH_HEADER, WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER
from services.trigger.webhook_service import WebhookService
TEST_SECRET_KEY = "test-secret-key"
class TestWebhookServiceUnit:
"""Unit tests for WebhookService focusing on business logic without database dependencies."""
@@ -559,3 +564,266 @@ class TestWebhookServiceUnit:
result = _prepare_webhook_execution("test_webhook", is_debug=True)
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)
def test_extract_workflow_call_depth_defaults_to_zero_for_invalid_values(self):
assert WebhookService.extract_workflow_call_depth({}, request_method="POST", request_path="/webhook") == 0
assert (
WebhookService.extract_workflow_call_depth(
{WORKFLOW_CALL_DEPTH_HEADER: "abc"},
request_method="POST",
request_path="/webhook",
)
== 0
)
assert (
WebhookService.extract_workflow_call_depth(
{WORKFLOW_CALL_DEPTH_HEADER.lower(): "-1"},
request_method="POST",
request_path="/webhook",
)
== 0
)
def test_extract_workflow_call_depth_ignores_unsigned_external_header(self):
assert (
WebhookService.extract_workflow_call_depth(
{WORKFLOW_CALL_DEPTH_HEADER: "5"},
request_method="POST",
request_path="/webhook",
)
== 0
)
def test_extract_workflow_call_depth_honors_signed_internal_header(self):
with patch("services.trigger.webhook_service.dify_config.SECRET_KEY", TEST_SECRET_KEY):
signature = build_workflow_call_depth_signature(
secret_key=TEST_SECRET_KEY,
method="POST",
path="/triggers/webhook/test-webhook",
depth="4",
)
assert (
WebhookService.extract_workflow_call_depth(
{
WORKFLOW_CALL_DEPTH_HEADER: "4",
WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER: signature,
},
request_method="POST",
request_path="/triggers/webhook/test-webhook",
)
== 4
)
def test_extract_workflow_call_depth_accepts_mixed_case_reserved_headers(self):
with patch("services.trigger.webhook_service.dify_config.SECRET_KEY", TEST_SECRET_KEY):
signature = build_workflow_call_depth_signature(
secret_key=TEST_SECRET_KEY,
method="POST",
path="/triggers/webhook/test-webhook",
depth="4",
)
assert (
WebhookService.extract_workflow_call_depth(
{
"X-Dify-Workflow-Call-Depth": "4",
"X-Dify-Workflow-Call-Depth-Signature": signature,
},
request_method="POST",
request_path="/triggers/webhook/test-webhook",
)
== 4
)
def test_extract_workflow_call_depth_rejects_signature_for_other_path(self):
with patch("services.trigger.webhook_service.dify_config.SECRET_KEY", TEST_SECRET_KEY):
wrong_signature = build_workflow_call_depth_signature(
secret_key=TEST_SECRET_KEY,
method="POST",
path="/triggers/webhook/wrong-webhook",
depth="4",
)
assert (
WebhookService.extract_workflow_call_depth(
{
WORKFLOW_CALL_DEPTH_HEADER: "4",
WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER: wrong_signature,
},
request_method="POST",
request_path="/triggers/webhook/right-webhook",
)
== 0
)
@patch("services.trigger.webhook_service.dify_config")
def test_extract_workflow_call_depth_honors_signature_with_empty_secret(self, mock_config):
mock_config.SECRET_KEY = ""
signature = build_workflow_call_depth_signature(
secret_key="",
method="POST",
path="/triggers/webhook/test-webhook",
depth="4",
)
assert (
WebhookService.extract_workflow_call_depth(
{
WORKFLOW_CALL_DEPTH_HEADER: "4",
WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER: signature,
},
request_method="POST",
request_path="/triggers/webhook/test-webhook",
)
== 4
)
@patch("services.trigger.webhook_service.QuotaType")
@patch("services.trigger.webhook_service.EndUserService")
@patch("services.trigger.webhook_service.AsyncWorkflowService")
@patch("services.trigger.webhook_service.Session")
@patch("services.trigger.webhook_service.db")
def test_trigger_workflow_execution_preserves_header_depth(
self,
mock_db,
mock_session,
mock_async_workflow_service,
mock_end_user_service,
mock_quota_type,
):
webhook_trigger = MagicMock(app_id="app", tenant_id="tenant", node_id="root", webhook_id="webhook")
workflow = MagicMock(id="workflow")
mock_end_user = MagicMock()
mock_end_user_service.get_or_create_end_user_by_type.return_value = mock_end_user
mock_db.engine = MagicMock()
mock_session.return_value.__enter__.return_value = MagicMock()
signature = build_workflow_call_depth_signature(
secret_key=TEST_SECRET_KEY,
method="POST",
path="/triggers/webhook/test-webhook",
depth="4",
)
with patch("services.trigger.webhook_service.dify_config.SECRET_KEY", TEST_SECRET_KEY):
WebhookService.trigger_workflow_execution(
webhook_trigger,
{"method": "POST", "headers": {}, "query_params": {}, "body": {}, "files": {}},
workflow,
call_depth=WebhookService.extract_workflow_call_depth(
{
WORKFLOW_CALL_DEPTH_HEADER: "4",
WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER: signature,
},
request_method="POST",
request_path="/triggers/webhook/test-webhook",
),
)
trigger_data = mock_async_workflow_service.trigger_workflow_async.call_args.args[2]
assert trigger_data.call_depth == 4
@patch("services.trigger.webhook_service.QuotaType")
@patch("services.trigger.webhook_service.EndUserService")
@patch("services.trigger.webhook_service.AsyncWorkflowService")
@patch("services.trigger.webhook_service.Session")
@patch("services.trigger.webhook_service.db")
def test_trigger_workflow_execution_ignores_spoofed_external_depth(
self,
mock_db,
mock_session,
mock_async_workflow_service,
mock_end_user_service,
mock_quota_type,
):
webhook_trigger = MagicMock(app_id="app", tenant_id="tenant", node_id="root", webhook_id="webhook")
workflow = MagicMock(id="workflow")
mock_end_user_service.get_or_create_end_user_by_type.return_value = MagicMock()
mock_db.engine = MagicMock()
mock_session.return_value.__enter__.return_value = MagicMock()
WebhookService.trigger_workflow_execution(
webhook_trigger,
{"method": "POST", "headers": {}, "query_params": {}, "body": {}, "files": {}},
workflow,
call_depth=WebhookService.extract_workflow_call_depth(
{WORKFLOW_CALL_DEPTH_HEADER: "5"},
request_method="POST",
request_path="/triggers/webhook/test-webhook",
),
)
trigger_data = mock_async_workflow_service.trigger_workflow_async.call_args.args[2]
assert trigger_data.call_depth == 0
@patch("services.trigger.webhook_service.QuotaType")
@patch("services.trigger.webhook_service.EndUserService")
@patch("services.trigger.webhook_service.AsyncWorkflowService")
@patch("services.trigger.webhook_service.Session")
@patch("services.trigger.webhook_service.db")
def test_trigger_workflow_execution_rejects_signature_captured_from_non_webhook_request(
self,
mock_db,
mock_session,
mock_async_workflow_service,
mock_end_user_service,
mock_quota_type,
):
webhook_trigger = MagicMock(app_id="app", tenant_id="tenant", node_id="root", webhook_id="webhook")
workflow = MagicMock(id="workflow")
mock_end_user_service.get_or_create_end_user_by_type.return_value = MagicMock()
mock_db.engine = MagicMock()
mock_session.return_value.__enter__.return_value = MagicMock()
captured_signature = build_workflow_call_depth_signature(
secret_key=dify_config.SECRET_KEY,
method="GET",
path="/v1/external-endpoint",
depth="5",
)
WebhookService.trigger_workflow_execution(
webhook_trigger,
{"method": "POST", "headers": {}, "query_params": {}, "body": {}, "files": {}},
workflow,
call_depth=WebhookService.extract_workflow_call_depth(
{
WORKFLOW_CALL_DEPTH_HEADER: "5",
WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER: captured_signature,
},
request_method="POST",
request_path="/triggers/webhook/test-webhook",
),
)
trigger_data = mock_async_workflow_service.trigger_workflow_async.call_args.args[2]
assert trigger_data.call_depth == 0
@patch("services.trigger.webhook_service.QuotaType")
@patch("services.trigger.webhook_service.EndUserService")
@patch("services.trigger.webhook_service.AsyncWorkflowService")
@patch("services.trigger.webhook_service.Session")
@patch("services.trigger.webhook_service.db")
def test_trigger_workflow_execution_does_not_require_request_context_when_call_depth_is_passed(
self,
mock_db,
mock_session,
mock_async_workflow_service,
mock_end_user_service,
mock_quota_type,
):
webhook_trigger = MagicMock(app_id="app", tenant_id="tenant", node_id="root", webhook_id="webhook")
workflow = MagicMock(id="workflow")
mock_end_user_service.get_or_create_end_user_by_type.return_value = MagicMock()
mock_db.engine = MagicMock()
mock_session.return_value.__enter__.return_value = MagicMock()
WebhookService.trigger_workflow_execution(
webhook_trigger,
{"method": "POST", "headers": {}, "query_params": {}, "body": {}, "files": {}},
workflow,
call_depth=4,
)
trigger_data = mock_async_workflow_service.trigger_workflow_async.call_args.args[2]
assert trigger_data.call_depth == 4

View File

@@ -1,3 +1,5 @@
from unittest.mock import MagicMock, patch
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
from services.workflow.entities import WebhookTriggerData
from tasks import async_workflow_tasks
@@ -16,3 +18,41 @@ def test_build_generator_args_sets_skip_flag_for_webhook():
assert args[SKIP_PREPARE_USER_INPUTS_KEY] is True
assert args["inputs"]["webhook_data"]["body"]["foo"] == "bar"
def test_execute_workflow_common_uses_trigger_call_depth():
trigger_data = WebhookTriggerData(
app_id="app",
tenant_id="tenant",
workflow_id="workflow",
root_node_id="node",
inputs={"webhook_data": {"body": {}}},
call_depth=3,
)
trigger_log = MagicMock(
id="log-id",
app_id="app",
workflow_id="workflow",
trigger_data=trigger_data.model_dump_json(),
)
trigger_log_repo = MagicMock()
trigger_log_repo.get_by_id.return_value = trigger_log
session = MagicMock()
session.scalar.side_effect = [MagicMock(), MagicMock()]
session_context = MagicMock()
session_context.__enter__.return_value = session
workflow_generator = MagicMock()
with (
patch.object(async_workflow_tasks.session_factory, "create_session", return_value=session_context),
patch.object(async_workflow_tasks, "SQLAlchemyWorkflowTriggerLogRepository", return_value=trigger_log_repo),
patch.object(async_workflow_tasks, "_get_user", return_value=MagicMock()),
patch.object(async_workflow_tasks, "WorkflowAppGenerator", return_value=workflow_generator),
):
async_workflow_tasks._execute_workflow_common(
async_workflow_tasks.WorkflowTaskData(workflow_trigger_log_id="log-id"),
MagicMock(),
MagicMock(),
)
assert workflow_generator.generate.call_args.kwargs["call_depth"] == 3