mirror of
https://github.com/langgenius/dify.git
synced 2026-03-19 22:52:01 +00:00
Compare commits
4 Commits
refactor/w
...
focal-quok
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5574802631 | ||
|
|
b8cedefd7d | ||
|
|
4ecba5858b | ||
|
|
5b9cb55c45 |
@@ -30,6 +30,7 @@ from fields.raws import FilesContainedField
|
||||
from libs.helper import TimestampField, uuid_value
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
@@ -335,7 +336,7 @@ class MessageFeedbackApi(Resource):
|
||||
if not args.rating and feedback:
|
||||
db.session.delete(feedback)
|
||||
elif args.rating and feedback:
|
||||
feedback.rating = args.rating
|
||||
feedback.rating = FeedbackRating(args.rating)
|
||||
feedback.content = args.content
|
||||
elif not args.rating and not feedback:
|
||||
raise ValueError("rating cannot be None when feedback not exists")
|
||||
@@ -347,9 +348,9 @@ class MessageFeedbackApi(Resource):
|
||||
app_id=app_model.id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating=rating_value,
|
||||
rating=FeedbackRating(rating_value),
|
||||
content=args.content,
|
||||
from_source="admin",
|
||||
from_source=FeedbackFromSource.ADMIN,
|
||||
from_account_id=current_user.id,
|
||||
)
|
||||
db.session.add(feedback)
|
||||
|
||||
@@ -27,6 +27,7 @@ from fields.message_fields import MessageInfiniteScrollPagination, MessageListIt
|
||||
from libs import helper
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from libs.login import current_account_with_tenant
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
@@ -116,7 +117,7 @@ class MessageFeedbackApi(InstalledAppResource):
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=current_user,
|
||||
rating=payload.rating,
|
||||
rating=FeedbackRating(payload.rating) if payload.rating else None,
|
||||
content=payload.content,
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
|
||||
@@ -7,7 +7,6 @@ from sqlalchemy import select
|
||||
from werkzeug.exceptions import Unauthorized
|
||||
|
||||
import services
|
||||
from configs import dify_config
|
||||
from controllers.common.errors import (
|
||||
FilenameNotExistsError,
|
||||
FileTooLargeError,
|
||||
@@ -30,7 +29,6 @@ from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.account import Tenant, TenantStatus
|
||||
from services.account_service import TenantService
|
||||
from services.billing_service import BillingService
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import FeatureService
|
||||
from services.file_service import FileService
|
||||
@@ -110,27 +108,9 @@ class TenantListApi(Resource):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
tenants = TenantService.get_join_tenants(current_user)
|
||||
tenant_dicts = []
|
||||
is_enterprise_only = dify_config.ENTERPRISE_ENABLED and not dify_config.BILLING_ENABLED
|
||||
is_saas = dify_config.EDITION == "CLOUD" and dify_config.BILLING_ENABLED
|
||||
tenant_plans: dict[str, dict] = {}
|
||||
use_legacy_feature_path = not is_enterprise_only and not is_saas
|
||||
|
||||
if is_saas:
|
||||
tenant_ids = [tenant.id for tenant in tenants]
|
||||
if tenant_ids:
|
||||
try:
|
||||
tenant_plans = BillingService.get_plan_bulk(tenant_ids)
|
||||
except Exception:
|
||||
logger.exception("failed to fetch workspace plans in bulk, falling back to legacy feature path")
|
||||
use_legacy_feature_path = True
|
||||
|
||||
for tenant in tenants:
|
||||
plan = CloudPlan.SANDBOX
|
||||
if is_saas and not use_legacy_feature_path:
|
||||
plan = tenant_plans.get(tenant.id, {}).get("plan", CloudPlan.SANDBOX)
|
||||
elif not is_enterprise_only:
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
plan = features.billing.subscription.plan or CloudPlan.SANDBOX
|
||||
features = FeatureService.get_features(tenant.id)
|
||||
|
||||
# Create a dictionary with tenant attributes
|
||||
tenant_dict = {
|
||||
@@ -138,7 +118,7 @@ class TenantListApi(Resource):
|
||||
"name": tenant.name,
|
||||
"status": tenant.status,
|
||||
"created_at": tenant.created_at,
|
||||
"plan": plan,
|
||||
"plan": features.billing.subscription.plan if features.billing.enabled else CloudPlan.SANDBOX,
|
||||
"current": tenant.id == current_tenant_id if current_tenant_id else False,
|
||||
}
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import MessageInfiniteScrollPagination, MessageListItem
|
||||
from libs.helper import UUIDStrOrEmpty
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
@@ -116,7 +117,7 @@ class MessageFeedbackApi(Resource):
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=end_user,
|
||||
rating=payload.rating,
|
||||
rating=FeedbackRating(payload.rating) if payload.rating else None,
|
||||
content=payload.content,
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
|
||||
@@ -52,8 +52,19 @@ def handle_webhook(webhook_id: str):
|
||||
if error:
|
||||
return jsonify({"error": "Bad Request", "message": error}), 400
|
||||
|
||||
trigger_call_depth = WebhookService.extract_workflow_call_depth(
|
||||
dict(request.headers),
|
||||
request_method=request.method,
|
||||
request_path=request.path,
|
||||
)
|
||||
|
||||
# Process webhook call (send to Celery)
|
||||
WebhookService.trigger_workflow_execution(webhook_trigger, webhook_data, workflow)
|
||||
WebhookService.trigger_workflow_execution(
|
||||
webhook_trigger,
|
||||
webhook_data,
|
||||
workflow,
|
||||
call_depth=trigger_call_depth,
|
||||
)
|
||||
|
||||
# Return configured response
|
||||
response_data, status_code = WebhookService.generate_webhook_response(node_config)
|
||||
|
||||
@@ -25,6 +25,7 @@ from fields.conversation_fields import ResultResponse
|
||||
from fields.message_fields import SuggestedQuestionsResponse, WebMessageInfiniteScrollPagination, WebMessageListItem
|
||||
from libs import helper
|
||||
from libs.helper import uuid_value
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import AppMode
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import MoreLikeThisDisabledError
|
||||
@@ -157,7 +158,7 @@ class MessageFeedbackApi(WebApiResource):
|
||||
app_model=app_model,
|
||||
message_id=message_id,
|
||||
user=end_user,
|
||||
rating=payload.rating,
|
||||
rating=FeedbackRating(payload.rating) if payload.rating else None,
|
||||
content=payload.content,
|
||||
)
|
||||
except MessageNotExistsError:
|
||||
|
||||
@@ -76,7 +76,7 @@ from dify_graph.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models.enums import CreatorUserRole, MessageStatus
|
||||
from models.enums import CreatorUserRole, MessageFileBelongsTo, MessageStatus
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
from models.workflow import Workflow
|
||||
|
||||
@@ -939,7 +939,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
type=file["type"],
|
||||
transfer_method=file["transfer_method"],
|
||||
url=file["remote_url"],
|
||||
belongs_to="assistant",
|
||||
belongs_to=MessageFileBelongsTo.ASSISTANT,
|
||||
upload_file_id=file["related_id"],
|
||||
created_by_role=CreatorUserRole.ACCOUNT
|
||||
if message.invoke_from in {InvokeFrom.EXPLORE, InvokeFrom.DEBUGGER}
|
||||
|
||||
@@ -40,7 +40,7 @@ from dify_graph.model_runtime.entities.message_entities import (
|
||||
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey
|
||||
from dify_graph.model_runtime.errors.invoke import InvokeBadRequestError
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, MessageFileBelongsTo
|
||||
from models.model import App, AppMode, Message, MessageAnnotation, MessageFile
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -419,7 +419,7 @@ class AppRunner:
|
||||
message_id=message_id,
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
belongs_to="assistant",
|
||||
belongs_to=MessageFileBelongsTo.ASSISTANT,
|
||||
url=f"/files/tools/{tool_file.id}",
|
||||
upload_file_id=tool_file.id,
|
||||
created_by_role=(
|
||||
|
||||
@@ -33,7 +33,7 @@ from extensions.ext_redis import get_pubsub_broadcast_channel
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, MessageFileBelongsTo
|
||||
from models.model import App, AppMode, AppModelConfig, Conversation, EndUser, Message, MessageFile
|
||||
from services.errors.app_model_config import AppModelConfigBrokenError
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
@@ -225,7 +225,7 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
message_id=message.id,
|
||||
type=file.type,
|
||||
transfer_method=file.transfer_method,
|
||||
belongs_to="user",
|
||||
belongs_to=MessageFileBelongsTo.USER,
|
||||
url=file.remote_url,
|
||||
upload_file_id=file.related_id,
|
||||
created_by_role=(CreatorUserRole.ACCOUNT if account_id else CreatorUserRole.END_USER),
|
||||
|
||||
@@ -84,6 +84,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
root_node_id=self._root_node_id,
|
||||
)
|
||||
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
@@ -91,6 +92,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
workflow=self._workflow,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
single_loop_run=self.application_generate_entity.single_loop_run,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
)
|
||||
else:
|
||||
inputs = self.application_generate_entity.inputs
|
||||
@@ -120,6 +122,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
call_depth=self.application_generate_entity.call_depth,
|
||||
root_node_id=self._root_node_id,
|
||||
)
|
||||
|
||||
|
||||
@@ -102,6 +102,7 @@ class WorkflowBasedAppRunner:
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
user_from: UserFrom,
|
||||
invoke_from: InvokeFrom,
|
||||
call_depth: int = 0,
|
||||
workflow_id: str = "",
|
||||
tenant_id: str = "",
|
||||
user_id: str = "",
|
||||
@@ -130,7 +131,7 @@ class WorkflowBasedAppRunner:
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
),
|
||||
call_depth=0,
|
||||
call_depth=call_depth,
|
||||
)
|
||||
|
||||
# Use the provided graph_runtime_state for consistent state management
|
||||
@@ -156,6 +157,7 @@ class WorkflowBasedAppRunner:
|
||||
workflow: Workflow,
|
||||
single_iteration_run: Any | None = None,
|
||||
single_loop_run: Any | None = None,
|
||||
call_depth: int = 0,
|
||||
) -> tuple[Graph, VariablePool, GraphRuntimeState]:
|
||||
"""
|
||||
Prepare graph, variable pool, and runtime state for single node execution
|
||||
@@ -189,6 +191,7 @@ class WorkflowBasedAppRunner:
|
||||
node_id=single_iteration_run.node_id,
|
||||
user_inputs=dict(single_iteration_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
call_depth=call_depth,
|
||||
node_type_filter_key="iteration_id",
|
||||
node_type_label="iteration",
|
||||
)
|
||||
@@ -198,6 +201,7 @@ class WorkflowBasedAppRunner:
|
||||
node_id=single_loop_run.node_id,
|
||||
user_inputs=dict(single_loop_run.inputs),
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
call_depth=call_depth,
|
||||
node_type_filter_key="loop_id",
|
||||
node_type_label="loop",
|
||||
)
|
||||
@@ -214,6 +218,7 @@ class WorkflowBasedAppRunner:
|
||||
node_id: str,
|
||||
user_inputs: dict[str, Any],
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
call_depth: int,
|
||||
node_type_filter_key: str, # 'iteration_id' or 'loop_id'
|
||||
node_type_label: str = "node", # 'iteration' or 'loop' for error messages
|
||||
) -> tuple[Graph, VariablePool]:
|
||||
@@ -283,7 +288,7 @@ class WorkflowBasedAppRunner:
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
),
|
||||
call_depth=0,
|
||||
call_depth=call_depth,
|
||||
)
|
||||
|
||||
node_factory = DifyNodeFactory(
|
||||
|
||||
@@ -34,6 +34,7 @@ from core.llm_generator.llm_generator import LLMGenerator
|
||||
from core.tools.signature import sign_tool_file
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.enums import MessageFileBelongsTo
|
||||
from models.model import AppMode, Conversation, MessageAnnotation, MessageFile
|
||||
from services.annotation_service import AppAnnotationService
|
||||
|
||||
@@ -233,7 +234,7 @@ class MessageCycleManager:
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
id=message_file.id,
|
||||
type=message_file.type,
|
||||
belongs_to=message_file.belongs_to or "user",
|
||||
belongs_to=message_file.belongs_to or MessageFileBelongsTo.USER,
|
||||
url=url,
|
||||
)
|
||||
|
||||
|
||||
@@ -34,7 +34,7 @@ from core.tools.workflow_as_tool.tool import WorkflowTool
|
||||
from dify_graph.file import FileType
|
||||
from dify_graph.file.models import FileTransferMethod
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.enums import CreatorUserRole, MessageFileBelongsTo
|
||||
from models.model import Message, MessageFile
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -352,7 +352,7 @@ class ToolEngine:
|
||||
message_id=agent_message.id,
|
||||
type=file_type,
|
||||
transfer_method=FileTransferMethod.TOOL_FILE,
|
||||
belongs_to="assistant",
|
||||
belongs_to=MessageFileBelongsTo.ASSISTANT,
|
||||
url=message.url,
|
||||
upload_file_id=tool_file_id,
|
||||
created_by_role=(
|
||||
|
||||
@@ -282,6 +282,7 @@ class DifyNodeFactory(NodeFactory):
|
||||
max_text_size=dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE,
|
||||
ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
secret_key=dify_config.SECRET_KEY,
|
||||
)
|
||||
|
||||
self._llm_credentials_provider, self._llm_model_factory = build_dify_model_access(self._dify_context.tenant_id)
|
||||
|
||||
27
api/dify_graph/call_depth.py
Normal file
27
api/dify_graph/call_depth.py
Normal file
@@ -0,0 +1,27 @@
|
||||
"""Helpers for workflow recursion depth propagation.
|
||||
|
||||
The HTTP request node emits a reserved depth header pair on outbound requests,
|
||||
and ``services.trigger.webhook_service`` validates that pair when a webhook is
|
||||
received. The signature binds the propagated depth to the concrete HTTP method
|
||||
and request path so a depth value captured for one endpoint cannot be replayed
|
||||
verbatim against another path.
|
||||
"""
|
||||
|
||||
import hashlib
|
||||
import hmac
|
||||
|
||||
|
||||
def build_workflow_call_depth_signature(*, secret_key: str, method: str, path: str, depth: str) -> str:
|
||||
"""Build the stable HMAC payload for workflow call-depth propagation.
|
||||
|
||||
Args:
|
||||
secret_key: Shared signing key used by both sender and receiver.
|
||||
method: Outbound or inbound HTTP method.
|
||||
path: Request path that the signature is bound to.
|
||||
depth: Workflow call depth value serialized as a string.
|
||||
|
||||
Returns:
|
||||
Hex-encoded HMAC-SHA256 digest for the method/path/depth tuple.
|
||||
"""
|
||||
payload = f"{method.upper()}:{path}:{depth}"
|
||||
return hmac.new(secret_key.encode("utf-8"), payload.encode("utf-8"), hashlib.sha256).hexdigest()
|
||||
@@ -2,3 +2,8 @@ SYSTEM_VARIABLE_NODE_ID = "sys"
|
||||
ENVIRONMENT_VARIABLE_NODE_ID = "env"
|
||||
CONVERSATION_VARIABLE_NODE_ID = "conversation"
|
||||
RAG_PIPELINE_VARIABLE_NODE_ID = "rag"
|
||||
|
||||
# Reserved for internal workflow-to-workflow HTTP calls. External callers should
|
||||
# not rely on or set this header.
|
||||
WORKFLOW_CALL_DEPTH_HEADER = "X-Dify-Workflow-Call-Depth"
|
||||
WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER = "X-Dify-Workflow-Call-Depth-Signature"
|
||||
|
||||
@@ -12,6 +12,7 @@ def build_http_request_config(
|
||||
max_text_size: int = 1 * 1024 * 1024,
|
||||
ssl_verify: bool = True,
|
||||
ssrf_default_max_retries: int = 3,
|
||||
secret_key: str = "",
|
||||
) -> HttpRequestNodeConfig:
|
||||
return HttpRequestNodeConfig(
|
||||
max_connect_timeout=max_connect_timeout,
|
||||
@@ -21,6 +22,7 @@ def build_http_request_config(
|
||||
max_text_size=max_text_size,
|
||||
ssl_verify=ssl_verify,
|
||||
ssrf_default_max_retries=ssrf_default_max_retries,
|
||||
secret_key=secret_key,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -76,6 +76,7 @@ class HttpRequestNodeConfig:
|
||||
max_text_size: int
|
||||
ssl_verify: bool
|
||||
ssrf_default_max_retries: int
|
||||
secret_key: str = ""
|
||||
|
||||
def default_timeout(self) -> "HttpRequestNodeTimeout":
|
||||
return HttpRequestNodeTimeout(
|
||||
|
||||
@@ -1,3 +1,11 @@
|
||||
"""HTTP request execution helpers for workflow nodes.
|
||||
|
||||
Besides normal request assembly, this executor is responsible for propagating
|
||||
workflow recursion depth across outbound HTTP calls. The reserved call-depth
|
||||
headers are always regenerated from the current node context so user-supplied
|
||||
values cannot override or poison the propagation contract.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import json
|
||||
import secrets
|
||||
@@ -10,6 +18,8 @@ from urllib.parse import urlencode, urlparse
|
||||
import httpx
|
||||
from json_repair import repair_json
|
||||
|
||||
from dify_graph.call_depth import build_workflow_call_depth_signature
|
||||
from dify_graph.constants import WORKFLOW_CALL_DEPTH_HEADER, WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER
|
||||
from dify_graph.file.enums import FileTransferMethod
|
||||
from dify_graph.runtime import VariablePool
|
||||
from dify_graph.variables.segments import ArrayFileSegment, FileSegment
|
||||
@@ -41,6 +51,8 @@ BODY_TYPE_TO_CONTENT_TYPE = {
|
||||
|
||||
|
||||
class Executor:
|
||||
"""Prepare, execute, and log a workflow HTTP request node invocation."""
|
||||
|
||||
method: Literal[
|
||||
"get",
|
||||
"head",
|
||||
@@ -77,6 +89,7 @@ class Executor:
|
||||
timeout: HttpRequestNodeTimeout,
|
||||
variable_pool: VariablePool,
|
||||
http_request_config: HttpRequestNodeConfig,
|
||||
workflow_call_depth: int = 0,
|
||||
max_retries: int | None = None,
|
||||
ssl_verify: bool | None = None,
|
||||
http_client: HttpClientProtocol,
|
||||
@@ -120,6 +133,7 @@ class Executor:
|
||||
# init template
|
||||
self.variable_pool = variable_pool
|
||||
self.node_data = node_data
|
||||
self.workflow_call_depth = workflow_call_depth
|
||||
self._initialize()
|
||||
|
||||
def _initialize(self):
|
||||
@@ -272,8 +286,33 @@ class Executor:
|
||||
self.data = form_data
|
||||
|
||||
def _assembling_headers(self) -> dict[str, Any]:
|
||||
"""Assemble outbound headers for the request.
|
||||
|
||||
Reserved workflow call-depth headers are removed case-insensitively
|
||||
before the canonical pair is re-added from ``workflow_call_depth``.
|
||||
The signature path mirrors Flask request matching, so URLs without an
|
||||
explicit path are normalized to ``/`` before signing. This keeps
|
||||
propagation deterministic even if a workflow author manually configured
|
||||
colliding headers on the node.
|
||||
"""
|
||||
authorization = deepcopy(self.auth)
|
||||
headers = deepcopy(self.headers) or {}
|
||||
reserved_header_names = {
|
||||
WORKFLOW_CALL_DEPTH_HEADER.lower(),
|
||||
WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER.lower(),
|
||||
}
|
||||
headers = {k: v for k, v in headers.items() if k.lower() not in reserved_header_names}
|
||||
parsed_url = urlparse(self.url)
|
||||
signed_path = parsed_url.path or "/"
|
||||
next_call_depth = str(self.workflow_call_depth + 1)
|
||||
headers[WORKFLOW_CALL_DEPTH_HEADER] = next_call_depth
|
||||
headers[WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER] = build_workflow_call_depth_signature(
|
||||
secret_key=self._http_request_config.secret_key,
|
||||
method=self.method,
|
||||
path=signed_path,
|
||||
depth=next_call_depth,
|
||||
)
|
||||
|
||||
if self.auth.type == "api-key":
|
||||
if self.auth.config is None:
|
||||
raise AuthorizationConfigError("self.authorization config is required")
|
||||
@@ -388,6 +427,12 @@ class Executor:
|
||||
return self._validate_and_parse_response(response)
|
||||
|
||||
def to_log(self):
|
||||
"""Render the request in raw HTTP form for node logs.
|
||||
|
||||
Internal workflow call-depth headers and authentication headers are
|
||||
masked so operational logs remain useful without exposing replayable or
|
||||
credential-bearing values.
|
||||
"""
|
||||
url_parts = urlparse(self.url)
|
||||
path = url_parts.path or "/"
|
||||
|
||||
@@ -410,6 +455,12 @@ class Executor:
|
||||
if body.type == "form-data":
|
||||
headers["Content-Type"] = f"multipart/form-data; boundary={boundary}"
|
||||
for k, v in headers.items():
|
||||
if k.lower() in {
|
||||
WORKFLOW_CALL_DEPTH_HEADER.lower(),
|
||||
WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER.lower(),
|
||||
}:
|
||||
raw += f"{k}: [internal]\r\n"
|
||||
continue
|
||||
if self.auth.type == "api-key":
|
||||
authorization_header = "Authorization"
|
||||
if self.auth.config and self.auth.config.header:
|
||||
|
||||
@@ -101,6 +101,7 @@ class HttpRequestNode(Node[HttpRequestNodeData]):
|
||||
timeout=self._get_request_timeout(self.node_data),
|
||||
variable_pool=self.graph_runtime_state.variable_pool,
|
||||
http_request_config=self._http_request_config,
|
||||
workflow_call_depth=self.workflow_call_depth,
|
||||
ssl_verify=self.node_data.ssl_verify,
|
||||
http_client=self._http_client,
|
||||
file_manager=self._file_manager,
|
||||
|
||||
@@ -158,6 +158,13 @@ class FeedbackFromSource(StrEnum):
|
||||
ADMIN = "admin"
|
||||
|
||||
|
||||
class FeedbackRating(StrEnum):
|
||||
"""MessageFeedback rating"""
|
||||
|
||||
LIKE = "like"
|
||||
DISLIKE = "dislike"
|
||||
|
||||
|
||||
class InvokeFrom(StrEnum):
|
||||
"""How a conversation/message was invoked"""
|
||||
|
||||
|
||||
@@ -36,7 +36,10 @@ from .enums import (
|
||||
BannerStatus,
|
||||
ConversationStatus,
|
||||
CreatorUserRole,
|
||||
FeedbackFromSource,
|
||||
FeedbackRating,
|
||||
MessageChainType,
|
||||
MessageFileBelongsTo,
|
||||
MessageStatus,
|
||||
)
|
||||
from .provider_ids import GenericProviderID
|
||||
@@ -1165,7 +1168,7 @@ class Conversation(Base):
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "like",
|
||||
MessageFeedback.rating == FeedbackRating.LIKE,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
@@ -1176,7 +1179,7 @@ class Conversation(Base):
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "user",
|
||||
MessageFeedback.rating == "dislike",
|
||||
MessageFeedback.rating == FeedbackRating.DISLIKE,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
@@ -1191,7 +1194,7 @@ class Conversation(Base):
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "like",
|
||||
MessageFeedback.rating == FeedbackRating.LIKE,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
@@ -1202,7 +1205,7 @@ class Conversation(Base):
|
||||
select(func.count(MessageFeedback.id)).where(
|
||||
MessageFeedback.conversation_id == self.id,
|
||||
MessageFeedback.from_source == "admin",
|
||||
MessageFeedback.rating == "dislike",
|
||||
MessageFeedback.rating == FeedbackRating.DISLIKE,
|
||||
)
|
||||
)
|
||||
or 0
|
||||
@@ -1725,8 +1728,8 @@ class MessageFeedback(TypeBase):
|
||||
app_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
conversation_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
message_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
rating: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
from_source: Mapped[str] = mapped_column(String(255), nullable=False)
|
||||
rating: Mapped[FeedbackRating] = mapped_column(EnumText(FeedbackRating, length=255), nullable=False)
|
||||
from_source: Mapped[FeedbackFromSource] = mapped_column(EnumText(FeedbackFromSource, length=255), nullable=False)
|
||||
content: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
from_end_user_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
from_account_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
@@ -1779,7 +1782,9 @@ class MessageFile(TypeBase):
|
||||
)
|
||||
created_by_role: Mapped[CreatorUserRole] = mapped_column(EnumText(CreatorUserRole, length=255), nullable=False)
|
||||
created_by: Mapped[str] = mapped_column(StringUUID, nullable=False)
|
||||
belongs_to: Mapped[Literal["user", "assistant"] | None] = mapped_column(String(255), nullable=True, default=None)
|
||||
belongs_to: Mapped[MessageFileBelongsTo | None] = mapped_column(
|
||||
EnumText(MessageFileBelongsTo, length=255), nullable=True, default=None
|
||||
)
|
||||
url: Mapped[str | None] = mapped_column(LongText, nullable=True, default=None)
|
||||
upload_file_id: Mapped[str | None] = mapped_column(StringUUID, nullable=True, default=None)
|
||||
created_at: Mapped[datetime] = mapped_column(
|
||||
|
||||
@@ -7,6 +7,7 @@ from flask import Response
|
||||
from sqlalchemy import or_
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import Account, App, Conversation, Message, MessageFeedback
|
||||
|
||||
|
||||
@@ -100,7 +101,7 @@ class FeedbackService:
|
||||
"ai_response": message.answer[:500] + "..."
|
||||
if len(message.answer) > 500
|
||||
else message.answer, # Truncate long responses
|
||||
"feedback_rating": "👍" if feedback.rating == "like" else "👎",
|
||||
"feedback_rating": "👍" if feedback.rating == FeedbackRating.LIKE else "👎",
|
||||
"feedback_rating_raw": feedback.rating,
|
||||
"feedback_comment": feedback.content or "",
|
||||
"feedback_source": feedback.from_source,
|
||||
|
||||
@@ -16,6 +16,7 @@ from dify_graph.model_runtime.entities.model_entities import ModelType
|
||||
from extensions.ext_database import db
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models import Account
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import App, AppMode, AppModelConfig, EndUser, Message, MessageFeedback
|
||||
from repositories.execution_extra_content_repository import ExecutionExtraContentRepository
|
||||
from repositories.sqlalchemy_execution_extra_content_repository import (
|
||||
@@ -172,7 +173,7 @@ class MessageService:
|
||||
app_model: App,
|
||||
message_id: str,
|
||||
user: Union[Account, EndUser] | None,
|
||||
rating: str | None,
|
||||
rating: FeedbackRating | None,
|
||||
content: str | None,
|
||||
):
|
||||
if not user:
|
||||
@@ -197,7 +198,7 @@ class MessageService:
|
||||
message_id=message.id,
|
||||
rating=rating,
|
||||
content=content,
|
||||
from_source=("user" if isinstance(user, EndUser) else "admin"),
|
||||
from_source=(FeedbackFromSource.USER if isinstance(user, EndUser) else FeedbackFromSource.ADMIN),
|
||||
from_end_user_id=(user.id if isinstance(user, EndUser) else None),
|
||||
from_account_id=(user.id if isinstance(user, Account) else None),
|
||||
)
|
||||
|
||||
@@ -402,6 +402,7 @@ class RagPipelineService:
|
||||
max_text_size=dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE,
|
||||
ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
secret_key=dify_config.SECRET_KEY,
|
||||
)
|
||||
}
|
||||
default_config = node_class.get_default_config(filters=filters)
|
||||
@@ -435,6 +436,7 @@ class RagPipelineService:
|
||||
max_text_size=dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE,
|
||||
ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
secret_key=dify_config.SECRET_KEY,
|
||||
)
|
||||
default_config = node_class.get_default_config(filters=final_filters or None)
|
||||
if not default_config:
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import hmac
|
||||
import json
|
||||
import logging
|
||||
import mimetypes
|
||||
@@ -23,6 +24,8 @@ from core.workflow.nodes.trigger_webhook.entities import (
|
||||
WebhookData,
|
||||
WebhookParameter,
|
||||
)
|
||||
from dify_graph.call_depth import build_workflow_call_depth_signature
|
||||
from dify_graph.constants import WORKFLOW_CALL_DEPTH_HEADER, WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER
|
||||
from dify_graph.entities.graph_config import NodeConfigDict
|
||||
from dify_graph.file.models import FileTransferMethod
|
||||
from dify_graph.variables.types import ArrayValidation, SegmentType
|
||||
@@ -57,10 +60,55 @@ class WebhookService:
|
||||
@staticmethod
|
||||
def _sanitize_key(key: str) -> str:
|
||||
"""Normalize external keys (headers/params) to workflow-safe variables."""
|
||||
if not isinstance(key, str):
|
||||
return key
|
||||
return key.replace("-", "_")
|
||||
|
||||
@classmethod
|
||||
def extract_workflow_call_depth(
|
||||
cls,
|
||||
headers: Mapping[str, Any],
|
||||
*,
|
||||
request_method: str,
|
||||
request_path: str,
|
||||
) -> int:
|
||||
"""Extract the reserved workflow recursion depth header.
|
||||
|
||||
The depth header is only trusted when accompanied by a valid HMAC
|
||||
signature for the current request method/path/depth tuple supplied by the
|
||||
caller while a request context is available. Header lookup is normalized
|
||||
case-insensitively so mixed-case spellings still round-trip after headers
|
||||
are materialized into a plain mapping. Invalid, missing, unsigned, or
|
||||
negative values are treated as external requests and therefore fall back
|
||||
to depth 0.
|
||||
"""
|
||||
normalized_headers = {str(key).lower(): value for key, value in headers.items()}
|
||||
|
||||
raw_value = normalized_headers.get(WORKFLOW_CALL_DEPTH_HEADER.lower())
|
||||
if raw_value is None:
|
||||
return 0
|
||||
|
||||
raw_signature = normalized_headers.get(WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER.lower())
|
||||
if raw_signature is None:
|
||||
return 0
|
||||
|
||||
normalized_value = str(raw_value).strip()
|
||||
# The receiver recomputes the signature from the current request context
|
||||
# instead of trusting the sender's path or method directly.
|
||||
expected_signature = build_workflow_call_depth_signature(
|
||||
secret_key=dify_config.SECRET_KEY,
|
||||
method=request_method,
|
||||
path=request_path,
|
||||
depth=normalized_value,
|
||||
)
|
||||
if not hmac.compare_digest(str(raw_signature).strip(), expected_signature):
|
||||
return 0
|
||||
|
||||
try:
|
||||
call_depth = int(normalized_value)
|
||||
except (TypeError, ValueError):
|
||||
return 0
|
||||
|
||||
return max(call_depth, 0)
|
||||
|
||||
@classmethod
|
||||
def get_webhook_trigger_and_workflow(
|
||||
cls, webhook_id: str, is_debug: bool = False
|
||||
@@ -744,7 +792,12 @@ class WebhookService:
|
||||
|
||||
@classmethod
|
||||
def trigger_workflow_execution(
|
||||
cls, webhook_trigger: WorkflowWebhookTrigger, webhook_data: dict[str, Any], workflow: Workflow
|
||||
cls,
|
||||
webhook_trigger: WorkflowWebhookTrigger,
|
||||
webhook_data: dict[str, Any],
|
||||
workflow: Workflow,
|
||||
*,
|
||||
call_depth: int = 0,
|
||||
) -> None:
|
||||
"""Trigger workflow execution via AsyncWorkflowService.
|
||||
|
||||
@@ -752,6 +805,8 @@ class WebhookService:
|
||||
webhook_trigger: The webhook trigger object
|
||||
webhook_data: Processed webhook data for workflow inputs
|
||||
workflow: The workflow to execute
|
||||
call_depth: Validated recursion depth derived earlier from the
|
||||
incoming request metadata
|
||||
|
||||
Raises:
|
||||
ValueError: If tenant owner is not found
|
||||
@@ -770,6 +825,7 @@ class WebhookService:
|
||||
root_node_id=webhook_trigger.node_id, # Start from the webhook node
|
||||
inputs=workflow_inputs,
|
||||
tenant_id=webhook_trigger.tenant_id,
|
||||
call_depth=call_depth,
|
||||
)
|
||||
|
||||
end_user = EndUserService.get_or_create_end_user_by_type(
|
||||
|
||||
@@ -26,7 +26,14 @@ class TriggerMetadata(BaseModel):
|
||||
|
||||
|
||||
class TriggerData(BaseModel):
|
||||
"""Base trigger data model for async workflow execution"""
|
||||
"""Base trigger data model for async workflow execution.
|
||||
|
||||
`call_depth` tracks only the current workflow-to-workflow HTTP recursion
|
||||
depth. It starts at 0 for external triggers and increments once per nested
|
||||
webhook-triggered workflow call. For webhook triggers, the value is derived
|
||||
from the reserved depth headers after `WebhookService.extract_workflow_call_depth`
|
||||
validates the signature against the inbound request context.
|
||||
"""
|
||||
|
||||
app_id: str
|
||||
tenant_id: str
|
||||
@@ -34,6 +41,7 @@ class TriggerData(BaseModel):
|
||||
root_node_id: str
|
||||
inputs: Mapping[str, Any]
|
||||
files: Sequence[Mapping[str, Any]] = Field(default_factory=list)
|
||||
call_depth: int = 0
|
||||
trigger_type: AppTriggerType
|
||||
trigger_from: WorkflowRunTriggeredFrom
|
||||
trigger_metadata: TriggerMetadata | None = None
|
||||
|
||||
@@ -638,6 +638,7 @@ class WorkflowService:
|
||||
max_text_size=dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE,
|
||||
ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
secret_key=dify_config.SECRET_KEY,
|
||||
)
|
||||
}
|
||||
default_config = node_class.get_default_config(filters=filters)
|
||||
@@ -673,6 +674,7 @@ class WorkflowService:
|
||||
max_text_size=dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE,
|
||||
ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
secret_key=dify_config.SECRET_KEY,
|
||||
)
|
||||
default_config = node_class.get_default_config(filters=resolved_filters or None)
|
||||
if not default_config:
|
||||
|
||||
@@ -164,7 +164,7 @@ def _execute_workflow_common(
|
||||
args=args,
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=False,
|
||||
call_depth=0,
|
||||
call_depth=trigger_data.call_depth,
|
||||
triggered_from=trigger_data.trigger_from,
|
||||
root_node_id=trigger_data.root_node_id,
|
||||
graph_engine_layers=[
|
||||
|
||||
@@ -14,6 +14,7 @@ from controllers.console.app import wraps
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import App, Tenant
|
||||
from models.account import Account, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import AppMode, MessageFeedback
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
@@ -77,8 +78,8 @@ class TestFeedbackExportApi:
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
message_id=message_id,
|
||||
rating="like",
|
||||
from_source="user",
|
||||
rating=FeedbackRating.LIKE,
|
||||
from_source=FeedbackFromSource.USER,
|
||||
content=None,
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
from_account_id=None,
|
||||
@@ -90,8 +91,8 @@ class TestFeedbackExportApi:
|
||||
app_id=app_id,
|
||||
conversation_id=conversation_id,
|
||||
message_id=message_id,
|
||||
rating="dislike",
|
||||
from_source="admin",
|
||||
rating=FeedbackRating.DISLIKE,
|
||||
from_source=FeedbackFromSource.ADMIN,
|
||||
content="The response was not helpful",
|
||||
from_end_user_id=None,
|
||||
from_account_id=str(uuid.uuid4()),
|
||||
@@ -277,8 +278,8 @@ class TestFeedbackExportApi:
|
||||
# Verify service was called with correct parameters
|
||||
mock_export_feedbacks.assert_called_once_with(
|
||||
app_id=mock_app_model.id,
|
||||
from_source="user",
|
||||
rating="dislike",
|
||||
from_source=FeedbackFromSource.USER,
|
||||
rating=FeedbackRating.DISLIKE,
|
||||
has_comment=True,
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-12-31",
|
||||
|
||||
@@ -7,6 +7,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from core.plugin.impl.exc import PluginDaemonClientSideError
|
||||
from models import Account
|
||||
from models.enums import MessageFileBelongsTo
|
||||
from models.model import AppModelConfig, Conversation, EndUser, Message, MessageAgentThought
|
||||
from services.account_service import AccountService, TenantService
|
||||
from services.agent_service import AgentService
|
||||
@@ -852,7 +853,7 @@ class TestAgentService:
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
url="http://example.com/file1.jpg",
|
||||
belongs_to="user",
|
||||
belongs_to=MessageFileBelongsTo.USER,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
@@ -861,7 +862,7 @@ class TestAgentService:
|
||||
type=FileType.IMAGE,
|
||||
transfer_method=FileTransferMethod.REMOTE_URL,
|
||||
url="http://example.com/file2.png",
|
||||
belongs_to="user",
|
||||
belongs_to=MessageFileBelongsTo.USER,
|
||||
created_by_role=CreatorUserRole.ACCOUNT,
|
||||
created_by=message.from_account_id,
|
||||
)
|
||||
|
||||
@@ -8,6 +8,7 @@ from unittest import mock
|
||||
import pytest
|
||||
|
||||
from extensions.ext_database import db
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import App, Conversation, Message
|
||||
from services.feedback_service import FeedbackService
|
||||
|
||||
@@ -47,8 +48,8 @@ class TestFeedbackService:
|
||||
app_id=app_id,
|
||||
conversation_id="test-conversation-id",
|
||||
message_id="test-message-id",
|
||||
rating="like",
|
||||
from_source="user",
|
||||
rating=FeedbackRating.LIKE,
|
||||
from_source=FeedbackFromSource.USER,
|
||||
content="Great answer!",
|
||||
from_end_user_id="user-123",
|
||||
from_account_id=None,
|
||||
@@ -61,8 +62,8 @@ class TestFeedbackService:
|
||||
app_id=app_id,
|
||||
conversation_id="test-conversation-id",
|
||||
message_id="test-message-id",
|
||||
rating="dislike",
|
||||
from_source="admin",
|
||||
rating=FeedbackRating.DISLIKE,
|
||||
from_source=FeedbackFromSource.ADMIN,
|
||||
content="Could be more detailed",
|
||||
from_end_user_id=None,
|
||||
from_account_id="admin-456",
|
||||
@@ -179,8 +180,8 @@ class TestFeedbackService:
|
||||
# Test with filters
|
||||
result = FeedbackService.export_feedbacks(
|
||||
app_id=sample_data["app"].id,
|
||||
from_source="admin",
|
||||
rating="dislike",
|
||||
from_source=FeedbackFromSource.ADMIN,
|
||||
rating=FeedbackRating.DISLIKE,
|
||||
has_comment=True,
|
||||
start_date="2024-01-01",
|
||||
end_date="2024-12-31",
|
||||
@@ -293,8 +294,8 @@ class TestFeedbackService:
|
||||
app_id=sample_data["app"].id,
|
||||
conversation_id="test-conversation-id",
|
||||
message_id="test-message-id",
|
||||
rating="dislike",
|
||||
from_source="user",
|
||||
rating=FeedbackRating.DISLIKE,
|
||||
from_source=FeedbackFromSource.USER,
|
||||
content="回答不够详细,需要更多信息",
|
||||
from_end_user_id="user-123",
|
||||
from_account_id=None,
|
||||
|
||||
@@ -7,6 +7,7 @@ import pytest
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
@@ -172,8 +173,8 @@ class TestAppMessageExportServiceIntegration:
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="like",
|
||||
from_source="user",
|
||||
rating=FeedbackRating.LIKE,
|
||||
from_source=FeedbackFromSource.USER,
|
||||
content="first",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
)
|
||||
@@ -181,8 +182,8 @@ class TestAppMessageExportServiceIntegration:
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="dislike",
|
||||
from_source="user",
|
||||
rating=FeedbackRating.DISLIKE,
|
||||
from_source=FeedbackFromSource.USER,
|
||||
content="second",
|
||||
from_end_user_id=conversation.from_end_user_id,
|
||||
)
|
||||
@@ -190,8 +191,8 @@ class TestAppMessageExportServiceIntegration:
|
||||
app_id=app.id,
|
||||
conversation_id=conversation.id,
|
||||
message_id=first_message.id,
|
||||
rating="like",
|
||||
from_source="admin",
|
||||
rating=FeedbackRating.LIKE,
|
||||
from_source=FeedbackFromSource.ADMIN,
|
||||
content="should-be-filtered",
|
||||
from_account_id=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ import pytest
|
||||
from faker import Faker
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import MessageFeedback
|
||||
from services.app_service import AppService
|
||||
from services.errors.message import (
|
||||
@@ -405,7 +406,7 @@ class TestMessageService:
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Create feedback
|
||||
rating = "like"
|
||||
rating = FeedbackRating.LIKE
|
||||
content = fake.text(max_nb_chars=100)
|
||||
feedback = MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating=rating, content=content
|
||||
@@ -435,7 +436,11 @@ class TestMessageService:
|
||||
# Test creating feedback with no user
|
||||
with pytest.raises(ValueError, match="user cannot be None"):
|
||||
MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=None, rating="like", content=fake.text(max_nb_chars=100)
|
||||
app_model=app,
|
||||
message_id=message.id,
|
||||
user=None,
|
||||
rating=FeedbackRating.LIKE,
|
||||
content=fake.text(max_nb_chars=100),
|
||||
)
|
||||
|
||||
def test_create_feedback_update_existing(
|
||||
@@ -452,14 +457,14 @@ class TestMessageService:
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
# Create initial feedback
|
||||
initial_rating = "like"
|
||||
initial_rating = FeedbackRating.LIKE
|
||||
initial_content = fake.text(max_nb_chars=100)
|
||||
feedback = MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating=initial_rating, content=initial_content
|
||||
)
|
||||
|
||||
# Update feedback
|
||||
updated_rating = "dislike"
|
||||
updated_rating = FeedbackRating.DISLIKE
|
||||
updated_content = fake.text(max_nb_chars=100)
|
||||
updated_feedback = MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating=updated_rating, content=updated_content
|
||||
@@ -487,7 +492,11 @@ class TestMessageService:
|
||||
|
||||
# Create initial feedback
|
||||
feedback = MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating="like", content=fake.text(max_nb_chars=100)
|
||||
app_model=app,
|
||||
message_id=message.id,
|
||||
user=account,
|
||||
rating=FeedbackRating.LIKE,
|
||||
content=fake.text(max_nb_chars=100),
|
||||
)
|
||||
|
||||
# Delete feedback by setting rating to None
|
||||
@@ -538,7 +547,7 @@ class TestMessageService:
|
||||
app_model=app,
|
||||
message_id=message.id,
|
||||
user=account,
|
||||
rating="like" if i % 2 == 0 else "dislike",
|
||||
rating=FeedbackRating.LIKE if i % 2 == 0 else FeedbackRating.DISLIKE,
|
||||
content=f"Feedback {i}: {fake.text(max_nb_chars=50)}",
|
||||
)
|
||||
feedbacks.append(feedback)
|
||||
@@ -568,7 +577,11 @@ class TestMessageService:
|
||||
message = self._create_test_message(db_session_with_containers, app, conversation, account, fake)
|
||||
|
||||
MessageService.create_feedback(
|
||||
app_model=app, message_id=message.id, user=account, rating="like", content=f"Feedback {i}"
|
||||
app_model=app,
|
||||
message_id=message.id,
|
||||
user=account,
|
||||
rating=FeedbackRating.LIKE,
|
||||
content=f"Feedback {i}",
|
||||
)
|
||||
|
||||
# Get feedbacks with pagination
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy.orm import Session
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_redis import redis_client
|
||||
from models.account import Account, Tenant, TenantAccountJoin, TenantAccountRole
|
||||
from models.enums import DataSourceType, MessageChainType
|
||||
from models.enums import DataSourceType, FeedbackFromSource, FeedbackRating, MessageChainType, MessageFileBelongsTo
|
||||
from models.model import (
|
||||
App,
|
||||
AppAnnotationHitHistory,
|
||||
@@ -166,7 +166,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
name="Test conversation",
|
||||
inputs={},
|
||||
status="normal",
|
||||
from_source="api",
|
||||
from_source=FeedbackFromSource.USER,
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
)
|
||||
db_session_with_containers.add(conversation)
|
||||
@@ -196,7 +196,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
answer_unit_price=Decimal("0.002"),
|
||||
total_price=Decimal("0.003"),
|
||||
currency="USD",
|
||||
from_source="api",
|
||||
from_source=FeedbackFromSource.USER,
|
||||
from_account_id=conversation.from_end_user_id,
|
||||
created_at=created_at,
|
||||
)
|
||||
@@ -216,8 +216,8 @@ class TestMessagesCleanServiceIntegration:
|
||||
app_id=message.app_id,
|
||||
conversation_id=message.conversation_id,
|
||||
message_id=message.id,
|
||||
rating="like",
|
||||
from_source="api",
|
||||
rating=FeedbackRating.LIKE,
|
||||
from_source=FeedbackFromSource.USER,
|
||||
from_end_user_id=str(uuid.uuid4()),
|
||||
)
|
||||
db_session_with_containers.add(feedback)
|
||||
@@ -249,7 +249,7 @@ class TestMessagesCleanServiceIntegration:
|
||||
type="image",
|
||||
transfer_method="local_file",
|
||||
url="http://example.com/test.jpg",
|
||||
belongs_to="user",
|
||||
belongs_to=MessageFileBelongsTo.USER,
|
||||
created_by_role="end_user",
|
||||
created_by=str(uuid.uuid4()),
|
||||
)
|
||||
|
||||
@@ -36,98 +36,7 @@ def unwrap(func):
|
||||
|
||||
|
||||
class TestTenantListApi:
|
||||
def test_get_success_saas_path(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant1 = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant 1",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
tenant2 = MagicMock(
|
||||
id="t2",
|
||||
name="Tenant 2",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant1, tenant2],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True),
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.BillingService.get_plan_bulk",
|
||||
return_value={
|
||||
"t1": {"plan": CloudPlan.TEAM, "expiration_date": 0},
|
||||
"t2": {"plan": CloudPlan.PROFESSIONAL, "expiration_date": 0},
|
||||
},
|
||||
) as get_plan_bulk_mock,
|
||||
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert len(result["workspaces"]) == 2
|
||||
assert result["workspaces"][0]["current"] is True
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
|
||||
assert result["workspaces"][1]["plan"] == CloudPlan.PROFESSIONAL
|
||||
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
|
||||
get_features_mock.assert_not_called()
|
||||
|
||||
def test_get_saas_path_falls_back_to_sandbox_for_missing_tenant(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant1 = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant 1",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
tenant2 = MagicMock(
|
||||
id="t2",
|
||||
name="Tenant 2",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant1, tenant2],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True),
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.BillingService.get_plan_bulk",
|
||||
return_value={"t1": {"plan": CloudPlan.TEAM, "expiration_date": 0}},
|
||||
) as get_plan_bulk_mock,
|
||||
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
|
||||
assert result["workspaces"][1]["plan"] == CloudPlan.SANDBOX
|
||||
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
|
||||
get_features_mock.assert_not_called()
|
||||
|
||||
def test_get_saas_path_falls_back_to_legacy_feature_path_on_bulk_error(self, app):
|
||||
def test_get_success(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@@ -145,41 +54,27 @@ class TestTenantListApi:
|
||||
)
|
||||
|
||||
features = MagicMock()
|
||||
features.billing.enabled = False
|
||||
features.billing.subscription.plan = CloudPlan.TEAM
|
||||
features.billing.enabled = True
|
||||
features.billing.subscription.plan = CloudPlan.SANDBOX
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2")
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t1")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant1, tenant2],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", True),
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "CLOUD"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.BillingService.get_plan_bulk",
|
||||
side_effect=RuntimeError("billing down"),
|
||||
) as get_plan_bulk_mock,
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.FeatureService.get_features",
|
||||
return_value=features,
|
||||
) as get_features_mock,
|
||||
patch("controllers.console.workspace.workspace.logger.exception") as logger_exception_mock,
|
||||
patch("controllers.console.workspace.workspace.FeatureService.get_features", return_value=features),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.TEAM
|
||||
assert result["workspaces"][1]["plan"] == CloudPlan.TEAM
|
||||
get_plan_bulk_mock.assert_called_once_with(["t1", "t2"])
|
||||
assert get_features_mock.call_count == 2
|
||||
logger_exception_mock.assert_called_once()
|
||||
assert len(result["workspaces"]) == 2
|
||||
assert result["workspaces"][0]["current"] is True
|
||||
|
||||
def test_get_billing_disabled_community_path(self, app):
|
||||
def test_get_billing_disabled(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
@@ -203,83 +98,15 @@ class TestTenantListApi:
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.FeatureService.get_features",
|
||||
return_value=features,
|
||||
) as get_features_mock,
|
||||
),
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
|
||||
get_features_mock.assert_called_once_with("t1")
|
||||
|
||||
def test_get_enterprise_only_skips_feature_service(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
tenant1 = MagicMock(
|
||||
id="t1",
|
||||
name="Tenant 1",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
tenant2 = MagicMock(
|
||||
id="t2",
|
||||
name="Tenant 2",
|
||||
status="active",
|
||||
created_at=datetime.utcnow(),
|
||||
)
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), "t2")
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[tenant1, tenant2],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", True),
|
||||
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
|
||||
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"][0]["plan"] == CloudPlan.SANDBOX
|
||||
assert result["workspaces"][1]["plan"] == CloudPlan.SANDBOX
|
||||
assert result["workspaces"][0]["current"] is False
|
||||
assert result["workspaces"][1]["current"] is True
|
||||
get_features_mock.assert_not_called()
|
||||
|
||||
def test_get_enterprise_only_with_empty_tenants(self, app):
|
||||
api = TenantListApi()
|
||||
method = unwrap(api.get)
|
||||
|
||||
with (
|
||||
app.test_request_context("/workspaces"),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.current_account_with_tenant", return_value=(MagicMock(), None)
|
||||
),
|
||||
patch(
|
||||
"controllers.console.workspace.workspace.TenantService.get_join_tenants",
|
||||
return_value=[],
|
||||
),
|
||||
patch("controllers.console.workspace.workspace.dify_config.ENTERPRISE_ENABLED", True),
|
||||
patch("controllers.console.workspace.workspace.dify_config.BILLING_ENABLED", False),
|
||||
patch("controllers.console.workspace.workspace.dify_config.EDITION", "SELF_HOSTED"),
|
||||
patch("controllers.console.workspace.workspace.FeatureService.get_features") as get_features_mock,
|
||||
):
|
||||
result, status = method(api)
|
||||
|
||||
assert status == 200
|
||||
assert result["workspaces"] == []
|
||||
get_features_mock.assert_not_called()
|
||||
|
||||
|
||||
class TestWorkspaceListApi:
|
||||
|
||||
@@ -31,6 +31,7 @@ from controllers.service_api.app.message import (
|
||||
MessageListQuery,
|
||||
MessageSuggestedApi,
|
||||
)
|
||||
from models.enums import FeedbackRating
|
||||
from models.model import App, AppMode, EndUser
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import (
|
||||
@@ -310,7 +311,7 @@ class TestMessageService:
|
||||
app_model=Mock(spec=App),
|
||||
message_id=str(uuid.uuid4()),
|
||||
user=Mock(spec=EndUser),
|
||||
rating="like",
|
||||
rating=FeedbackRating.LIKE,
|
||||
content="Great response!",
|
||||
)
|
||||
|
||||
@@ -326,7 +327,7 @@ class TestMessageService:
|
||||
app_model=Mock(spec=App),
|
||||
message_id="invalid_message_id",
|
||||
user=Mock(spec=EndUser),
|
||||
rating="like",
|
||||
rating=FeedbackRating.LIKE,
|
||||
content=None,
|
||||
)
|
||||
|
||||
|
||||
@@ -11,6 +11,7 @@ import controllers.trigger.webhook as module
|
||||
def mock_request():
|
||||
module.request = types.SimpleNamespace(
|
||||
method="POST",
|
||||
path="/triggers/webhook/wh-1",
|
||||
headers={"x-test": "1"},
|
||||
args={"a": "b"},
|
||||
)
|
||||
@@ -56,14 +57,17 @@ class TestHandleWebhook:
|
||||
@patch.object(module.WebhookService, "extract_and_validate_webhook_data")
|
||||
@patch.object(module.WebhookService, "trigger_workflow_execution")
|
||||
@patch.object(module.WebhookService, "generate_webhook_response")
|
||||
@patch.object(module.WebhookService, "extract_workflow_call_depth", return_value=4)
|
||||
def test_success(
|
||||
self,
|
||||
mock_extract_depth,
|
||||
mock_generate,
|
||||
mock_trigger,
|
||||
mock_extract,
|
||||
mock_get,
|
||||
):
|
||||
mock_get.return_value = (DummyWebhookTrigger(), "workflow", "node_config")
|
||||
webhook_trigger = DummyWebhookTrigger()
|
||||
mock_get.return_value = (webhook_trigger, "workflow", "node_config")
|
||||
mock_extract.return_value = {"input": "x"}
|
||||
mock_generate.return_value = ({"ok": True}, 200)
|
||||
|
||||
@@ -71,7 +75,12 @@ class TestHandleWebhook:
|
||||
|
||||
assert status == 200
|
||||
assert response["ok"] is True
|
||||
mock_trigger.assert_called_once()
|
||||
mock_extract_depth.assert_called_once_with(
|
||||
{"x-test": "1"},
|
||||
request_method="POST",
|
||||
request_path=module.request.path,
|
||||
)
|
||||
mock_trigger.assert_called_once_with(webhook_trigger, {"input": "x"}, "workflow", call_depth=4)
|
||||
|
||||
@patch.object(module.WebhookService, "get_webhook_trigger_and_workflow")
|
||||
@patch.object(module.WebhookService, "extract_and_validate_webhook_data", side_effect=ValueError("bad"))
|
||||
|
||||
@@ -124,6 +124,7 @@ class TestWorkflowBasedAppRunner:
|
||||
node_id="node-1",
|
||||
user_inputs={},
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
call_depth=3,
|
||||
node_type_filter_key="iteration_id",
|
||||
node_type_label="iteration",
|
||||
)
|
||||
@@ -131,6 +132,35 @@ class TestWorkflowBasedAppRunner:
|
||||
assert graph is not None
|
||||
assert variable_pool is graph_runtime_state.variable_pool
|
||||
|
||||
def test_init_graph_passes_call_depth_into_node_factory(self, monkeypatch):
|
||||
runner = WorkflowBasedAppRunner(queue_manager=SimpleNamespace(), app_id="app")
|
||||
runtime_state = GraphRuntimeState(
|
||||
variable_pool=VariablePool(system_variables=SystemVariable.default()),
|
||||
start_at=0.0,
|
||||
)
|
||||
captured: dict[str, int] = {}
|
||||
|
||||
class _FakeNodeFactory:
|
||||
def __init__(self, *, graph_init_params, graph_runtime_state):
|
||||
captured["call_depth"] = graph_init_params.call_depth
|
||||
|
||||
monkeypatch.setattr("core.app.apps.workflow_app_runner.DifyNodeFactory", _FakeNodeFactory)
|
||||
monkeypatch.setattr(
|
||||
"core.app.apps.workflow_app_runner.Graph.init",
|
||||
lambda **kwargs: SimpleNamespace(),
|
||||
)
|
||||
|
||||
graph = runner._init_graph(
|
||||
graph_config={"nodes": [{"id": "start", "data": {"type": "start", "version": "1"}}], "edges": []},
|
||||
graph_runtime_state=runtime_state,
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=4,
|
||||
)
|
||||
|
||||
assert graph is not None
|
||||
assert captured["call_depth"] == 4
|
||||
|
||||
def test_handle_graph_run_events_and_pause_notifications(self, monkeypatch):
|
||||
published: list[object] = []
|
||||
|
||||
|
||||
@@ -100,6 +100,7 @@ def test_run_uses_single_node_execution_branch(
|
||||
workflow=workflow,
|
||||
single_iteration_run=single_iteration_run,
|
||||
single_loop_run=single_loop_run,
|
||||
call_depth=0,
|
||||
)
|
||||
init_graph.assert_not_called()
|
||||
|
||||
@@ -156,6 +157,7 @@ def test_single_node_run_validates_target_node_config(monkeypatch) -> None:
|
||||
node_id="loop-node",
|
||||
user_inputs={},
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
call_depth=0,
|
||||
node_type_filter_key="loop_id",
|
||||
node_type_label="loop",
|
||||
)
|
||||
|
||||
@@ -2,6 +2,8 @@ import pytest
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from dify_graph.call_depth import build_workflow_call_depth_signature
|
||||
from dify_graph.constants import WORKFLOW_CALL_DEPTH_HEADER, WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER
|
||||
from dify_graph.file.file_manager import file_manager
|
||||
from dify_graph.nodes.http_request import (
|
||||
BodyData,
|
||||
@@ -24,7 +26,9 @@ HTTP_REQUEST_CONFIG = HttpRequestNodeConfig(
|
||||
max_text_size=dify_config.HTTP_REQUEST_NODE_MAX_TEXT_SIZE,
|
||||
ssl_verify=dify_config.HTTP_REQUEST_NODE_SSL_VERIFY,
|
||||
ssrf_default_max_retries=dify_config.SSRF_DEFAULT_MAX_RETRIES,
|
||||
secret_key=dify_config.SECRET_KEY,
|
||||
)
|
||||
TEST_SECRET_KEY = "test-secret-key"
|
||||
|
||||
|
||||
def test_executor_with_json_body_and_number_variable():
|
||||
@@ -661,3 +665,320 @@ def test_executor_with_json_body_preserves_numbers_and_strings():
|
||||
|
||||
assert executor.json["count"] == 42
|
||||
assert executor.json["id"] == "abc-123"
|
||||
|
||||
|
||||
def test_executor_propagates_workflow_call_depth_header():
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Depth propagation",
|
||||
method="get",
|
||||
url="http://localhost:5001/triggers/webhook/test-webhook",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="X-Test: value",
|
||||
params="",
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
http_request_config=HttpRequestNodeConfig(
|
||||
max_connect_timeout=HTTP_REQUEST_CONFIG.max_connect_timeout,
|
||||
max_read_timeout=HTTP_REQUEST_CONFIG.max_read_timeout,
|
||||
max_write_timeout=HTTP_REQUEST_CONFIG.max_write_timeout,
|
||||
max_binary_size=HTTP_REQUEST_CONFIG.max_binary_size,
|
||||
max_text_size=HTTP_REQUEST_CONFIG.max_text_size,
|
||||
ssl_verify=HTTP_REQUEST_CONFIG.ssl_verify,
|
||||
ssrf_default_max_retries=HTTP_REQUEST_CONFIG.ssrf_default_max_retries,
|
||||
secret_key=TEST_SECRET_KEY,
|
||||
),
|
||||
variable_pool=variable_pool,
|
||||
workflow_call_depth=2,
|
||||
http_client=ssrf_proxy,
|
||||
file_manager=file_manager,
|
||||
)
|
||||
|
||||
headers = executor._assembling_headers()
|
||||
|
||||
assert headers["X-Test"] == "value"
|
||||
assert headers[WORKFLOW_CALL_DEPTH_HEADER] == "3"
|
||||
assert headers[WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER] == build_workflow_call_depth_signature(
|
||||
secret_key=TEST_SECRET_KEY,
|
||||
method="get",
|
||||
path="/triggers/webhook/test-webhook",
|
||||
depth="3",
|
||||
)
|
||||
|
||||
|
||||
def test_executor_replaces_lowercase_reserved_headers_for_internal_webhook_target():
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Reserved header replacement",
|
||||
method="get",
|
||||
url="http://localhost:5001/triggers/webhook/test-webhook",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers=(
|
||||
"x-dify-workflow-call-depth: user-value\n"
|
||||
"x-dify-workflow-call-depth-signature: user-signature\n"
|
||||
"X-Test: value"
|
||||
),
|
||||
params="",
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
http_request_config=HttpRequestNodeConfig(
|
||||
max_connect_timeout=HTTP_REQUEST_CONFIG.max_connect_timeout,
|
||||
max_read_timeout=HTTP_REQUEST_CONFIG.max_read_timeout,
|
||||
max_write_timeout=HTTP_REQUEST_CONFIG.max_write_timeout,
|
||||
max_binary_size=HTTP_REQUEST_CONFIG.max_binary_size,
|
||||
max_text_size=HTTP_REQUEST_CONFIG.max_text_size,
|
||||
ssl_verify=HTTP_REQUEST_CONFIG.ssl_verify,
|
||||
ssrf_default_max_retries=HTTP_REQUEST_CONFIG.ssrf_default_max_retries,
|
||||
secret_key=TEST_SECRET_KEY,
|
||||
),
|
||||
variable_pool=variable_pool,
|
||||
workflow_call_depth=2,
|
||||
http_client=ssrf_proxy,
|
||||
file_manager=file_manager,
|
||||
)
|
||||
|
||||
headers = executor._assembling_headers()
|
||||
|
||||
assert headers["X-Test"] == "value"
|
||||
assert headers[WORKFLOW_CALL_DEPTH_HEADER] == "3"
|
||||
assert headers[WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER] == build_workflow_call_depth_signature(
|
||||
secret_key=TEST_SECRET_KEY,
|
||||
method="get",
|
||||
path="/triggers/webhook/test-webhook",
|
||||
depth="3",
|
||||
)
|
||||
assert "x-dify-workflow-call-depth" not in headers
|
||||
assert "x-dify-workflow-call-depth-signature" not in headers
|
||||
|
||||
|
||||
def test_executor_propagates_workflow_call_depth_with_empty_secret():
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Depth propagation with empty secret",
|
||||
method="get",
|
||||
url="http://localhost:5001/triggers/webhook/test-webhook",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="X-Test: value",
|
||||
params="",
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
http_request_config=HttpRequestNodeConfig(
|
||||
max_connect_timeout=HTTP_REQUEST_CONFIG.max_connect_timeout,
|
||||
max_read_timeout=HTTP_REQUEST_CONFIG.max_read_timeout,
|
||||
max_write_timeout=HTTP_REQUEST_CONFIG.max_write_timeout,
|
||||
max_binary_size=HTTP_REQUEST_CONFIG.max_binary_size,
|
||||
max_text_size=HTTP_REQUEST_CONFIG.max_text_size,
|
||||
ssl_verify=HTTP_REQUEST_CONFIG.ssl_verify,
|
||||
ssrf_default_max_retries=HTTP_REQUEST_CONFIG.ssrf_default_max_retries,
|
||||
secret_key="",
|
||||
),
|
||||
variable_pool=variable_pool,
|
||||
workflow_call_depth=2,
|
||||
http_client=ssrf_proxy,
|
||||
file_manager=file_manager,
|
||||
)
|
||||
|
||||
headers = executor._assembling_headers()
|
||||
|
||||
assert headers[WORKFLOW_CALL_DEPTH_HEADER] == "3"
|
||||
assert headers[WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER] == build_workflow_call_depth_signature(
|
||||
secret_key="",
|
||||
method="get",
|
||||
path="/triggers/webhook/test-webhook",
|
||||
depth="3",
|
||||
)
|
||||
|
||||
|
||||
def test_executor_log_masks_internal_depth_headers():
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Depth propagation",
|
||||
method="get",
|
||||
url="http://localhost:5001/triggers/webhook/test-webhook",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="X-Test: value",
|
||||
params="",
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
http_request_config=HttpRequestNodeConfig(
|
||||
max_connect_timeout=HTTP_REQUEST_CONFIG.max_connect_timeout,
|
||||
max_read_timeout=HTTP_REQUEST_CONFIG.max_read_timeout,
|
||||
max_write_timeout=HTTP_REQUEST_CONFIG.max_write_timeout,
|
||||
max_binary_size=HTTP_REQUEST_CONFIG.max_binary_size,
|
||||
max_text_size=HTTP_REQUEST_CONFIG.max_text_size,
|
||||
ssl_verify=HTTP_REQUEST_CONFIG.ssl_verify,
|
||||
ssrf_default_max_retries=HTTP_REQUEST_CONFIG.ssrf_default_max_retries,
|
||||
secret_key=TEST_SECRET_KEY,
|
||||
),
|
||||
variable_pool=variable_pool,
|
||||
workflow_call_depth=2,
|
||||
http_client=ssrf_proxy,
|
||||
file_manager=file_manager,
|
||||
)
|
||||
|
||||
raw_log = executor.to_log()
|
||||
|
||||
assert f"{WORKFLOW_CALL_DEPTH_HEADER}: [internal]" in raw_log
|
||||
assert f"{WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER}: [internal]" in raw_log
|
||||
assert "X-Dify-Workflow-Call-Depth: 3" not in raw_log
|
||||
|
||||
|
||||
def test_executor_log_masks_reserved_headers_regardless_of_case():
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
node_data = HttpRequestNodeData(
|
||||
title="Reserved header replacement",
|
||||
method="get",
|
||||
url="http://localhost:5001/triggers/webhook/test-webhook",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers=(
|
||||
"x-dify-workflow-call-depth: user-value\n"
|
||||
"x-dify-workflow-call-depth-signature: user-signature\n"
|
||||
"X-Test: value"
|
||||
),
|
||||
params="",
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
http_request_config=HttpRequestNodeConfig(
|
||||
max_connect_timeout=HTTP_REQUEST_CONFIG.max_connect_timeout,
|
||||
max_read_timeout=HTTP_REQUEST_CONFIG.max_read_timeout,
|
||||
max_write_timeout=HTTP_REQUEST_CONFIG.max_write_timeout,
|
||||
max_binary_size=HTTP_REQUEST_CONFIG.max_binary_size,
|
||||
max_text_size=HTTP_REQUEST_CONFIG.max_text_size,
|
||||
ssl_verify=HTTP_REQUEST_CONFIG.ssl_verify,
|
||||
ssrf_default_max_retries=HTTP_REQUEST_CONFIG.ssrf_default_max_retries,
|
||||
secret_key=TEST_SECRET_KEY,
|
||||
),
|
||||
variable_pool=variable_pool,
|
||||
workflow_call_depth=2,
|
||||
http_client=ssrf_proxy,
|
||||
file_manager=file_manager,
|
||||
)
|
||||
|
||||
raw_log = executor.to_log()
|
||||
|
||||
assert "x-dify-workflow-call-depth: [internal]" not in raw_log
|
||||
assert "x-dify-workflow-call-depth-signature: [internal]" not in raw_log
|
||||
assert f"{WORKFLOW_CALL_DEPTH_HEADER}: [internal]" in raw_log
|
||||
assert f"{WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER}: [internal]" in raw_log
|
||||
assert "user-signature" not in raw_log
|
||||
assert "user-value" not in raw_log
|
||||
|
||||
|
||||
def test_executor_propagates_workflow_call_depth_to_arbitrary_target_with_secret():
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
node_data = HttpRequestNodeData(
|
||||
title="External target",
|
||||
method="get",
|
||||
url="https://api.example.com/data",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="X-Test: value",
|
||||
params="",
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
http_request_config=HttpRequestNodeConfig(
|
||||
max_connect_timeout=HTTP_REQUEST_CONFIG.max_connect_timeout,
|
||||
max_read_timeout=HTTP_REQUEST_CONFIG.max_read_timeout,
|
||||
max_write_timeout=HTTP_REQUEST_CONFIG.max_write_timeout,
|
||||
max_binary_size=HTTP_REQUEST_CONFIG.max_binary_size,
|
||||
max_text_size=HTTP_REQUEST_CONFIG.max_text_size,
|
||||
ssl_verify=HTTP_REQUEST_CONFIG.ssl_verify,
|
||||
ssrf_default_max_retries=HTTP_REQUEST_CONFIG.ssrf_default_max_retries,
|
||||
secret_key=TEST_SECRET_KEY,
|
||||
),
|
||||
variable_pool=variable_pool,
|
||||
workflow_call_depth=2,
|
||||
http_client=ssrf_proxy,
|
||||
file_manager=file_manager,
|
||||
)
|
||||
|
||||
headers = executor._assembling_headers()
|
||||
|
||||
assert headers["X-Test"] == "value"
|
||||
assert headers[WORKFLOW_CALL_DEPTH_HEADER] == "3"
|
||||
assert headers[WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER] == build_workflow_call_depth_signature(
|
||||
secret_key=TEST_SECRET_KEY,
|
||||
method="get",
|
||||
path="/data",
|
||||
depth="3",
|
||||
)
|
||||
|
||||
|
||||
def test_executor_normalizes_empty_url_path_when_signing_workflow_call_depth():
|
||||
variable_pool = VariablePool(
|
||||
system_variables=SystemVariable.default(),
|
||||
user_inputs={},
|
||||
)
|
||||
node_data = HttpRequestNodeData(
|
||||
title="External target without explicit path",
|
||||
method="get",
|
||||
url="https://api.example.com",
|
||||
authorization=HttpRequestNodeAuthorization(type="no-auth"),
|
||||
headers="X-Test: value",
|
||||
params="",
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
node_data=node_data,
|
||||
timeout=HttpRequestNodeTimeout(connect=10, read=30, write=30),
|
||||
http_request_config=HttpRequestNodeConfig(
|
||||
max_connect_timeout=HTTP_REQUEST_CONFIG.max_connect_timeout,
|
||||
max_read_timeout=HTTP_REQUEST_CONFIG.max_read_timeout,
|
||||
max_write_timeout=HTTP_REQUEST_CONFIG.max_write_timeout,
|
||||
max_binary_size=HTTP_REQUEST_CONFIG.max_binary_size,
|
||||
max_text_size=HTTP_REQUEST_CONFIG.max_text_size,
|
||||
ssl_verify=HTTP_REQUEST_CONFIG.ssl_verify,
|
||||
ssrf_default_max_retries=HTTP_REQUEST_CONFIG.ssrf_default_max_retries,
|
||||
secret_key=TEST_SECRET_KEY,
|
||||
),
|
||||
variable_pool=variable_pool,
|
||||
workflow_call_depth=2,
|
||||
http_client=ssrf_proxy,
|
||||
file_manager=file_manager,
|
||||
)
|
||||
|
||||
headers = executor._assembling_headers()
|
||||
|
||||
assert headers["X-Test"] == "value"
|
||||
assert headers[WORKFLOW_CALL_DEPTH_HEADER] == "3"
|
||||
assert headers[WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER] == build_workflow_call_depth_signature(
|
||||
secret_key=TEST_SECRET_KEY,
|
||||
method="get",
|
||||
path="/",
|
||||
depth="3",
|
||||
)
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, UserFrom
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from dify_graph.graph_engine.command_channels.redis_channel import RedisChannel
|
||||
@@ -134,3 +136,27 @@ class TestWorkflowEntryRedisChannel:
|
||||
assert len(events) == 2
|
||||
assert events[0] == mock_event1
|
||||
assert events[1] == mock_event2
|
||||
|
||||
def test_workflow_entry_rejects_depth_over_limit(self):
|
||||
mock_graph = MagicMock()
|
||||
mock_variable_pool = MagicMock(spec=VariablePool)
|
||||
mock_graph_runtime_state = MagicMock(spec=GraphRuntimeState)
|
||||
mock_graph_runtime_state.variable_pool = mock_variable_pool
|
||||
|
||||
with (
|
||||
patch("core.workflow.workflow_entry.dify_config.WORKFLOW_CALL_MAX_DEPTH", 1),
|
||||
pytest.raises(ValueError, match="Max workflow call depth 1 reached."),
|
||||
):
|
||||
WorkflowEntry(
|
||||
tenant_id="test-tenant",
|
||||
app_id="test-app",
|
||||
workflow_id="test-workflow",
|
||||
graph_config={"nodes": [], "edges": []},
|
||||
graph=mock_graph,
|
||||
user_id="test-user",
|
||||
user_from=UserFrom.ACCOUNT,
|
||||
invoke_from=InvokeFrom.DEBUGGER,
|
||||
call_depth=2,
|
||||
variable_pool=mock_variable_pool,
|
||||
graph_runtime_state=mock_graph_runtime_state,
|
||||
)
|
||||
|
||||
@@ -4,6 +4,7 @@ from unittest.mock import MagicMock, patch
|
||||
import pytest
|
||||
|
||||
from libs.infinite_scroll_pagination import InfiniteScrollPagination
|
||||
from models.enums import FeedbackFromSource, FeedbackRating
|
||||
from models.model import App, AppMode, EndUser, Message
|
||||
from services.errors.message import (
|
||||
FirstMessageNotExistsError,
|
||||
@@ -820,14 +821,14 @@ class TestMessageServiceFeedback:
|
||||
app_model=app,
|
||||
message_id="msg-123",
|
||||
user=user,
|
||||
rating="like",
|
||||
rating=FeedbackRating.LIKE,
|
||||
content="Good answer",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result.rating == "like"
|
||||
assert result.rating == FeedbackRating.LIKE
|
||||
assert result.content == "Good answer"
|
||||
assert result.from_source == "user"
|
||||
assert result.from_source == FeedbackFromSource.USER
|
||||
mock_db.session.add.assert_called_once()
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
@@ -852,13 +853,13 @@ class TestMessageServiceFeedback:
|
||||
app_model=app,
|
||||
message_id="msg-123",
|
||||
user=user,
|
||||
rating="dislike",
|
||||
rating=FeedbackRating.DISLIKE,
|
||||
content="Bad answer",
|
||||
)
|
||||
|
||||
# Assert
|
||||
assert result == feedback
|
||||
assert feedback.rating == "dislike"
|
||||
assert feedback.rating == FeedbackRating.DISLIKE
|
||||
assert feedback.content == "Bad answer"
|
||||
mock_db.session.commit.assert_called_once()
|
||||
|
||||
|
||||
@@ -5,8 +5,13 @@ import pytest
|
||||
from flask import Flask
|
||||
from werkzeug.datastructures import FileStorage
|
||||
|
||||
from configs import dify_config
|
||||
from dify_graph.call_depth import build_workflow_call_depth_signature
|
||||
from dify_graph.constants import WORKFLOW_CALL_DEPTH_HEADER, WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER
|
||||
from services.trigger.webhook_service import WebhookService
|
||||
|
||||
TEST_SECRET_KEY = "test-secret-key"
|
||||
|
||||
|
||||
class TestWebhookServiceUnit:
|
||||
"""Unit tests for WebhookService focusing on business logic without database dependencies."""
|
||||
@@ -559,3 +564,266 @@ class TestWebhookServiceUnit:
|
||||
|
||||
result = _prepare_webhook_execution("test_webhook", is_debug=True)
|
||||
assert result == (mock_trigger, mock_workflow, mock_config, mock_data, None)
|
||||
|
||||
def test_extract_workflow_call_depth_defaults_to_zero_for_invalid_values(self):
|
||||
assert WebhookService.extract_workflow_call_depth({}, request_method="POST", request_path="/webhook") == 0
|
||||
assert (
|
||||
WebhookService.extract_workflow_call_depth(
|
||||
{WORKFLOW_CALL_DEPTH_HEADER: "abc"},
|
||||
request_method="POST",
|
||||
request_path="/webhook",
|
||||
)
|
||||
== 0
|
||||
)
|
||||
assert (
|
||||
WebhookService.extract_workflow_call_depth(
|
||||
{WORKFLOW_CALL_DEPTH_HEADER.lower(): "-1"},
|
||||
request_method="POST",
|
||||
request_path="/webhook",
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
def test_extract_workflow_call_depth_ignores_unsigned_external_header(self):
|
||||
assert (
|
||||
WebhookService.extract_workflow_call_depth(
|
||||
{WORKFLOW_CALL_DEPTH_HEADER: "5"},
|
||||
request_method="POST",
|
||||
request_path="/webhook",
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
def test_extract_workflow_call_depth_honors_signed_internal_header(self):
|
||||
with patch("services.trigger.webhook_service.dify_config.SECRET_KEY", TEST_SECRET_KEY):
|
||||
signature = build_workflow_call_depth_signature(
|
||||
secret_key=TEST_SECRET_KEY,
|
||||
method="POST",
|
||||
path="/triggers/webhook/test-webhook",
|
||||
depth="4",
|
||||
)
|
||||
|
||||
assert (
|
||||
WebhookService.extract_workflow_call_depth(
|
||||
{
|
||||
WORKFLOW_CALL_DEPTH_HEADER: "4",
|
||||
WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER: signature,
|
||||
},
|
||||
request_method="POST",
|
||||
request_path="/triggers/webhook/test-webhook",
|
||||
)
|
||||
== 4
|
||||
)
|
||||
|
||||
def test_extract_workflow_call_depth_accepts_mixed_case_reserved_headers(self):
|
||||
with patch("services.trigger.webhook_service.dify_config.SECRET_KEY", TEST_SECRET_KEY):
|
||||
signature = build_workflow_call_depth_signature(
|
||||
secret_key=TEST_SECRET_KEY,
|
||||
method="POST",
|
||||
path="/triggers/webhook/test-webhook",
|
||||
depth="4",
|
||||
)
|
||||
|
||||
assert (
|
||||
WebhookService.extract_workflow_call_depth(
|
||||
{
|
||||
"X-Dify-Workflow-Call-Depth": "4",
|
||||
"X-Dify-Workflow-Call-Depth-Signature": signature,
|
||||
},
|
||||
request_method="POST",
|
||||
request_path="/triggers/webhook/test-webhook",
|
||||
)
|
||||
== 4
|
||||
)
|
||||
|
||||
def test_extract_workflow_call_depth_rejects_signature_for_other_path(self):
|
||||
with patch("services.trigger.webhook_service.dify_config.SECRET_KEY", TEST_SECRET_KEY):
|
||||
wrong_signature = build_workflow_call_depth_signature(
|
||||
secret_key=TEST_SECRET_KEY,
|
||||
method="POST",
|
||||
path="/triggers/webhook/wrong-webhook",
|
||||
depth="4",
|
||||
)
|
||||
|
||||
assert (
|
||||
WebhookService.extract_workflow_call_depth(
|
||||
{
|
||||
WORKFLOW_CALL_DEPTH_HEADER: "4",
|
||||
WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER: wrong_signature,
|
||||
},
|
||||
request_method="POST",
|
||||
request_path="/triggers/webhook/right-webhook",
|
||||
)
|
||||
== 0
|
||||
)
|
||||
|
||||
@patch("services.trigger.webhook_service.dify_config")
|
||||
def test_extract_workflow_call_depth_honors_signature_with_empty_secret(self, mock_config):
|
||||
mock_config.SECRET_KEY = ""
|
||||
|
||||
signature = build_workflow_call_depth_signature(
|
||||
secret_key="",
|
||||
method="POST",
|
||||
path="/triggers/webhook/test-webhook",
|
||||
depth="4",
|
||||
)
|
||||
|
||||
assert (
|
||||
WebhookService.extract_workflow_call_depth(
|
||||
{
|
||||
WORKFLOW_CALL_DEPTH_HEADER: "4",
|
||||
WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER: signature,
|
||||
},
|
||||
request_method="POST",
|
||||
request_path="/triggers/webhook/test-webhook",
|
||||
)
|
||||
== 4
|
||||
)
|
||||
|
||||
@patch("services.trigger.webhook_service.QuotaType")
|
||||
@patch("services.trigger.webhook_service.EndUserService")
|
||||
@patch("services.trigger.webhook_service.AsyncWorkflowService")
|
||||
@patch("services.trigger.webhook_service.Session")
|
||||
@patch("services.trigger.webhook_service.db")
|
||||
def test_trigger_workflow_execution_preserves_header_depth(
|
||||
self,
|
||||
mock_db,
|
||||
mock_session,
|
||||
mock_async_workflow_service,
|
||||
mock_end_user_service,
|
||||
mock_quota_type,
|
||||
):
|
||||
webhook_trigger = MagicMock(app_id="app", tenant_id="tenant", node_id="root", webhook_id="webhook")
|
||||
workflow = MagicMock(id="workflow")
|
||||
mock_end_user = MagicMock()
|
||||
mock_end_user_service.get_or_create_end_user_by_type.return_value = mock_end_user
|
||||
mock_db.engine = MagicMock()
|
||||
mock_session.return_value.__enter__.return_value = MagicMock()
|
||||
signature = build_workflow_call_depth_signature(
|
||||
secret_key=TEST_SECRET_KEY,
|
||||
method="POST",
|
||||
path="/triggers/webhook/test-webhook",
|
||||
depth="4",
|
||||
)
|
||||
|
||||
with patch("services.trigger.webhook_service.dify_config.SECRET_KEY", TEST_SECRET_KEY):
|
||||
WebhookService.trigger_workflow_execution(
|
||||
webhook_trigger,
|
||||
{"method": "POST", "headers": {}, "query_params": {}, "body": {}, "files": {}},
|
||||
workflow,
|
||||
call_depth=WebhookService.extract_workflow_call_depth(
|
||||
{
|
||||
WORKFLOW_CALL_DEPTH_HEADER: "4",
|
||||
WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER: signature,
|
||||
},
|
||||
request_method="POST",
|
||||
request_path="/triggers/webhook/test-webhook",
|
||||
),
|
||||
)
|
||||
|
||||
trigger_data = mock_async_workflow_service.trigger_workflow_async.call_args.args[2]
|
||||
assert trigger_data.call_depth == 4
|
||||
|
||||
@patch("services.trigger.webhook_service.QuotaType")
|
||||
@patch("services.trigger.webhook_service.EndUserService")
|
||||
@patch("services.trigger.webhook_service.AsyncWorkflowService")
|
||||
@patch("services.trigger.webhook_service.Session")
|
||||
@patch("services.trigger.webhook_service.db")
|
||||
def test_trigger_workflow_execution_ignores_spoofed_external_depth(
|
||||
self,
|
||||
mock_db,
|
||||
mock_session,
|
||||
mock_async_workflow_service,
|
||||
mock_end_user_service,
|
||||
mock_quota_type,
|
||||
):
|
||||
webhook_trigger = MagicMock(app_id="app", tenant_id="tenant", node_id="root", webhook_id="webhook")
|
||||
workflow = MagicMock(id="workflow")
|
||||
mock_end_user_service.get_or_create_end_user_by_type.return_value = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
mock_session.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
WebhookService.trigger_workflow_execution(
|
||||
webhook_trigger,
|
||||
{"method": "POST", "headers": {}, "query_params": {}, "body": {}, "files": {}},
|
||||
workflow,
|
||||
call_depth=WebhookService.extract_workflow_call_depth(
|
||||
{WORKFLOW_CALL_DEPTH_HEADER: "5"},
|
||||
request_method="POST",
|
||||
request_path="/triggers/webhook/test-webhook",
|
||||
),
|
||||
)
|
||||
|
||||
trigger_data = mock_async_workflow_service.trigger_workflow_async.call_args.args[2]
|
||||
assert trigger_data.call_depth == 0
|
||||
|
||||
@patch("services.trigger.webhook_service.QuotaType")
|
||||
@patch("services.trigger.webhook_service.EndUserService")
|
||||
@patch("services.trigger.webhook_service.AsyncWorkflowService")
|
||||
@patch("services.trigger.webhook_service.Session")
|
||||
@patch("services.trigger.webhook_service.db")
|
||||
def test_trigger_workflow_execution_rejects_signature_captured_from_non_webhook_request(
|
||||
self,
|
||||
mock_db,
|
||||
mock_session,
|
||||
mock_async_workflow_service,
|
||||
mock_end_user_service,
|
||||
mock_quota_type,
|
||||
):
|
||||
webhook_trigger = MagicMock(app_id="app", tenant_id="tenant", node_id="root", webhook_id="webhook")
|
||||
workflow = MagicMock(id="workflow")
|
||||
mock_end_user_service.get_or_create_end_user_by_type.return_value = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
mock_session.return_value.__enter__.return_value = MagicMock()
|
||||
captured_signature = build_workflow_call_depth_signature(
|
||||
secret_key=dify_config.SECRET_KEY,
|
||||
method="GET",
|
||||
path="/v1/external-endpoint",
|
||||
depth="5",
|
||||
)
|
||||
|
||||
WebhookService.trigger_workflow_execution(
|
||||
webhook_trigger,
|
||||
{"method": "POST", "headers": {}, "query_params": {}, "body": {}, "files": {}},
|
||||
workflow,
|
||||
call_depth=WebhookService.extract_workflow_call_depth(
|
||||
{
|
||||
WORKFLOW_CALL_DEPTH_HEADER: "5",
|
||||
WORKFLOW_CALL_DEPTH_SIGNATURE_HEADER: captured_signature,
|
||||
},
|
||||
request_method="POST",
|
||||
request_path="/triggers/webhook/test-webhook",
|
||||
),
|
||||
)
|
||||
|
||||
trigger_data = mock_async_workflow_service.trigger_workflow_async.call_args.args[2]
|
||||
assert trigger_data.call_depth == 0
|
||||
|
||||
@patch("services.trigger.webhook_service.QuotaType")
|
||||
@patch("services.trigger.webhook_service.EndUserService")
|
||||
@patch("services.trigger.webhook_service.AsyncWorkflowService")
|
||||
@patch("services.trigger.webhook_service.Session")
|
||||
@patch("services.trigger.webhook_service.db")
|
||||
def test_trigger_workflow_execution_does_not_require_request_context_when_call_depth_is_passed(
|
||||
self,
|
||||
mock_db,
|
||||
mock_session,
|
||||
mock_async_workflow_service,
|
||||
mock_end_user_service,
|
||||
mock_quota_type,
|
||||
):
|
||||
webhook_trigger = MagicMock(app_id="app", tenant_id="tenant", node_id="root", webhook_id="webhook")
|
||||
workflow = MagicMock(id="workflow")
|
||||
mock_end_user_service.get_or_create_end_user_by_type.return_value = MagicMock()
|
||||
mock_db.engine = MagicMock()
|
||||
mock_session.return_value.__enter__.return_value = MagicMock()
|
||||
|
||||
WebhookService.trigger_workflow_execution(
|
||||
webhook_trigger,
|
||||
{"method": "POST", "headers": {}, "query_params": {}, "body": {}, "files": {}},
|
||||
workflow,
|
||||
call_depth=4,
|
||||
)
|
||||
|
||||
trigger_data = mock_async_workflow_service.trigger_workflow_async.call_args.args[2]
|
||||
assert trigger_data.call_depth == 4
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from core.app.apps.workflow.app_generator import SKIP_PREPARE_USER_INPUTS_KEY
|
||||
from services.workflow.entities import WebhookTriggerData
|
||||
from tasks import async_workflow_tasks
|
||||
@@ -16,3 +18,41 @@ def test_build_generator_args_sets_skip_flag_for_webhook():
|
||||
|
||||
assert args[SKIP_PREPARE_USER_INPUTS_KEY] is True
|
||||
assert args["inputs"]["webhook_data"]["body"]["foo"] == "bar"
|
||||
|
||||
|
||||
def test_execute_workflow_common_uses_trigger_call_depth():
|
||||
trigger_data = WebhookTriggerData(
|
||||
app_id="app",
|
||||
tenant_id="tenant",
|
||||
workflow_id="workflow",
|
||||
root_node_id="node",
|
||||
inputs={"webhook_data": {"body": {}}},
|
||||
call_depth=3,
|
||||
)
|
||||
trigger_log = MagicMock(
|
||||
id="log-id",
|
||||
app_id="app",
|
||||
workflow_id="workflow",
|
||||
trigger_data=trigger_data.model_dump_json(),
|
||||
)
|
||||
trigger_log_repo = MagicMock()
|
||||
trigger_log_repo.get_by_id.return_value = trigger_log
|
||||
session = MagicMock()
|
||||
session.scalar.side_effect = [MagicMock(), MagicMock()]
|
||||
session_context = MagicMock()
|
||||
session_context.__enter__.return_value = session
|
||||
workflow_generator = MagicMock()
|
||||
|
||||
with (
|
||||
patch.object(async_workflow_tasks.session_factory, "create_session", return_value=session_context),
|
||||
patch.object(async_workflow_tasks, "SQLAlchemyWorkflowTriggerLogRepository", return_value=trigger_log_repo),
|
||||
patch.object(async_workflow_tasks, "_get_user", return_value=MagicMock()),
|
||||
patch.object(async_workflow_tasks, "WorkflowAppGenerator", return_value=workflow_generator),
|
||||
):
|
||||
async_workflow_tasks._execute_workflow_common(
|
||||
async_workflow_tasks.WorkflowTaskData(workflow_trigger_log_id="log-id"),
|
||||
MagicMock(),
|
||||
MagicMock(),
|
||||
)
|
||||
|
||||
assert workflow_generator.generate.call_args.kwargs["call_depth"] == 3
|
||||
|
||||
Reference in New Issue
Block a user