Compare commits

...

3 Commits

Author SHA1 Message Date
yunlu.wen
3060a19dd6 resolve conflict
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
2026-03-27 16:57:14 +08:00
Xiyuan Chen
fb2eb3aadf feat: enterprise otel exporter (#33138)
Co-authored-by: QuantumGhost <obelisk.reg+git@gmail.com>
Co-authored-by: Yunlu Wen <yunlu.wen@dify.ai>
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-27 16:56:03 +08:00
Wu Tianwei
d1f6edd7ab fix(prompt-editor): fix unexpected blur effect in prompt editor (#34114)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-03-26 14:44:40 +08:00
60 changed files with 9891 additions and 208 deletions

View File

@@ -143,6 +143,7 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_compress,
ext_database,
ext_enterprise_telemetry,
ext_fastopenapi,
ext_forward_refs,
ext_hosting_provider,
@@ -193,6 +194,7 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_fastopenapi,
ext_otel,
ext_enterprise_telemetry,
ext_request_logging,
ext_session_factory,
]

View File

@@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, Settings
from libs.file_utils import search_file_upwards
from .deploy import DeploymentConfig
from .enterprise import EnterpriseFeatureConfig
from .enterprise import EnterpriseFeatureConfig, EnterpriseTelemetryConfig
from .extra import ExtraServiceConfig
from .feature import FeatureConfig
from .middleware import MiddlewareConfig
@@ -73,6 +73,8 @@ class DifyConfig(
# Enterprise feature configs
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
EnterpriseFeatureConfig,
# Enterprise telemetry configs
EnterpriseTelemetryConfig,
):
model_config = SettingsConfigDict(
# read from dotenv format config file

View File

@@ -22,3 +22,52 @@ class EnterpriseFeatureConfig(BaseSettings):
ENTERPRISE_REQUEST_TIMEOUT: int = Field(
ge=1, description="Maximum timeout in seconds for enterprise requests", default=5
)
class EnterpriseTelemetryConfig(BaseSettings):
"""
Configuration for enterprise telemetry.
"""
ENTERPRISE_TELEMETRY_ENABLED: bool = Field(
description="Enable enterprise telemetry collection (also requires ENTERPRISE_ENABLED=true).",
default=False,
)
ENTERPRISE_OTLP_ENDPOINT: str = Field(
description="Enterprise OTEL collector endpoint.",
default="",
)
ENTERPRISE_OTLP_HEADERS: str = Field(
description="Auth headers for OTLP export (key=value,key2=value2).",
default="",
)
ENTERPRISE_OTLP_PROTOCOL: str = Field(
description="OTLP protocol: 'http' or 'grpc' (default: http).",
default="http",
)
ENTERPRISE_OTLP_API_KEY: str = Field(
description="Bearer token for enterprise OTLP export authentication.",
default="",
)
ENTERPRISE_INCLUDE_CONTENT: bool = Field(
description="Include input/output content in traces (privacy toggle).",
# Setting the default value to False to avoid accidentally log PII data in traces.
default=False,
)
ENTERPRISE_SERVICE_NAME: str = Field(
description="Service name for OTEL resource.",
default="dify",
)
ENTERPRISE_OTEL_SAMPLING_RATE: float = Field(
description="Sampling rate for enterprise traces (0.0 to 1.0, default 1.0 = 100%).",
default=1.0,
ge=0.0,
le=1.0,
)

View File

@@ -9,8 +9,8 @@ from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
class BaseTraceInfo(BaseModel):
message_id: str | None = None
message_data: Any | None = None
inputs: Union[str, dict[str, Any], list] | None = None
outputs: Union[str, dict[str, Any], list] | None = None
inputs: Union[str, dict[str, Any], list[Any]] | None = None
outputs: Union[str, dict[str, Any], list[Any]] | None = None
start_time: datetime | None = None
end_time: datetime | None = None
metadata: dict[str, Any]
@@ -18,7 +18,7 @@ class BaseTraceInfo(BaseModel):
@field_validator("inputs", "outputs")
@classmethod
def ensure_type(cls, v):
def ensure_type(cls, v: str | dict[str, Any] | list[Any] | None) -> str | dict[str, Any] | list[Any] | None:
if v is None:
return None
if isinstance(v, str | dict | list):
@@ -27,6 +27,48 @@ class BaseTraceInfo(BaseModel):
model_config = ConfigDict(protected_namespaces=())
@property
def resolved_trace_id(self) -> str | None:
"""Get trace_id with intelligent fallback.
Priority:
1. External trace_id (from X-Trace-Id header)
2. workflow_run_id (if this trace type has it)
3. message_id (as final fallback)
"""
if self.trace_id:
return self.trace_id
# Try workflow_run_id (only exists on workflow-related traces)
workflow_run_id = getattr(self, "workflow_run_id", None)
if workflow_run_id:
return workflow_run_id
# Final fallback to message_id
return str(self.message_id) if self.message_id else None
@property
def resolved_parent_context(self) -> tuple[str | None, str | None]:
"""Resolve cross-workflow parent linking from metadata.
Extracts typed parent IDs from the untyped ``parent_trace_context``
metadata dict (set by tool_node when invoking nested workflows).
Returns:
(trace_correlation_override, parent_span_id_source) where
trace_correlation_override is the outer workflow_run_id and
parent_span_id_source is the outer node_execution_id.
"""
parent_ctx = self.metadata.get("parent_trace_context")
if not isinstance(parent_ctx, dict):
return None, None
trace_override = parent_ctx.get("parent_workflow_run_id")
parent_span = parent_ctx.get("parent_node_execution_id")
return (
trace_override if isinstance(trace_override, str) else None,
parent_span if isinstance(parent_span, str) else None,
)
@field_serializer("start_time", "end_time")
def serialize_datetime(self, dt: datetime | None) -> str | None:
if dt is None:
@@ -48,7 +90,10 @@ class WorkflowTraceInfo(BaseTraceInfo):
workflow_run_version: str
error: str | None = None
total_tokens: int
prompt_tokens: int | None = None
completion_tokens: int | None = None
file_list: list[str]
invoked_by: str | None = None
query: str
metadata: dict[str, Any]
@@ -59,7 +104,7 @@ class MessageTraceInfo(BaseTraceInfo):
answer_tokens: int
total_tokens: int
error: str | None = None
file_list: Union[str, dict[str, Any], list] | None = None
file_list: Union[str, dict[str, Any], list[Any]] | None = None
message_file_data: Any | None = None
conversation_mode: str
gen_ai_server_time_to_first_token: float | None = None
@@ -106,7 +151,7 @@ class ToolTraceInfo(BaseTraceInfo):
tool_config: dict[str, Any]
time_cost: Union[int, float]
tool_parameters: dict[str, Any]
file_url: Union[str, None, list] = None
file_url: Union[str, None, list[str]] = None
class GenerateNameTraceInfo(BaseTraceInfo):
@@ -114,6 +159,79 @@ class GenerateNameTraceInfo(BaseTraceInfo):
tenant_id: str
class PromptGenerationTraceInfo(BaseTraceInfo):
"""Trace information for prompt generation operations (rule-generate, code-generate, etc.)."""
tenant_id: str
user_id: str
app_id: str | None = None
operation_type: str
instruction: str
prompt_tokens: int
completion_tokens: int
total_tokens: int
model_provider: str
model_name: str
latency: float
total_price: float | None = None
currency: str | None = None
error: str | None = None
model_config = ConfigDict(protected_namespaces=())
class WorkflowNodeTraceInfo(BaseTraceInfo):
workflow_id: str
workflow_run_id: str
tenant_id: str
node_execution_id: str
node_id: str
node_type: str
title: str
status: str
error: str | None = None
elapsed_time: float
index: int
predecessor_node_id: str | None = None
total_tokens: int = 0
total_price: float = 0.0
currency: str | None = None
model_provider: str | None = None
model_name: str | None = None
prompt_tokens: int | None = None
completion_tokens: int | None = None
tool_name: str | None = None
iteration_id: str | None = None
iteration_index: int | None = None
loop_id: str | None = None
loop_index: int | None = None
parallel_id: str | None = None
node_inputs: Mapping[str, Any] | None = None
node_outputs: Mapping[str, Any] | None = None
process_data: Mapping[str, Any] | None = None
invoked_by: str | None = None
model_config = ConfigDict(protected_namespaces=())
class DraftNodeExecutionTrace(WorkflowNodeTraceInfo):
pass
class TaskData(BaseModel):
app_id: str
trace_info_type: str
@@ -128,11 +246,31 @@ trace_info_info_map = {
"DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo,
"ToolTraceInfo": ToolTraceInfo,
"GenerateNameTraceInfo": GenerateNameTraceInfo,
"PromptGenerationTraceInfo": PromptGenerationTraceInfo,
"WorkflowNodeTraceInfo": WorkflowNodeTraceInfo,
"DraftNodeExecutionTrace": DraftNodeExecutionTrace,
}
class OperationType(StrEnum):
"""Operation type for token metric labels.
Used as a metric attribute on ``dify.tokens.input`` / ``dify.tokens.output``
counters so consumers can break down token usage by operation.
"""
WORKFLOW = "workflow"
NODE_EXECUTION = "node_execution"
MESSAGE = "message"
RULE_GENERATE = "rule_generate"
CODE_GENERATE = "code_generate"
STRUCTURED_OUTPUT = "structured_output"
INSTRUCTION_MODIFY = "instruction_modify"
class TraceTaskName(StrEnum):
CONVERSATION_TRACE = "conversation"
DRAFT_NODE_EXECUTION_TRACE = "draft_node_execution"
WORKFLOW_TRACE = "workflow"
MESSAGE_TRACE = "message"
MODERATION_TRACE = "moderation"
@@ -140,4 +278,6 @@ class TraceTaskName(StrEnum):
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
TOOL_TRACE = "tool"
GENERATE_NAME_TRACE = "generate_conversation_name"
PROMPT_GENERATION_TRACE = "prompt_generation"
NODE_EXECUTION_TRACE = "node_execution"
DATASOURCE_TRACE = "datasource"

View File

@@ -15,22 +15,32 @@ from sqlalchemy import select
from sqlalchemy.orm import Session, sessionmaker
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum
from core.ops.entities.config_entity import (
OPS_FILE_PATH,
TracingProviderEnum,
)
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
DraftNodeExecutionTrace,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
PromptGenerationTraceInfo,
SuggestedQuestionTraceInfo,
TaskData,
ToolTraceInfo,
TraceTaskName,
WorkflowNodeTraceInfo,
WorkflowTraceInfo,
)
from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.engine import db
from models.account import Tenant
from models.dataset import Dataset
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.provider import Provider, ProviderCredential, ProviderModel, ProviderModelCredential, ProviderType
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from models.workflow import WorkflowAppLog
from tasks.ops_trace_task import process_trace_tasks
@@ -40,9 +50,144 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]:
"""Return (app_name, workspace_name) for the given IDs. Falls back to empty strings."""
app_name = ""
workspace_name = ""
if not app_id and not tenant_id:
return app_name, workspace_name
with Session(db.engine) as session:
if app_id:
name = session.scalar(select(App.name).where(App.id == app_id))
if name:
app_name = name
if tenant_id:
name = session.scalar(select(Tenant.name).where(Tenant.id == tenant_id))
if name:
workspace_name = name
return app_name, workspace_name
_PROVIDER_TYPE_TO_MODEL: dict[str, type] = {
"builtin": BuiltinToolProvider,
"plugin": BuiltinToolProvider,
"api": ApiToolProvider,
"workflow": WorkflowToolProvider,
"mcp": MCPToolProvider,
}
def _lookup_credential_name(credential_id: str | None, provider_type: str | None) -> str:
if not credential_id:
return ""
model_cls = _PROVIDER_TYPE_TO_MODEL.get(provider_type or "")
if not model_cls:
return ""
with Session(db.engine) as session:
name = session.scalar(select(model_cls.name).where(model_cls.id == credential_id)) # type: ignore[attr-defined]
return str(name) if name else ""
def _lookup_llm_credential_info(
tenant_id: str | None, provider: str | None, model: str | None, model_type: str | None = "llm"
) -> tuple[str | None, str]:
"""
Lookup LLM credential ID and name for the given provider and model.
Returns (credential_id, credential_name).
Handles async timing issues gracefully - if credential is deleted between lookups,
returns the ID but empty name rather than failing.
"""
if not tenant_id or not provider:
return None, ""
try:
with Session(db.engine) as session:
# Try to find provider-level or model-level configuration
provider_record = session.scalar(
select(Provider).where(
Provider.tenant_id == tenant_id,
Provider.provider_name == provider,
Provider.provider_type == ProviderType.CUSTOM,
)
)
if not provider_record:
return None, ""
# Check if there's a model-specific config
credential_id = None
credential_name = ""
is_model_level = False
if model:
# Try model-level first
model_record = session.scalar(
select(ProviderModel).where(
ProviderModel.tenant_id == tenant_id,
ProviderModel.provider_name == provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type,
)
)
if model_record and model_record.credential_id:
credential_id = model_record.credential_id
is_model_level = True
if not credential_id and provider_record.credential_id:
# Fall back to provider-level credential
credential_id = provider_record.credential_id
is_model_level = False
# Lookup credential_name if we have credential_id
if credential_id:
try:
if is_model_level:
# Query ProviderModelCredential
cred_name = session.scalar(
select(ProviderModelCredential.credential_name).where(
ProviderModelCredential.id == credential_id
)
)
else:
# Query ProviderCredential
cred_name = session.scalar(
select(ProviderCredential.credential_name).where(ProviderCredential.id == credential_id)
)
if cred_name:
credential_name = str(cred_name)
except Exception as e:
# Credential might have been deleted between lookups (async timing)
# Return ID but empty name rather than failing
logger.warning(
"Failed to lookup credential name for credential_id=%s (provider=%s, model=%s): %s",
credential_id,
provider,
model,
str(e),
exc_info=True,
)
return credential_id, credential_name
except Exception as e:
# Database query failed or other unexpected error
# Return empty rather than propagating error to telemetry emission
logger.warning(
"Failed to lookup LLM credential info for tenant_id=%s, provider=%s, model=%s: %s",
tenant_id,
provider,
model,
str(e),
exc_info=True,
)
return None, ""
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
def __getitem__(self, key: str) -> dict[str, Any]:
match key:
def __getitem__(self, provider: str) -> dict[str, Any]:
match provider:
case TracingProviderEnum.LANGFUSE:
from core.ops.entities.config_entity import LangfuseConfig
from core.ops.langfuse_trace.langfuse_trace import LangFuseDataTrace
@@ -149,7 +294,7 @@ class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
}
case _:
raise KeyError(f"Unsupported tracing provider: {key}")
raise KeyError(f"Unsupported tracing provider: {provider}")
provider_config_map = OpsTraceProviderConfigMap()
@@ -314,6 +459,10 @@ class OpsTraceManager:
if app_id is None:
return None
# Handle storage_id format (tenant-{uuid}) - not a real app_id
if isinstance(app_id, str) and app_id.startswith("tenant-"):
return None
app: App | None = db.session.query(App).where(App.id == app_id).first()
if app is None:
@@ -466,8 +615,6 @@ class TraceTask:
@classmethod
def _get_workflow_run_repo(cls):
from repositories.factory import DifyAPIRepositoryFactory
if cls._workflow_run_repo is None:
with cls._repo_lock:
if cls._workflow_run_repo is None:
@@ -478,6 +625,77 @@ class TraceTask:
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
return cls._workflow_run_repo
@classmethod
def _calculate_workflow_token_split(
cls, session: "Session", workflow_run_id: str, tenant_id: str
) -> tuple[int, int]:
"""Sum prompt/completion tokens across all node executions for a workflow run.
Reads from the ``outputs`` column (where LLM nodes store ``usage.prompt_tokens``
and ``usage.completion_tokens``) rather than ``execution_metadata``, which only
carries ``total_tokens``. Projects only the ``outputs`` column to avoid loading
large JSON blobs unnecessarily.
"""
import json
from models.workflow import WorkflowNodeExecutionModel
rows = (
session.execute(
select(WorkflowNodeExecutionModel.outputs).where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
)
)
.scalars()
.all()
)
total_prompt = 0
total_completion = 0
for raw in rows:
if not raw:
continue
try:
outputs = json.loads(raw) if isinstance(raw, str) else raw
except (ValueError, TypeError):
continue
if not isinstance(outputs, dict):
continue
usage = outputs.get("usage")
if not isinstance(usage, dict):
continue
prompt = usage.get("prompt_tokens")
if isinstance(prompt, (int, float)):
total_prompt += int(prompt)
completion = usage.get("completion_tokens")
if isinstance(completion, (int, float)):
total_completion += int(completion)
return (total_prompt, total_completion)
@classmethod
def _get_user_id_from_metadata(cls, metadata: dict[str, Any]) -> str:
"""Extract user ID from metadata, prioritizing end_user over account.
Returns the actual user ID (end_user or account) who invoked the workflow,
regardless of invoke_from context.
"""
# Priority 1: End user (external users via API/WebApp)
if user_id := metadata.get("from_end_user_id"):
return f"end_user:{user_id}"
# Priority 2: Account user (internal users via console/debugger)
if user_id := metadata.get("from_account_id"):
return f"account:{user_id}"
# Priority 3: User (internal users via console/debugger)
if user_id := metadata.get("user_id"):
return f"user:{user_id}"
return "anonymous"
def __init__(
self,
trace_type: Any,
@@ -491,6 +709,7 @@ class TraceTask:
self.trace_type = trace_type
self.message_id = message_id
self.workflow_run_id = workflow_execution.id_ if workflow_execution else None
self.workflow_total_tokens: int | None = workflow_execution.total_tokens if workflow_execution else None
self.conversation_id = conversation_id
self.user_id = user_id
self.timer = timer
@@ -498,6 +717,8 @@ class TraceTask:
self.app_id = None
self.trace_id = None
self.kwargs = kwargs
if user_id is not None and "user_id" not in self.kwargs:
self.kwargs["user_id"] = user_id
external_trace_id = kwargs.get("external_trace_id")
if external_trace_id:
self.trace_id = external_trace_id
@@ -509,9 +730,12 @@ class TraceTask:
preprocess_map = {
TraceTaskName.CONVERSATION_TRACE: lambda: self.conversation_trace(**self.kwargs),
TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
workflow_run_id=self.workflow_run_id,
conversation_id=self.conversation_id,
user_id=self.user_id,
total_tokens_override=self.workflow_total_tokens,
),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id, **self.kwargs),
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
message_id=self.message_id, timer=self.timer, **self.kwargs
),
@@ -527,6 +751,9 @@ class TraceTask:
TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
),
TraceTaskName.PROMPT_GENERATION_TRACE: lambda: self.prompt_generation_trace(**self.kwargs),
TraceTaskName.NODE_EXECUTION_TRACE: lambda: self.node_execution_trace(**self.kwargs),
TraceTaskName.DRAFT_NODE_EXECUTION_TRACE: lambda: self.draft_node_execution_trace(**self.kwargs),
}
return preprocess_map.get(self.trace_type, lambda: None)()
@@ -541,6 +768,7 @@ class TraceTask:
workflow_run_id: str | None,
conversation_id: str | None,
user_id: str | None,
total_tokens_override: int | None = None,
):
if not workflow_run_id:
return {}
@@ -560,7 +788,7 @@ class TraceTask:
workflow_run_version = workflow_run.version
error = workflow_run.error or ""
total_tokens = workflow_run.total_tokens
total_tokens = total_tokens_override if total_tokens_override is not None else workflow_run.total_tokens
file_list = workflow_run_inputs.get("sys.file") or []
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
@@ -581,8 +809,18 @@ class TraceTask:
Message.workflow_run_id == workflow_run_id,
)
message_id = session.scalar(message_data_stmt)
prompt_tokens, completion_tokens = self._calculate_workflow_token_split(
session, workflow_run_id=workflow_run_id, tenant_id=tenant_id
)
metadata = {
from core.telemetry.gateway import is_enterprise_telemetry_enabled
if is_enterprise_telemetry_enabled():
app_name, workspace_name = _lookup_app_and_workspace_names(workflow_run.app_id, tenant_id)
else:
app_name, workspace_name = "", ""
metadata: dict[str, Any] = {
"workflow_id": workflow_id,
"conversation_id": conversation_id,
"workflow_run_id": workflow_run_id,
@@ -595,8 +833,14 @@ class TraceTask:
"triggered_from": workflow_run.triggered_from,
"user_id": user_id,
"app_id": workflow_run.app_id,
"app_name": app_name,
"workspace_name": workspace_name,
}
parent_trace_context = self.kwargs.get("parent_trace_context")
if parent_trace_context:
metadata["parent_trace_context"] = parent_trace_context
workflow_trace_info = WorkflowTraceInfo(
trace_id=self.trace_id,
workflow_data=workflow_run.to_dict(),
@@ -611,6 +855,8 @@ class TraceTask:
workflow_run_version=workflow_run_version,
error=error,
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
file_list=file_list,
query=query,
metadata=metadata,
@@ -618,10 +864,11 @@ class TraceTask:
message_id=message_id,
start_time=workflow_run.created_at,
end_time=workflow_run.finished_at,
invoked_by=self._get_user_id_from_metadata(metadata),
)
return workflow_trace_info
def message_trace(self, message_id: str | None):
def message_trace(self, message_id: str | None, **kwargs):
if not message_id:
return {}
message_data = get_message_data(message_id)
@@ -644,6 +891,19 @@ class TraceTask:
streaming_metrics = self._extract_streaming_metrics(message_data)
tenant_id = ""
with Session(db.engine) as session:
tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id))
if tid:
tenant_id = str(tid)
from core.telemetry.gateway import is_enterprise_telemetry_enabled
if is_enterprise_telemetry_enabled():
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
else:
app_name, workspace_name = "", ""
metadata = {
"conversation_id": message_data.conversation_id,
"ls_provider": message_data.model_provider,
@@ -655,7 +915,14 @@ class TraceTask:
"workflow_run_id": message_data.workflow_run_id,
"from_source": message_data.from_source,
"message_id": message_id,
"tenant_id": tenant_id,
"app_id": message_data.app_id,
"user_id": message_data.from_end_user_id or message_data.from_account_id,
"app_name": app_name,
"workspace_name": workspace_name,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
message_tokens = message_data.message_tokens
@@ -672,7 +939,9 @@ class TraceTask:
outputs=message_data.answer,
file_list=file_list,
start_time=created_at,
end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
end_time=message_data.updated_at
if message_data.updated_at and message_data.updated_at > created_at
else created_at + timedelta(seconds=message_data.provider_response_latency),
metadata=metadata,
message_file_data=message_file_data,
conversation_mode=conversation_mode,
@@ -697,6 +966,8 @@ class TraceTask:
"preset_response": moderation_result.preset_response,
"query": moderation_result.query,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
# get workflow_app_log_id
workflow_app_log_id = None
@@ -738,6 +1009,8 @@ class TraceTask:
"workflow_run_id": message_data.workflow_run_id,
"from_source": message_data.from_source,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
# get workflow_app_log_id
workflow_app_log_id = None
@@ -777,6 +1050,52 @@ class TraceTask:
if not message_data:
return {}
tenant_id = ""
with Session(db.engine) as session:
tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id))
if tid:
tenant_id = str(tid)
from core.telemetry.gateway import is_enterprise_telemetry_enabled
if is_enterprise_telemetry_enabled():
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
else:
app_name, workspace_name = "", ""
doc_list = [doc.model_dump() for doc in documents] if documents else []
dataset_ids: set[str] = set()
for doc in doc_list:
doc_meta = doc.get("metadata") or {}
did = doc_meta.get("dataset_id")
if did:
dataset_ids.add(did)
embedding_models: dict[str, dict[str, str]] = {}
if dataset_ids:
with Session(db.engine) as session:
rows = session.execute(
select(Dataset.id, Dataset.embedding_model, Dataset.embedding_model_provider).where(
Dataset.id.in_(list(dataset_ids))
)
).all()
for row in rows:
embedding_models[str(row[0])] = {
"embedding_model": row[1] or "",
"embedding_model_provider": row[2] or "",
}
# Extract rerank model info from retrieval_model kwargs
rerank_model_provider = ""
rerank_model_name = ""
if "retrieval_model" in kwargs:
retrieval_model = kwargs["retrieval_model"]
if isinstance(retrieval_model, dict):
reranking_model = retrieval_model.get("reranking_model")
if isinstance(reranking_model, dict):
rerank_model_provider = reranking_model.get("reranking_provider_name", "")
rerank_model_name = reranking_model.get("reranking_model_name", "")
metadata = {
"message_id": message_id,
"ls_provider": message_data.model_provider,
@@ -787,13 +1106,23 @@ class TraceTask:
"agent_based": message_data.agent_based,
"workflow_run_id": message_data.workflow_run_id,
"from_source": message_data.from_source,
"tenant_id": tenant_id,
"app_id": message_data.app_id,
"user_id": message_data.from_end_user_id or message_data.from_account_id,
"app_name": app_name,
"workspace_name": workspace_name,
"embedding_models": embedding_models,
"rerank_model_provider": rerank_model_provider,
"rerank_model_name": rerank_model_name,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
trace_id=self.trace_id,
message_id=message_id,
inputs=message_data.query or message_data.inputs,
documents=[doc.model_dump() for doc in documents] if documents else [],
documents=doc_list,
start_time=timer.get("start"),
end_time=timer.get("end"),
metadata=metadata,
@@ -836,6 +1165,10 @@ class TraceTask:
"error": error,
"tool_parameters": tool_parameters,
}
if message_data.workflow_run_id:
metadata["workflow_run_id"] = message_data.workflow_run_id
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
file_url = ""
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
@@ -890,6 +1223,8 @@ class TraceTask:
"conversation_id": conversation_id,
"tenant_id": tenant_id,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
generate_name_trace_info = GenerateNameTraceInfo(
trace_id=self.trace_id,
@@ -904,6 +1239,182 @@ class TraceTask:
return generate_name_trace_info
def prompt_generation_trace(self, **kwargs) -> PromptGenerationTraceInfo | dict:
tenant_id = kwargs.get("tenant_id", "")
user_id = kwargs.get("user_id", "")
app_id = kwargs.get("app_id")
operation_type = kwargs.get("operation_type", "")
instruction = kwargs.get("instruction", "")
generated_output = kwargs.get("generated_output", "")
prompt_tokens = kwargs.get("prompt_tokens", 0)
completion_tokens = kwargs.get("completion_tokens", 0)
total_tokens = kwargs.get("total_tokens", 0)
model_provider = kwargs.get("model_provider", "")
model_name = kwargs.get("model_name", "")
latency = kwargs.get("latency", 0.0)
timer = kwargs.get("timer")
start_time = timer.get("start") if timer else None
end_time = timer.get("end") if timer else None
total_price = kwargs.get("total_price")
currency = kwargs.get("currency")
error = kwargs.get("error")
app_name = None
workspace_name = None
if app_id:
app_name, workspace_name = _lookup_app_and_workspace_names(app_id, tenant_id)
metadata = {
"tenant_id": tenant_id,
"user_id": user_id,
"app_id": app_id or "",
"app_name": app_name,
"workspace_name": workspace_name,
"operation_type": operation_type,
"model_provider": model_provider,
"model_name": model_name,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
return PromptGenerationTraceInfo(
trace_id=self.trace_id,
inputs=instruction,
outputs=generated_output,
start_time=start_time,
end_time=end_time,
metadata=metadata,
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
operation_type=operation_type,
instruction=instruction,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model_provider=model_provider,
model_name=model_name,
latency=latency,
total_price=total_price,
currency=currency,
error=error,
)
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict:
node_data: dict = kwargs.get("node_execution_data", {})
if not node_data:
return {}
from core.telemetry.gateway import is_enterprise_telemetry_enabled
if is_enterprise_telemetry_enabled():
app_name, workspace_name = _lookup_app_and_workspace_names(
node_data.get("app_id"), node_data.get("tenant_id")
)
else:
app_name, workspace_name = "", ""
# Try tool credential lookup first
credential_id = node_data.get("credential_id")
if is_enterprise_telemetry_enabled():
credential_name = _lookup_credential_name(credential_id, node_data.get("credential_provider_type"))
# If no credential_id found (e.g., LLM nodes), try LLM credential lookup
if not credential_id:
llm_cred_id, llm_cred_name = _lookup_llm_credential_info(
tenant_id=node_data.get("tenant_id"),
provider=node_data.get("model_provider"),
model=node_data.get("model_name"),
model_type="llm",
)
if llm_cred_id:
credential_id = llm_cred_id
credential_name = llm_cred_name
else:
credential_name = ""
metadata: dict[str, Any] = {
"tenant_id": node_data.get("tenant_id"),
"app_id": node_data.get("app_id"),
"app_name": app_name,
"workspace_name": workspace_name,
"user_id": node_data.get("user_id"),
"invoke_from": node_data.get("invoke_from"),
"credential_id": credential_id,
"credential_name": credential_name,
"dataset_ids": node_data.get("dataset_ids"),
"dataset_names": node_data.get("dataset_names"),
"plugin_name": node_data.get("plugin_name"),
}
parent_trace_context = node_data.get("parent_trace_context")
if parent_trace_context:
metadata["parent_trace_context"] = parent_trace_context
message_id: str | None = None
conversation_id = node_data.get("conversation_id")
workflow_execution_id = node_data.get("workflow_execution_id")
if conversation_id and workflow_execution_id and not parent_trace_context:
with Session(db.engine) as session:
msg_id = session.scalar(
select(Message.id).where(
Message.conversation_id == conversation_id,
Message.workflow_run_id == workflow_execution_id,
)
)
if msg_id:
message_id = str(msg_id)
metadata["message_id"] = message_id
if conversation_id:
metadata["conversation_id"] = conversation_id
return WorkflowNodeTraceInfo(
trace_id=self.trace_id,
message_id=message_id,
start_time=node_data.get("created_at"),
end_time=node_data.get("finished_at"),
metadata=metadata,
workflow_id=node_data.get("workflow_id", ""),
workflow_run_id=node_data.get("workflow_execution_id", ""),
tenant_id=node_data.get("tenant_id", ""),
node_execution_id=node_data.get("node_execution_id", ""),
node_id=node_data.get("node_id", ""),
node_type=node_data.get("node_type", ""),
title=node_data.get("title", ""),
status=node_data.get("status", ""),
error=node_data.get("error"),
elapsed_time=node_data.get("elapsed_time", 0.0),
index=node_data.get("index", 0),
predecessor_node_id=node_data.get("predecessor_node_id"),
total_tokens=node_data.get("total_tokens", 0),
total_price=node_data.get("total_price", 0.0),
currency=node_data.get("currency"),
model_provider=node_data.get("model_provider"),
model_name=node_data.get("model_name"),
prompt_tokens=node_data.get("prompt_tokens"),
completion_tokens=node_data.get("completion_tokens"),
tool_name=node_data.get("tool_name"),
iteration_id=node_data.get("iteration_id"),
iteration_index=node_data.get("iteration_index"),
loop_id=node_data.get("loop_id"),
loop_index=node_data.get("loop_index"),
parallel_id=node_data.get("parallel_id"),
node_inputs=node_data.get("node_inputs"),
node_outputs=node_data.get("node_outputs"),
process_data=node_data.get("process_data"),
invoked_by=self._get_user_id_from_metadata(metadata),
)
def draft_node_execution_trace(self, **kwargs) -> DraftNodeExecutionTrace | dict:
node_trace = self.node_execution_trace(**kwargs)
if not isinstance(node_trace, WorkflowNodeTraceInfo):
return node_trace
return DraftNodeExecutionTrace(**node_trace.model_dump())
def _extract_streaming_metrics(self, message_data) -> dict:
if not message_data.message_metadata:
return {}
@@ -937,13 +1448,17 @@ class TraceQueueManager:
self.user_id = user_id
self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
self.flask_app = current_app._get_current_object() # type: ignore
from core.telemetry.gateway import is_enterprise_telemetry_enabled
self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled()
if trace_manager_timer is None:
self.start_timer()
def add_trace_task(self, trace_task: TraceTask):
global trace_manager_timer, trace_manager_queue
try:
if self.trace_instance:
if self._enterprise_telemetry_enabled or self.trace_instance:
trace_task.app_id = self.app_id
trace_manager_queue.put(trace_task)
except Exception:
@@ -979,20 +1494,27 @@ class TraceQueueManager:
def send_to_celery(self, tasks: list[TraceTask]):
with self.flask_app.app_context():
for task in tasks:
if task.app_id is None:
continue
storage_id = task.app_id
if storage_id is None:
tenant_id = task.kwargs.get("tenant_id")
if tenant_id:
storage_id = f"tenant-{tenant_id}"
else:
logger.warning("Skipping trace without app_id or tenant_id, trace_type: %s", task.trace_type)
continue
file_id = uuid4().hex
trace_info = task.execute()
task_data = TaskData(
app_id=task.app_id,
app_id=storage_id,
trace_info_type=type(trace_info).__name__,
trace_info=trace_info.model_dump() if trace_info else None,
)
file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
file_path = f"{OPS_FILE_PATH}{storage_id}/{file_id}.json"
storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
file_info = {
"file_id": file_id,
"app_id": task.app_id,
"app_id": storage_id,
}
process_trace_tasks.delay(file_info) # type: ignore

View File

@@ -0,0 +1,43 @@
"""Telemetry facade.
Thin public API for emitting telemetry events. All routing logic
lives in ``core.telemetry.gateway`` which is shared by both CE and EE.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from core.ops.entities.trace_entity import TraceTaskName
from core.telemetry.events import TelemetryContext, TelemetryEvent
from core.telemetry.gateway import emit as gateway_emit
from core.telemetry.gateway import get_trace_task_to_case
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
def emit(event: TelemetryEvent, trace_manager: TraceQueueManager | None = None) -> None:
"""Emit a telemetry event.
Translates the ``TelemetryEvent`` (keyed by ``TraceTaskName``) into a
``TelemetryCase`` and delegates to ``core.telemetry.gateway.emit()``.
"""
case = get_trace_task_to_case().get(event.name)
if case is None:
return
context: dict[str, object] = {
"tenant_id": event.context.tenant_id,
"user_id": event.context.user_id,
"app_id": event.context.app_id,
}
gateway_emit(case, context, event.payload, trace_manager)
__all__ = [
"TelemetryContext",
"TelemetryEvent",
"TraceTaskName",
"emit",
]

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from core.ops.entities.trace_entity import TraceTaskName
@dataclass(frozen=True)
class TelemetryContext:
tenant_id: str | None = None
user_id: str | None = None
app_id: str | None = None
@dataclass(frozen=True)
class TelemetryEvent:
name: TraceTaskName
context: TelemetryContext
payload: dict[str, Any]

View File

@@ -0,0 +1,239 @@
"""Telemetry gateway — single routing layer for all editions.
Maps ``TelemetryCase`` → ``CaseRoute`` and dispatches events to either
the CE/EE trace pipeline (``TraceQueueManager``) or the enterprise-only
metric/log Celery queue.
This module lives in ``core/`` so both CE and EE share one routing table
and one ``emit()`` entry point. No separate enterprise gateway module is
needed — enterprise-specific dispatch (Celery task, payload offloading)
is handled here behind lazy imports that no-op in CE.
"""
from __future__ import annotations
import json
import logging
import uuid
from typing import TYPE_CHECKING, Any
from core.ops.entities.trace_entity import TraceTaskName
from enterprise.telemetry.contracts import CaseRoute, SignalType
from extensions.ext_storage import storage
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
from enterprise.telemetry.contracts import TelemetryCase
logger = logging.getLogger(__name__)
PAYLOAD_SIZE_THRESHOLD_BYTES = 1 * 1024 * 1024
# ---------------------------------------------------------------------------
# Routing table — authoritative mapping for all editions
# ---------------------------------------------------------------------------
_case_to_trace_task: dict[TelemetryCase, TraceTaskName] | None = None
_case_routing: dict[TelemetryCase, CaseRoute] | None = None
def _get_case_to_trace_task() -> dict[TelemetryCase, TraceTaskName]:
global _case_to_trace_task
if _case_to_trace_task is None:
from enterprise.telemetry.contracts import TelemetryCase
_case_to_trace_task = {
TelemetryCase.WORKFLOW_RUN: TraceTaskName.WORKFLOW_TRACE,
TelemetryCase.MESSAGE_RUN: TraceTaskName.MESSAGE_TRACE,
TelemetryCase.NODE_EXECUTION: TraceTaskName.NODE_EXECUTION_TRACE,
TelemetryCase.DRAFT_NODE_EXECUTION: TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
TelemetryCase.PROMPT_GENERATION: TraceTaskName.PROMPT_GENERATION_TRACE,
TelemetryCase.TOOL_EXECUTION: TraceTaskName.TOOL_TRACE,
TelemetryCase.MODERATION_CHECK: TraceTaskName.MODERATION_TRACE,
TelemetryCase.SUGGESTED_QUESTION: TraceTaskName.SUGGESTED_QUESTION_TRACE,
TelemetryCase.DATASET_RETRIEVAL: TraceTaskName.DATASET_RETRIEVAL_TRACE,
TelemetryCase.GENERATE_NAME: TraceTaskName.GENERATE_NAME_TRACE,
}
return _case_to_trace_task
def get_trace_task_to_case() -> dict[TraceTaskName, TelemetryCase]:
"""Return TraceTaskName → TelemetryCase (inverse of _get_case_to_trace_task)."""
return {v: k for k, v in _get_case_to_trace_task().items()}
def _get_case_routing() -> dict[TelemetryCase, CaseRoute]:
global _case_routing
if _case_routing is None:
from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase
_case_routing = {
# TRACE — CE-eligible (flow in both CE and EE)
TelemetryCase.WORKFLOW_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.MESSAGE_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.TOOL_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.MODERATION_CHECK: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.SUGGESTED_QUESTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.DATASET_RETRIEVAL: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.GENERATE_NAME: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
# TRACE — enterprise-only
TelemetryCase.NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
TelemetryCase.DRAFT_NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
TelemetryCase.PROMPT_GENERATION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
# METRIC_LOG — enterprise-only (signal-driven, not trace)
TelemetryCase.APP_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.APP_UPDATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.APP_DELETED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.FEEDBACK_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
}
return _case_routing
def __getattr__(name: str) -> dict:
"""Lazy module-level access to routing tables."""
if name == "CASE_ROUTING":
return _get_case_routing()
if name == "CASE_TO_TRACE_TASK":
return _get_case_to_trace_task()
raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def is_enterprise_telemetry_enabled() -> bool:
try:
from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled
return is_enterprise_telemetry_enabled()
except Exception:
return False
def _handle_payload_sizing(
payload: dict[str, Any],
tenant_id: str,
event_id: str,
) -> tuple[dict[str, Any], str | None]:
"""Inline or offload payload based on size.
Returns ``(payload_for_envelope, storage_key | None)``. Payloads
exceeding ``PAYLOAD_SIZE_THRESHOLD_BYTES`` are written to object
storage and replaced with an empty dict in the envelope.
"""
try:
payload_json = json.dumps(payload)
payload_size = len(payload_json.encode("utf-8"))
except (TypeError, ValueError):
logger.warning("Failed to serialize payload for sizing: event_id=%s", event_id)
return payload, None
if payload_size <= PAYLOAD_SIZE_THRESHOLD_BYTES:
return payload, None
storage_key = f"telemetry/{tenant_id}/{event_id}.json"
try:
storage.save(storage_key, payload_json.encode("utf-8"))
logger.debug("Stored large payload to storage: key=%s, size=%d", storage_key, payload_size)
return {}, storage_key
except Exception:
logger.warning("Failed to store large payload, inlining instead: event_id=%s", event_id, exc_info=True)
return payload, None
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def emit(
case: TelemetryCase,
context: dict[str, Any],
payload: dict[str, Any],
trace_manager: TraceQueueManager | None = None,
) -> None:
"""Route a telemetry event to the correct pipeline.
TRACE events are enqueued into ``TraceQueueManager`` (works in both CE
and EE). Enterprise-only traces are silently dropped when EE is
disabled.
METRIC_LOG events are dispatched to the enterprise Celery queue;
silently dropped when enterprise telemetry is unavailable.
"""
route = _get_case_routing().get(case)
if route is None:
logger.warning("Unknown telemetry case: %s, dropping event", case)
return
if not route.ce_eligible and not is_enterprise_telemetry_enabled():
logger.debug("Dropping EE-only event: case=%s (EE disabled)", case)
return
if route.signal_type == SignalType.TRACE:
_emit_trace(case, context, payload, trace_manager)
else:
_emit_metric_log(case, context, payload)
def _emit_trace(
case: TelemetryCase,
context: dict[str, Any],
payload: dict[str, Any],
trace_manager: TraceQueueManager | None,
) -> None:
from core.ops.ops_trace_manager import TraceQueueManager as LocalTraceQueueManager
from core.ops.ops_trace_manager import TraceTask
trace_task_name = _get_case_to_trace_task().get(case)
if trace_task_name is None:
logger.warning("No TraceTaskName mapping for case: %s", case)
return
queue_manager = trace_manager or LocalTraceQueueManager(
app_id=context.get("app_id"),
user_id=context.get("user_id"),
)
queue_manager.add_trace_task(TraceTask(trace_task_name, user_id=context.get("user_id"), **payload))
logger.debug("Enqueued trace task: case=%s, app_id=%s", case, context.get("app_id"))
def _emit_metric_log(
case: TelemetryCase,
context: dict[str, Any],
payload: dict[str, Any],
) -> None:
"""Build envelope and dispatch to enterprise Celery queue.
No-ops when the enterprise telemetry task is not importable (CE mode).
"""
try:
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
except ImportError:
logger.debug("Enterprise metric/log dispatch unavailable, dropping: case=%s", case)
return
tenant_id = context.get("tenant_id") or ""
event_id = str(uuid.uuid4())
payload_for_envelope, payload_ref = _handle_payload_sizing(payload, tenant_id, event_id)
from enterprise.telemetry.contracts import TelemetryEnvelope
envelope = TelemetryEnvelope(
case=case,
tenant_id=tenant_id,
event_id=event_id,
payload=payload_for_envelope,
metadata={"payload_ref": payload_ref} if payload_ref else None,
)
process_enterprise_telemetry.delay(envelope.model_dump_json())
logger.debug(
"Enqueued metric/log event: case=%s, tenant_id=%s, event_id=%s",
case,
tenant_id,
event_id,
)

View File

View File

@@ -0,0 +1,525 @@
# Dify Enterprise Telemetry Data Dictionary
Quick reference for all telemetry signals emitted by Dify Enterprise. For configuration and architecture details, see [README.md](./README.md).
## Resource Attributes
Attached to every signal (Span, Metric, Log).
| Attribute | Type | Example |
|-----------|------|---------|
| `service.name` | string | `dify` |
| `host.name` | string | `dify-api-7f8b` |
## Traces (Spans)
### `dify.workflow.run`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.trace_id` | string | Business trace ID (Workflow Run ID) |
| `dify.tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.workflow.id` | string | Workflow definition ID |
| `dify.workflow.run_id` | string | Unique ID for this run |
| `dify.workflow.status` | string | `succeeded`, `failed`, `stopped`, etc. |
| `dify.workflow.error` | string | Error message if failed |
| `dify.workflow.elapsed_time` | float | Total execution time (seconds) |
| `dify.invoke_from` | string | `api`, `webapp`, `debug` |
| `dify.conversation.id` | string | Conversation ID (optional) |
| `dify.message.id` | string | Message ID (optional) |
| `dify.invoked_by` | string | User ID who triggered the run |
| `gen_ai.usage.total_tokens` | int | Total tokens across all nodes (optional) |
| `gen_ai.user.id` | string | End-user identifier (optional) |
| `dify.parent.trace_id` | string | Parent workflow trace ID (optional) |
| `dify.parent.workflow.run_id` | string | Parent workflow run ID (optional) |
| `dify.parent.node.execution_id` | string | Parent node execution ID (optional) |
| `dify.parent.app.id` | string | Parent app ID (optional) |
### `dify.node.execution`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.trace_id` | string | Business trace ID |
| `dify.tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.workflow.id` | string | Workflow definition ID |
| `dify.workflow.run_id` | string | Workflow Run ID |
| `dify.message.id` | string | Message ID (optional) |
| `dify.conversation.id` | string | Conversation ID (optional) |
| `dify.node.execution_id` | string | Unique node execution ID |
| `dify.node.id` | string | Node ID in workflow graph |
| `dify.node.type` | string | Node type (see appendix) |
| `dify.node.title` | string | Display title |
| `dify.node.status` | string | `succeeded`, `failed` |
| `dify.node.error` | string | Error message if failed |
| `dify.node.elapsed_time` | float | Execution time (seconds) |
| `dify.node.index` | int | Execution order index |
| `dify.node.predecessor_node_id` | string | Triggering node ID |
| `dify.node.iteration_id` | string | Iteration ID (optional) |
| `dify.node.loop_id` | string | Loop ID (optional) |
| `dify.node.parallel_id` | string | Parallel branch ID (optional) |
| `dify.node.invoked_by` | string | User ID who triggered execution |
| `gen_ai.usage.input_tokens` | int | Prompt tokens (LLM nodes only) |
| `gen_ai.usage.output_tokens` | int | Completion tokens (LLM nodes only) |
| `gen_ai.usage.total_tokens` | int | Total tokens (LLM nodes only) |
| `gen_ai.request.model` | string | LLM model name (LLM nodes only) |
| `gen_ai.provider.name` | string | LLM provider name (LLM nodes only) |
| `gen_ai.user.id` | string | End-user identifier (optional) |
### `dify.node.execution.draft`
Same attributes as `dify.node.execution`. Emitted during Preview/Debug runs.
## Counters
All counters are cumulative and emitted at 100% accuracy.
### Token Counters
| Metric | Unit | Description |
|--------|------|-------------|
| `dify.tokens.total` | `{token}` | Total tokens consumed |
| `dify.tokens.input` | `{token}` | Input (prompt) tokens |
| `dify.tokens.output` | `{token}` | Output (completion) tokens |
**Labels:**
- `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type` (if node_execution)
⚠️ **Warning:** `dify.tokens.total` at workflow level includes all node tokens. Filter by `operation_type` to avoid double-counting.
#### Token Hierarchy & Query Patterns
Token metrics are emitted at multiple layers. Understanding the hierarchy prevents double-counting:
```
App-level total
├── workflow ← sum of all node_execution tokens (DO NOT add both)
│ └── node_execution ← per-node breakdown
├── message ← independent (non-workflow chat apps only)
├── rule_generate ← independent helper LLM call
├── code_generate ← independent helper LLM call
├── structured_output ← independent helper LLM call
└── instruction_modify← independent helper LLM call
```
**Key rule:** `workflow` tokens already include all `node_execution` tokens. Never sum both.
**Available labels on token metrics:** `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type`.
App name is only available on span attributes (`dify.app.name`), not metric labels — use `app_id` for metric queries.
**Common queries** (PromQL):
```promql
# ── Totals ──────────────────────────────────────────────────
# App-level total (exclude node_execution to avoid double-counting)
sum by (app_id) (dify_tokens_total{operation_type!="node_execution"})
# Single app total
sum (dify_tokens_total{app_id="<app_id>", operation_type!="node_execution"})
# Per-tenant totals
sum by (tenant_id) (dify_tokens_total{operation_type!="node_execution"})
# ── Drill-down ──────────────────────────────────────────────
# Workflow-level tokens for an app
sum (dify_tokens_total{app_id="<app_id>", operation_type="workflow"})
# Node-level breakdown within an app
sum by (node_type) (dify_tokens_total{app_id="<app_id>", operation_type="node_execution"})
# Model breakdown for an app
sum by (model_provider, model_name) (dify_tokens_total{app_id="<app_id>"})
# Input vs output per model
sum by (model_name) (dify_tokens_input_total{app_id="<app_id>"})
sum by (model_name) (dify_tokens_output_total{app_id="<app_id>"})
# ── Rates ───────────────────────────────────────────────────
# Token consumption rate (per hour)
sum(rate(dify_tokens_total{operation_type!="node_execution"}[1h]))
# Per-app consumption rate
sum by (app_id) (rate(dify_tokens_total{operation_type!="node_execution"}[1h]))
```
**Finding `app_id` from app name** (trace query — Tempo / Jaeger):
```
{ resource.dify.app.name = "My Chatbot" } | select(resource.dify.app.id)
```
### Request Counters
| Metric | Unit | Description |
|--------|------|-------------|
| `dify.requests.total` | `{request}` | Total operations count |
**Labels by type:**
| `type` | Additional Labels |
|--------|-------------------|
| `workflow` | `tenant_id`, `app_id`, `status`, `invoke_from` |
| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` |
| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` |
| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name`, `status`, `invoke_from` |
| `tool` | `tenant_id`, `app_id`, `tool_name` |
| `moderation` | `tenant_id`, `app_id` |
| `suggested_question` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
| `dataset_retrieval` | `tenant_id`, `app_id` |
| `generate_name` | `tenant_id`, `app_id` |
| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `status` |
### Error Counters
| Metric | Unit | Description |
|--------|------|-------------|
| `dify.errors.total` | `{error}` | Total failed operations |
**Labels by type:**
| `type` | Additional Labels |
|--------|-------------------|
| `workflow` | `tenant_id`, `app_id` |
| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` |
| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` |
| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
| `tool` | `tenant_id`, `app_id`, `tool_name` |
| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` |
### Other Counters
| Metric | Unit | Labels |
|--------|------|--------|
| `dify.feedback.total` | `{feedback}` | `tenant_id`, `app_id`, `rating` |
| `dify.dataset.retrievals.total` | `{retrieval}` | `tenant_id`, `app_id`, `dataset_id`, `embedding_model_provider`, `embedding_model`, `rerank_model_provider`, `rerank_model` |
| `dify.app.created.total` | `{app}` | `tenant_id`, `app_id`, `mode` |
| `dify.app.updated.total` | `{app}` | `tenant_id`, `app_id` |
| `dify.app.deleted.total` | `{app}` | `tenant_id`, `app_id` |
## Histograms
| Metric | Unit | Labels |
|--------|------|--------|
| `dify.workflow.duration` | `s` | `tenant_id`, `app_id`, `status` |
| `dify.node.duration` | `s` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `plugin_name` |
| `dify.message.duration` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
| `dify.message.time_to_first_token` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
| `dify.tool.duration` | `s` | `tenant_id`, `app_id`, `tool_name` |
| `dify.prompt_generation.duration` | `s` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` |
## Structured Logs
### Span Companion Logs
Logs that accompany spans. Signal type: `span_detail`
#### `dify.workflow.run` Companion Log
**Common attributes:** All span attributes (see Traces section) plus:
| Additional Attribute | Type | Always Present | Description |
|---------------------|------|----------------|-------------|
| `dify.app.name` | string | No | Application display name |
| `dify.workspace.name` | string | No | Workspace display name |
| `dify.workflow.version` | string | Yes | Workflow definition version |
| `dify.workflow.inputs` | string/JSON | Yes | Input parameters (content-gated) |
| `dify.workflow.outputs` | string/JSON | Yes | Output results (content-gated) |
| `dify.workflow.query` | string | No | User query text (content-gated) |
**Event attributes:**
- `dify.event.name`: `"dify.workflow.run"`
- `dify.event.signal`: `"span_detail"`
- `trace_id`, `span_id`, `tenant_id`, `user_id`
#### `dify.node.execution` and `dify.node.execution.draft` Companion Logs
**Common attributes:** All span attributes (see Traces section) plus:
| Additional Attribute | Type | Always Present | Description |
|---------------------|------|----------------|-------------|
| `dify.app.name` | string | No | Application display name |
| `dify.workspace.name` | string | No | Workspace display name |
| `dify.invoke_from` | string | No | Invocation source |
| `gen_ai.tool.name` | string | No | Tool name (tool nodes only) |
| `dify.node.total_price` | float | No | Cost (LLM nodes only) |
| `dify.node.currency` | string | No | Currency code (LLM nodes only) |
| `dify.node.iteration_index` | int | No | Iteration index (iteration nodes) |
| `dify.node.loop_index` | int | No | Loop index (loop nodes) |
| `dify.plugin.name` | string | No | Plugin name (tool/knowledge nodes) |
| `dify.credential.name` | string | No | Credential name (plugin nodes) |
| `dify.credential.id` | string | No | Credential ID (plugin nodes) |
| `dify.dataset.ids` | JSON array | No | Dataset IDs (knowledge nodes) |
| `dify.dataset.names` | JSON array | No | Dataset names (knowledge nodes) |
| `dify.node.inputs` | string/JSON | Yes | Node inputs (content-gated) |
| `dify.node.outputs` | string/JSON | Yes | Node outputs (content-gated) |
| `dify.node.process_data` | string/JSON | No | Processing data (content-gated) |
**Event attributes:**
- `dify.event.name`: `"dify.node.execution"` or `"dify.node.execution.draft"`
- `dify.event.signal`: `"span_detail"`
- `trace_id`, `span_id`, `tenant_id`, `user_id`
### Standalone Logs
Logs without structural spans. Signal type: `metric_only`
#### `dify.message.run`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.message.run"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID (32-char hex) |
| `span_id` | string | OTEL span ID (16-char hex) |
| `tenant_id` | string | Tenant identifier |
| `user_id` | string | User identifier (optional) |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.conversation.id` | string | Conversation ID (optional) |
| `dify.workflow.run_id` | string | Workflow run ID (optional) |
| `dify.invoke_from` | string | `service-api`, `web-app`, `debugger`, `explore` |
| `gen_ai.provider.name` | string | LLM provider |
| `gen_ai.request.model` | string | LLM model |
| `gen_ai.usage.input_tokens` | int | Input tokens |
| `gen_ai.usage.output_tokens` | int | Output tokens |
| `gen_ai.usage.total_tokens` | int | Total tokens |
| `dify.message.status` | string | `succeeded`, `failed` |
| `dify.message.error` | string | Error message (if failed) |
| `dify.message.duration` | float | Duration (seconds) |
| `dify.message.time_to_first_token` | float | TTFT (seconds) |
| `dify.message.inputs` | string/JSON | Inputs (content-gated) |
| `dify.message.outputs` | string/JSON | Outputs (content-gated) |
#### `dify.tool.execution`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.tool.execution"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.tool.name` | string | Tool name |
| `dify.tool.duration` | float | Duration (seconds) |
| `dify.tool.status` | string | `succeeded`, `failed` |
| `dify.tool.error` | string | Error message (if failed) |
| `dify.tool.inputs` | string/JSON | Inputs (content-gated) |
| `dify.tool.outputs` | string/JSON | Outputs (content-gated) |
| `dify.tool.parameters` | string/JSON | Parameters (content-gated) |
| `dify.tool.config` | string/JSON | Configuration (content-gated) |
#### `dify.moderation.check`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.moderation.check"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.moderation.type` | string | `input`, `output` |
| `dify.moderation.action` | string | `pass`, `block`, `flag` |
| `dify.moderation.flagged` | boolean | Whether flagged |
| `dify.moderation.categories` | JSON array | Flagged categories |
| `dify.moderation.query` | string | Content (content-gated) |
#### `dify.suggested_question.generation`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.suggested_question.generation"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.suggested_question.count` | int | Number of questions |
| `dify.suggested_question.duration` | float | Duration (seconds) |
| `dify.suggested_question.status` | string | `succeeded`, `failed` |
| `dify.suggested_question.error` | string | Error message (if failed) |
| `dify.suggested_question.questions` | JSON array | Questions (content-gated) |
#### `dify.dataset.retrieval`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.dataset.retrieval"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.dataset.id` | string | Dataset identifier |
| `dify.dataset.name` | string | Dataset name |
| `dify.dataset.embedding_providers` | JSON array | Embedding model providers (one per dataset) |
| `dify.dataset.embedding_models` | JSON array | Embedding models (one per dataset) |
| `dify.retrieval.rerank_provider` | string | Rerank model provider |
| `dify.retrieval.rerank_model` | string | Rerank model name |
| `dify.retrieval.query` | string | Search query (content-gated) |
| `dify.retrieval.document_count` | int | Documents retrieved |
| `dify.retrieval.duration` | float | Duration (seconds) |
| `dify.retrieval.status` | string | `succeeded`, `failed` |
| `dify.retrieval.error` | string | Error message (if failed) |
| `dify.dataset.documents` | JSON array | Documents (content-gated) |
#### `dify.generate_name.execution`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.generate_name.execution"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.conversation.id` | string | Conversation identifier |
| `dify.generate_name.duration` | float | Duration (seconds) |
| `dify.generate_name.status` | string | `succeeded`, `failed` |
| `dify.generate_name.error` | string | Error message (if failed) |
| `dify.generate_name.inputs` | string/JSON | Inputs (content-gated) |
| `dify.generate_name.outputs` | string | Generated name (content-gated) |
#### `dify.prompt_generation.execution`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.prompt_generation.execution"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.prompt_generation.operation_type` | string | Operation type (see appendix) |
| `gen_ai.provider.name` | string | LLM provider |
| `gen_ai.request.model` | string | LLM model |
| `gen_ai.usage.input_tokens` | int | Input tokens |
| `gen_ai.usage.output_tokens` | int | Output tokens |
| `gen_ai.usage.total_tokens` | int | Total tokens |
| `dify.prompt_generation.duration` | float | Duration (seconds) |
| `dify.prompt_generation.status` | string | `succeeded`, `failed` |
| `dify.prompt_generation.error` | string | Error message (if failed) |
| `dify.prompt_generation.instruction` | string | Instruction (content-gated) |
| `dify.prompt_generation.output` | string/JSON | Output (content-gated) |
#### `dify.app.created`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.app.created"` |
| `dify.event.signal` | string | `"metric_only"` |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.app.mode` | string | `chat`, `completion`, `agent-chat`, `workflow` |
| `dify.app.created_at` | string | Timestamp (ISO 8601) |
#### `dify.app.updated`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.app.updated"` |
| `dify.event.signal` | string | `"metric_only"` |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.app.updated_at` | string | Timestamp (ISO 8601) |
#### `dify.app.deleted`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.app.deleted"` |
| `dify.event.signal` | string | `"metric_only"` |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.app.deleted_at` | string | Timestamp (ISO 8601) |
#### `dify.feedback.created`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.feedback.created"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.feedback.rating` | string | `like`, `dislike`, `null` |
| `dify.feedback.content` | string | Feedback text (content-gated) |
| `dify.feedback.created_at` | string | Timestamp (ISO 8601) |
#### `dify.telemetry.rehydration_failed`
Diagnostic event for telemetry system health monitoring.
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.telemetry.rehydration_failed"` |
| `dify.event.signal` | string | `"metric_only"` |
| `tenant_id` | string | Tenant identifier |
| `dify.telemetry.error` | string | Error message |
| `dify.telemetry.payload_type` | string | Payload type (see appendix) |
| `dify.telemetry.correlation_id` | string | Correlation ID |
## Content-Gated Attributes
When `ENTERPRISE_INCLUDE_CONTENT=false`, these attributes are replaced with reference strings (`ref:{id_type}={uuid}`).
| Attribute | Signal |
|-----------|--------|
| `dify.workflow.inputs` | `dify.workflow.run` |
| `dify.workflow.outputs` | `dify.workflow.run` |
| `dify.workflow.query` | `dify.workflow.run` |
| `dify.node.inputs` | `dify.node.execution` |
| `dify.node.outputs` | `dify.node.execution` |
| `dify.node.process_data` | `dify.node.execution` |
| `dify.message.inputs` | `dify.message.run` |
| `dify.message.outputs` | `dify.message.run` |
| `dify.tool.inputs` | `dify.tool.execution` |
| `dify.tool.outputs` | `dify.tool.execution` |
| `dify.tool.parameters` | `dify.tool.execution` |
| `dify.tool.config` | `dify.tool.execution` |
| `dify.moderation.query` | `dify.moderation.check` |
| `dify.suggested_question.questions` | `dify.suggested_question.generation` |
| `dify.retrieval.query` | `dify.dataset.retrieval` |
| `dify.dataset.documents` | `dify.dataset.retrieval` |
| `dify.generate_name.inputs` | `dify.generate_name.execution` |
| `dify.generate_name.outputs` | `dify.generate_name.execution` |
| `dify.prompt_generation.instruction` | `dify.prompt_generation.execution` |
| `dify.prompt_generation.output` | `dify.prompt_generation.execution` |
| `dify.feedback.content` | `dify.feedback.created` |
## Appendix
### Operation Types
- `workflow`, `node_execution`, `message`, `rule_generate`, `code_generate`, `structured_output`, `instruction_modify`
### Node Types
- `start`, `end`, `answer`, `llm`, `knowledge-retrieval`, `knowledge-index`, `if-else`, `code`, `template-transform`, `question-classifier`, `http-request`, `tool`, `datasource`, `variable-aggregator`, `loop`, `iteration`, `parameter-extractor`, `assigner`, `document-extractor`, `list-operator`, `agent`, `trigger-webhook`, `trigger-schedule`, `trigger-plugin`, `human-input`
### Workflow Statuses
- `running`, `succeeded`, `failed`, `stopped`, `partial-succeeded`, `paused`
### Payload Types
- `workflow`, `node`, `message`, `tool`, `moderation`, `suggested_question`, `dataset_retrieval`, `generate_name`, `prompt_generation`, `app`, `feedback`
### Null Value Behavior
**Spans:** Attributes with `null` values are omitted.
**Logs:** Attributes with `null` values appear as `null` in JSON.
**Content-Gated:** Replaced with reference strings, not set to `null`.

View File

@@ -0,0 +1,121 @@
# Dify Enterprise Telemetry
This document provides an overview of the Dify Enterprise OpenTelemetry (OTEL) exporter and how to configure it for integration with observability stacks like Prometheus, Grafana, Jaeger, or Honeycomb.
## Overview
Dify Enterprise uses a "slim span + rich companion log" architecture to provide high-fidelity observability without overwhelming trace storage.
- **Traces (Spans)**: Capture the structure, identity, and timing of high-level operations (Workflows and Nodes).
- **Structured Logs**: Provide deep context (inputs, outputs, metadata) for every event, correlated to spans via `trace_id` and `span_id`.
- **Metrics**: Provide 100% accurate counters and histograms for usage, performance, and error tracking.
### Signal Architecture
```mermaid
graph TD
A[Workflow Run] -->|Span| B(dify.workflow.run)
A -->|Log| C(dify.workflow.run detail)
B ---|trace_id| C
D[Node Execution] -->|Span| E(dify.node.execution)
D -->|Log| F(dify.node.execution detail)
E ---|span_id| F
G[Message/Tool/etc] -->|Log| H(dify.* event)
G -->|Metric| I(dify.* counter/histogram)
```
## Configuration
The Enterprise OTEL exporter is configured via environment variables.
| Variable | Description | Default |
|----------|-------------|---------|
| `ENTERPRISE_ENABLED` | Master switch for all enterprise features. | `false` |
| `ENTERPRISE_TELEMETRY_ENABLED` | Master switch for enterprise telemetry. | `false` |
| `ENTERPRISE_OTLP_ENDPOINT` | OTLP collector endpoint (e.g., `http://otel-collector:4318`). | - |
| `ENTERPRISE_OTLP_HEADERS` | Custom headers for OTLP requests (e.g., `x-scope-orgid=tenant1`). | - |
| `ENTERPRISE_OTLP_PROTOCOL` | OTLP transport protocol (`http` or `grpc`). | `http` |
| `ENTERPRISE_OTLP_API_KEY` | Bearer token for authentication. | - |
| `ENTERPRISE_INCLUDE_CONTENT` | Whether to include sensitive content (inputs/outputs) in logs. | `false` |
| `ENTERPRISE_SERVICE_NAME` | Service name reported to OTEL. | `dify` |
| `ENTERPRISE_OTEL_SAMPLING_RATE` | Sampling rate for traces (0.0 to 1.0). Metrics are always 100%. | `1.0` |
## Correlation Model
Dify uses deterministic ID generation to ensure signals are correlated across different services and asynchronous tasks.
### ID Generation Rules
- `trace_id`: Derived from the correlation ID (workflow_run_id or node_execution_id for drafts) using `int(UUID(correlation_id))`
- `span_id`: Derived from the source ID using the lower 64 bits of `UUID(source_id)`
### Scenario A: Simple Workflow
A single workflow run with multiple nodes. All spans and logs share the same `trace_id` (derived from `workflow_run_id`).
```
trace_id = UUID(workflow_run_id)
├── [root span] dify.workflow.run (span_id = hash(workflow_run_id))
│ ├── [child] dify.node.execution - "Start" (span_id = hash(node_exec_id_1))
│ ├── [child] dify.node.execution - "LLM" (span_id = hash(node_exec_id_2))
│ └── [child] dify.node.execution - "End" (span_id = hash(node_exec_id_3))
```
### Scenario B: Nested Sub-Workflow
A workflow calling another workflow via a Tool or Sub-workflow node. The child workflow's spans are linked to the parent via `parent_span_id`. Both workflows share the same trace_id.
```
trace_id = UUID(outer_workflow_run_id) ← shared across both workflows
├── [root] dify.workflow.run (outer) (span_id = hash(outer_workflow_run_id))
│ ├── dify.node.execution - "Start Node"
│ ├── dify.node.execution - "Tool Node" (triggers sub-workflow)
│ │ └── [child] dify.workflow.run (inner) (span_id = hash(inner_workflow_run_id))
│ │ ├── dify.node.execution - "Inner Start"
│ │ └── dify.node.execution - "Inner End"
│ └── dify.node.execution - "End Node"
```
**Key attributes for nested workflows:**
- Inner workflow's `dify.parent.trace_id` = outer `workflow_run_id`
- Inner workflow's `dify.parent.node.execution_id` = tool node's `execution_id`
- Inner workflow's `dify.parent.workflow.run_id` = outer `workflow_run_id`
- Inner workflow's `dify.parent.app.id` = outer `app_id`
### Scenario C: Draft Node Execution
A single node run in isolation (debugger/preview mode). It creates its own trace where the node span is the root.
```
trace_id = UUID(node_execution_id) ← own trace, NOT part of any workflow
└── dify.node.execution.draft (span_id = hash(node_execution_id))
```
**Key difference:** Draft executions use `node_execution_id` as the correlation_id, so they are NOT children of any workflow trace.
## Content Gating
When `ENTERPRISE_INCLUDE_CONTENT` is set to `false`, sensitive content attributes (inputs, outputs, queries) are replaced with reference strings (e.g., `ref:workflow_run_id=...`) to prevent data leakage to the OTEL collector.
**Reference String Format:**
```
ref:{id_type}={uuid}
```
**Examples:**
```
ref:workflow_run_id=550e8400-e29b-41d4-a716-446655440000
ref:node_execution_id=660e8400-e29b-41d4-a716-446655440001
ref:message_id=770e8400-e29b-41d4-a716-446655440002
```
To retrieve actual content when gating is enabled, query the Dify database using the provided UUID.
## Reference
For a complete list of telemetry signals, attributes, and data structures, see [DATA_DICTIONARY.md](./DATA_DICTIONARY.md).

View File

View File

@@ -0,0 +1,73 @@
"""Telemetry gateway contracts and data structures.
This module defines the envelope format for telemetry events and the routing
configuration that determines how each event type is processed.
"""
from __future__ import annotations
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, ConfigDict
class TelemetryCase(StrEnum):
"""Enumeration of all known telemetry event cases."""
WORKFLOW_RUN = "workflow_run"
NODE_EXECUTION = "node_execution"
DRAFT_NODE_EXECUTION = "draft_node_execution"
MESSAGE_RUN = "message_run"
TOOL_EXECUTION = "tool_execution"
MODERATION_CHECK = "moderation_check"
SUGGESTED_QUESTION = "suggested_question"
DATASET_RETRIEVAL = "dataset_retrieval"
GENERATE_NAME = "generate_name"
PROMPT_GENERATION = "prompt_generation"
APP_CREATED = "app_created"
APP_UPDATED = "app_updated"
APP_DELETED = "app_deleted"
FEEDBACK_CREATED = "feedback_created"
class SignalType(StrEnum):
"""Signal routing type for telemetry cases."""
TRACE = "trace"
METRIC_LOG = "metric_log"
class CaseRoute(BaseModel):
"""Routing configuration for a telemetry case.
Attributes:
signal_type: The type of signal (trace or metric_log).
ce_eligible: Whether this case is eligible for community edition tracing.
"""
signal_type: SignalType
ce_eligible: bool
class TelemetryEnvelope(BaseModel):
"""Envelope for telemetry events.
Attributes:
case: The telemetry case type.
tenant_id: The tenant identifier.
event_id: Unique event identifier for deduplication.
payload: The main event payload (inline for small payloads,
empty when offloaded to storage via ``payload_ref``).
metadata: Optional metadata dictionary. When the gateway
offloads a large payload to object storage, this contains
``{"payload_ref": "<storage_key>"}``.
"""
model_config = ConfigDict(extra="forbid", use_enum_values=False)
case: TelemetryCase
tenant_id: str
event_id: str
payload: dict[str, Any]
metadata: dict[str, Any] | None = None

View File

@@ -0,0 +1,89 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
from models.workflow import WorkflowNodeExecutionModel
def enqueue_draft_node_execution_trace(
*,
execution: WorkflowNodeExecutionModel,
outputs: Mapping[str, Any] | None,
workflow_execution_id: str | None,
user_id: str,
) -> None:
node_data = _build_node_execution_data(
execution=execution,
outputs=outputs,
workflow_execution_id=workflow_execution_id,
)
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id=execution.tenant_id,
user_id=user_id,
app_id=execution.app_id,
),
payload={"node_execution_data": node_data},
)
)
def _build_node_execution_data(
*,
execution: WorkflowNodeExecutionModel,
outputs: Mapping[str, Any] | None,
workflow_execution_id: str | None,
) -> dict[str, Any]:
metadata = execution.execution_metadata_dict
node_outputs = outputs if outputs is not None else execution.outputs_dict
execution_id = workflow_execution_id or execution.workflow_run_id or execution.id
process_data = execution.process_data_dict or {}
# Extract token breakdown from outputs.usage (set by LLM node)
usage: Mapping[str, Any] = {}
if isinstance(node_outputs, Mapping):
raw_usage = node_outputs.get("usage")
if isinstance(raw_usage, Mapping):
usage = raw_usage
return {
"workflow_id": execution.workflow_id,
"workflow_execution_id": execution_id,
"tenant_id": execution.tenant_id,
"app_id": execution.app_id,
"node_execution_id": execution.id,
"node_id": execution.node_id,
"node_type": execution.node_type,
"title": execution.title,
"status": execution.status,
"error": execution.error,
"elapsed_time": execution.elapsed_time,
"index": execution.index,
"predecessor_node_id": execution.predecessor_node_id,
"created_at": execution.created_at,
"finished_at": execution.finished_at,
"total_tokens": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0),
"total_price": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0),
"currency": metadata.get(WorkflowNodeExecutionMetadataKey.CURRENCY),
"model_provider": process_data.get("model_provider"),
"model_name": process_data.get("model_name"),
"prompt_tokens": usage.get("prompt_tokens"),
"completion_tokens": usage.get("completion_tokens"),
"tool_name": (metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name")
if isinstance(metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict)
else None,
"iteration_id": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID),
"iteration_index": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX),
"loop_id": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID),
"loop_index": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX),
"parallel_id": metadata.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID),
"node_inputs": execution.inputs_dict,
"node_outputs": node_outputs,
"process_data": execution.process_data_dict,
}

View File

@@ -0,0 +1,966 @@
"""Enterprise trace handler — duck-typed, NOT a BaseTraceInstance subclass.
Invoked directly in the Celery task, not through OpsTraceManager dispatch.
Only requires a matching ``trace(trace_info)`` method signature.
Signal strategy:
- **Traces (spans)**: workflow run, node execution, draft node execution only.
- **Metrics + structured logs**: all other event types.
Token metric labels (unified structure):
All token metrics (dify.tokens.input, dify.tokens.output, dify.tokens.total) use the
same label set for consistent filtering and aggregation:
- tenant_id: Tenant identifier
- app_id: Application identifier
- operation_type: Source of token usage (workflow | node_execution | message | rule_generate | etc.)
- model_provider: LLM provider name (empty string if not applicable)
- model_name: LLM model name (empty string if not applicable)
- node_type: Workflow node type (empty string if not node_execution)
This unified structure allows filtering by operation_type to separate:
- Workflow-level aggregates (operation_type=workflow)
- Individual node executions (operation_type=node_execution)
- Direct message calls (operation_type=message)
- Prompt generation operations (operation_type=rule_generate, code_generate, etc.)
Without this, tokens are double-counted when querying totals (workflow totals include
node totals, since workflow.total_tokens is the sum of all node tokens).
"""
from __future__ import annotations
import json
import logging
from typing import Any, cast
from opentelemetry.util.types import AttributeValue
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
DraftNodeExecutionTrace,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
OperationType,
PromptGenerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowNodeTraceInfo,
WorkflowTraceInfo,
)
from enterprise.telemetry.entities import (
EnterpriseTelemetryCounter,
EnterpriseTelemetryEvent,
EnterpriseTelemetryHistogram,
EnterpriseTelemetrySpan,
TokenMetricLabels,
)
from enterprise.telemetry.telemetry_log import emit_metric_only_event, emit_telemetry_log
logger = logging.getLogger(__name__)
class EnterpriseOtelTrace:
"""Duck-typed enterprise trace handler.
``*_trace`` methods emit spans (workflow/node only) or structured logs
(all other events), plus metrics at 100 % accuracy.
"""
def __init__(self) -> None:
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if exporter is None:
raise RuntimeError("EnterpriseOtelTrace instantiated but exporter is not initialized")
self._exporter = exporter
def trace(self, trace_info: BaseTraceInfo) -> None:
if isinstance(trace_info, WorkflowTraceInfo):
self._workflow_trace(trace_info)
elif isinstance(trace_info, MessageTraceInfo):
self._message_trace(trace_info)
elif isinstance(trace_info, ToolTraceInfo):
self._tool_trace(trace_info)
elif isinstance(trace_info, DraftNodeExecutionTrace):
self._draft_node_execution_trace(trace_info)
elif isinstance(trace_info, WorkflowNodeTraceInfo):
self._node_execution_trace(trace_info)
elif isinstance(trace_info, ModerationTraceInfo):
self._moderation_trace(trace_info)
elif isinstance(trace_info, SuggestedQuestionTraceInfo):
self._suggested_question_trace(trace_info)
elif isinstance(trace_info, DatasetRetrievalTraceInfo):
self._dataset_retrieval_trace(trace_info)
elif isinstance(trace_info, GenerateNameTraceInfo):
self._generate_name_trace(trace_info)
elif isinstance(trace_info, PromptGenerationTraceInfo):
self._prompt_generation_trace(trace_info)
else:
raise AssertionError("this statment should be unreachable")
def _common_attrs(self, trace_info: BaseTraceInfo) -> dict[str, Any]:
metadata = self._metadata(trace_info)
tenant_id, app_id, user_id = self._context_ids(trace_info, metadata)
return {
"dify.trace_id": trace_info.resolved_trace_id,
"dify.tenant_id": tenant_id,
"dify.app_id": app_id,
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"gen_ai.user.id": user_id,
"dify.message.id": trace_info.message_id,
}
def _metadata(self, trace_info: BaseTraceInfo) -> dict[str, Any]:
return trace_info.metadata
def _context_ids(
self,
trace_info: BaseTraceInfo,
metadata: dict[str, Any],
) -> tuple[str | None, str | None, str | None]:
tenant_id = getattr(trace_info, "tenant_id", None) or metadata.get("tenant_id")
app_id = getattr(trace_info, "app_id", None) or metadata.get("app_id")
user_id = getattr(trace_info, "user_id", None) or metadata.get("user_id")
return tenant_id, app_id, user_id
def _labels(self, **values: AttributeValue) -> dict[str, AttributeValue]:
return dict(values)
def _safe_payload_value(self, value: Any) -> str | dict[str, Any] | list[object] | None:
if isinstance(value, str):
return value
if isinstance(value, dict):
return cast(dict[str, Any], value)
if isinstance(value, list):
items: list[object] = []
for item in cast(list[object], value):
items.append(item)
return items
return None
def _content_or_ref(self, value: Any, ref: str) -> Any:
if self._exporter.include_content:
return self._maybe_json(value)
return ref
def _maybe_json(self, value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
try:
return json.dumps(value, default=str)
except (TypeError, ValueError):
return str(value)
# ------------------------------------------------------------------
# SPAN-emitting handlers (workflow, node execution, draft node)
# ------------------------------------------------------------------
def _workflow_trace(self, info: WorkflowTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
# -- Span attrs: identity + structure + status + timing + gen_ai scalars --
span_attrs: dict[str, Any] = {
"dify.trace_id": info.resolved_trace_id,
"dify.tenant_id": tenant_id,
"dify.app_id": app_id,
"dify.workflow.id": info.workflow_id,
"dify.workflow.run_id": info.workflow_run_id,
"dify.workflow.status": info.workflow_run_status,
"dify.workflow.error": info.error,
"dify.workflow.elapsed_time": info.workflow_run_elapsed_time,
"dify.invoke_from": metadata.get("triggered_from"),
"dify.conversation.id": info.conversation_id,
"dify.message.id": info.message_id,
"dify.invoked_by": info.invoked_by,
"gen_ai.usage.total_tokens": info.total_tokens,
"gen_ai.user.id": user_id,
}
trace_correlation_override, parent_span_id_source = info.resolved_parent_context
parent_ctx = metadata.get("parent_trace_context")
if isinstance(parent_ctx, dict):
parent_ctx_dict = cast(dict[str, Any], parent_ctx)
span_attrs["dify.parent.trace_id"] = parent_ctx_dict.get("trace_id")
span_attrs["dify.parent.node.execution_id"] = parent_ctx_dict.get("parent_node_execution_id")
span_attrs["dify.parent.workflow.run_id"] = parent_ctx_dict.get("parent_workflow_run_id")
span_attrs["dify.parent.app.id"] = parent_ctx_dict.get("parent_app_id")
self._exporter.export_span(
EnterpriseTelemetrySpan.WORKFLOW_RUN,
span_attrs,
correlation_id=info.workflow_run_id,
span_id_source=info.workflow_run_id,
start_time=info.start_time,
end_time=info.end_time,
trace_correlation_override=trace_correlation_override,
parent_span_id_source=parent_span_id_source,
)
# -- Companion log: ALL attrs (span + detail) for full picture --
log_attrs: dict[str, Any] = {**span_attrs}
log_attrs.update(
{
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"gen_ai.user.id": user_id,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.workflow.version": info.workflow_run_version,
}
)
ref = f"ref:workflow_run_id={info.workflow_run_id}"
log_attrs["dify.workflow.inputs"] = self._content_or_ref(info.workflow_run_inputs, ref)
log_attrs["dify.workflow.outputs"] = self._content_or_ref(info.workflow_run_outputs, ref)
log_attrs["dify.workflow.query"] = self._content_or_ref(info.query, ref)
emit_telemetry_log(
event_name=EnterpriseTelemetryEvent.WORKFLOW_RUN,
attributes=log_attrs,
signal="span_detail",
trace_id_source=info.workflow_run_id,
span_id_source=info.workflow_run_id,
tenant_id=tenant_id,
user_id=user_id,
)
# -- Metrics --
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
token_labels = TokenMetricLabels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=OperationType.WORKFLOW,
model_provider="",
model_name="",
node_type="",
).to_dict()
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
if info.prompt_tokens is not None and info.prompt_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels)
if info.completion_tokens is not None and info.completion_tokens > 0:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
)
invoke_from = metadata.get("triggered_from", "")
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="workflow",
status=info.workflow_run_status,
invoke_from=invoke_from,
),
)
# Prefer wall-clock timestamps over the elapsed_time field: elapsed_time defaults
# to 0 in the DB and can be stale if the Celery write races with the trace task.
# start_time = workflow_run.created_at, end_time = workflow_run.finished_at.
if info.start_time and info.end_time:
workflow_duration = (info.end_time - info.start_time).total_seconds()
elif info.workflow_run_elapsed_time:
workflow_duration = float(info.workflow_run_elapsed_time)
else:
workflow_duration = 0.0
self._exporter.record_histogram(
EnterpriseTelemetryHistogram.WORKFLOW_DURATION,
workflow_duration,
self._labels(
**labels,
status=info.workflow_run_status,
),
)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="workflow",
),
)
def _node_execution_trace(self, info: WorkflowNodeTraceInfo) -> None:
self._emit_node_execution_trace(info, EnterpriseTelemetrySpan.NODE_EXECUTION, "node")
def _draft_node_execution_trace(self, info: DraftNodeExecutionTrace) -> None:
self._emit_node_execution_trace(
info,
EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION,
"draft_node",
correlation_id_override=info.node_execution_id,
trace_correlation_override_param=info.workflow_run_id,
)
def _emit_node_execution_trace(
self,
info: WorkflowNodeTraceInfo,
span_name: EnterpriseTelemetrySpan,
request_type: str,
correlation_id_override: str | None = None,
trace_correlation_override_param: str | None = None,
) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
# -- Span attrs: identity + structure + status + timing + gen_ai scalars --
span_attrs: dict[str, Any] = {
"dify.trace_id": info.resolved_trace_id,
"dify.tenant_id": tenant_id,
"dify.app_id": app_id,
"dify.workflow.id": info.workflow_id,
"dify.workflow.run_id": info.workflow_run_id,
"dify.message.id": info.message_id,
"dify.conversation.id": metadata.get("conversation_id"),
"dify.node.execution_id": info.node_execution_id,
"dify.node.id": info.node_id,
"dify.node.type": info.node_type,
"dify.node.title": info.title,
"dify.node.status": info.status,
"dify.node.error": info.error,
"dify.node.elapsed_time": info.elapsed_time,
"dify.node.index": info.index,
"dify.node.predecessor_node_id": info.predecessor_node_id,
"dify.node.iteration_id": info.iteration_id,
"dify.node.loop_id": info.loop_id,
"dify.node.parallel_id": info.parallel_id,
"dify.node.invoked_by": info.invoked_by,
"gen_ai.usage.input_tokens": info.prompt_tokens,
"gen_ai.usage.output_tokens": info.completion_tokens,
"gen_ai.usage.total_tokens": info.total_tokens,
"gen_ai.request.model": info.model_name,
"gen_ai.provider.name": info.model_provider,
"gen_ai.user.id": user_id,
}
resolved_override, _ = info.resolved_parent_context
trace_correlation_override = trace_correlation_override_param or resolved_override
effective_correlation_id = correlation_id_override or info.workflow_run_id
self._exporter.export_span(
span_name,
span_attrs,
correlation_id=effective_correlation_id,
span_id_source=info.node_execution_id,
start_time=info.start_time,
end_time=info.end_time,
trace_correlation_override=trace_correlation_override,
)
# -- Companion log: ALL attrs (span + detail) --
log_attrs: dict[str, Any] = {**span_attrs}
log_attrs.update(
{
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"dify.invoke_from": metadata.get("invoke_from"),
"gen_ai.user.id": user_id,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.node.total_price": info.total_price,
"dify.node.currency": info.currency,
"gen_ai.provider.name": info.model_provider,
"gen_ai.request.model": info.model_name,
"gen_ai.tool.name": info.tool_name,
"dify.node.iteration_index": info.iteration_index,
"dify.node.loop_index": info.loop_index,
"dify.plugin.name": metadata.get("plugin_name"),
"dify.credential.name": metadata.get("credential_name"),
"dify.credential.id": metadata.get("credential_id"),
"dify.dataset.ids": self._maybe_json(metadata.get("dataset_ids")),
"dify.dataset.names": self._maybe_json(metadata.get("dataset_names")),
}
)
ref = f"ref:node_execution_id={info.node_execution_id}"
log_attrs["dify.node.inputs"] = self._content_or_ref(info.node_inputs, ref)
log_attrs["dify.node.outputs"] = self._content_or_ref(info.node_outputs, ref)
log_attrs["dify.node.process_data"] = self._content_or_ref(info.process_data, ref)
emit_telemetry_log(
event_name=span_name.value,
attributes=log_attrs,
signal="span_detail",
trace_id_source=info.workflow_run_id,
span_id_source=info.node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
# -- Metrics --
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
node_type=info.node_type,
model_provider=info.model_provider or "",
)
if info.total_tokens:
token_labels = TokenMetricLabels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=OperationType.NODE_EXECUTION,
model_provider=info.model_provider or "",
model_name=info.model_name or "",
node_type=info.node_type,
).to_dict()
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
if info.prompt_tokens is not None and info.prompt_tokens > 0:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels
)
if info.completion_tokens is not None and info.completion_tokens > 0:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type=request_type,
status=info.status,
model_name=info.model_name or "",
),
)
duration_labels = dict(labels)
duration_labels["model_name"] = info.model_name or ""
plugin_name = metadata.get("plugin_name")
if plugin_name and info.node_type in {"tool", "knowledge-retrieval"}:
duration_labels["plugin_name"] = plugin_name
self._exporter.record_histogram(EnterpriseTelemetryHistogram.NODE_DURATION, info.elapsed_time, duration_labels)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type=request_type,
model_name=info.model_name or "",
),
)
# ------------------------------------------------------------------
# METRIC-ONLY handlers (structured log + counters/histograms)
# ------------------------------------------------------------------
def _message_trace(self, info: MessageTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs.update(
{
"dify.invoke_from": metadata.get("from_source"),
"dify.conversation.id": metadata.get("conversation_id"),
"dify.conversation.mode": info.conversation_mode,
"gen_ai.provider.name": metadata.get("ls_provider"),
"gen_ai.request.model": metadata.get("ls_model_name"),
"gen_ai.usage.input_tokens": info.message_tokens,
"gen_ai.usage.output_tokens": info.answer_tokens,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.message.status": metadata.get("status"),
"dify.message.error": info.error,
"dify.message.from_source": metadata.get("from_source"),
"dify.message.from_end_user_id": metadata.get("from_end_user_id"),
"dify.message.from_account_id": metadata.get("from_account_id"),
"dify.streaming": info.is_streaming_request,
"dify.message.time_to_first_token": info.gen_ai_server_time_to_first_token,
"dify.message.streaming_duration": info.llm_streaming_time_to_generate,
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
if info.start_time and info.end_time:
attrs["dify.message.duration"] = (info.end_time - info.start_time).total_seconds()
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
ref = f"ref:message_id={info.message_id}"
inputs = self._safe_payload_value(info.inputs)
outputs = self._safe_payload_value(info.outputs)
attrs["dify.message.inputs"] = self._content_or_ref(inputs, ref)
attrs["dify.message.outputs"] = self._content_or_ref(outputs, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.MESSAGE_RUN,
attributes=attrs,
trace_id_source=metadata.get("workflow_run_id") or (str(info.message_id) if info.message_id else None),
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
model_provider=metadata.get("ls_provider") or "",
model_name=metadata.get("ls_model_name") or "",
)
token_labels = TokenMetricLabels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=OperationType.MESSAGE,
model_provider=metadata.get("ls_provider") or "",
model_name=metadata.get("ls_model_name") or "",
node_type="",
).to_dict()
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
if info.message_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.message_tokens, token_labels)
if info.answer_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.answer_tokens, token_labels)
invoke_from = metadata.get("from_source", "")
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="message",
status=metadata.get("status", ""),
invoke_from=invoke_from,
),
)
if info.start_time and info.end_time:
duration = (info.end_time - info.start_time).total_seconds()
self._exporter.record_histogram(EnterpriseTelemetryHistogram.MESSAGE_DURATION, duration, labels)
if info.gen_ai_server_time_to_first_token is not None:
self._exporter.record_histogram(
EnterpriseTelemetryHistogram.MESSAGE_TTFT, info.gen_ai_server_time_to_first_token, labels
)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="message",
),
)
def _tool_trace(self, info: ToolTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs.update(
{
"dify.tool.name": info.tool_name,
"dify.tool.duration": float(info.time_cost),
"dify.tool.status": "failed" if info.error else "succeeded",
"dify.tool.error": info.error,
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
ref = f"ref:message_id={info.message_id}"
attrs["dify.tool.inputs"] = self._content_or_ref(info.tool_inputs, ref)
attrs["dify.tool.outputs"] = self._content_or_ref(info.tool_outputs, ref)
attrs["dify.tool.parameters"] = self._content_or_ref(info.tool_parameters, ref)
attrs["dify.tool.config"] = self._content_or_ref(info.tool_config, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.TOOL_EXECUTION,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
tool_name=info.tool_name,
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="tool",
),
)
self._exporter.record_histogram(EnterpriseTelemetryHistogram.TOOL_DURATION, float(info.time_cost), labels)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="tool",
),
)
def _moderation_trace(self, info: ModerationTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs.update(
{
"dify.moderation.flagged": info.flagged,
"dify.moderation.action": info.action,
"dify.moderation.preset_response": info.preset_response,
"dify.moderation.type": metadata.get("moderation_type", "input"),
"dify.moderation.categories": self._maybe_json(metadata.get("moderation_categories", [])),
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
attrs["dify.moderation.query"] = self._content_or_ref(
info.query,
f"ref:message_id={info.message_id}",
)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.MODERATION_CHECK,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="moderation",
),
)
def _suggested_question_trace(self, info: SuggestedQuestionTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
duration: float | None = None
if info.start_time is not None and info.end_time is not None:
duration = (info.end_time - info.start_time).total_seconds()
error = info.error or (info.metadata.get("error") if info.metadata else None)
status = "failed" if error else (info.status or "succeeded")
attrs.update(
{
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.suggested_question.status": status,
"dify.suggested_question.error": error,
"dify.suggested_question.duration": duration,
"gen_ai.provider.name": info.model_provider,
"gen_ai.request.model": info.model_id,
"dify.suggested_question.count": len(info.suggested_question),
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
attrs["dify.suggested_question.questions"] = self._content_or_ref(
info.suggested_question,
f"ref:message_id={info.message_id}",
)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.SUGGESTED_QUESTION_GENERATION,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="suggested_question",
model_provider=info.model_provider or "",
model_name=info.model_id or "",
),
)
def _dataset_retrieval_trace(self, info: DatasetRetrievalTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs["dify.retrieval.error"] = info.error
attrs["dify.retrieval.status"] = "failed" if info.error else "succeeded"
if info.start_time and info.end_time:
attrs["dify.retrieval.duration"] = (info.end_time - info.start_time).total_seconds()
attrs["dify.workflow.run_id"] = metadata.get("workflow_run_id")
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
docs: list[dict[str, Any]] = []
documents_any: Any = info.documents
documents_list: list[Any] = cast(list[Any], documents_any) if isinstance(documents_any, list) else []
for entry in documents_list:
if isinstance(entry, dict):
entry_dict: dict[str, Any] = cast(dict[str, Any], entry)
docs.append(entry_dict)
dataset_ids: list[str] = []
dataset_names: list[str] = []
structured_docs: list[dict[str, Any]] = []
for doc in docs:
meta_raw = doc.get("metadata")
meta: dict[str, Any] = cast(dict[str, Any], meta_raw) if isinstance(meta_raw, dict) else {}
did = meta.get("dataset_id")
dname = meta.get("dataset_name")
if did and did not in dataset_ids:
dataset_ids.append(did)
if dname and dname not in dataset_names:
dataset_names.append(dname)
structured_docs.append(
{
"dataset_id": did,
"document_id": meta.get("document_id"),
"segment_id": meta.get("segment_id"),
"score": meta.get("score"),
}
)
attrs["dify.dataset.id"] = self._maybe_json(dataset_ids)
attrs["dify.dataset.name"] = self._maybe_json(dataset_names)
attrs["dify.retrieval.document_count"] = len(docs)
embedding_models_raw: Any = metadata.get("embedding_models")
embedding_models: dict[str, Any] = (
cast(dict[str, Any], embedding_models_raw) if isinstance(embedding_models_raw, dict) else {}
)
if embedding_models:
providers: list[str] = []
models: list[str] = []
for ds_info in embedding_models.values():
if isinstance(ds_info, dict):
ds_info_dict: dict[str, Any] = cast(dict[str, Any], ds_info)
p = ds_info_dict.get("embedding_model_provider", "")
m = ds_info_dict.get("embedding_model", "")
if p and p not in providers:
providers.append(p)
if m and m not in models:
models.append(m)
attrs["dify.dataset.embedding_providers"] = self._maybe_json(providers)
attrs["dify.dataset.embedding_models"] = self._maybe_json(models)
# Add rerank model to logs
rerank_provider = metadata.get("rerank_model_provider", "")
rerank_model = metadata.get("rerank_model_name", "")
if rerank_provider or rerank_model:
attrs["dify.retrieval.rerank_provider"] = rerank_provider
attrs["dify.retrieval.rerank_model"] = rerank_model
ref = f"ref:message_id={info.message_id}"
retrieval_inputs = self._safe_payload_value(info.inputs)
attrs["dify.retrieval.query"] = self._content_or_ref(retrieval_inputs, ref)
attrs["dify.dataset.documents"] = self._content_or_ref(structured_docs, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.DATASET_RETRIEVAL,
attributes=attrs,
trace_id_source=metadata.get("workflow_run_id") or (str(info.message_id) if info.message_id else None),
span_id_source=node_execution_id or (str(info.message_id) if info.message_id else None),
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="dataset_retrieval",
),
)
for did in dataset_ids:
# Get embedding model for this specific dataset
ds_embedding_info = embedding_models.get(did, {})
embedding_provider = ds_embedding_info.get("embedding_model_provider", "")
embedding_model = ds_embedding_info.get("embedding_model", "")
# Get rerank model (same for all datasets in this retrieval)
rerank_provider = metadata.get("rerank_model_provider", "")
rerank_model = metadata.get("rerank_model_name", "")
self._exporter.increment_counter(
EnterpriseTelemetryCounter.DATASET_RETRIEVALS,
1,
self._labels(
**labels,
dataset_id=did,
embedding_model_provider=embedding_provider,
embedding_model=embedding_model,
rerank_model_provider=rerank_provider,
rerank_model=rerank_model,
),
)
def _generate_name_trace(self, info: GenerateNameTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs["dify.conversation.id"] = info.conversation_id
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
duration: float | None = None
if info.start_time is not None and info.end_time is not None:
duration = (info.end_time - info.start_time).total_seconds()
error: str | None = metadata.get("error") if metadata else None
status = "failed" if error else "succeeded"
attrs["dify.generate_name.duration"] = duration
attrs["dify.generate_name.status"] = status
attrs["dify.generate_name.error"] = error
ref = f"ref:conversation_id={info.conversation_id}"
inputs = self._safe_payload_value(info.inputs)
outputs = self._safe_payload_value(info.outputs)
attrs["dify.generate_name.inputs"] = self._content_or_ref(inputs, ref)
attrs["dify.generate_name.outputs"] = self._content_or_ref(outputs, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.GENERATE_NAME_EXECUTION,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="generate_name",
),
)
def _prompt_generation_trace(self, info: PromptGenerationTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = {
"dify.trace_id": info.resolved_trace_id,
"dify.tenant_id": tenant_id,
"gen_ai.user.id": user_id,
"dify.app_id": app_id or "",
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"dify.prompt_generation.operation_type": info.operation_type,
"gen_ai.provider.name": info.model_provider,
"gen_ai.request.model": info.model_name,
"gen_ai.usage.input_tokens": info.prompt_tokens,
"gen_ai.usage.output_tokens": info.completion_tokens,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.prompt_generation.duration": info.latency,
"dify.prompt_generation.status": "failed" if info.error else "succeeded",
"dify.prompt_generation.error": info.error,
}
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
if info.total_price is not None:
attrs["dify.prompt_generation.total_price"] = info.total_price
attrs["dify.prompt_generation.currency"] = info.currency
ref = f"ref:trace_id={info.trace_id}"
outputs = self._safe_payload_value(info.outputs)
attrs["dify.prompt_generation.instruction"] = self._content_or_ref(info.instruction, ref)
attrs["dify.prompt_generation.output"] = self._content_or_ref(outputs, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.PROMPT_GENERATION_EXECUTION,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
token_labels = TokenMetricLabels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=info.operation_type,
model_provider=info.model_provider,
model_name=info.model_name,
node_type="",
).to_dict()
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=info.operation_type,
model_provider=info.model_provider,
model_name=info.model_name,
)
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
if info.prompt_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels)
if info.completion_tokens > 0:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
)
prompt_status = "failed" if info.error else "succeeded"
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="prompt_generation",
status=prompt_status,
),
)
self._exporter.record_histogram(
EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION,
info.latency,
labels,
)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="prompt_generation",
),
)

View File

@@ -0,0 +1,121 @@
from enum import StrEnum
from typing import cast
from opentelemetry.util.types import AttributeValue
from pydantic import BaseModel, ConfigDict
class EnterpriseTelemetrySpan(StrEnum):
WORKFLOW_RUN = "dify.workflow.run"
NODE_EXECUTION = "dify.node.execution"
DRAFT_NODE_EXECUTION = "dify.node.execution.draft"
class EnterpriseTelemetryEvent(StrEnum):
"""Event names for enterprise telemetry logs."""
APP_CREATED = "dify.app.created"
APP_UPDATED = "dify.app.updated"
APP_DELETED = "dify.app.deleted"
FEEDBACK_CREATED = "dify.feedback.created"
WORKFLOW_RUN = "dify.workflow.run"
MESSAGE_RUN = "dify.message.run"
TOOL_EXECUTION = "dify.tool.execution"
MODERATION_CHECK = "dify.moderation.check"
SUGGESTED_QUESTION_GENERATION = "dify.suggested_question.generation"
DATASET_RETRIEVAL = "dify.dataset.retrieval"
GENERATE_NAME_EXECUTION = "dify.generate_name.execution"
PROMPT_GENERATION_EXECUTION = "dify.prompt_generation.execution"
REHYDRATION_FAILED = "dify.telemetry.rehydration_failed"
class EnterpriseTelemetryCounter(StrEnum):
TOKENS = "tokens"
INPUT_TOKENS = "input_tokens"
OUTPUT_TOKENS = "output_tokens"
REQUESTS = "requests"
ERRORS = "errors"
FEEDBACK = "feedback"
DATASET_RETRIEVALS = "dataset_retrievals"
APP_CREATED = "app_created"
APP_UPDATED = "app_updated"
APP_DELETED = "app_deleted"
class EnterpriseTelemetryHistogram(StrEnum):
WORKFLOW_DURATION = "workflow_duration"
NODE_DURATION = "node_duration"
MESSAGE_DURATION = "message_duration"
MESSAGE_TTFT = "message_ttft"
TOOL_DURATION = "tool_duration"
PROMPT_GENERATION_DURATION = "prompt_generation_duration"
class TokenMetricLabels(BaseModel):
"""Unified label structure for all dify.token.* metrics.
All token counters (dify.tokens.input, dify.tokens.output, dify.tokens.total) MUST
use this exact label set to ensure consistent filtering and aggregation across
different operation types.
Attributes:
tenant_id: Tenant identifier.
app_id: Application identifier.
operation_type: Source of token usage (workflow | node_execution | message |
rule_generate | code_generate | structured_output | instruction_modify).
model_provider: LLM provider name. Empty string if not applicable (e.g., workflow-level).
model_name: LLM model name. Empty string if not applicable (e.g., workflow-level).
node_type: Workflow node type. Empty string unless operation_type=node_execution.
Usage:
labels = TokenMetricLabels(
tenant_id="tenant-123",
app_id="app-456",
operation_type=OperationType.WORKFLOW,
model_provider="",
model_name="",
node_type="",
)
exporter.increment_counter(
EnterpriseTelemetryCounter.INPUT_TOKENS,
100,
labels.to_dict()
)
Design rationale:
Without this unified structure, tokens get double-counted when querying totals
because workflow.total_tokens is already the sum of all node tokens. The
operation_type label allows filtering to separate workflow-level aggregates from
node-level detail, while keeping the same label cardinality for consistent queries.
"""
tenant_id: str
app_id: str
operation_type: str
model_provider: str
model_name: str
node_type: str
model_config = ConfigDict(extra="forbid", frozen=True)
def to_dict(self) -> dict[str, AttributeValue]:
return cast(
dict[str, AttributeValue],
{
"tenant_id": self.tenant_id,
"app_id": self.app_id,
"operation_type": self.operation_type,
"model_provider": self.model_provider,
"model_name": self.model_name,
"node_type": self.node_type,
},
)
__all__ = [
"EnterpriseTelemetryCounter",
"EnterpriseTelemetryEvent",
"EnterpriseTelemetryHistogram",
"EnterpriseTelemetrySpan",
"TokenMetricLabels",
]

View File

@@ -0,0 +1,72 @@
"""Blinker signal handlers for enterprise telemetry.
Registered at import time via ``@signal.connect`` decorators.
Import must happen during ``ext_enterprise_telemetry.init_app()`` to
ensure handlers fire. Each handler delegates to ``core.telemetry.gateway``
which handles routing, EE-gating, and dispatch.
All handlers are best-effort: exceptions are caught and logged so that
telemetry failures never break user-facing operations.
"""
from __future__ import annotations
import logging
from events.app_event import app_was_created, app_was_deleted, app_was_updated
logger = logging.getLogger(__name__)
__all__ = [
"_handle_app_created",
"_handle_app_deleted",
"_handle_app_updated",
]
@app_was_created.connect
def _handle_app_created(sender: object, **kwargs: object) -> None:
try:
from core.telemetry.gateway import emit as gateway_emit
from enterprise.telemetry.contracts import TelemetryCase
gateway_emit(
case=TelemetryCase.APP_CREATED,
context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")},
payload={
"app_id": getattr(sender, "id", None),
"mode": getattr(sender, "mode", None),
},
)
except Exception:
logger.warning("Failed to emit app_created telemetry", exc_info=True)
@app_was_updated.connect
def _handle_app_updated(sender: object, **kwargs: object) -> None:
try:
from core.telemetry.gateway import emit as gateway_emit
from enterprise.telemetry.contracts import TelemetryCase
gateway_emit(
case=TelemetryCase.APP_UPDATED,
context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")},
payload={"app_id": getattr(sender, "id", None)},
)
except Exception:
logger.warning("Failed to emit app_updated telemetry", exc_info=True)
@app_was_deleted.connect
def _handle_app_deleted(sender: object, **kwargs: object) -> None:
try:
from core.telemetry.gateway import emit as gateway_emit
from enterprise.telemetry.contracts import TelemetryCase
gateway_emit(
case=TelemetryCase.APP_DELETED,
context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")},
payload={"app_id": getattr(sender, "id", None)},
)
except Exception:
logger.warning("Failed to emit app_deleted telemetry", exc_info=True)

View File

@@ -0,0 +1,283 @@
"""Enterprise OTEL exporter — shared by EnterpriseOtelTrace, event handlers, and direct instrumentation.
Uses dedicated TracerProvider and MeterProvider instances (configurable sampling,
independent from ext_otel.py infrastructure).
Initialized once during Flask extension init (single-threaded via ext_enterprise_telemetry.py).
Accessed via ``ext_enterprise_telemetry.get_enterprise_exporter()`` from any thread/process.
"""
import logging
import socket
import uuid
from datetime import UTC, datetime
from typing import Any, cast
from opentelemetry import trace
from opentelemetry.baggage import get_all
from opentelemetry.baggage.propagation import W3CBaggagePropagator
from opentelemetry.context import Context
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace import SpanContext, TraceFlags
from opentelemetry.util.types import Attributes, AttributeValue
from configs import dify_config
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram
from enterprise.telemetry.id_generator import (
CorrelationIdGenerator,
compute_deterministic_span_id,
set_correlation_id,
set_span_id_source,
)
logger = logging.getLogger(__name__)
def is_enterprise_telemetry_enabled() -> bool:
return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED)
def _parse_otlp_headers(raw: str) -> dict[str, str]:
ctx = W3CBaggagePropagator().extract({"baggage": raw})
return {k: v for k, v in get_all(ctx).items() if isinstance(v, str)}
def _datetime_to_ns(dt: datetime) -> int:
"""Convert a datetime to nanoseconds since epoch (OTEL convention)."""
# Ensure we always interpret naive datetimes as UTC instead of local time.
if dt.tzinfo is None:
dt = dt.replace(tzinfo=UTC)
else:
dt = dt.astimezone(UTC)
return int(dt.timestamp() * 1_000_000_000)
class _ExporterFactory:
def __init__(self, protocol: str, endpoint: str, headers: dict[str, str], insecure: bool):
self._protocol = protocol
self._endpoint = endpoint
self._headers = headers
self._grpc_headers = tuple(headers.items()) if headers else None
self._http_headers = headers or None
self._insecure = insecure
def create_trace_exporter(self) -> HTTPSpanExporter | GRPCSpanExporter:
if self._protocol == "grpc":
return GRPCSpanExporter(
endpoint=self._endpoint or None,
headers=self._grpc_headers,
insecure=self._insecure,
)
trace_endpoint = f"{self._endpoint}/v1/traces" if self._endpoint else ""
return HTTPSpanExporter(endpoint=trace_endpoint or None, headers=self._http_headers)
def create_metric_exporter(self) -> HTTPMetricExporter | GRPCMetricExporter:
if self._protocol == "grpc":
return GRPCMetricExporter(
endpoint=self._endpoint or None,
headers=self._grpc_headers,
insecure=self._insecure,
)
metric_endpoint = f"{self._endpoint}/v1/metrics" if self._endpoint else ""
return HTTPMetricExporter(endpoint=metric_endpoint or None, headers=self._http_headers)
class EnterpriseExporter:
"""Shared OTEL exporter for all enterprise telemetry.
``export_span`` creates spans with optional real timestamps, deterministic
span/trace IDs, and cross-workflow parent linking.
``increment_counter`` / ``record_histogram`` emit OTEL metrics at 100% accuracy.
"""
def __init__(self, config: object) -> None:
endpoint: str = getattr(config, "ENTERPRISE_OTLP_ENDPOINT", "")
headers_raw: str = getattr(config, "ENTERPRISE_OTLP_HEADERS", "")
protocol: str = (getattr(config, "ENTERPRISE_OTLP_PROTOCOL", "http") or "http").lower()
service_name: str = getattr(config, "ENTERPRISE_SERVICE_NAME", "dify")
sampling_rate: float = getattr(config, "ENTERPRISE_OTEL_SAMPLING_RATE", 1.0)
self.include_content: bool = getattr(config, "ENTERPRISE_INCLUDE_CONTENT", True)
api_key: str = getattr(config, "ENTERPRISE_OTLP_API_KEY", "")
# Auto-detect TLS: https:// uses secure, everything else is insecure
insecure = not endpoint.startswith("https://")
resource = Resource(
attributes={
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.HOST_NAME: socket.gethostname(),
}
)
sampler = ParentBasedTraceIdRatio(sampling_rate)
id_generator = CorrelationIdGenerator()
self._tracer_provider = TracerProvider(resource=resource, sampler=sampler, id_generator=id_generator)
headers = _parse_otlp_headers(headers_raw)
if api_key:
if "authorization" in headers:
logger.warning(
"ENTERPRISE_OTLP_API_KEY is set but ENTERPRISE_OTLP_HEADERS also contains "
"'authorization'; the API key will take precedence."
)
headers["authorization"] = f"Bearer {api_key}"
factory = _ExporterFactory(protocol, endpoint, headers, insecure=insecure)
trace_exporter = factory.create_trace_exporter()
self._tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
self._tracer = self._tracer_provider.get_tracer("dify.enterprise")
metric_exporter = factory.create_metric_exporter()
self._meter_provider = MeterProvider(
resource=resource,
metric_readers=[PeriodicExportingMetricReader(metric_exporter)],
)
meter = self._meter_provider.get_meter("dify.enterprise")
self._counters = {
EnterpriseTelemetryCounter.TOKENS: meter.create_counter("dify.tokens.total", unit="{token}"),
EnterpriseTelemetryCounter.INPUT_TOKENS: meter.create_counter("dify.tokens.input", unit="{token}"),
EnterpriseTelemetryCounter.OUTPUT_TOKENS: meter.create_counter("dify.tokens.output", unit="{token}"),
EnterpriseTelemetryCounter.REQUESTS: meter.create_counter("dify.requests.total", unit="{request}"),
EnterpriseTelemetryCounter.ERRORS: meter.create_counter("dify.errors.total", unit="{error}"),
EnterpriseTelemetryCounter.FEEDBACK: meter.create_counter("dify.feedback.total", unit="{feedback}"),
EnterpriseTelemetryCounter.DATASET_RETRIEVALS: meter.create_counter(
"dify.dataset.retrievals.total", unit="{retrieval}"
),
EnterpriseTelemetryCounter.APP_CREATED: meter.create_counter("dify.app.created.total", unit="{app}"),
EnterpriseTelemetryCounter.APP_UPDATED: meter.create_counter("dify.app.updated.total", unit="{app}"),
EnterpriseTelemetryCounter.APP_DELETED: meter.create_counter("dify.app.deleted.total", unit="{app}"),
}
self._histograms = {
EnterpriseTelemetryHistogram.WORKFLOW_DURATION: meter.create_histogram("dify.workflow.duration", unit="s"),
EnterpriseTelemetryHistogram.NODE_DURATION: meter.create_histogram("dify.node.duration", unit="s"),
EnterpriseTelemetryHistogram.MESSAGE_DURATION: meter.create_histogram("dify.message.duration", unit="s"),
EnterpriseTelemetryHistogram.MESSAGE_TTFT: meter.create_histogram(
"dify.message.time_to_first_token", unit="s"
),
EnterpriseTelemetryHistogram.TOOL_DURATION: meter.create_histogram("dify.tool.duration", unit="s"),
EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION: meter.create_histogram(
"dify.prompt_generation.duration", unit="s"
),
}
def export_span(
self,
name: str,
attributes: dict[str, Any],
correlation_id: str | None = None,
span_id_source: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
trace_correlation_override: str | None = None,
parent_span_id_source: str | None = None,
) -> None:
"""Export an OTEL span with optional deterministic IDs and real timestamps.
Args:
name: Span operation name.
attributes: Span attributes dict.
correlation_id: Source for trace_id derivation (groups spans in one trace).
span_id_source: Source for deterministic span_id (e.g. workflow_run_id or node_execution_id).
start_time: Real span start time. When None, uses current time.
end_time: Real span end time. When None, span ends immediately.
trace_correlation_override: Override trace_id source (for cross-workflow linking).
When set, trace_id is derived from this instead of ``correlation_id``.
parent_span_id_source: Override parent span_id source (for cross-workflow linking).
When set, parent span_id is derived from this value. When None and
``correlation_id`` is set, parent is the workflow root span.
"""
effective_trace_correlation = trace_correlation_override or correlation_id
set_correlation_id(effective_trace_correlation)
set_span_id_source(span_id_source)
try:
parent_context: Context | None = None
# A span is the "root" of its correlation group when span_id_source == correlation_id
# (i.e. a workflow root span). All other spans are children.
if parent_span_id_source:
# Cross-workflow linking: parent is an explicit span (e.g. tool node in outer workflow)
parent_span_id = compute_deterministic_span_id(parent_span_id_source)
try:
parent_trace_id = int(uuid.UUID(effective_trace_correlation)) if effective_trace_correlation else 0
except (ValueError, AttributeError):
logger.warning(
"Invalid trace correlation UUID for cross-workflow link: %s, span=%s",
effective_trace_correlation,
name,
)
parent_trace_id = 0
if parent_trace_id:
parent_span_context = SpanContext(
trace_id=parent_trace_id,
span_id=parent_span_id,
is_remote=True,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context))
elif correlation_id and correlation_id != span_id_source:
# Child span: parent is the correlation-group root (workflow root span)
parent_span_id = compute_deterministic_span_id(correlation_id)
try:
parent_trace_id = int(uuid.UUID(effective_trace_correlation or correlation_id))
except (ValueError, AttributeError):
logger.warning(
"Invalid trace correlation UUID for child span link: %s, span=%s",
effective_trace_correlation or correlation_id,
name,
)
parent_trace_id = 0
if parent_trace_id:
parent_span_context = SpanContext(
trace_id=parent_trace_id,
span_id=parent_span_id,
is_remote=True,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context))
span_start_time = _datetime_to_ns(start_time) if start_time is not None else None
span_end_on_exit = end_time is None
with self._tracer.start_as_current_span(
name,
context=parent_context,
start_time=span_start_time,
end_on_exit=span_end_on_exit,
) as span:
for key, value in attributes.items():
if value is not None:
span.set_attribute(key, value)
if end_time is not None:
span.end(end_time=_datetime_to_ns(end_time))
except Exception:
logger.exception("Failed to export span %s", name)
finally:
set_correlation_id(None)
set_span_id_source(None)
def increment_counter(
self, name: EnterpriseTelemetryCounter, value: int, labels: dict[str, AttributeValue]
) -> None:
counter = self._counters.get(name)
if counter:
counter.add(value, cast(Attributes, labels))
def record_histogram(
self, name: EnterpriseTelemetryHistogram, value: float, labels: dict[str, AttributeValue]
) -> None:
histogram = self._histograms.get(name)
if histogram:
histogram.record(value, cast(Attributes, labels))
def shutdown(self) -> None:
self._tracer_provider.shutdown()
self._meter_provider.shutdown()

View File

@@ -0,0 +1,75 @@
"""Custom OTEL ID Generator for correlation-based trace/span ID derivation.
Uses contextvars for thread-safe correlation_id -> trace_id mapping.
When a span_id_source is set, the span_id is derived deterministically
from that value, enabling any span to reference another as parent
without depending on span creation order.
"""
import random
import uuid
from contextvars import ContextVar
from opentelemetry.sdk.trace.id_generator import IdGenerator
_correlation_id_context: ContextVar[str | None] = ContextVar("correlation_id", default=None)
_span_id_source_context: ContextVar[str | None] = ContextVar("span_id_source", default=None)
def set_correlation_id(correlation_id: str | None) -> None:
_correlation_id_context.set(correlation_id)
def get_correlation_id() -> str | None:
return _correlation_id_context.get()
def set_span_id_source(source_id: str | None) -> None:
"""Set the source for deterministic span_id generation.
When set, ``generate_span_id()`` derives the span_id from this value
(lower 64 bits of the UUID). Pass the ``workflow_run_id`` for workflow
root spans or ``node_execution_id`` for node spans.
"""
_span_id_source_context.set(source_id)
def compute_deterministic_span_id(source_id: str) -> int:
"""Derive a deterministic span_id from any UUID string.
Uses the lower 64 bits of the UUID, guaranteeing non-zero output
(OTEL requires span_id != 0).
"""
span_id = uuid.UUID(source_id).int & ((1 << 64) - 1)
return span_id if span_id != 0 else 1
class CorrelationIdGenerator(IdGenerator):
"""ID generator that derives trace_id and optionally span_id from context.
- trace_id: always derived from correlation_id (groups all spans in one trace)
- span_id: derived from span_id_source when set (enables deterministic
parent-child linking), otherwise random
"""
def generate_trace_id(self) -> int:
correlation_id = _correlation_id_context.get()
if correlation_id:
try:
return uuid.UUID(correlation_id).int
except (ValueError, AttributeError):
pass
return random.getrandbits(128)
def generate_span_id(self) -> int:
source = _span_id_source_context.get()
if source:
try:
return compute_deterministic_span_id(source)
except (ValueError, AttributeError):
pass
span_id = random.getrandbits(64)
while span_id == 0:
span_id = random.getrandbits(64)
return span_id

View File

@@ -0,0 +1,421 @@
"""Enterprise metric/log event handler.
This module processes metric and log telemetry events after they've been
dequeued from the enterprise_telemetry Celery queue. It handles case routing,
idempotency checking, and payload rehydration.
"""
from __future__ import annotations
import json
import logging
from datetime import UTC, datetime
from typing import Any
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
logger = logging.getLogger(__name__)
class EnterpriseMetricHandler:
"""Handler for enterprise metric and log telemetry events.
Processes envelopes from the enterprise_telemetry queue, routing each
case to the appropriate handler method. Implements idempotency checking
and payload rehydration with fallback.
"""
def _increment_diagnostic_counter(self, counter_name: str, labels: dict[str, str] | None = None) -> None:
"""Increment a diagnostic counter for operational monitoring.
Args:
counter_name: Name of the counter (e.g., 'processed_total', 'deduped_total').
labels: Optional labels for the counter.
"""
try:
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
return
full_counter_name = f"enterprise_telemetry.handler.{counter_name}"
logger.debug(
"Diagnostic counter: %s, labels=%s",
full_counter_name,
labels or {},
)
except Exception:
logger.debug("Failed to increment diagnostic counter: %s", counter_name, exc_info=True)
def handle(self, envelope: TelemetryEnvelope) -> None:
"""Main entry point for processing telemetry envelopes.
Args:
envelope: The telemetry envelope to process.
"""
# Check for duplicate events
if self._is_duplicate(envelope):
logger.debug(
"Skipping duplicate event: tenant_id=%s, event_id=%s",
envelope.tenant_id,
envelope.event_id,
)
self._increment_diagnostic_counter("deduped_total")
return
# Route to appropriate handler based on case
case = envelope.case
if case == TelemetryCase.APP_CREATED:
self._on_app_created(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_created"})
elif case == TelemetryCase.APP_UPDATED:
self._on_app_updated(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_updated"})
elif case == TelemetryCase.APP_DELETED:
self._on_app_deleted(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"})
elif case == TelemetryCase.FEEDBACK_CREATED:
self._on_feedback_created(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"})
elif case == TelemetryCase.MESSAGE_RUN:
self._on_message_run(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "message_run"})
elif case == TelemetryCase.TOOL_EXECUTION:
self._on_tool_execution(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"})
elif case == TelemetryCase.MODERATION_CHECK:
self._on_moderation_check(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"})
elif case == TelemetryCase.SUGGESTED_QUESTION:
self._on_suggested_question(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"})
elif case == TelemetryCase.DATASET_RETRIEVAL:
self._on_dataset_retrieval(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"})
elif case == TelemetryCase.GENERATE_NAME:
self._on_generate_name(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "generate_name"})
elif case == TelemetryCase.PROMPT_GENERATION:
self._on_prompt_generation(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"})
else:
logger.warning(
"Unknown telemetry case: %s (tenant_id=%s, event_id=%s)",
case,
envelope.tenant_id,
envelope.event_id,
)
def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool:
"""Check if this event has already been processed.
Uses Redis with TTL for deduplication. Returns True if duplicate,
False if first time seeing this event.
Args:
envelope: The telemetry envelope to check.
Returns:
True if this event_id has been seen before, False otherwise.
"""
dedup_key = f"telemetry:dedup:{envelope.tenant_id}:{envelope.event_id}"
try:
# Atomic set-if-not-exists with 1h TTL
# Returns True if key was set (first time), None if already exists (duplicate)
was_set = redis_client.set(dedup_key, b"1", nx=True, ex=3600)
return was_set is None
except Exception:
# Fail open: if Redis is unavailable, process the event
# (prefer occasional duplicate over lost data)
logger.warning(
"Redis unavailable for deduplication check, processing event anyway: %s",
envelope.event_id,
exc_info=True,
)
return False
def _rehydrate(self, envelope: TelemetryEnvelope) -> dict[str, Any]:
"""Rehydrate payload from storage reference or inline data.
If the envelope payload is empty and metadata contains a
``payload_ref``, the full payload is loaded from object storage
(where the gateway wrote it as JSON). When both the inline
payload and storage resolution fail, a degraded-event marker
is emitted so the gap is observable.
Args:
envelope: The telemetry envelope containing payload data.
Returns:
The rehydrated payload dictionary, or ``{}`` on total failure.
"""
payload = envelope.payload
# Resolve from object storage when the gateway offloaded a large payload.
if not payload and envelope.metadata:
payload_ref = envelope.metadata.get("payload_ref")
if payload_ref:
try:
payload_bytes = storage.load(payload_ref)
payload = json.loads(payload_bytes.decode("utf-8"))
logger.debug("Loaded payload from storage: key=%s", payload_ref)
except Exception:
logger.warning(
"Failed to load payload from storage: key=%s, event_id=%s",
payload_ref,
envelope.event_id,
exc_info=True,
)
if not payload:
# Storage resolution failed or no data available — emit degraded event.
logger.error(
"Payload rehydration failed for event_id=%s, tenant_id=%s, case=%s",
envelope.event_id,
envelope.tenant_id,
envelope.case,
)
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.REHYDRATION_FAILED,
attributes={
"tenant_id": envelope.tenant_id,
"dify.telemetry.error": f"Payload rehydration failed for event_id={envelope.event_id}",
"dify.telemetry.payload_type": envelope.case,
"dify.telemetry.correlation_id": envelope.event_id,
},
tenant_id=envelope.tenant_id,
)
self._increment_diagnostic_counter("rehydration_failed_total")
return {}
return payload
# Stub methods for each metric/log case
# These will be implemented in later tasks with actual emission logic
def _on_app_created(self, envelope: TelemetryEnvelope) -> None:
"""Handle app created event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for APP_CREATED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
attrs = {
"dify.app_id": payload.get("app_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.event.id": envelope.event_id,
"dify.app.mode": payload.get("mode"),
"dify.app.created_at": datetime.now(UTC).isoformat(),
}
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.APP_CREATED,
attributes=attrs,
tenant_id=envelope.tenant_id,
)
exporter.increment_counter(
EnterpriseTelemetryCounter.APP_CREATED,
1,
{
"tenant_id": envelope.tenant_id,
"app_id": str(payload.get("app_id", "")),
"mode": str(payload.get("mode", "")),
},
)
def _on_app_updated(self, envelope: TelemetryEnvelope) -> None:
"""Handle app updated event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for APP_UPDATED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
attrs = {
"dify.app_id": payload.get("app_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.event.id": envelope.event_id,
"dify.app.updated_at": datetime.now(UTC).isoformat(),
}
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.APP_UPDATED,
attributes=attrs,
tenant_id=envelope.tenant_id,
)
exporter.increment_counter(
EnterpriseTelemetryCounter.APP_UPDATED,
1,
{
"tenant_id": envelope.tenant_id,
"app_id": str(payload.get("app_id", "")),
},
)
def _on_app_deleted(self, envelope: TelemetryEnvelope) -> None:
"""Handle app deleted event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for APP_DELETED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
attrs = {
"dify.app_id": payload.get("app_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.event.id": envelope.event_id,
"dify.app.deleted_at": datetime.now(UTC).isoformat(),
}
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.APP_DELETED,
attributes=attrs,
tenant_id=envelope.tenant_id,
)
exporter.increment_counter(
EnterpriseTelemetryCounter.APP_DELETED,
1,
{
"tenant_id": envelope.tenant_id,
"app_id": str(payload.get("app_id", "")),
},
)
def _on_feedback_created(self, envelope: TelemetryEnvelope) -> None:
"""Handle feedback created event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for FEEDBACK_CREATED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
include_content = exporter.include_content
attrs: dict = {
"dify.message.id": payload.get("message_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.event.id": envelope.event_id,
"dify.app_id": payload.get("app_id"),
"dify.conversation.id": payload.get("conversation_id"),
"gen_ai.user.id": payload.get("from_end_user_id") or payload.get("from_account_id"),
"dify.feedback.rating": payload.get("rating"),
"dify.feedback.from_source": payload.get("from_source"),
"dify.feedback.created_at": datetime.now(UTC).isoformat(),
}
if include_content:
attrs["dify.feedback.content"] = payload.get("content")
user_id = payload.get("from_end_user_id") or payload.get("from_account_id")
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.FEEDBACK_CREATED,
attributes=attrs,
tenant_id=envelope.tenant_id,
user_id=str(user_id or ""),
)
exporter.increment_counter(
EnterpriseTelemetryCounter.FEEDBACK,
1,
{
"tenant_id": envelope.tenant_id,
"app_id": str(payload.get("app_id", "")),
"rating": str(payload.get("rating", "")),
},
)
def _on_message_run(self, envelope: TelemetryEnvelope) -> None:
"""Handle message run event.
Intentionally a no-op: metrics and structured logs for message runs are
emitted directly by EnterpriseOtelTrace._message_trace at trace time,
not through the metric handler queue path.
"""
logger.debug("Processing MESSAGE_RUN: event_id=%s", envelope.event_id)
def _on_tool_execution(self, envelope: TelemetryEnvelope) -> None:
"""Handle tool execution event.
Intentionally a no-op: metrics and structured logs for tool executions
are emitted directly by EnterpriseOtelTrace._tool_trace at trace time,
not through the metric handler queue path.
"""
logger.debug("Processing TOOL_EXECUTION: event_id=%s", envelope.event_id)
def _on_moderation_check(self, envelope: TelemetryEnvelope) -> None:
"""Handle moderation check event.
Intentionally a no-op: metrics and structured logs for moderation checks
are emitted directly by EnterpriseOtelTrace._moderation_trace at trace time,
not through the metric handler queue path.
"""
logger.debug("Processing MODERATION_CHECK: event_id=%s", envelope.event_id)
def _on_suggested_question(self, envelope: TelemetryEnvelope) -> None:
"""Handle suggested question event.
Intentionally a no-op: metrics and structured logs for suggested questions
are emitted directly by EnterpriseOtelTrace._suggested_question_trace at
trace time, not through the metric handler queue path.
"""
logger.debug("Processing SUGGESTED_QUESTION: event_id=%s", envelope.event_id)
def _on_dataset_retrieval(self, envelope: TelemetryEnvelope) -> None:
"""Handle dataset retrieval event.
Intentionally a no-op: metrics and structured logs for dataset retrievals
are emitted directly by EnterpriseOtelTrace._dataset_retrieval_trace at
trace time, not through the metric handler queue path.
"""
logger.debug("Processing DATASET_RETRIEVAL: event_id=%s", envelope.event_id)
def _on_generate_name(self, envelope: TelemetryEnvelope) -> None:
"""Handle generate name event.
Intentionally a no-op: metrics and structured logs for generate name
operations are emitted directly by EnterpriseOtelTrace._generate_name_trace
at trace time, not through the metric handler queue path.
"""
logger.debug("Processing GENERATE_NAME: event_id=%s", envelope.event_id)
def _on_prompt_generation(self, envelope: TelemetryEnvelope) -> None:
"""Handle prompt generation event.
Intentionally a no-op: metrics and structured logs for prompt generation
operations are emitted directly by EnterpriseOtelTrace._prompt_generation_trace
at trace time, not through the metric handler queue path.
"""
logger.debug("Processing PROMPT_GENERATION: event_id=%s", envelope.event_id)

View File

@@ -0,0 +1,122 @@
"""Structured-log emitter for enterprise telemetry events.
Emits structured JSON log lines correlated with OTEL traces via trace_id.
Picked up by ``StructuredJSONFormatter`` → stdout/Loki/Elastic.
"""
from __future__ import annotations
import logging
import uuid
from functools import lru_cache
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
logger = logging.getLogger("dify.telemetry")
@lru_cache(maxsize=4096)
def compute_trace_id_hex(uuid_str: str | None) -> str:
"""Convert a business UUID string to a 32-hex OTEL-compatible trace_id.
Returns empty string when *uuid_str* is ``None`` or invalid.
"""
if not uuid_str:
return ""
normalized = uuid_str.strip().lower()
if len(normalized) == 32 and all(ch in "0123456789abcdef" for ch in normalized):
return normalized
try:
return f"{uuid.UUID(normalized).int:032x}"
except (ValueError, AttributeError):
return ""
@lru_cache(maxsize=4096)
def compute_span_id_hex(uuid_str: str | None) -> str:
if not uuid_str:
return ""
normalized = uuid_str.strip().lower()
if len(normalized) == 16 and all(ch in "0123456789abcdef" for ch in normalized):
return normalized
try:
from enterprise.telemetry.id_generator import compute_deterministic_span_id
return f"{compute_deterministic_span_id(normalized):016x}"
except (ValueError, AttributeError):
return ""
def emit_telemetry_log(
*,
event_name: str | EnterpriseTelemetryEvent,
attributes: dict[str, Any],
signal: str = "metric_only",
trace_id_source: str | None = None,
span_id_source: str | None = None,
tenant_id: str | None = None,
user_id: str | None = None,
) -> None:
"""Emit a structured log line for a telemetry event.
Parameters
----------
event_name:
Canonical event name, e.g. ``"dify.workflow.run"``.
attributes:
All event-specific attributes (already built by the caller).
signal:
``"metric_only"`` for events with no span, ``"span_detail"``
for detail logs accompanying a slim span.
trace_id_source:
A UUID string (e.g. ``workflow_run_id``) used to derive a 32-hex
trace_id for cross-signal correlation.
tenant_id:
Tenant identifier (for the ``IdentityContextFilter``).
user_id:
User identifier (for the ``IdentityContextFilter``).
"""
if not logger.isEnabledFor(logging.INFO):
return
attrs = {
"dify.event.name": event_name,
"dify.event.signal": signal,
**attributes,
}
extra: dict[str, Any] = {"attributes": attrs}
trace_id_hex = compute_trace_id_hex(trace_id_source)
if trace_id_hex:
extra["trace_id"] = trace_id_hex
span_id_hex = compute_span_id_hex(span_id_source)
if span_id_hex:
extra["span_id"] = span_id_hex
if tenant_id:
extra["tenant_id"] = tenant_id
if user_id:
extra["user_id"] = user_id
logger.info("telemetry.%s", signal, extra=extra)
def emit_metric_only_event(
*,
event_name: str | EnterpriseTelemetryEvent,
attributes: dict[str, Any],
trace_id_source: str | None = None,
span_id_source: str | None = None,
tenant_id: str | None = None,
user_id: str | None = None,
) -> None:
emit_telemetry_log(
event_name=event_name,
attributes=attributes,
signal="metric_only",
trace_id_source=trace_id_source,
span_id_source=span_id_source,
tenant_id=tenant_id,
user_id=user_id,
)

View File

@@ -11,3 +11,9 @@ app_published_workflow_was_updated = signal("app-published-workflow-was-updated"
# sender: app, kwargs: synced_draft_workflow
app_draft_workflow_was_synced = signal("app-draft-workflow-was-synced")
# sender: app
app_was_updated = signal("app-was-updated")
# sender: app
app_was_deleted = signal("app-was-deleted")

View File

@@ -204,6 +204,8 @@ def init_app(app: DifyApp) -> Celery:
"schedule": timedelta(minutes=dify_config.API_TOKEN_LAST_USED_UPDATE_INTERVAL),
}
if dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED:
imports.append("tasks.enterprise_telemetry_task")
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
return celery_app

View File

@@ -0,0 +1,50 @@
"""Flask extension for enterprise telemetry lifecycle management.
Initializes the EnterpriseExporter singleton during ``create_app()``
(single-threaded), registers blinker event handlers, and hooks atexit
for graceful shutdown.
Skipped entirely when either ``ENTERPRISE_ENABLED`` or ``ENTERPRISE_TELEMETRY_ENABLED``
is false (``is_enabled()`` gate).
"""
from __future__ import annotations
import atexit
import logging
from typing import TYPE_CHECKING
from configs import dify_config
if TYPE_CHECKING:
from dify_app import DifyApp
from enterprise.telemetry.exporter import EnterpriseExporter
logger = logging.getLogger(__name__)
_exporter: EnterpriseExporter | None = None
def is_enabled() -> bool:
return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED)
def init_app(app: DifyApp) -> None:
global _exporter
if not is_enabled():
return
from enterprise.telemetry.exporter import EnterpriseExporter
_exporter = EnterpriseExporter(dify_config)
atexit.register(_exporter.shutdown)
# Import to trigger @signal.connect decorator registration
import enterprise.telemetry.event_handlers # noqa: F401 # type: ignore[reportUnusedImport]
logger.info("Enterprise telemetry initialized")
def get_enterprise_exporter() -> EnterpriseExporter | None:
return _exporter

View File

@@ -78,16 +78,24 @@ def init_app(app: DifyApp):
protocol = (dify_config.OTEL_EXPORTER_OTLP_PROTOCOL or "").lower()
if dify_config.OTEL_EXPORTER_TYPE == "otlp":
if protocol == "grpc":
# Auto-detect TLS: https:// uses secure, everything else is insecure
endpoint = dify_config.OTLP_BASE_ENDPOINT
insecure = not endpoint.startswith("https://")
# Header field names must consist of lowercase letters, check RFC7540
grpc_headers = (
(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),) if dify_config.OTLP_API_KEY else ()
)
exporter = GRPCSpanExporter(
endpoint=dify_config.OTLP_BASE_ENDPOINT,
# Header field names must consist of lowercase letters, check RFC7540
headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),),
insecure=True,
endpoint=endpoint,
headers=grpc_headers,
insecure=insecure,
)
metric_exporter = GRPCMetricExporter(
endpoint=dify_config.OTLP_BASE_ENDPOINT,
headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),),
insecure=True,
endpoint=endpoint,
headers=grpc_headers,
insecure=insecure,
)
else:
headers = {"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"} if dify_config.OTLP_API_KEY else None

View File

@@ -5,7 +5,7 @@ This module provides parsers that extract node-specific metadata and set
OpenTelemetry span attributes according to semantic conventions.
"""
from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps
from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps, should_include_content
from extensions.otel.parser.llm import LLMNodeOTelParser
from extensions.otel.parser.retrieval import RetrievalNodeOTelParser
from extensions.otel.parser.tool import ToolNodeOTelParser
@@ -17,4 +17,5 @@ __all__ = [
"RetrievalNodeOTelParser",
"ToolNodeOTelParser",
"safe_json_dumps",
"should_include_content",
]

View File

@@ -1,5 +1,10 @@
"""
Base parser interface and utilities for OpenTelemetry node parsers.
Content gating: ``should_include_content()`` controls whether content-bearing
span attributes (inputs, outputs, prompts, completions, documents) are written.
Gate is only active in EE (``ENTERPRISE_ENABLED=True``) when
``ENTERPRISE_INCLUDE_CONTENT=False``; CE behaviour is unchanged.
"""
import json
@@ -9,6 +14,7 @@ from opentelemetry.trace import Span
from opentelemetry.trace.status import Status, StatusCode
from pydantic import BaseModel
from configs import dify_config
from dify_graph.enums import BuiltinNodeTypes
from dify_graph.file.models import File
from dify_graph.graph_events import GraphNodeEventBase
@@ -17,6 +23,16 @@ from dify_graph.variables import Segment
from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes
def should_include_content() -> bool:
"""Return True if content should be written to spans.
CE (ENTERPRISE_ENABLED=False): always True — no behaviour change.
"""
if not dify_config.ENTERPRISE_ENABLED:
return True
return dify_config.ENTERPRISE_INCLUDE_CONTENT
def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str:
"""
Safely serialize objects to JSON, handling non-serializable types.
@@ -101,10 +117,11 @@ class DefaultNodeOTelParser:
# Extract inputs and outputs from result_event
if result_event and result_event.node_run_result:
node_run_result = result_event.node_run_result
if node_run_result.inputs:
span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs))
if node_run_result.outputs:
span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs))
if should_include_content():
if node_run_result.inputs:
span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs))
if node_run_result.outputs:
span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs))
if error:
span.record_exception(error)

View File

@@ -21,3 +21,15 @@ class DifySpanAttributes:
INVOKE_FROM = "dify.invoke_from"
"""Invocation source, e.g. SERVICE_API, WEB_APP, DEBUGGER."""
INVOKED_BY = "dify.invoked_by"
"""Invoked by, e.g. end_user, account, user."""
USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
"""Number of input tokens (prompt tokens) used."""
USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
"""Number of output tokens (completion tokens) generated."""
USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
"""Total number of tokens used."""

View File

@@ -109,6 +109,15 @@ core/trigger/debug/event_selectors.py
core/trigger/entities/entities.py
core/trigger/provider.py
core/workflow/workflow_entry.py
enterprise/telemetry/contracts.py
enterprise/telemetry/draft_trace.py
enterprise/telemetry/enterprise_trace.py
enterprise/telemetry/entities/__init__.py
enterprise/telemetry/event_handlers.py
enterprise/telemetry/exporter.py
enterprise/telemetry/id_generator.py
enterprise/telemetry/metric_handler.py
enterprise/telemetry/telemetry_log.py
dify_graph/entities/workflow_execution.py
dify_graph/file/file_manager.py
dify_graph/graph_engine/error_handler.py

View File

@@ -14,7 +14,7 @@ from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from dify_graph.model_runtime.entities.model_entities import ModelPropertyKey, ModelType
from dify_graph.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from events.app_event import app_was_created
from events.app_event import app_was_created, app_was_deleted, app_was_updated
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
@@ -272,6 +272,8 @@ class AppService:
app.updated_at = naive_utc_now()
db.session.commit()
app_was_updated.send(app)
return app
def update_app_name(self, app: App, name: str) -> App:
@@ -287,6 +289,8 @@ class AppService:
app.updated_at = naive_utc_now()
db.session.commit()
app_was_updated.send(app)
return app
def update_app_icon(self, app: App, icon: str, icon_background: str) -> App:
@@ -304,6 +308,8 @@ class AppService:
app.updated_at = naive_utc_now()
db.session.commit()
app_was_updated.send(app)
return app
def update_app_site_status(self, app: App, enable_site: bool) -> App:
@@ -321,6 +327,8 @@ class AppService:
app.updated_at = naive_utc_now()
db.session.commit()
app_was_updated.send(app)
return app
def update_app_api_status(self, app: App, enable_api: bool) -> App:
@@ -339,6 +347,8 @@ class AppService:
app.updated_at = naive_utc_now()
db.session.commit()
app_was_updated.send(app)
return app
def delete_app(self, app: App):
@@ -346,6 +356,8 @@ class AppService:
Delete app
:param app: App instance
"""
app_was_deleted.send(app)
db.session.delete(app)
db.session.commit()

View File

@@ -53,6 +53,7 @@ from dify_graph.repositories.workflow_node_execution_repository import OrderConf
from dify_graph.runtime import VariablePool
from dify_graph.system_variable import SystemVariable
from dify_graph.variables.variables import VariableBase
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
@@ -571,6 +572,13 @@ class RagPipelineService:
outputs=workflow_node_execution.outputs,
)
session.commit()
if workflow_node_execution_db_model is not None:
enqueue_draft_node_execution_trace(
execution=workflow_node_execution_db_model,
outputs=workflow_node_execution.outputs,
workflow_execution_id=None,
user_id=account.id,
)
return workflow_node_execution_db_model
def run_datasource_workflow_node(
@@ -1334,6 +1342,12 @@ class RagPipelineService:
outputs=workflow_node_execution.outputs,
)
session.commit()
enqueue_draft_node_execution_trace(
execution=workflow_node_execution_db_model,
outputs=workflow_node_execution.outputs,
workflow_execution_id=None,
user_id=current_user.id,
)
return workflow_node_execution_db_model
def get_recommended_plugins(self, type: str) -> dict:

View File

@@ -49,6 +49,7 @@ from dify_graph.variable_loader import load_into_variable_pool
from dify_graph.variables import VariableBase
from dify_graph.variables.input_entities import VariableEntityType
from dify_graph.variables.variables import Variable
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
from enums.cloud_plan import CloudPlan
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db
@@ -841,6 +842,13 @@ class WorkflowService:
draft_var_saver.save(process_data=node_execution.process_data, outputs=outputs)
session.commit()
enqueue_draft_node_execution_trace(
execution=workflow_node_execution,
outputs=outputs,
workflow_execution_id=None,
user_id=account.id,
)
return workflow_node_execution
def get_human_input_form_preview(

View File

@@ -0,0 +1,52 @@
"""Celery worker for enterprise metric/log telemetry events.
This module defines the Celery task that processes telemetry envelopes
from the enterprise_telemetry queue. It deserializes envelopes and
dispatches them to the EnterpriseMetricHandler.
"""
import json
import logging
from celery import shared_task
from enterprise.telemetry.contracts import TelemetryEnvelope
from enterprise.telemetry.metric_handler import EnterpriseMetricHandler
logger = logging.getLogger(__name__)
@shared_task(queue="enterprise_telemetry")
def process_enterprise_telemetry(envelope_json: str) -> None:
"""Process enterprise metric/log telemetry envelope.
This task is enqueued by the TelemetryGateway for metric/log-only
events. It deserializes the envelope and dispatches to the handler.
Best-effort processing: logs errors but never raises, to avoid
failing user requests due to telemetry issues.
Args:
envelope_json: JSON-serialized TelemetryEnvelope.
"""
try:
# Deserialize envelope
envelope_dict = json.loads(envelope_json)
envelope = TelemetryEnvelope.model_validate(envelope_dict)
# Process through handler
handler = EnterpriseMetricHandler()
handler.handle(envelope)
logger.debug(
"Successfully processed telemetry envelope: tenant_id=%s, event_id=%s, case=%s",
envelope.tenant_id,
envelope.event_id,
envelope.case,
)
except Exception:
# Best-effort: log and drop on error, never fail user request
logger.warning(
"Failed to process enterprise telemetry envelope, dropping event",
exc_info=True,
)

View File

@@ -39,17 +39,36 @@ def process_trace_tasks(file_info):
trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]]
try:
trace_type = trace_info_info_map.get(trace_info_type)
if trace_type:
trace_info = trace_type(**trace_info)
from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled
if is_ee_telemetry_enabled():
from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace
try:
EnterpriseOtelTrace().trace(trace_info)
except Exception:
logger.exception("Enterprise trace failed for app_id: %s", app_id)
if trace_instance:
with current_app.app_context():
trace_type = trace_info_info_map.get(trace_info_type)
if trace_type:
trace_info = trace_type(**trace_info)
trace_instance.trace(trace_info)
logger.info("Processing trace tasks success, app_id: %s", app_id)
except Exception as e:
logger.info("error:\n\n\n%s\n\n\n\n", e)
logger.exception("Processing trace tasks failed, app_id: %s", app_id)
failed_key = f"{OPS_TRACE_FAILED_KEY}_{app_id}"
redis_client.incr(failed_key)
logger.info("Processing trace tasks failed, app_id: %s", app_id)
finally:
storage.delete(file_path)
try:
storage.delete(file_path)
except Exception as e:
logger.warning(
"Failed to delete trace file %s for app_id %s: %s",
file_path,
app_id,
e,
)

View File

@@ -0,0 +1,554 @@
"""Unit tests for lookup helper functions in core.ops.ops_trace_manager.
Covers:
- _lookup_app_and_workspace_names
- _lookup_credential_name
- _lookup_llm_credential_info
- TraceTask._get_user_id_from_metadata
"""
from unittest.mock import MagicMock, patch
import pytest
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_db_and_session_patches(scalar_side_effect=None, scalar_return_value=None):
"""Return (mock_db, cm, session) ready to patch 'core.ops.ops_trace_manager.db'
and 'core.ops.ops_trace_manager.Session'.
Provide either scalar_side_effect (list, for multiple calls) or
scalar_return_value (single value).
"""
mock_db = MagicMock()
mock_db.engine = MagicMock()
session = MagicMock()
if scalar_side_effect is not None:
session.scalar.side_effect = scalar_side_effect
else:
session.scalar.return_value = scalar_return_value
cm = MagicMock()
cm.__enter__ = MagicMock(return_value=session)
cm.__exit__ = MagicMock(return_value=False)
return mock_db, cm, session
# ---------------------------------------------------------------------------
# _lookup_app_and_workspace_names
# ---------------------------------------------------------------------------
class TestLookupAppAndWorkspaceNames:
"""Tests for _lookup_app_and_workspace_names(app_id, tenant_id)."""
def test_both_found(self):
"""Returns (app_name, workspace_name) when both records exist."""
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=["MyApp", "MyWorkspace"])
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456")
assert app_name == "MyApp"
assert workspace_name == "MyWorkspace"
def test_app_only_found(self):
"""Returns (app_name, '') when tenant record is absent."""
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=["MyApp", None])
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456")
assert app_name == "MyApp"
assert workspace_name == ""
def test_tenant_only_found(self):
"""Returns ('', workspace_name) when app record is absent."""
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[None, "MyWorkspace"])
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456")
assert app_name == ""
assert workspace_name == "MyWorkspace"
def test_neither_found(self):
"""Returns ('', '') when both DB lookups return None."""
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[None, None])
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
app_name, workspace_name = _lookup_app_and_workspace_names("app-123", "tenant-456")
assert app_name == ""
assert workspace_name == ""
def test_none_inputs_skips_db(self):
"""Returns ('', '') immediately when both IDs are None — no DB access."""
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
mock_db = MagicMock()
mock_session_cls = MagicMock()
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", mock_session_cls),
):
app_name, workspace_name = _lookup_app_and_workspace_names(None, None)
mock_session_cls.assert_not_called()
assert app_name == ""
assert workspace_name == ""
def test_app_id_none_only_queries_tenant(self):
"""When app_id is None, only the tenant query is issued."""
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="OnlyWorkspace")
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
app_name, workspace_name = _lookup_app_and_workspace_names(None, "tenant-456")
assert app_name == ""
assert workspace_name == "OnlyWorkspace"
assert session.scalar.call_count == 1
def test_tenant_id_none_only_queries_app(self):
"""When tenant_id is None, only the app query is issued."""
from core.ops.ops_trace_manager import _lookup_app_and_workspace_names
mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="OnlyApp")
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
app_name, workspace_name = _lookup_app_and_workspace_names("app-123", None)
assert app_name == "OnlyApp"
assert workspace_name == ""
assert session.scalar.call_count == 1
# ---------------------------------------------------------------------------
# _lookup_credential_name
# ---------------------------------------------------------------------------
class TestLookupCredentialName:
"""Tests for _lookup_credential_name(credential_id, provider_type)."""
@pytest.mark.parametrize("provider_type", ["builtin", "plugin", "api", "workflow", "mcp"])
def test_known_provider_types_return_name(self, provider_type):
"""Each valid provider_type results in a DB query and returns the credential name."""
from core.ops.ops_trace_manager import _lookup_credential_name
mock_db, cm, session = _make_db_and_session_patches(scalar_return_value="CredentialA")
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
result = _lookup_credential_name("cred-123", provider_type)
assert result == "CredentialA"
session.scalar.assert_called_once()
def test_credential_not_found_returns_empty_string(self):
"""Returns '' when DB yields None for the given credential_id."""
from core.ops.ops_trace_manager import _lookup_credential_name
mock_db, cm, _session = _make_db_and_session_patches(scalar_return_value=None)
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
result = _lookup_credential_name("cred-999", "api")
assert result == ""
def test_invalid_provider_type_returns_empty_string_without_db(self):
"""Returns '' immediately for an unrecognised provider_type — no DB access."""
from core.ops.ops_trace_manager import _lookup_credential_name
mock_db = MagicMock()
mock_session_cls = MagicMock()
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", mock_session_cls),
):
result = _lookup_credential_name("cred-123", "unknown_type")
mock_session_cls.assert_not_called()
assert result == ""
def test_none_credential_id_returns_empty_string_without_db(self):
"""Returns '' immediately when credential_id is None — no DB access."""
from core.ops.ops_trace_manager import _lookup_credential_name
mock_db = MagicMock()
mock_session_cls = MagicMock()
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", mock_session_cls),
):
result = _lookup_credential_name(None, "api")
mock_session_cls.assert_not_called()
assert result == ""
def test_none_provider_type_returns_empty_string_without_db(self):
"""Returns '' immediately when provider_type is None — no DB access."""
from core.ops.ops_trace_manager import _lookup_credential_name
mock_db = MagicMock()
mock_session_cls = MagicMock()
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", mock_session_cls),
):
result = _lookup_credential_name("cred-123", None)
mock_session_cls.assert_not_called()
assert result == ""
def test_builtin_and_plugin_map_to_same_model(self):
"""Both 'builtin' and 'plugin' provider_types query BuiltinToolProvider."""
from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL
from models.tools import BuiltinToolProvider
assert _PROVIDER_TYPE_TO_MODEL["builtin"] is BuiltinToolProvider
assert _PROVIDER_TYPE_TO_MODEL["plugin"] is BuiltinToolProvider
def test_api_maps_to_api_tool_provider(self):
"""'api' maps to ApiToolProvider."""
from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL
from models.tools import ApiToolProvider
assert _PROVIDER_TYPE_TO_MODEL["api"] is ApiToolProvider
def test_workflow_maps_to_workflow_tool_provider(self):
"""'workflow' maps to WorkflowToolProvider."""
from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL
from models.tools import WorkflowToolProvider
assert _PROVIDER_TYPE_TO_MODEL["workflow"] is WorkflowToolProvider
def test_mcp_maps_to_mcp_tool_provider(self):
"""'mcp' maps to MCPToolProvider."""
from core.ops.ops_trace_manager import _PROVIDER_TYPE_TO_MODEL
from models.tools import MCPToolProvider
assert _PROVIDER_TYPE_TO_MODEL["mcp"] is MCPToolProvider
# ---------------------------------------------------------------------------
# _lookup_llm_credential_info
# ---------------------------------------------------------------------------
class TestLookupLlmCredentialInfo:
"""Tests for _lookup_llm_credential_info(tenant_id, provider, model, model_type)."""
def _provider_record(self, credential_id: str | None = None) -> MagicMock:
record = MagicMock()
record.credential_id = credential_id
return record
def _model_record(self, credential_id: str | None = None) -> MagicMock:
record = MagicMock()
record.credential_id = credential_id
return record
def test_model_level_credential_found(self):
"""Returns model-level credential_id and name when ProviderModel has a credential."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
provider_record = self._provider_record(credential_id=None)
model_record = self._model_record(credential_id="model-cred-id")
# scalar calls: (1) Provider, (2) ProviderModel, (3) ProviderModelCredential.credential_name
mock_db, cm, _session = _make_db_and_session_patches(
scalar_side_effect=[provider_record, model_record, "ModelCredName"]
)
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
assert cred_id == "model-cred-id"
assert cred_name == "ModelCredName"
def test_provider_level_fallback_when_no_model_credential(self):
"""Falls back to provider-level credential when ProviderModel has no credential_id."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
provider_record = self._provider_record(credential_id="prov-cred-id")
model_record = self._model_record(credential_id=None)
# scalar calls: (1) Provider, (2) ProviderModel (no cred), (3) ProviderCredential.credential_name
mock_db, cm, _session = _make_db_and_session_patches(
scalar_side_effect=[provider_record, model_record, "ProvCredName"]
)
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
assert cred_id == "prov-cred-id"
assert cred_name == "ProvCredName"
def test_provider_level_fallback_when_no_model_record(self):
"""Falls back to provider-level credential when no ProviderModel row exists."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
provider_record = self._provider_record(credential_id="prov-cred-id")
# scalar calls: (1) Provider, (2) ProviderModel → None, (3) ProviderCredential.credential_name
mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[provider_record, None, "ProvCredName"])
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
assert cred_id == "prov-cred-id"
assert cred_name == "ProvCredName"
def test_no_model_arg_uses_provider_level_only(self):
"""When model is None, skips ProviderModel query and uses provider credential."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
provider_record = self._provider_record(credential_id="prov-cred-id")
# scalar calls: (1) Provider, (2) ProviderCredential.credential_name — no ProviderModel
mock_db, cm, session = _make_db_and_session_patches(scalar_side_effect=[provider_record, "ProvCredName"])
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", None)
assert cred_id == "prov-cred-id"
assert cred_name == "ProvCredName"
assert session.scalar.call_count == 2
def test_provider_not_found_returns_none_and_empty(self):
"""Returns (None, '') when Provider record does not exist."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
mock_db, cm, _session = _make_db_and_session_patches(scalar_return_value=None)
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
assert cred_id is None
assert cred_name == ""
def test_none_tenant_id_returns_none_and_empty_without_db(self):
"""Returns (None, '') immediately when tenant_id is None — no DB access."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
mock_db = MagicMock()
mock_session_cls = MagicMock()
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", mock_session_cls),
):
cred_id, cred_name = _lookup_llm_credential_info(None, "openai", "gpt-4")
mock_session_cls.assert_not_called()
assert cred_id is None
assert cred_name == ""
def test_none_provider_returns_none_and_empty_without_db(self):
"""Returns (None, '') immediately when provider is None — no DB access."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
mock_db = MagicMock()
mock_session_cls = MagicMock()
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", mock_session_cls),
):
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", None, "gpt-4")
mock_session_cls.assert_not_called()
assert cred_id is None
assert cred_name == ""
def test_db_error_on_outer_query_returns_none_and_empty(self):
"""Returns (None, '') and logs a warning when the outer DB query raises."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
mock_db, cm, session = _make_db_and_session_patches()
session.scalar.side_effect = Exception("DB connection failed")
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
assert cred_id is None
assert cred_name == ""
def test_credential_name_lookup_failure_returns_id_with_empty_name(self):
"""When credential name sub-query fails, returns cred_id but '' for name."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
provider_record = self._provider_record(credential_id="prov-cred-id")
# Provider found, no model record, then name lookup raises
mock_db, cm, _session = _make_db_and_session_patches(
scalar_side_effect=[provider_record, None, Exception("deleted")]
)
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
assert cred_id == "prov-cred-id"
assert cred_name == ""
def test_no_credential_on_provider_or_model_returns_none_id(self):
"""Returns (None, '') when neither provider nor model has a credential_id."""
from core.ops.ops_trace_manager import _lookup_llm_credential_info
provider_record = self._provider_record(credential_id=None)
model_record = self._model_record(credential_id=None)
mock_db, cm, _session = _make_db_and_session_patches(scalar_side_effect=[provider_record, model_record])
with (
patch("core.ops.ops_trace_manager.db", mock_db),
patch("core.ops.ops_trace_manager.Session", return_value=cm),
):
cred_id, cred_name = _lookup_llm_credential_info("tenant-1", "openai", "gpt-4")
assert cred_id is None
assert cred_name == ""
# ---------------------------------------------------------------------------
# TraceTask._get_user_id_from_metadata
# ---------------------------------------------------------------------------
class TestGetUserIdFromMetadata:
"""Tests for TraceTask._get_user_id_from_metadata(metadata).
Pure dict logic — no DB access required.
"""
@pytest.fixture
def get_user_id(self):
"""Return the classmethod under test."""
from core.ops.ops_trace_manager import TraceTask
return TraceTask._get_user_id_from_metadata
def test_from_end_user_id_has_highest_priority(self, get_user_id):
"""from_end_user_id takes precedence over all other keys."""
metadata = {
"from_end_user_id": "eu-abc",
"from_account_id": "acc-xyz",
"user_id": "u-123",
}
assert get_user_id(metadata) == "end_user:eu-abc"
def test_from_account_id_used_when_no_end_user(self, get_user_id):
"""from_account_id is used when from_end_user_id is absent."""
metadata = {
"from_account_id": "acc-xyz",
"user_id": "u-123",
}
assert get_user_id(metadata) == "account:acc-xyz"
def test_user_id_used_when_no_end_user_or_account(self, get_user_id):
"""user_id is used when both higher-priority keys are absent."""
metadata = {"user_id": "u-123"}
assert get_user_id(metadata) == "user:u-123"
def test_returns_anonymous_when_all_keys_absent(self, get_user_id):
"""Returns 'anonymous' when metadata has none of the expected keys."""
assert get_user_id({}) == "anonymous"
def test_empty_string_end_user_id_is_skipped(self, get_user_id):
"""Empty string for from_end_user_id is falsy and falls through to next key."""
metadata = {
"from_end_user_id": "",
"from_account_id": "acc-xyz",
}
assert get_user_id(metadata) == "account:acc-xyz"
def test_empty_string_account_id_is_skipped(self, get_user_id):
"""Empty string for from_account_id is falsy and falls through to user_id."""
metadata = {
"from_end_user_id": "",
"from_account_id": "",
"user_id": "u-123",
}
assert get_user_id(metadata) == "user:u-123"
def test_empty_string_user_id_falls_through_to_anonymous(self, get_user_id):
"""Empty string for user_id is falsy, so 'anonymous' is returned."""
metadata = {
"from_end_user_id": "",
"from_account_id": "",
"user_id": "",
}
assert get_user_id(metadata) == "anonymous"
def test_only_from_end_user_id_present(self, get_user_id):
"""Minimal case: only from_end_user_id present."""
assert get_user_id({"from_end_user_id": "eu-only"}) == "end_user:eu-only"
def test_irrelevant_keys_do_not_interfere(self, get_user_id):
"""Extra metadata keys have no effect on the result."""
metadata = {"invoke_from": "web", "app_id": "a1"}
assert get_user_id(metadata) == "anonymous"

View File

@@ -86,6 +86,7 @@ def make_message_data(**overrides):
created_at = datetime(2025, 2, 20, 12, 0, 0)
base = {
"id": "msg-id",
"app_id": "app-id",
"conversation_id": "conv-id",
"created_at": created_at,
"updated_at": created_at + timedelta(seconds=3),
@@ -182,6 +183,9 @@ class DummySessionContext:
def __exit__(self, exc_type, exc_val, exc_tb):
return False
def execute(self, *args, **kwargs):
return self
def scalar(self, *args, **kwargs):
if self._index >= len(self._values):
return None
@@ -189,6 +193,12 @@ class DummySessionContext:
self._index += 1
return value
def scalars(self, *args, **kwargs):
return self
def all(self):
return []
@pytest.fixture(autouse=True)
def patch_provider_map(monkeypatch):
@@ -454,7 +464,7 @@ def test_trace_task_message_trace(trace_task_message, mock_db):
def test_trace_task_workflow_trace(workflow_repo_fixture, mock_db):
DummySessionContext.scalar_values = ["wf-app-log", "message-ref"]
execution = SimpleNamespace(id_="run-id")
execution = SimpleNamespace(id_="run-id", total_tokens=0)
task = TraceTask(
trace_type=TraceTaskName.WORKFLOW_TRACE, workflow_execution=execution, conversation_id="conv", user_id="user"
)

View File

@@ -0,0 +1,194 @@
"""Unit tests for TraceQueueManager telemetry guard.
Verifies that TraceQueueManager.add_trace_task() only enqueues tasks when at
least one consumer is active:
- Enterprise telemetry is enabled (_enterprise_telemetry_enabled=True), OR
- A third-party trace instance (Langfuse, etc.) is configured
When neither is active, tasks are silently dropped to avoid unnecessary work.
When BOTH are false, tasks are silently dropped (correct behavior).
"""
import queue
import sys
import types
from unittest.mock import MagicMock, patch
import pytest
@pytest.fixture
def trace_queue_manager_and_task(monkeypatch):
"""Fixture to provide TraceQueueManager and TraceTask with delayed imports."""
module_name = "core.ops.ops_trace_manager"
if module_name not in sys.modules:
ops_stub = types.ModuleType(module_name)
class StubTraceTask:
def __init__(self, trace_type):
self.trace_type = trace_type
self.app_id = None
class StubTraceQueueManager:
def __init__(self, app_id=None):
self.app_id = app_id
from core.telemetry.gateway import is_enterprise_telemetry_enabled
self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled()
self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id)
def add_trace_task(self, trace_task):
if self._enterprise_telemetry_enabled or self.trace_instance:
trace_task.app_id = self.app_id
from core.ops.ops_trace_manager import trace_manager_queue
trace_manager_queue.put(trace_task)
class StubOpsTraceManager:
@staticmethod
def get_ops_trace_instance(app_id):
return None
ops_stub.TraceQueueManager = StubTraceQueueManager
ops_stub.TraceTask = StubTraceTask
ops_stub.OpsTraceManager = StubOpsTraceManager
ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue)
monkeypatch.setitem(sys.modules, module_name, ops_stub)
from core.ops.entities.trace_entity import TraceTaskName
ops_module = __import__(module_name, fromlist=["TraceQueueManager", "TraceTask"])
TraceQueueManager = ops_module.TraceQueueManager
TraceTask = ops_module.TraceTask
return TraceQueueManager, TraceTask, TraceTaskName
class TestTraceQueueManagerTelemetryGuard:
"""Test TraceQueueManager's telemetry guard in add_trace_task()."""
def test_task_not_enqueued_when_telemetry_disabled_and_no_trace_instance(self, trace_queue_manager_and_task):
"""Verify task is NOT enqueued when telemetry disabled and no trace instance.
This is the core guard: when _enterprise_telemetry_enabled=False AND
trace_instance=None, the task should be silently dropped.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False),
patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="test-app-id")
manager.add_trace_task(trace_task)
mock_queue.put.assert_not_called()
def test_task_enqueued_when_telemetry_enabled(self, trace_queue_manager_and_task):
"""Verify task IS enqueued when enterprise telemetry is enabled.
When _enterprise_telemetry_enabled=True, the task should be enqueued
regardless of trace_instance state.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True),
patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="test-app-id")
manager.add_trace_task(trace_task)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.app_id == "test-app-id"
def test_task_enqueued_when_trace_instance_configured(self, trace_queue_manager_and_task):
"""Verify task IS enqueued when third-party trace instance is configured.
When trace_instance is not None (e.g., Langfuse configured), the task
should be enqueued even if enterprise telemetry is disabled.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
mock_trace_instance = MagicMock()
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False),
patch(
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance
),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="test-app-id")
manager.add_trace_task(trace_task)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.app_id == "test-app-id"
def test_task_enqueued_when_both_telemetry_and_trace_instance_enabled(self, trace_queue_manager_and_task):
"""Verify task IS enqueued when both telemetry and trace instance are enabled.
When both _enterprise_telemetry_enabled=True AND trace_instance is set,
the task should definitely be enqueued.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
mock_trace_instance = MagicMock()
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True),
patch(
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance
),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="test-app-id")
manager.add_trace_task(trace_task)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.app_id == "test-app-id"
def test_app_id_set_before_enqueue(self, trace_queue_manager_and_task):
"""Verify app_id is set on the task before enqueuing.
The guard logic sets trace_task.app_id = self.app_id before calling
trace_manager_queue.put(trace_task). This test verifies that behavior.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True),
patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="expected-app-id")
manager.add_trace_task(trace_task)
called_task = mock_queue.put.call_args[0][0]
assert called_task.app_id == "expected-app-id"

View File

@@ -0,0 +1,181 @@
"""Unit tests for core.telemetry.emit() routing and enterprise-only filtering."""
from __future__ import annotations
import queue
import sys
import types
from unittest.mock import MagicMock, patch
import pytest
from core.ops.entities.trace_entity import TraceTaskName
from core.telemetry.events import TelemetryContext, TelemetryEvent
@pytest.fixture
def telemetry_test_setup(monkeypatch):
module_name = "core.ops.ops_trace_manager"
ops_stub = types.ModuleType(module_name)
class StubTraceTask:
def __init__(self, trace_type, **kwargs):
self.trace_type = trace_type
self.app_id = None
self.kwargs = kwargs
class StubTraceQueueManager:
def __init__(self, app_id=None, user_id=None):
self.app_id = app_id
self.user_id = user_id
self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id)
def add_trace_task(self, trace_task):
trace_task.app_id = self.app_id
from core.ops.ops_trace_manager import trace_manager_queue
trace_manager_queue.put(trace_task)
class StubOpsTraceManager:
@staticmethod
def get_ops_trace_instance(app_id):
return None
ops_stub.TraceQueueManager = StubTraceQueueManager
ops_stub.TraceTask = StubTraceTask
ops_stub.OpsTraceManager = StubOpsTraceManager
ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue)
monkeypatch.setitem(sys.modules, module_name, ops_stub)
from core.telemetry import emit
return emit, ops_stub.trace_manager_queue
class TestTelemetryEmit:
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_emit_enterprise_trace_creates_trace_task(self, mock_ee, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
event = TelemetryEvent(
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={"key": "value"},
)
emit_fn(event)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
def test_emit_community_trace_enqueued(self, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
event = TelemetryEvent(
name=TraceTaskName.WORKFLOW_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={},
)
emit_fn(event)
mock_queue.put.assert_called_once()
def test_emit_enterprise_only_trace_dropped_when_ee_disabled(self, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
event = TelemetryEvent(
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={},
)
emit_fn(event)
mock_queue.put.assert_not_called()
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_emit_all_enterprise_only_traces_allowed_when_ee_enabled(self, mock_ee, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
enterprise_only_traces = [
TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
TraceTaskName.NODE_EXECUTION_TRACE,
TraceTaskName.PROMPT_GENERATION_TRACE,
]
for trace_name in enterprise_only_traces:
mock_queue.reset_mock()
event = TelemetryEvent(
name=trace_name,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={},
)
emit_fn(event)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.trace_type == trace_name
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_emit_passes_name_directly_to_trace_task(self, mock_ee, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
event = TelemetryEvent(
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={"extra": "data"},
)
emit_fn(event)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
assert isinstance(called_task.trace_type, TraceTaskName)
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_emit_with_provided_trace_manager(self, mock_ee, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
mock_trace_manager = MagicMock()
mock_trace_manager.add_trace_task = MagicMock()
event = TelemetryEvent(
name=TraceTaskName.NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={},
)
emit_fn(event, trace_manager=mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
called_task = mock_trace_manager.add_trace_task.call_args[0][0]
assert called_task.trace_type == TraceTaskName.NODE_EXECUTION_TRACE

View File

@@ -0,0 +1,225 @@
from __future__ import annotations
import sys
from unittest.mock import MagicMock, patch
import pytest
from core.telemetry.gateway import emit, is_enterprise_telemetry_enabled
from enterprise.telemetry.contracts import TelemetryCase
class TestTelemetryCoreExports:
def test_is_enterprise_telemetry_enabled_exported(self) -> None:
from core.telemetry.gateway import is_enterprise_telemetry_enabled as exported_func
assert callable(exported_func)
@pytest.fixture
def mock_ops_trace_manager():
mock_module = MagicMock()
mock_trace_task_class = MagicMock()
mock_trace_task_class.return_value = MagicMock()
mock_module.TraceTask = mock_trace_task_class
mock_module.TraceQueueManager = MagicMock()
mock_trace_entity = MagicMock()
mock_trace_task_name = MagicMock()
mock_trace_task_name.return_value = "workflow"
mock_trace_entity.TraceTaskName = mock_trace_task_name
with (
patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}),
patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}),
):
yield mock_module, mock_trace_entity
class TestGatewayIntegrationTraceRouting:
@pytest.fixture
def mock_trace_manager(self) -> MagicMock:
return MagicMock()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_ce_eligible_trace_routed_to_trace_manager(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True):
context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"}
payload = {"workflow_run_id": "run-abc"}
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_ce_eligible_trace_routed_when_ee_disabled(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"workflow_run_id": "run-abc"}
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_enterprise_only_trace_dropped_when_ee_disabled(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_id": "node-abc"}
emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_not_called()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_enterprise_only_trace_routed_when_ee_enabled(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_id": "node-abc"}
emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
class TestGatewayIntegrationMetricRouting:
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_metric_case_routes_to_celery_task(
self,
mock_ee_enabled: MagicMock,
) -> None:
from enterprise.telemetry.contracts import TelemetryEnvelope
with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") as mock_delay:
context = {"tenant_id": "tenant-123"}
payload = {"app_id": "app-abc", "name": "My App"}
emit(TelemetryCase.APP_CREATED, context, payload)
mock_delay.assert_called_once()
envelope_json = mock_delay.call_args[0][0]
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
assert envelope.case == TelemetryCase.APP_CREATED
assert envelope.tenant_id == "tenant-123"
assert envelope.payload["app_id"] == "app-abc"
@pytest.mark.usefixtures("mock_ops_trace_manager")
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_tool_execution_trace_routed(
self,
mock_ee_enabled: MagicMock,
) -> None:
mock_trace_manager = MagicMock()
context = {"tenant_id": "tenant-123", "app_id": "app-123"}
payload = {"tool_name": "test_tool", "tool_inputs": {}, "tool_outputs": "result"}
emit(TelemetryCase.TOOL_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@pytest.mark.usefixtures("mock_ops_trace_manager")
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_moderation_check_trace_routed(
self,
mock_ee_enabled: MagicMock,
) -> None:
mock_trace_manager = MagicMock()
context = {"tenant_id": "tenant-123", "app_id": "app-123"}
payload = {"message_id": "msg-123", "moderation_result": {"flagged": False}}
emit(TelemetryCase.MODERATION_CHECK, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
class TestGatewayIntegrationCEEligibility:
@pytest.fixture
def mock_trace_manager(self) -> MagicMock:
return MagicMock()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_workflow_run_is_ce_eligible(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"workflow_run_id": "run-abc"}
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_message_run_is_ce_eligible(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"message_id": "msg-abc", "conversation_id": "conv-123"}
emit(TelemetryCase.MESSAGE_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_node_execution_not_ce_eligible(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_id": "node-abc"}
emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_not_called()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_draft_node_execution_not_ce_eligible(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_execution_data": {}}
emit(TelemetryCase.DRAFT_NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_not_called()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_prompt_generation_not_ce_eligible(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"}
payload = {"operation_type": "generate", "instruction": "test"}
emit(TelemetryCase.PROMPT_GENERATION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_not_called()
class TestIsEnterpriseTelemetryEnabled:
def test_returns_false_when_exporter_import_fails(self) -> None:
with patch.dict(sys.modules, {"enterprise.telemetry.exporter": None}):
result = is_enterprise_telemetry_enabled()
assert result is False
def test_function_is_callable(self) -> None:
assert callable(is_enterprise_telemetry_enabled)

View File

@@ -0,0 +1,230 @@
"""Unit tests for telemetry gateway contracts."""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from core.telemetry.gateway import CASE_ROUTING
from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase, TelemetryEnvelope
class TestTelemetryCase:
"""Tests for TelemetryCase enum."""
def test_all_cases_defined(self) -> None:
"""Verify all 14 telemetry cases are defined."""
expected_cases = {
"WORKFLOW_RUN",
"NODE_EXECUTION",
"DRAFT_NODE_EXECUTION",
"MESSAGE_RUN",
"TOOL_EXECUTION",
"MODERATION_CHECK",
"SUGGESTED_QUESTION",
"DATASET_RETRIEVAL",
"GENERATE_NAME",
"PROMPT_GENERATION",
"APP_CREATED",
"APP_UPDATED",
"APP_DELETED",
"FEEDBACK_CREATED",
}
actual_cases = {case.name for case in TelemetryCase}
assert actual_cases == expected_cases
def test_case_values(self) -> None:
"""Verify case enum values are correct."""
assert TelemetryCase.WORKFLOW_RUN.value == "workflow_run"
assert TelemetryCase.NODE_EXECUTION.value == "node_execution"
assert TelemetryCase.DRAFT_NODE_EXECUTION.value == "draft_node_execution"
assert TelemetryCase.MESSAGE_RUN.value == "message_run"
assert TelemetryCase.TOOL_EXECUTION.value == "tool_execution"
assert TelemetryCase.MODERATION_CHECK.value == "moderation_check"
assert TelemetryCase.SUGGESTED_QUESTION.value == "suggested_question"
assert TelemetryCase.DATASET_RETRIEVAL.value == "dataset_retrieval"
assert TelemetryCase.GENERATE_NAME.value == "generate_name"
assert TelemetryCase.PROMPT_GENERATION.value == "prompt_generation"
assert TelemetryCase.APP_CREATED.value == "app_created"
assert TelemetryCase.APP_UPDATED.value == "app_updated"
assert TelemetryCase.APP_DELETED.value == "app_deleted"
assert TelemetryCase.FEEDBACK_CREATED.value == "feedback_created"
class TestCaseRoute:
"""Tests for CaseRoute model."""
def test_valid_trace_route(self) -> None:
"""Verify valid trace route creation."""
route = CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True)
assert route.signal_type == SignalType.TRACE
assert route.ce_eligible is True
def test_valid_metric_log_route(self) -> None:
"""Verify valid metric_log route creation."""
route = CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False)
assert route.signal_type == SignalType.METRIC_LOG
assert route.ce_eligible is False
def test_invalid_signal_type(self) -> None:
"""Verify invalid signal_type is rejected."""
with pytest.raises(ValidationError):
CaseRoute(signal_type="invalid", ce_eligible=True)
class TestTelemetryEnvelope:
"""Tests for TelemetryEnvelope model."""
def test_valid_envelope_minimal(self) -> None:
"""Verify valid minimal envelope creation."""
envelope = TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
tenant_id="tenant-123",
event_id="event-456",
payload={"key": "value"},
)
assert envelope.case == TelemetryCase.WORKFLOW_RUN
assert envelope.tenant_id == "tenant-123"
assert envelope.event_id == "event-456"
assert envelope.payload == {"key": "value"}
assert envelope.metadata is None
def test_valid_envelope_full(self) -> None:
"""Verify valid envelope with all fields."""
metadata = {"payload_ref": "telemetry/tenant-789/event-012.json"}
envelope = TelemetryEnvelope(
case=TelemetryCase.MESSAGE_RUN,
tenant_id="tenant-789",
event_id="event-012",
payload={"message": "hello"},
metadata=metadata,
)
assert envelope.case == TelemetryCase.MESSAGE_RUN
assert envelope.tenant_id == "tenant-789"
assert envelope.event_id == "event-012"
assert envelope.payload == {"message": "hello"}
assert envelope.metadata == metadata
def test_missing_required_case(self) -> None:
"""Verify missing case field is rejected."""
with pytest.raises(ValidationError):
TelemetryEnvelope(
tenant_id="tenant-123",
event_id="event-456",
payload={"key": "value"},
)
def test_missing_required_tenant_id(self) -> None:
"""Verify missing tenant_id field is rejected."""
with pytest.raises(ValidationError):
TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
event_id="event-456",
payload={"key": "value"},
)
def test_missing_required_event_id(self) -> None:
"""Verify missing event_id field is rejected."""
with pytest.raises(ValidationError):
TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
tenant_id="tenant-123",
payload={"key": "value"},
)
def test_missing_required_payload(self) -> None:
"""Verify missing payload field is rejected."""
with pytest.raises(ValidationError):
TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
tenant_id="tenant-123",
event_id="event-456",
)
def test_metadata_none(self) -> None:
"""Verify metadata can be None."""
envelope = TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
tenant_id="tenant-123",
event_id="event-456",
payload={"key": "value"},
metadata=None,
)
assert envelope.metadata is None
class TestCaseRouting:
"""Tests for CASE_ROUTING table."""
def test_all_cases_routed(self) -> None:
"""Verify all 14 cases have routing entries."""
assert len(CASE_ROUTING) == 14
for case in TelemetryCase:
assert case in CASE_ROUTING
def test_trace_ce_eligible_cases(self) -> None:
"""Verify trace cases with CE eligibility."""
ce_eligible_trace_cases = {
TelemetryCase.WORKFLOW_RUN,
TelemetryCase.MESSAGE_RUN,
}
for case in ce_eligible_trace_cases:
route = CASE_ROUTING[case]
assert route.signal_type == SignalType.TRACE
assert route.ce_eligible is True
def test_trace_enterprise_only_cases(self) -> None:
"""Verify trace cases that are enterprise-only."""
enterprise_only_trace_cases = {
TelemetryCase.NODE_EXECUTION,
TelemetryCase.DRAFT_NODE_EXECUTION,
TelemetryCase.PROMPT_GENERATION,
}
for case in enterprise_only_trace_cases:
route = CASE_ROUTING[case]
assert route.signal_type == SignalType.TRACE
assert route.ce_eligible is False
def test_metric_log_cases(self) -> None:
"""Verify metric/log-only cases."""
metric_log_cases = {
TelemetryCase.APP_CREATED,
TelemetryCase.APP_UPDATED,
TelemetryCase.APP_DELETED,
TelemetryCase.FEEDBACK_CREATED,
}
for case in metric_log_cases:
route = CASE_ROUTING[case]
assert route.signal_type == SignalType.METRIC_LOG
assert route.ce_eligible is False
def test_routing_table_completeness(self) -> None:
"""Verify routing table covers all cases with correct types."""
trace_cases = {
TelemetryCase.WORKFLOW_RUN,
TelemetryCase.MESSAGE_RUN,
TelemetryCase.NODE_EXECUTION,
TelemetryCase.DRAFT_NODE_EXECUTION,
TelemetryCase.PROMPT_GENERATION,
TelemetryCase.TOOL_EXECUTION,
TelemetryCase.MODERATION_CHECK,
TelemetryCase.SUGGESTED_QUESTION,
TelemetryCase.DATASET_RETRIEVAL,
TelemetryCase.GENERATE_NAME,
}
metric_log_cases = {
TelemetryCase.APP_CREATED,
TelemetryCase.APP_UPDATED,
TelemetryCase.APP_DELETED,
TelemetryCase.FEEDBACK_CREATED,
}
all_cases = trace_cases | metric_log_cases
assert len(all_cases) == 14
assert all_cases == set(TelemetryCase)
for case in trace_cases:
assert CASE_ROUTING[case].signal_type == SignalType.TRACE
for case in metric_log_cases:
assert CASE_ROUTING[case].signal_type == SignalType.METRIC_LOG

View File

@@ -0,0 +1,519 @@
"""Unit tests for enterprise/telemetry/draft_trace.py."""
from __future__ import annotations
from datetime import UTC, datetime
from unittest.mock import MagicMock, patch
from dify_graph.enums import WorkflowNodeExecutionMetadataKey
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_execution(**overrides) -> MagicMock:
"""Return a minimal WorkflowNodeExecutionModel mock."""
execution = MagicMock()
execution.tenant_id = overrides.get("tenant_id", "tenant-1")
execution.app_id = overrides.get("app_id", "app-1")
execution.workflow_id = overrides.get("workflow_id", "wf-1")
execution.id = overrides.get("id", "exec-1")
execution.node_id = overrides.get("node_id", "node-1")
execution.node_type = overrides.get("node_type", "llm")
execution.title = overrides.get("title", "My LLM Node")
execution.status = overrides.get("status", "succeeded")
execution.error = overrides.get("error")
execution.elapsed_time = overrides.get("elapsed_time", 1.5)
execution.index = overrides.get("index", 1)
execution.predecessor_node_id = overrides.get("predecessor_node_id")
execution.created_at = overrides.get("created_at", datetime(2024, 1, 1, tzinfo=UTC))
execution.finished_at = overrides.get("finished_at", datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC))
execution.workflow_run_id = overrides.get("workflow_run_id", "run-1")
execution.inputs_dict = overrides.get("inputs_dict", {"prompt": "hello"})
execution.outputs_dict = overrides.get("outputs_dict", {"answer": "world"})
execution.process_data_dict = overrides.get("process_data_dict", {})
execution.execution_metadata_dict = overrides.get("execution_metadata_dict", {})
return execution
# ---------------------------------------------------------------------------
# _build_node_execution_data
# ---------------------------------------------------------------------------
class TestBuildNodeExecutionData:
def test_basic_fields_populated(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution()
result = _build_node_execution_data(
execution=execution,
outputs=None,
workflow_execution_id="run-override",
)
assert result["workflow_id"] == "wf-1"
assert result["tenant_id"] == "tenant-1"
assert result["app_id"] == "app-1"
assert result["node_execution_id"] == "exec-1"
assert result["node_id"] == "node-1"
assert result["node_type"] == "llm"
assert result["title"] == "My LLM Node"
assert result["status"] == "succeeded"
assert result["error"] is None
assert result["elapsed_time"] == 1.5
assert result["index"] == 1
def test_workflow_execution_id_prefers_parameter(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(workflow_run_id="run-from-model")
result = _build_node_execution_data(
execution=execution,
outputs=None,
workflow_execution_id="explicit-run",
)
assert result["workflow_execution_id"] == "explicit-run"
def test_workflow_execution_id_falls_back_to_run_id(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(workflow_run_id="run-from-model")
result = _build_node_execution_data(
execution=execution,
outputs=None,
workflow_execution_id=None,
)
assert result["workflow_execution_id"] == "run-from-model"
def test_workflow_execution_id_falls_back_to_execution_id(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(workflow_run_id=None, id="exec-fallback")
result = _build_node_execution_data(
execution=execution,
outputs=None,
workflow_execution_id=None,
)
assert result["workflow_execution_id"] == "exec-fallback"
def test_outputs_param_overrides_execution_outputs(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(outputs_dict={"from_model": True})
result = _build_node_execution_data(
execution=execution,
outputs={"from_param": True},
workflow_execution_id=None,
)
assert result["node_outputs"] == {"from_param": True}
def test_outputs_none_uses_execution_outputs_dict(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(outputs_dict={"from_model": True})
result = _build_node_execution_data(
execution=execution,
outputs=None,
workflow_execution_id=None,
)
assert result["node_outputs"] == {"from_model": True}
def test_metadata_token_fields_default_to_zero(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(execution_metadata_dict={})
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
assert result["total_tokens"] == 0
assert result["total_price"] == 0.0
assert result["currency"] is None
def test_metadata_token_fields_populated_from_metadata(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
metadata = {
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 200,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: 0.05,
WorkflowNodeExecutionMetadataKey.CURRENCY: "USD",
}
execution = _make_execution(execution_metadata_dict=metadata)
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
assert result["total_tokens"] == 200
assert result["total_price"] == 0.05
assert result["currency"] == "USD"
def test_tool_name_extracted_from_tool_info_dict(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
metadata = {
WorkflowNodeExecutionMetadataKey.TOOL_INFO: {"tool_name": "web_search"},
}
execution = _make_execution(execution_metadata_dict=metadata)
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
assert result["tool_name"] == "web_search"
def test_tool_name_is_none_when_tool_info_not_dict(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
metadata = {WorkflowNodeExecutionMetadataKey.TOOL_INFO: "not-a-dict"}
execution = _make_execution(execution_metadata_dict=metadata)
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
assert result["tool_name"] is None
def test_tool_name_is_none_when_tool_info_absent(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(execution_metadata_dict={})
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
assert result["tool_name"] is None
def test_iteration_and_loop_fields(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
metadata = {
WorkflowNodeExecutionMetadataKey.ITERATION_ID: "iter-1",
WorkflowNodeExecutionMetadataKey.ITERATION_INDEX: 3,
WorkflowNodeExecutionMetadataKey.LOOP_ID: "loop-1",
WorkflowNodeExecutionMetadataKey.LOOP_INDEX: 2,
WorkflowNodeExecutionMetadataKey.PARALLEL_ID: "par-1",
}
execution = _make_execution(execution_metadata_dict=metadata)
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
assert result["iteration_id"] == "iter-1"
assert result["iteration_index"] == 3
assert result["loop_id"] == "loop-1"
assert result["loop_index"] == 2
assert result["parallel_id"] == "par-1"
def test_node_inputs_and_process_data_included(self) -> None:
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(
inputs_dict={"q": "test"},
process_data_dict={"step": 1},
)
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
assert result["node_inputs"] == {"q": "test"}
assert result["process_data"] == {"step": 1}
# ---------------------------------------------------------------------------
# enqueue_draft_node_execution_trace
# ---------------------------------------------------------------------------
class TestEnqueueDraftNodeExecutionTrace:
@patch("enterprise.telemetry.draft_trace.telemetry_emit")
def test_emits_telemetry_event(self, mock_emit: MagicMock) -> None:
from core.telemetry import TelemetryEvent, TraceTaskName
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
execution = _make_execution()
enqueue_draft_node_execution_trace(
execution=execution,
outputs={"result": "ok"},
workflow_execution_id="run-x",
user_id="user-1",
)
mock_emit.assert_called_once()
event: TelemetryEvent = mock_emit.call_args[0][0]
assert event.name == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
assert event.context.tenant_id == "tenant-1"
assert event.context.user_id == "user-1"
assert event.context.app_id == "app-1"
@patch("enterprise.telemetry.draft_trace.telemetry_emit")
def test_payload_contains_node_execution_data(self, mock_emit: MagicMock) -> None:
from core.telemetry import TelemetryEvent
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
execution = _make_execution()
enqueue_draft_node_execution_trace(
execution=execution,
outputs=None,
workflow_execution_id=None,
user_id="user-2",
)
event: TelemetryEvent = mock_emit.call_args[0][0]
node_data = event.payload["node_execution_data"]
assert node_data["workflow_id"] == "wf-1"
assert node_data["node_type"] == "llm"
assert node_data["status"] == "succeeded"
@patch("enterprise.telemetry.draft_trace.telemetry_emit")
def test_outputs_forwarded_to_build(self, mock_emit: MagicMock) -> None:
from core.telemetry import TelemetryEvent
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
execution = _make_execution(outputs_dict={"default": True})
enqueue_draft_node_execution_trace(
execution=execution,
outputs={"explicit": True},
workflow_execution_id=None,
user_id="user-3",
)
event: TelemetryEvent = mock_emit.call_args[0][0]
assert event.payload["node_execution_data"]["node_outputs"] == {"explicit": True}
@patch("enterprise.telemetry.draft_trace.telemetry_emit")
def test_none_outputs_uses_execution_outputs(self, mock_emit: MagicMock) -> None:
from core.telemetry import TelemetryEvent
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
execution = _make_execution(outputs_dict={"from_model": "yes"})
enqueue_draft_node_execution_trace(
execution=execution,
outputs=None,
workflow_execution_id=None,
user_id="user-4",
)
event: TelemetryEvent = mock_emit.call_args[0][0]
assert event.payload["node_execution_data"]["node_outputs"] == {"from_model": "yes"}
# ---------------------------------------------------------------------------
# End-to-end token/model data flow: _build_node_execution_data →
# ops_trace_manager.draft_node_execution_trace → DraftNodeExecutionTrace
# ---------------------------------------------------------------------------
def _make_llm_execution() -> MagicMock:
"""Return a WorkflowNodeExecutionModel mock that mimics a real LLM node.
The field values match what dify_graph/nodes/llm/node.py produces:
- process_data_dict contains model_provider, model_name, and usage
- outputs_dict contains usage with prompt/completion breakdown
- execution_metadata_dict contains total_tokens/total_price/currency
"""
return _make_execution(
tenant_id="tenant-flow",
app_id="app-flow",
workflow_id="wf-flow",
id="exec-flow",
node_id="node-llm",
node_type="llm",
title="GPT-4o Node",
status="succeeded",
elapsed_time=2.3,
workflow_run_id=None,
process_data_dict={
"model_mode": "chat",
"model_provider": "openai",
"model_name": "gpt-4o",
"prompts": [{"role": "user", "text": "hello"}],
"usage": {
"prompt_tokens": 50,
"prompt_unit_price": 0.00001,
"prompt_price_unit": 0.001,
"prompt_price": 0.0005,
"completion_tokens": 30,
"completion_unit_price": 0.00003,
"completion_price_unit": 0.001,
"completion_price": 0.0009,
"total_tokens": 80,
"total_price": 0.0014,
"currency": "USD",
"latency": 2.3,
},
"finish_reason": "stop",
},
outputs_dict={
"text": "world",
"usage": {
"prompt_tokens": 50,
"completion_tokens": 30,
"total_tokens": 80,
"total_price": 0.0014,
"currency": "USD",
},
"finish_reason": "stop",
},
execution_metadata_dict={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: 80,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: 0.0014,
WorkflowNodeExecutionMetadataKey.CURRENCY: "USD",
},
)
class TestDraftTraceTokenDataFlow:
"""End-to-end test: verify all token and model fields survive from
_build_node_execution_data through ops_trace_manager.draft_node_execution_trace
to the DraftNodeExecutionTrace that enterprise_trace.py consumes.
"""
def test_all_token_and_model_fields_reach_trace_info(self) -> None:
"""Simulate the full draft trace data flow for an LLM node and
assert every token/model field that enterprise_trace._emit_node_execution_trace
reads is populated correctly on the resulting DraftNodeExecutionTrace."""
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_llm_execution()
node_data = _build_node_execution_data(
execution=execution,
outputs=None,
workflow_execution_id="run-flow",
)
# Simulate what ops_trace_manager.draft_node_execution_trace does:
# it calls node_execution_trace(node_execution_data=node_data) which
# reads top-level keys from node_data. Verify all expected keys exist.
expected_keys = {
# Token fields — read by enterprise_trace._emit_node_execution_trace
"total_tokens",
"total_price",
"currency",
"prompt_tokens",
"completion_tokens",
# Model fields — read for span attrs and metric labels
"model_provider",
"model_name",
# Node identity — read for span attrs
"node_type",
"node_execution_id",
"node_id",
"title",
"status",
"error",
"elapsed_time",
# Workflow context
"workflow_id",
"workflow_execution_id",
"tenant_id",
"app_id",
# Structure fields
"index",
"predecessor_node_id",
"iteration_id",
"iteration_index",
"loop_id",
"loop_index",
"parallel_id",
# Tool field
"tool_name",
# Content fields
"node_inputs",
"node_outputs",
"process_data",
# Timestamps
"created_at",
"finished_at",
}
assert set(node_data.keys()) == expected_keys
# Verify token/model values are correct (not None/zero when data exists)
assert node_data["total_tokens"] == 80
assert node_data["total_price"] == 0.0014
assert node_data["currency"] == "USD"
assert node_data["prompt_tokens"] == 50
assert node_data["completion_tokens"] == 30
assert node_data["model_provider"] == "openai"
assert node_data["model_name"] == "gpt-4o"
assert node_data["node_type"] == "llm"
def test_non_llm_node_has_none_for_model_and_token_breakdown(self) -> None:
"""For non-LLM nodes (e.g. code, IF), model and token breakdown
should be None, but total_tokens from metadata should still work."""
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(
node_type="code",
process_data_dict={"code": "print('hi')"},
outputs_dict={"result": "hi"},
execution_metadata_dict={},
)
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
assert result["model_provider"] is None
assert result["model_name"] is None
assert result["prompt_tokens"] is None
assert result["completion_tokens"] is None
assert result["total_tokens"] == 0
def test_none_process_data_and_none_outputs(self) -> None:
"""Both process_data_dict and outputs_dict are None — exercises
the `or {}` fallback and isinstance guard together."""
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_execution(process_data_dict=None, outputs_dict=None)
result = _build_node_execution_data(execution=execution, outputs=None, workflow_execution_id=None)
assert result["model_provider"] is None
assert result["model_name"] is None
assert result["prompt_tokens"] is None
assert result["completion_tokens"] is None
def test_node_data_feeds_into_draft_node_execution_trace(self) -> None:
"""Verify the node_data dict can be consumed by
ops_trace_manager.draft_node_execution_trace without error and
produces a DraftNodeExecutionTrace with correct token/model fields."""
from enterprise.telemetry.draft_trace import _build_node_execution_data
execution = _make_llm_execution()
node_data = _build_node_execution_data(
execution=execution,
outputs=None,
workflow_execution_id="run-e2e",
)
# Directly construct DraftNodeExecutionTrace the way
# ops_trace_manager.node_execution_trace does (lines 1315-1350),
# skipping DB lookups by providing minimal metadata.
from core.ops.entities.trace_entity import DraftNodeExecutionTrace
trace_info = DraftNodeExecutionTrace(
workflow_id=node_data.get("workflow_id", ""),
workflow_run_id=node_data.get("workflow_execution_id", ""),
tenant_id=node_data.get("tenant_id", ""),
node_execution_id=node_data.get("node_execution_id", ""),
node_id=node_data.get("node_id", ""),
node_type=node_data.get("node_type", ""),
title=node_data.get("title", ""),
status=node_data.get("status", ""),
error=node_data.get("error"),
elapsed_time=node_data.get("elapsed_time", 0.0),
index=node_data.get("index", 0),
predecessor_node_id=node_data.get("predecessor_node_id"),
total_tokens=node_data.get("total_tokens", 0),
total_price=node_data.get("total_price", 0.0),
currency=node_data.get("currency"),
model_provider=node_data.get("model_provider"),
model_name=node_data.get("model_name"),
prompt_tokens=node_data.get("prompt_tokens"),
completion_tokens=node_data.get("completion_tokens"),
tool_name=node_data.get("tool_name"),
iteration_id=node_data.get("iteration_id"),
iteration_index=node_data.get("iteration_index"),
loop_id=node_data.get("loop_id"),
loop_index=node_data.get("loop_index"),
parallel_id=node_data.get("parallel_id"),
node_inputs=node_data.get("node_inputs"),
node_outputs=node_data.get("node_outputs"),
process_data=node_data.get("process_data"),
start_time=node_data.get("created_at"),
end_time=node_data.get("finished_at"),
metadata={},
)
# These are the fields enterprise_trace._emit_node_execution_trace reads
assert trace_info.total_tokens == 80
assert trace_info.prompt_tokens == 50
assert trace_info.completion_tokens == 30
assert trace_info.model_provider == "openai"
assert trace_info.model_name == "gpt-4o"
assert trace_info.node_type == "llm"
assert trace_info.total_price == 0.0014
assert trace_info.currency == "USD"

File diff suppressed because it is too large Load Diff

View File

@@ -0,0 +1,54 @@
from unittest.mock import MagicMock, patch
import pytest
from enterprise.telemetry import event_handlers
from enterprise.telemetry.contracts import TelemetryCase
@pytest.fixture
def mock_gateway_emit():
with patch("core.telemetry.gateway.emit") as mock:
yield mock
def test_handle_app_created_calls_task(mock_gateway_emit):
sender = MagicMock()
sender.id = "app-123"
sender.tenant_id = "tenant-456"
sender.mode = "chat"
event_handlers._handle_app_created(sender)
mock_gateway_emit.assert_called_once_with(
case=TelemetryCase.APP_CREATED,
context={"tenant_id": "tenant-456"},
payload={"app_id": "app-123", "mode": "chat"},
)
def test_handle_app_created_no_exporter(mock_gateway_emit):
"""Gateway handles exporter availability internally; handler always calls gateway."""
sender = MagicMock()
sender.id = "app-123"
sender.tenant_id = "tenant-456"
event_handlers._handle_app_created(sender)
mock_gateway_emit.assert_called_once()
def test_handlers_create_valid_envelopes(mock_gateway_emit):
"""Verify handlers pass correct TelemetryCase and payload structure."""
sender = MagicMock()
sender.id = "app-123"
sender.tenant_id = "tenant-456"
sender.mode = "chat"
event_handlers._handle_app_created(sender)
call_kwargs = mock_gateway_emit.call_args[1]
assert call_kwargs["case"] == TelemetryCase.APP_CREATED
assert call_kwargs["context"]["tenant_id"] == "tenant-456"
assert call_kwargs["payload"]["app_id"] == "app-123"
assert call_kwargs["payload"]["mode"] == "chat"

View File

@@ -0,0 +1,628 @@
"""Unit tests for EnterpriseExporter and _ExporterFactory."""
from __future__ import annotations
from datetime import UTC, datetime
from types import SimpleNamespace
from unittest.mock import MagicMock, patch
from configs.enterprise import EnterpriseTelemetryConfig
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram
from enterprise.telemetry.exporter import EnterpriseExporter, _datetime_to_ns, _parse_otlp_headers
def test_config_api_key_default_empty():
"""Test that ENTERPRISE_OTLP_API_KEY defaults to empty string."""
config = EnterpriseTelemetryConfig()
assert config.ENTERPRISE_OTLP_API_KEY == ""
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_api_key_only_injects_bearer_header(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""Test that API key alone injects Bearer authorization header."""
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com",
ENTERPRISE_OTLP_HEADERS="",
ENTERPRISE_OTLP_PROTOCOL="grpc",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY="test-secret-key",
)
EnterpriseExporter(mock_config)
# Verify span exporter was called with Bearer header
assert mock_span_exporter.call_args is not None
headers = mock_span_exporter.call_args.kwargs.get("headers")
assert headers is not None
assert ("authorization", "Bearer test-secret-key") in headers
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_empty_api_key_no_auth_header(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""Test that empty API key does not inject authorization header."""
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com",
ENTERPRISE_OTLP_HEADERS="",
ENTERPRISE_OTLP_PROTOCOL="grpc",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY="",
)
EnterpriseExporter(mock_config)
# Verify span exporter was called without authorization header
assert mock_span_exporter.call_args is not None
headers = mock_span_exporter.call_args.kwargs.get("headers")
# Headers should be None or not contain authorization
if headers is not None:
assert not any(key == "authorization" for key, _ in headers)
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_api_key_and_custom_headers_merge(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""Test that API key and custom headers are merged correctly."""
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com",
ENTERPRISE_OTLP_HEADERS="x-custom=foo",
ENTERPRISE_OTLP_PROTOCOL="grpc",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY="test-key",
)
EnterpriseExporter(mock_config)
# Verify both headers are present
assert mock_span_exporter.call_args is not None
headers = mock_span_exporter.call_args.kwargs.get("headers")
assert headers is not None
assert ("authorization", "Bearer test-key") in headers
assert ("x-custom", "foo") in headers
@patch("enterprise.telemetry.exporter.logger")
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_api_key_overrides_conflicting_header(
mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock, mock_logger: MagicMock
) -> None:
"""Test that API key overrides conflicting authorization header and logs warning."""
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com",
ENTERPRISE_OTLP_HEADERS="authorization=Basic+old",
ENTERPRISE_OTLP_PROTOCOL="grpc",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY="test-key",
)
EnterpriseExporter(mock_config)
# Verify Bearer header takes precedence
assert mock_span_exporter.call_args is not None
headers = mock_span_exporter.call_args.kwargs.get("headers")
assert headers is not None
assert ("authorization", "Bearer test-key") in headers
# Verify old authorization header is not present
assert ("authorization", "Basic old") not in headers
# Verify warning was logged
mock_logger.warning.assert_called_once()
assert mock_logger.warning.call_args is not None
warning_message = mock_logger.warning.call_args[0][0]
assert "ENTERPRISE_OTLP_API_KEY is set" in warning_message
assert "authorization" in warning_message
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_https_endpoint_uses_secure_grpc(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""Test that https:// endpoint enables TLS (insecure=False) for gRPC."""
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com",
ENTERPRISE_OTLP_HEADERS="",
ENTERPRISE_OTLP_PROTOCOL="grpc",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY="test-key",
)
EnterpriseExporter(mock_config)
# Verify insecure=False for both exporters (https:// scheme)
assert mock_span_exporter.call_args is not None
assert mock_span_exporter.call_args.kwargs["insecure"] is False
assert mock_metric_exporter.call_args is not None
assert mock_metric_exporter.call_args.kwargs["insecure"] is False
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_http_endpoint_uses_insecure_grpc(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""Test that http:// endpoint uses insecure gRPC (insecure=True)."""
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="http://collector.example.com",
ENTERPRISE_OTLP_HEADERS="",
ENTERPRISE_OTLP_PROTOCOL="grpc",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY="",
)
EnterpriseExporter(mock_config)
# Verify insecure=True for both exporters (http:// scheme)
assert mock_span_exporter.call_args is not None
assert mock_span_exporter.call_args.kwargs["insecure"] is True
assert mock_metric_exporter.call_args is not None
assert mock_metric_exporter.call_args.kwargs["insecure"] is True
@patch("enterprise.telemetry.exporter.HTTPSpanExporter")
@patch("enterprise.telemetry.exporter.HTTPMetricExporter")
def test_insecure_not_passed_to_http_exporters(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""Test that insecure parameter is not passed to HTTP exporters."""
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="http://collector.example.com",
ENTERPRISE_OTLP_HEADERS="",
ENTERPRISE_OTLP_PROTOCOL="http",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY="test-key",
)
EnterpriseExporter(mock_config)
# Verify insecure kwarg is NOT in HTTP exporter calls
assert mock_span_exporter.call_args is not None
assert "insecure" not in mock_span_exporter.call_args.kwargs
assert mock_metric_exporter.call_args is not None
assert "insecure" not in mock_metric_exporter.call_args.kwargs
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_api_key_with_special_chars_preserved(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""Test that API key with special characters is preserved without mangling."""
special_key = "abc+def/ghi=jkl=="
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="https://collector.example.com",
ENTERPRISE_OTLP_HEADERS="",
ENTERPRISE_OTLP_PROTOCOL="grpc",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY=special_key,
)
EnterpriseExporter(mock_config)
# Verify special characters are preserved in Bearer header
assert mock_span_exporter.call_args is not None
headers = mock_span_exporter.call_args.kwargs.get("headers")
assert headers is not None
assert ("authorization", f"Bearer {special_key}") in headers
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_no_scheme_localhost_uses_insecure(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""Test that endpoint without scheme defaults to insecure for localhost."""
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="localhost:4317",
ENTERPRISE_OTLP_HEADERS="",
ENTERPRISE_OTLP_PROTOCOL="grpc",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY="",
)
EnterpriseExporter(mock_config)
# Verify insecure=True for localhost without scheme
assert mock_span_exporter.call_args is not None
assert mock_span_exporter.call_args.kwargs["insecure"] is True
assert mock_metric_exporter.call_args is not None
assert mock_metric_exporter.call_args.kwargs["insecure"] is True
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_no_scheme_production_uses_insecure(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""Test that endpoint without scheme defaults to insecure (not https://)."""
mock_config = SimpleNamespace(
ENTERPRISE_OTLP_ENDPOINT="collector.example.com:4317",
ENTERPRISE_OTLP_HEADERS="",
ENTERPRISE_OTLP_PROTOCOL="grpc",
ENTERPRISE_SERVICE_NAME="dify",
ENTERPRISE_OTEL_SAMPLING_RATE=1.0,
ENTERPRISE_INCLUDE_CONTENT=True,
ENTERPRISE_OTLP_API_KEY="",
)
EnterpriseExporter(mock_config)
# Verify insecure=True for any endpoint without https:// scheme
assert mock_span_exporter.call_args is not None
assert mock_span_exporter.call_args.kwargs["insecure"] is True
assert mock_metric_exporter.call_args is not None
assert mock_metric_exporter.call_args.kwargs["insecure"] is True
# ---------------------------------------------------------------------------
# _parse_otlp_headers (line 55 — pair without "=" is skipped)
# ---------------------------------------------------------------------------
def test_parse_otlp_headers_empty_returns_empty_dict() -> None:
assert _parse_otlp_headers("") == {}
def test_parse_otlp_headers_value_may_contain_equals() -> None:
result = _parse_otlp_headers("token=abc=def==")
assert result == {"token": "abc=def=="}
def test_parse_otlp_headers_url_encoded() -> None:
result = _parse_otlp_headers("key=%E4%BD%A0%E5%A5%BD")
assert result == {"key": "你好"}
# ---------------------------------------------------------------------------
# _datetime_to_ns (lines 64-68)
# ---------------------------------------------------------------------------
def test_datetime_to_ns_naive_treated_as_utc() -> None:
"""Naive datetime must be interpreted as UTC (line 64-65)."""
naive = datetime(2024, 1, 1, 0, 0, 0) # no tzinfo
aware_utc = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
assert _datetime_to_ns(naive) == _datetime_to_ns(aware_utc)
def test_datetime_to_ns_tz_aware_converted_to_utc() -> None:
"""Timezone-aware datetime must be converted to UTC before computing ns (line 66-67)."""
import zoneinfo
eastern = zoneinfo.ZoneInfo("America/New_York")
dt_east = datetime(2024, 6, 1, 12, 0, 0, tzinfo=eastern) # UTC-4 in summer
dt_utc = dt_east.astimezone(UTC)
assert _datetime_to_ns(dt_east) == _datetime_to_ns(dt_utc)
def test_datetime_to_ns_returns_integer_nanoseconds() -> None:
dt = datetime(2024, 1, 1, 0, 0, 1, tzinfo=UTC)
result = _datetime_to_ns(dt)
# 2024-01-01 00:00:01 UTC = epoch + some_seconds; result should be in nanoseconds
assert isinstance(result, int)
# 1 second past epoch start of 2024 — should be > 1_700_000_000_000_000_000 (rough lower bound)
assert result > 1_700_000_000_000_000_000
# ---------------------------------------------------------------------------
# EnterpriseExporter constructor — include_content property (line 115 / 288-289)
# ---------------------------------------------------------------------------
def _make_grpc_config(**overrides) -> SimpleNamespace:
defaults = {
"ENTERPRISE_OTLP_ENDPOINT": "https://collector.example.com",
"ENTERPRISE_OTLP_HEADERS": "",
"ENTERPRISE_OTLP_PROTOCOL": "grpc",
"ENTERPRISE_SERVICE_NAME": "dify",
"ENTERPRISE_OTEL_SAMPLING_RATE": 1.0,
"ENTERPRISE_INCLUDE_CONTENT": True,
"ENTERPRISE_OTLP_API_KEY": "",
}
defaults.update(overrides)
return SimpleNamespace(**defaults)
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_include_content_true_stored_on_exporter(
mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock
) -> None:
"""include_content=True is stored as a public attribute (line 115)."""
exporter = EnterpriseExporter(_make_grpc_config(ENTERPRISE_INCLUDE_CONTENT=True))
assert exporter.include_content is True
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_include_content_false_stored_on_exporter(
mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock
) -> None:
"""include_content=False is preserved (lines 288-289 path exercised by callers)."""
exporter = EnterpriseExporter(_make_grpc_config(ENTERPRISE_INCLUDE_CONTENT=False))
assert exporter.include_content is False
# ---------------------------------------------------------------------------
# EnterpriseExporter constructor — gRPC setup (lines 64-68 exporter-init path)
# ---------------------------------------------------------------------------
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_grpc_exporter_created_with_correct_endpoint(
mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock
) -> None:
"""GRPCSpanExporter and GRPCMetricExporter receive the configured endpoint."""
EnterpriseExporter(_make_grpc_config(ENTERPRISE_OTLP_ENDPOINT="https://my-collector:4317"))
assert mock_span_exporter.call_args.kwargs["endpoint"] == "https://my-collector:4317"
assert mock_metric_exporter.call_args.kwargs["endpoint"] == "https://my-collector:4317"
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_grpc_exporter_empty_endpoint_passes_none(
mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock
) -> None:
"""Empty string endpoint is normalised to None for both gRPC exporters."""
EnterpriseExporter(_make_grpc_config(ENTERPRISE_OTLP_ENDPOINT=""))
assert mock_span_exporter.call_args.kwargs["endpoint"] is None
assert mock_metric_exporter.call_args.kwargs["endpoint"] is None
# ---------------------------------------------------------------------------
# EnterpriseExporter.export_span (lines 204-271)
# ---------------------------------------------------------------------------
def _make_exporter_with_mock_tracer() -> tuple[EnterpriseExporter, MagicMock, MagicMock]:
"""Return (exporter, mock_tracer, mock_span) with OTEL internals fully mocked."""
mock_span = MagicMock()
mock_span.__enter__ = MagicMock(return_value=mock_span)
mock_span.__exit__ = MagicMock(return_value=False)
mock_tracer = MagicMock()
mock_tracer.start_as_current_span.return_value = mock_span
with (
patch("enterprise.telemetry.exporter.GRPCSpanExporter"),
patch("enterprise.telemetry.exporter.GRPCMetricExporter"),
):
exporter = EnterpriseExporter(_make_grpc_config())
exporter._tracer = mock_tracer
return exporter, mock_tracer, mock_span
@patch("enterprise.telemetry.exporter.set_correlation_id")
@patch("enterprise.telemetry.exporter.set_span_id_source")
def test_export_span_sets_and_clears_context(mock_set_span: MagicMock, mock_set_corr: MagicMock) -> None:
"""export_span sets correlation/span context before the span and clears them in finally."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
exporter.export_span(
name="test.span",
attributes={"k": "v"},
correlation_id="corr-1",
span_id_source="span-src-1",
)
# Context was set at the start of the call
mock_set_corr.assert_any_call("corr-1")
mock_set_span.assert_any_call("span-src-1")
# Context was cleared in finally
mock_set_corr.assert_called_with(None)
mock_set_span.assert_called_with(None)
def test_export_span_sets_attributes_on_span() -> None:
"""All non-None attribute values are set on the span via set_attribute."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
exporter.export_span(
name="test.span",
attributes={"key1": "value1", "key2": None, "key3": 42},
)
# set_attribute should be called for non-None values only
calls = list(mock_span.set_attribute.call_args_list)
keys_set = {c[0][0] for c in calls}
assert "key1" in keys_set
assert "key3" in keys_set
assert "key2" not in keys_set
def test_export_span_no_end_time_uses_end_on_exit() -> None:
"""When end_time is None, end_on_exit=True is passed to start_as_current_span."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
exporter.export_span(name="test.span", attributes={})
_, kwargs = mock_tracer.start_as_current_span.call_args
assert kwargs["end_on_exit"] is True
def test_export_span_with_end_time_calls_span_end() -> None:
"""When end_time is provided, span.end() is called with the converted ns timestamp."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
start = datetime(2024, 1, 1, 0, 0, 0, tzinfo=UTC)
end = datetime(2024, 1, 1, 0, 0, 5, tzinfo=UTC)
exporter.export_span(name="test.span", attributes={}, start_time=start, end_time=end)
mock_span.end.assert_called_once()
end_ns = mock_span.end.call_args.kwargs["end_time"]
assert end_ns == _datetime_to_ns(end)
def test_export_span_with_start_time_passed_to_start_as_current_span() -> None:
"""When start_time is provided it is converted to ns and passed to start_as_current_span."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
start = datetime(2024, 3, 1, 12, 0, 0, tzinfo=UTC)
exporter.export_span(name="test.span", attributes={}, start_time=start)
_, kwargs = mock_tracer.start_as_current_span.call_args
assert kwargs["start_time"] == _datetime_to_ns(start)
def test_export_span_root_span_no_parent_context() -> None:
"""When span_id_source == correlation_id the span is root — no parent context."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
uid = "123e4567-e89b-12d3-a456-426614174000"
exporter.export_span(
name="root.span",
attributes={},
correlation_id=uid,
span_id_source=uid,
)
_, kwargs = mock_tracer.start_as_current_span.call_args
assert kwargs["context"] is None
def test_export_span_child_span_has_parent_context() -> None:
"""When correlation_id != span_id_source the child span gets a parent context."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
corr_uid = "123e4567-e89b-12d3-a456-426614174000"
node_uid = "987fbc97-4bed-5078-9f07-9141ba07c9f3"
exporter.export_span(
name="child.span",
attributes={},
correlation_id=corr_uid,
span_id_source=node_uid,
)
_, kwargs = mock_tracer.start_as_current_span.call_args
assert kwargs["context"] is not None
def test_export_span_cross_workflow_parent_context() -> None:
"""When parent_span_id_source is set, the cross-workflow parent context is built."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
corr_uid = "123e4567-e89b-12d3-a456-426614174000"
parent_uid = "987fbc97-4bed-5078-9f07-9141ba07c9f3"
exporter.export_span(
name="cross.span",
attributes={},
correlation_id=corr_uid,
parent_span_id_source=parent_uid,
)
_, kwargs = mock_tracer.start_as_current_span.call_args
assert kwargs["context"] is not None
@patch("enterprise.telemetry.exporter.logger")
def test_export_span_logs_exception_on_error(mock_logger: MagicMock) -> None:
"""If the span block raises, the exception is logged and context is still cleared."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
mock_tracer.start_as_current_span.side_effect = RuntimeError("boom")
exporter.export_span(name="bad.span", attributes={}) # must not raise
mock_logger.exception.assert_called_once()
assert "bad.span" in mock_logger.exception.call_args[0][1]
@patch("enterprise.telemetry.exporter.logger")
def test_export_span_invalid_trace_correlation_logs_warning(mock_logger: MagicMock) -> None:
"""Invalid UUID for trace_correlation_override triggers a warning log."""
exporter, mock_tracer, mock_span = _make_exporter_with_mock_tracer()
parent_uid = "987fbc97-4bed-5078-9f07-9141ba07c9f3"
exporter.export_span(
name="link.span",
attributes={},
correlation_id="not-a-valid-uuid",
parent_span_id_source=parent_uid,
)
mock_logger.warning.assert_called()
# ---------------------------------------------------------------------------
# EnterpriseExporter.increment_counter (lines 276-278)
# ---------------------------------------------------------------------------
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_increment_counter_calls_add_on_counter(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""increment_counter calls .add() on the matching counter instrument."""
exporter = EnterpriseExporter(_make_grpc_config())
mock_counter = MagicMock()
exporter._counters[EnterpriseTelemetryCounter.TOKENS] = mock_counter
labels = {"tenant_id": "t1", "app_id": "app-1"}
exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, 50, labels)
mock_counter.add.assert_called_once_with(50, labels)
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_increment_counter_unknown_name_is_noop(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""increment_counter silently does nothing when the counter is not found."""
exporter = EnterpriseExporter(_make_grpc_config())
exporter._counters.clear()
# Should not raise
exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, 5, {})
# ---------------------------------------------------------------------------
# EnterpriseExporter.record_histogram (lines 283-285)
# ---------------------------------------------------------------------------
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_record_histogram_calls_record_on_histogram(
mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock
) -> None:
"""record_histogram calls .record() on the matching histogram instrument."""
exporter = EnterpriseExporter(_make_grpc_config())
mock_histogram = MagicMock()
exporter._histograms[EnterpriseTelemetryHistogram.WORKFLOW_DURATION] = mock_histogram
labels = {"tenant_id": "t1"}
exporter.record_histogram(EnterpriseTelemetryHistogram.WORKFLOW_DURATION, 3.14, labels)
mock_histogram.record.assert_called_once_with(3.14, labels)
@patch("enterprise.telemetry.exporter.GRPCSpanExporter")
@patch("enterprise.telemetry.exporter.GRPCMetricExporter")
def test_record_histogram_unknown_name_is_noop(mock_metric_exporter: MagicMock, mock_span_exporter: MagicMock) -> None:
"""record_histogram silently does nothing when the histogram is not found."""
exporter = EnterpriseExporter(_make_grpc_config())
exporter._histograms.clear()
# Should not raise
exporter.record_histogram(EnterpriseTelemetryHistogram.WORKFLOW_DURATION, 1.0, {})

View File

@@ -0,0 +1,272 @@
from __future__ import annotations
import sys
from unittest.mock import MagicMock, patch
import pytest
from core.ops.entities.trace_entity import TraceTaskName
from core.telemetry.gateway import (
CASE_ROUTING,
CASE_TO_TRACE_TASK,
PAYLOAD_SIZE_THRESHOLD_BYTES,
emit,
)
from enterprise.telemetry.contracts import SignalType, TelemetryCase, TelemetryEnvelope
class TestCaseRoutingTable:
def test_all_cases_have_routing(self) -> None:
for case in TelemetryCase:
assert case in CASE_ROUTING, f"Missing routing for {case}"
def test_trace_cases(self) -> None:
trace_cases = [
TelemetryCase.WORKFLOW_RUN,
TelemetryCase.MESSAGE_RUN,
TelemetryCase.NODE_EXECUTION,
TelemetryCase.DRAFT_NODE_EXECUTION,
TelemetryCase.PROMPT_GENERATION,
]
for case in trace_cases:
assert CASE_ROUTING[case].signal_type is SignalType.TRACE, f"{case} should be trace"
def test_metric_log_cases(self) -> None:
metric_log_cases = [
TelemetryCase.APP_CREATED,
TelemetryCase.APP_UPDATED,
TelemetryCase.APP_DELETED,
TelemetryCase.FEEDBACK_CREATED,
]
for case in metric_log_cases:
assert CASE_ROUTING[case].signal_type is SignalType.METRIC_LOG, f"{case} should be metric_log"
def test_ce_eligible_cases(self) -> None:
ce_eligible_cases = [
TelemetryCase.WORKFLOW_RUN,
TelemetryCase.MESSAGE_RUN,
TelemetryCase.TOOL_EXECUTION,
TelemetryCase.MODERATION_CHECK,
TelemetryCase.SUGGESTED_QUESTION,
TelemetryCase.DATASET_RETRIEVAL,
TelemetryCase.GENERATE_NAME,
]
for case in ce_eligible_cases:
assert CASE_ROUTING[case].ce_eligible is True, f"{case} should be CE eligible"
def test_enterprise_only_cases(self) -> None:
enterprise_only_cases = [
TelemetryCase.NODE_EXECUTION,
TelemetryCase.DRAFT_NODE_EXECUTION,
TelemetryCase.PROMPT_GENERATION,
]
for case in enterprise_only_cases:
assert CASE_ROUTING[case].ce_eligible is False, f"{case} should be enterprise-only"
def test_trace_cases_have_task_name_mapping(self) -> None:
trace_cases = [c for c in TelemetryCase if CASE_ROUTING[c].signal_type is SignalType.TRACE]
for case in trace_cases:
assert case in CASE_TO_TRACE_TASK, f"Missing TraceTaskName mapping for {case}"
@pytest.fixture
def mock_ops_trace_manager():
mock_module = MagicMock()
mock_trace_task_class = MagicMock()
mock_trace_task_class.return_value = MagicMock()
mock_module.TraceTask = mock_trace_task_class
mock_module.TraceQueueManager = MagicMock()
mock_trace_entity = MagicMock()
mock_trace_task_name = MagicMock()
mock_trace_task_name.return_value = "workflow"
mock_trace_entity.TraceTaskName = mock_trace_task_name
with (
patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}),
patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}),
):
yield mock_module, mock_trace_entity
class TestGatewayTraceRouting:
@pytest.fixture
def mock_trace_manager(self) -> MagicMock:
return MagicMock()
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_trace_case_routes_to_trace_manager(
self,
mock_ee_enabled: MagicMock,
mock_trace_manager: MagicMock,
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
) -> None:
context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"}
payload = {"workflow_run_id": "run-abc"}
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False)
def test_ce_eligible_trace_enqueued_when_ee_disabled(
self,
mock_ee_enabled: MagicMock,
mock_trace_manager: MagicMock,
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
) -> None:
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"workflow_run_id": "run-abc"}
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False)
def test_enterprise_only_trace_dropped_when_ee_disabled(
self,
mock_ee_enabled: MagicMock,
mock_trace_manager: MagicMock,
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
) -> None:
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_id": "node-abc"}
emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_not_called()
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_enterprise_only_trace_enqueued_when_ee_enabled(
self,
mock_ee_enabled: MagicMock,
mock_trace_manager: MagicMock,
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
) -> None:
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_id": "node-abc"}
emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
class TestGatewayMetricLogRouting:
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
def test_metric_case_routes_to_celery_task(
self,
mock_delay: MagicMock,
mock_ee_enabled: MagicMock,
) -> None:
context = {"tenant_id": "tenant-123"}
payload = {"app_id": "app-abc", "name": "My App"}
emit(TelemetryCase.APP_CREATED, context, payload)
mock_delay.assert_called_once()
envelope_json = mock_delay.call_args[0][0]
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
assert envelope.case == TelemetryCase.APP_CREATED
assert envelope.tenant_id == "tenant-123"
assert envelope.payload["app_id"] == "app-abc"
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
def test_envelope_has_unique_event_id(
self,
mock_delay: MagicMock,
mock_ee_enabled: MagicMock,
) -> None:
context = {"tenant_id": "tenant-123"}
payload = {"app_id": "app-abc"}
emit(TelemetryCase.APP_CREATED, context, payload)
emit(TelemetryCase.APP_CREATED, context, payload)
assert mock_delay.call_count == 2
envelope1 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[0][0][0])
envelope2 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[1][0][0])
assert envelope1.event_id != envelope2.event_id
class TestGatewayPayloadSizing:
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
def test_small_payload_inlined(
self,
mock_delay: MagicMock,
mock_ee_enabled: MagicMock,
) -> None:
context = {"tenant_id": "tenant-123"}
payload = {"key": "small_value"}
emit(TelemetryCase.APP_CREATED, context, payload)
envelope_json = mock_delay.call_args[0][0]
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
assert envelope.payload == payload
assert envelope.metadata is None
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
@patch("core.telemetry.gateway.storage")
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
def test_large_payload_stored(
self,
mock_delay: MagicMock,
mock_storage: MagicMock,
mock_ee_enabled: MagicMock,
) -> None:
context = {"tenant_id": "tenant-123"}
large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000)
payload = {"key": large_value}
emit(TelemetryCase.APP_CREATED, context, payload)
mock_storage.save.assert_called_once()
storage_key = mock_storage.save.call_args[0][0]
assert storage_key.startswith("telemetry/tenant-123/")
envelope_json = mock_delay.call_args[0][0]
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
assert envelope.payload == {}
assert envelope.metadata is not None
assert envelope.metadata["payload_ref"] == storage_key
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
@patch("core.telemetry.gateway.storage")
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
def test_large_payload_fallback_on_storage_error(
self,
mock_delay: MagicMock,
mock_storage: MagicMock,
mock_ee_enabled: MagicMock,
) -> None:
mock_storage.save.side_effect = Exception("Storage failure")
context = {"tenant_id": "tenant-123"}
large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000)
payload = {"key": large_value}
emit(TelemetryCase.APP_CREATED, context, payload)
envelope_json = mock_delay.call_args[0][0]
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
assert envelope.payload == payload
assert envelope.metadata is None
class TestTraceTaskNameMapping:
def test_workflow_run_mapping(self) -> None:
assert CASE_TO_TRACE_TASK[TelemetryCase.WORKFLOW_RUN] is TraceTaskName.WORKFLOW_TRACE
def test_message_run_mapping(self) -> None:
assert CASE_TO_TRACE_TASK[TelemetryCase.MESSAGE_RUN] is TraceTaskName.MESSAGE_TRACE
def test_node_execution_mapping(self) -> None:
assert CASE_TO_TRACE_TASK[TelemetryCase.NODE_EXECUTION] is TraceTaskName.NODE_EXECUTION_TRACE
def test_draft_node_execution_mapping(self) -> None:
assert CASE_TO_TRACE_TASK[TelemetryCase.DRAFT_NODE_EXECUTION] is TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
def test_prompt_generation_mapping(self) -> None:
assert CASE_TO_TRACE_TASK[TelemetryCase.PROMPT_GENERATION] is TraceTaskName.PROMPT_GENERATION_TRACE

View File

@@ -0,0 +1,201 @@
"""Unit tests for enterprise/telemetry/id_generator.py."""
from __future__ import annotations
import uuid
from unittest.mock import patch
# ---------------------------------------------------------------------------
# compute_deterministic_span_id
# ---------------------------------------------------------------------------
class TestComputeDeterministicSpanId:
def test_returns_lower_64_bits_of_uuid(self) -> None:
from enterprise.telemetry.id_generator import compute_deterministic_span_id
uid = "123e4567-e89b-12d3-a456-426614174000"
expected = uuid.UUID(uid).int & ((1 << 64) - 1)
assert compute_deterministic_span_id(uid) == expected
def test_non_zero_result_returned_unchanged(self) -> None:
from enterprise.telemetry.id_generator import compute_deterministic_span_id
# This UUID has non-zero lower 64 bits
uid = "123e4567-e89b-12d3-a456-426614174000"
result = compute_deterministic_span_id(uid)
assert result != 0
def test_zero_lower_bits_returns_one(self) -> None:
"""When the lower 64 bits of the UUID int are 0, the function must return 1 (OTEL requirement)."""
from enterprise.telemetry.id_generator import compute_deterministic_span_id
# Craft a UUID whose lower 64 bits are 0: upper 64 bits are 1, lower 64 bits are 0.
# int = (1 << 64), UUID fields constructed to produce this integer.
target_int = 1 << 64 # lower 64 bits are 0x0000000000000000
crafted_uuid = uuid.UUID(int=target_int)
result = compute_deterministic_span_id(str(crafted_uuid))
assert result == 1
def test_raises_on_invalid_uuid(self) -> None:
import pytest
from enterprise.telemetry.id_generator import compute_deterministic_span_id
with pytest.raises((ValueError, AttributeError)):
compute_deterministic_span_id("not-a-uuid")
def test_different_uuids_produce_different_span_ids(self) -> None:
from enterprise.telemetry.id_generator import compute_deterministic_span_id
uid1 = "123e4567-e89b-12d3-a456-426614174000"
uid2 = "987fbc97-4bed-5078-9f07-9141ba07c9f3"
assert compute_deterministic_span_id(uid1) != compute_deterministic_span_id(uid2)
def test_deterministic_same_input_same_output(self) -> None:
from enterprise.telemetry.id_generator import compute_deterministic_span_id
uid = "123e4567-e89b-12d3-a456-426614174000"
assert compute_deterministic_span_id(uid) == compute_deterministic_span_id(uid)
# ---------------------------------------------------------------------------
# Context variable helpers
# ---------------------------------------------------------------------------
class TestContextVariableHelpers:
def test_set_and_get_correlation_id(self) -> None:
from enterprise.telemetry.id_generator import get_correlation_id, set_correlation_id
set_correlation_id("corr-123")
assert get_correlation_id() == "corr-123"
def test_clear_correlation_id(self) -> None:
from enterprise.telemetry.id_generator import get_correlation_id, set_correlation_id
set_correlation_id("corr-abc")
set_correlation_id(None)
assert get_correlation_id() is None
def test_correlation_id_default_is_none(self) -> None:
from enterprise.telemetry.id_generator import get_correlation_id, set_correlation_id
set_correlation_id(None)
assert get_correlation_id() is None
def test_set_span_id_source_stored_in_context(self) -> None:
from enterprise.telemetry.id_generator import _span_id_source_context, set_span_id_source
set_span_id_source("span-src-1")
assert _span_id_source_context.get() == "span-src-1"
def test_clear_span_id_source(self) -> None:
from enterprise.telemetry.id_generator import _span_id_source_context, set_span_id_source
set_span_id_source("span-src-1")
set_span_id_source(None)
assert _span_id_source_context.get() is None
# ---------------------------------------------------------------------------
# CorrelationIdGenerator.generate_trace_id
# ---------------------------------------------------------------------------
class TestCorrelationIdGeneratorGenerateTraceId:
def setup_method(self) -> None:
from enterprise.telemetry.id_generator import set_correlation_id
set_correlation_id(None)
def test_returns_uuid_int_when_correlation_id_set(self) -> None:
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_correlation_id
uid = "123e4567-e89b-12d3-a456-426614174000"
set_correlation_id(uid)
gen = CorrelationIdGenerator()
trace_id = gen.generate_trace_id()
assert trace_id == uuid.UUID(uid).int
def test_returns_random_when_no_correlation_id(self) -> None:
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_correlation_id
set_correlation_id(None)
gen = CorrelationIdGenerator()
# Should return a non-zero int without raising
trace_id = gen.generate_trace_id()
assert isinstance(trace_id, int)
assert trace_id > 0
def test_returns_random_when_correlation_id_is_invalid_uuid(self) -> None:
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_correlation_id
set_correlation_id("not-a-valid-uuid")
gen = CorrelationIdGenerator()
with patch("enterprise.telemetry.id_generator.random.getrandbits", return_value=42) as mock_rng:
trace_id = gen.generate_trace_id()
mock_rng.assert_called_once_with(128)
assert trace_id == 42
# ---------------------------------------------------------------------------
# CorrelationIdGenerator.generate_span_id
# ---------------------------------------------------------------------------
class TestCorrelationIdGeneratorGenerateSpanId:
def setup_method(self) -> None:
from enterprise.telemetry.id_generator import set_span_id_source
set_span_id_source(None)
def test_uses_deterministic_span_id_when_source_set(self) -> None:
from enterprise.telemetry.id_generator import (
CorrelationIdGenerator,
compute_deterministic_span_id,
set_span_id_source,
)
uid = "123e4567-e89b-12d3-a456-426614174000"
set_span_id_source(uid)
gen = CorrelationIdGenerator()
span_id = gen.generate_span_id()
assert span_id == compute_deterministic_span_id(uid)
def test_returns_random_when_no_source(self) -> None:
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source
set_span_id_source(None)
gen = CorrelationIdGenerator()
span_id = gen.generate_span_id()
assert isinstance(span_id, int)
assert span_id != 0
def test_returns_random_when_source_is_invalid_uuid(self) -> None:
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source
set_span_id_source("not-a-uuid")
gen = CorrelationIdGenerator()
with patch("enterprise.telemetry.id_generator.random.getrandbits", return_value=7) as mock_rng:
span_id = gen.generate_span_id()
assert span_id == 7
def test_random_span_id_retried_if_zero(self) -> None:
"""generate_span_id must never return 0 — it retries until non-zero."""
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source
set_span_id_source(None)
gen = CorrelationIdGenerator()
# First call returns 0 (invalid), second returns 99
with patch("enterprise.telemetry.id_generator.random.getrandbits", side_effect=[0, 99]):
span_id = gen.generate_span_id()
assert span_id == 99
def test_generate_span_id_always_non_zero(self) -> None:
from enterprise.telemetry.id_generator import CorrelationIdGenerator, set_span_id_source
set_span_id_source(None)
gen = CorrelationIdGenerator()
for _ in range(20):
assert gen.generate_span_id() != 0

View File

@@ -0,0 +1,511 @@
"""Unit tests for EnterpriseMetricHandler."""
import json
from unittest.mock import MagicMock, patch
import pytest
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
from enterprise.telemetry.metric_handler import EnterpriseMetricHandler
@pytest.fixture
def mock_redis():
with patch("enterprise.telemetry.metric_handler.redis_client") as mock:
yield mock
@pytest.fixture
def sample_envelope():
return TelemetryEnvelope(
case=TelemetryCase.APP_CREATED,
tenant_id="test-tenant",
event_id="test-event-123",
payload={"app_id": "app-123", "name": "Test App"},
)
def test_dispatch_app_created(sample_envelope, mock_redis):
mock_redis.set.return_value = True
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_app_created") as mock_handler:
handler.handle(sample_envelope)
mock_handler.assert_called_once_with(sample_envelope)
def test_dispatch_app_updated(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_UPDATED,
tenant_id="test-tenant",
event_id="test-event-456",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_app_updated") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_app_deleted(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_DELETED,
tenant_id="test-tenant",
event_id="test-event-789",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_app_deleted") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_feedback_created(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.FEEDBACK_CREATED,
tenant_id="test-tenant",
event_id="test-event-abc",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_feedback_created") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_message_run(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.MESSAGE_RUN,
tenant_id="test-tenant",
event_id="test-event-msg",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_message_run") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_tool_execution(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.TOOL_EXECUTION,
tenant_id="test-tenant",
event_id="test-event-tool",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_tool_execution") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_moderation_check(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.MODERATION_CHECK,
tenant_id="test-tenant",
event_id="test-event-mod",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_moderation_check") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_suggested_question(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.SUGGESTED_QUESTION,
tenant_id="test-tenant",
event_id="test-event-sq",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_suggested_question") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_dataset_retrieval(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.DATASET_RETRIEVAL,
tenant_id="test-tenant",
event_id="test-event-ds",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_dataset_retrieval") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_generate_name(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.GENERATE_NAME,
tenant_id="test-tenant",
event_id="test-event-gn",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_generate_name") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_prompt_generation(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.PROMPT_GENERATION,
tenant_id="test-tenant",
event_id="test-event-pg",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_prompt_generation") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_all_known_cases_have_handlers(mock_redis):
mock_redis.set.return_value = True
handler = EnterpriseMetricHandler()
for case in TelemetryCase:
envelope = TelemetryEnvelope(
case=case,
tenant_id="test-tenant",
event_id=f"test-{case.value}",
payload={},
)
handler.handle(envelope)
def test_idempotency_duplicate(sample_envelope, mock_redis):
mock_redis.set.return_value = None
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_app_created") as mock_handler:
handler.handle(sample_envelope)
mock_handler.assert_not_called()
def test_idempotency_first_seen(sample_envelope, mock_redis):
mock_redis.set.return_value = True
handler = EnterpriseMetricHandler()
is_dup = handler._is_duplicate(sample_envelope)
assert is_dup is False
mock_redis.set.assert_called_once_with(
"telemetry:dedup:test-tenant:test-event-123",
b"1",
nx=True,
ex=3600,
)
def test_idempotency_redis_failure_fails_open(sample_envelope, mock_redis, caplog):
mock_redis.set.side_effect = Exception("Redis unavailable")
handler = EnterpriseMetricHandler()
is_dup = handler._is_duplicate(sample_envelope)
assert is_dup is False
assert "Redis unavailable for deduplication check" in caplog.text
def test_rehydration_uses_payload(sample_envelope):
handler = EnterpriseMetricHandler()
payload = handler._rehydrate(sample_envelope)
assert payload == {"app_id": "app-123", "name": "Test App"}
def test_rehydration_from_storage():
"""Verify _rehydrate loads payload from object storage via payload_ref."""
stored_data = {"app_id": "app-stored", "mode": "workflow"}
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_CREATED,
tenant_id="test-tenant",
event_id="test-event-fb",
payload={},
metadata={"payload_ref": "telemetry/test-tenant/test-event-fb.json"},
)
handler = EnterpriseMetricHandler()
with patch("enterprise.telemetry.metric_handler.storage") as mock_storage:
mock_storage.load.return_value = json.dumps(stored_data).encode("utf-8")
payload = handler._rehydrate(envelope)
assert payload == stored_data
mock_storage.load.assert_called_once_with("telemetry/test-tenant/test-event-fb.json")
def test_rehydration_storage_failure_emits_degraded_event():
"""Verify _rehydrate emits degraded event when storage load fails."""
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_CREATED,
tenant_id="test-tenant",
event_id="test-event-fail",
payload={},
metadata={"payload_ref": "telemetry/test-tenant/test-event-fail.json"},
)
handler = EnterpriseMetricHandler()
with (
patch("enterprise.telemetry.metric_handler.storage") as mock_storage,
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
):
mock_storage.load.side_effect = Exception("Storage unavailable")
payload = handler._rehydrate(envelope)
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
assert payload == {}
mock_emit.assert_called_once()
call_args = mock_emit.call_args
assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.REHYDRATION_FAILED
assert "dify.telemetry.error" in call_args[1]["attributes"]
def test_rehydration_emits_degraded_event_on_empty_payload():
"""Verify _rehydrate emits degraded event when payload is empty and no ref exists."""
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_CREATED,
tenant_id="test-tenant",
event_id="test-event-empty",
payload={},
)
handler = EnterpriseMetricHandler()
with patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit:
payload = handler._rehydrate(envelope)
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
assert payload == {}
mock_emit.assert_called_once()
call_args = mock_emit.call_args
assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.REHYDRATION_FAILED
assert "dify.telemetry.error" in call_args[1]["attributes"]
def test_on_app_created_emits_correct_event(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_CREATED,
tenant_id="tenant-123",
event_id="event-456",
payload={"app_id": "app-789", "mode": "chat"},
)
handler = EnterpriseMetricHandler()
with (
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
):
mock_exporter = MagicMock()
mock_get_exporter.return_value = mock_exporter
handler._on_app_created(envelope)
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
mock_emit.assert_called_once()
call_args = mock_emit.call_args
assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.APP_CREATED
assert call_args[1]["tenant_id"] == "tenant-123"
attrs = call_args[1]["attributes"]
assert attrs["dify.app_id"] == "app-789"
assert attrs["dify.tenant_id"] == "tenant-123"
assert attrs["dify.event.id"] == "event-456"
assert attrs["dify.app.mode"] == "chat"
assert "dify.app.created_at" in attrs
from enterprise.telemetry.entities import EnterpriseTelemetryCounter
mock_exporter.increment_counter.assert_called_once()
counter_call = mock_exporter.increment_counter.call_args
assert counter_call[0][0] == EnterpriseTelemetryCounter.APP_CREATED
assert counter_call[0][1] == 1
assert counter_call[0][2]["tenant_id"] == "tenant-123"
assert counter_call[0][2]["app_id"] == "app-789"
assert counter_call[0][2]["mode"] == "chat"
def test_on_app_updated_emits_correct_event(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_UPDATED,
tenant_id="tenant-123",
event_id="event-456",
payload={"app_id": "app-789"},
)
handler = EnterpriseMetricHandler()
with (
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
):
mock_exporter = MagicMock()
mock_get_exporter.return_value = mock_exporter
handler._on_app_updated(envelope)
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
mock_emit.assert_called_once()
call_args = mock_emit.call_args
assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.APP_UPDATED
assert call_args[1]["tenant_id"] == "tenant-123"
attrs = call_args[1]["attributes"]
assert attrs["dify.app_id"] == "app-789"
assert attrs["dify.tenant_id"] == "tenant-123"
assert attrs["dify.event.id"] == "event-456"
assert "dify.app.updated_at" in attrs
from enterprise.telemetry.entities import EnterpriseTelemetryCounter
mock_exporter.increment_counter.assert_called_once()
counter_call = mock_exporter.increment_counter.call_args
assert counter_call[0][0] == EnterpriseTelemetryCounter.APP_UPDATED
assert counter_call[0][1] == 1
assert counter_call[0][2]["tenant_id"] == "tenant-123"
assert counter_call[0][2]["app_id"] == "app-789"
def test_on_app_deleted_emits_correct_event(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_DELETED,
tenant_id="tenant-123",
event_id="event-456",
payload={"app_id": "app-789"},
)
handler = EnterpriseMetricHandler()
with (
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
):
mock_exporter = MagicMock()
mock_get_exporter.return_value = mock_exporter
handler._on_app_deleted(envelope)
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
mock_emit.assert_called_once()
call_args = mock_emit.call_args
assert call_args[1]["event_name"] == EnterpriseTelemetryEvent.APP_DELETED
assert call_args[1]["tenant_id"] == "tenant-123"
attrs = call_args[1]["attributes"]
assert attrs["dify.app_id"] == "app-789"
assert attrs["dify.tenant_id"] == "tenant-123"
assert attrs["dify.event.id"] == "event-456"
assert "dify.app.deleted_at" in attrs
from enterprise.telemetry.entities import EnterpriseTelemetryCounter
mock_exporter.increment_counter.assert_called_once()
counter_call = mock_exporter.increment_counter.call_args
assert counter_call[0][0] == EnterpriseTelemetryCounter.APP_DELETED
assert counter_call[0][1] == 1
assert counter_call[0][2]["tenant_id"] == "tenant-123"
assert counter_call[0][2]["app_id"] == "app-789"
def test_on_feedback_created_emits_correct_event(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.FEEDBACK_CREATED,
tenant_id="tenant-123",
event_id="event-456",
payload={
"message_id": "msg-001",
"app_id": "app-789",
"conversation_id": "conv-123",
"from_end_user_id": "user-456",
"from_account_id": None,
"rating": "like",
"from_source": "api",
"content": "Great!",
},
)
handler = EnterpriseMetricHandler()
with (
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
):
mock_exporter = MagicMock()
mock_exporter.include_content = True
mock_get_exporter.return_value = mock_exporter
handler._on_feedback_created(envelope)
mock_emit.assert_called_once()
call_args = mock_emit.call_args
assert call_args[1]["event_name"] == "dify.feedback.created"
assert call_args[1]["attributes"]["dify.message.id"] == "msg-001"
assert call_args[1]["attributes"]["dify.feedback.content"] == "Great!"
assert "dify.feedback.created_at" in call_args[1]["attributes"]
assert call_args[1]["tenant_id"] == "tenant-123"
assert call_args[1]["user_id"] == "user-456"
mock_exporter.increment_counter.assert_called_once()
counter_args = mock_exporter.increment_counter.call_args
assert counter_args[0][2]["app_id"] == "app-789"
assert counter_args[0][2]["rating"] == "like"
def test_on_feedback_created_without_content(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.FEEDBACK_CREATED,
tenant_id="tenant-123",
event_id="event-456",
payload={
"message_id": "msg-001",
"app_id": "app-789",
"conversation_id": "conv-123",
"from_end_user_id": "user-456",
"from_account_id": None,
"rating": "like",
"from_source": "api",
"content": "Great!",
},
)
handler = EnterpriseMetricHandler()
with (
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
):
mock_exporter = MagicMock()
mock_exporter.include_content = False
mock_get_exporter.return_value = mock_exporter
handler._on_feedback_created(envelope)
mock_emit.assert_called_once()
call_args = mock_emit.call_args
assert "dify.feedback.content" not in call_args[1]["attributes"]

View File

@@ -0,0 +1,327 @@
"""Unit tests for enterprise/telemetry/telemetry_log.py."""
from __future__ import annotations
import uuid
from unittest.mock import MagicMock, patch
# ---------------------------------------------------------------------------
# compute_trace_id_hex
# ---------------------------------------------------------------------------
class TestComputeTraceIdHex:
def setup_method(self) -> None:
# Clear lru_cache between tests to avoid cross-test pollution
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
compute_trace_id_hex.cache_clear()
def test_none_returns_empty(self) -> None:
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
assert compute_trace_id_hex(None) == ""
def test_empty_string_returns_empty(self) -> None:
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
assert compute_trace_id_hex("") == ""
def test_already_32_hex_chars_returned_as_is(self) -> None:
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
hex_id = "a" * 32
assert compute_trace_id_hex(hex_id) == hex_id
def test_valid_uuid_string_converted_to_32_hex(self) -> None:
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
uid = "123e4567-e89b-12d3-a456-426614174000"
result = compute_trace_id_hex(uid)
assert len(result) == 32
assert all(ch in "0123456789abcdef" for ch in result)
# Round-trip: int of the UUID should equal the int parsed from result
assert int(result, 16) == uuid.UUID(uid).int
def test_invalid_string_returns_empty(self) -> None:
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
assert compute_trace_id_hex("not-a-uuid") == ""
def test_whitespace_stripped(self) -> None:
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
uid = " 123e4567-e89b-12d3-a456-426614174000 "
result = compute_trace_id_hex(uid)
assert len(result) == 32
def test_uppercase_uuid_accepted(self) -> None:
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
uid = "123E4567-E89B-12D3-A456-426614174000"
result = compute_trace_id_hex(uid)
assert len(result) == 32
def test_result_is_cached(self) -> None:
from enterprise.telemetry.telemetry_log import compute_trace_id_hex
uid = "123e4567-e89b-12d3-a456-426614174000"
r1 = compute_trace_id_hex(uid)
r2 = compute_trace_id_hex(uid)
assert r1 == r2
info = compute_trace_id_hex.cache_info()
assert info.hits >= 1
# ---------------------------------------------------------------------------
# compute_span_id_hex
# ---------------------------------------------------------------------------
class TestComputeSpanIdHex:
def setup_method(self) -> None:
from enterprise.telemetry.telemetry_log import compute_span_id_hex
compute_span_id_hex.cache_clear()
def test_none_returns_empty(self) -> None:
from enterprise.telemetry.telemetry_log import compute_span_id_hex
assert compute_span_id_hex(None) == ""
def test_empty_string_returns_empty(self) -> None:
from enterprise.telemetry.telemetry_log import compute_span_id_hex
assert compute_span_id_hex("") == ""
def test_already_16_hex_chars_returned_as_is(self) -> None:
from enterprise.telemetry.telemetry_log import compute_span_id_hex
hex_id = "abcdef0123456789"
assert compute_span_id_hex(hex_id) == hex_id
def test_valid_uuid_produces_16_hex_span_id(self) -> None:
from enterprise.telemetry.telemetry_log import compute_span_id_hex
uid = "123e4567-e89b-12d3-a456-426614174000"
result = compute_span_id_hex(uid)
assert len(result) == 16
assert all(ch in "0123456789abcdef" for ch in result)
def test_invalid_string_returns_empty(self) -> None:
from enterprise.telemetry.telemetry_log import compute_span_id_hex
assert compute_span_id_hex("not-a-uuid-at-all!") == ""
def test_result_is_cached(self) -> None:
from enterprise.telemetry.telemetry_log import compute_span_id_hex
uid = "123e4567-e89b-12d3-a456-426614174000"
compute_span_id_hex(uid)
compute_span_id_hex(uid)
info = compute_span_id_hex.cache_info()
assert info.hits >= 1
# ---------------------------------------------------------------------------
# emit_telemetry_log
# ---------------------------------------------------------------------------
class TestEmitTelemetryLog:
def setup_method(self) -> None:
from enterprise.telemetry.telemetry_log import compute_span_id_hex, compute_trace_id_hex
compute_trace_id_hex.cache_clear()
compute_span_id_hex.cache_clear()
@patch("enterprise.telemetry.telemetry_log.logger")
def test_logs_info_with_event_name_and_signal(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = True
emit_telemetry_log(
event_name="dify.workflow.run",
attributes={"tenant_id": "t1"},
signal="metric_only",
)
mock_logger.info.assert_called_once()
args, kwargs = mock_logger.info.call_args
assert args[0] == "telemetry.%s"
assert args[1] == "metric_only"
extra = kwargs["extra"]
assert extra["attributes"]["dify.event.name"] == "dify.workflow.run"
assert extra["attributes"]["dify.event.signal"] == "metric_only"
@patch("enterprise.telemetry.telemetry_log.logger")
def test_no_log_when_info_disabled(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = False
emit_telemetry_log(event_name="dify.workflow.run", attributes={})
mock_logger.info.assert_not_called()
@patch("enterprise.telemetry.telemetry_log.logger")
def test_trace_id_added_to_extra_when_valid_uuid(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = True
uid = "123e4567-e89b-12d3-a456-426614174000"
emit_telemetry_log(event_name="test.event", attributes={}, trace_id_source=uid)
extra = mock_logger.info.call_args.kwargs["extra"]
assert "trace_id" in extra
assert len(extra["trace_id"]) == 32
@patch("enterprise.telemetry.telemetry_log.logger")
def test_trace_id_absent_when_invalid_source(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = True
emit_telemetry_log(event_name="test.event", attributes={}, trace_id_source="bad-id")
extra = mock_logger.info.call_args.kwargs["extra"]
assert "trace_id" not in extra
@patch("enterprise.telemetry.telemetry_log.logger")
def test_span_id_added_to_extra_when_valid_uuid(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = True
uid = "123e4567-e89b-12d3-a456-426614174000"
emit_telemetry_log(event_name="test.event", attributes={}, span_id_source=uid)
extra = mock_logger.info.call_args.kwargs["extra"]
assert "span_id" in extra
assert len(extra["span_id"]) == 16
@patch("enterprise.telemetry.telemetry_log.logger")
def test_tenant_id_added_when_provided(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = True
emit_telemetry_log(event_name="test.event", attributes={}, tenant_id="tenant-99")
extra = mock_logger.info.call_args.kwargs["extra"]
assert extra["tenant_id"] == "tenant-99"
@patch("enterprise.telemetry.telemetry_log.logger")
def test_user_id_added_when_provided(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = True
emit_telemetry_log(event_name="test.event", attributes={}, user_id="user-42")
extra = mock_logger.info.call_args.kwargs["extra"]
assert extra["user_id"] == "user-42"
@patch("enterprise.telemetry.telemetry_log.logger")
def test_tenant_and_user_id_absent_when_not_provided(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = True
emit_telemetry_log(event_name="test.event", attributes={})
extra = mock_logger.info.call_args.kwargs["extra"]
assert "tenant_id" not in extra
assert "user_id" not in extra
@patch("enterprise.telemetry.telemetry_log.logger")
def test_caller_attributes_merged_into_attrs(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = True
emit_telemetry_log(
event_name="dify.node.run",
attributes={"node_type": "code", "elapsed": 0.5},
)
extra = mock_logger.info.call_args.kwargs["extra"]
assert extra["attributes"]["node_type"] == "code"
assert extra["attributes"]["elapsed"] == 0.5
@patch("enterprise.telemetry.telemetry_log.logger")
def test_signal_span_detail_forwarded(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_telemetry_log
mock_logger.isEnabledFor.return_value = True
emit_telemetry_log(event_name="test.event", attributes={}, signal="span_detail")
args = mock_logger.info.call_args[0]
assert args[1] == "span_detail"
extra = mock_logger.info.call_args.kwargs["extra"]
assert extra["attributes"]["dify.event.signal"] == "span_detail"
# ---------------------------------------------------------------------------
# emit_metric_only_event
# ---------------------------------------------------------------------------
class TestEmitMetricOnlyEvent:
def setup_method(self) -> None:
from enterprise.telemetry.telemetry_log import compute_span_id_hex, compute_trace_id_hex
compute_trace_id_hex.cache_clear()
compute_span_id_hex.cache_clear()
@patch("enterprise.telemetry.telemetry_log.logger")
def test_delegates_to_emit_telemetry_log_with_metric_only_signal(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_metric_only_event
mock_logger.isEnabledFor.return_value = True
emit_metric_only_event(
event_name="dify.app.created",
attributes={"app_id": "app-1"},
tenant_id="t1",
user_id="u1",
)
mock_logger.info.assert_called_once()
extra = mock_logger.info.call_args.kwargs["extra"]
assert extra["attributes"]["dify.event.signal"] == "metric_only"
assert extra["attributes"]["dify.event.name"] == "dify.app.created"
assert extra["attributes"]["app_id"] == "app-1"
assert extra["tenant_id"] == "t1"
assert extra["user_id"] == "u1"
@patch("enterprise.telemetry.telemetry_log.logger")
def test_trace_and_span_ids_passed_through(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_metric_only_event
mock_logger.isEnabledFor.return_value = True
uid = "123e4567-e89b-12d3-a456-426614174000"
emit_metric_only_event(
event_name="dify.workflow.run",
attributes={},
trace_id_source=uid,
span_id_source=uid,
)
extra = mock_logger.info.call_args.kwargs["extra"]
assert "trace_id" in extra
assert "span_id" in extra
@patch("enterprise.telemetry.telemetry_log.logger")
def test_no_log_emitted_when_logger_disabled(self, mock_logger: MagicMock) -> None:
from enterprise.telemetry.telemetry_log import emit_metric_only_event
mock_logger.isEnabledFor.return_value = False
emit_metric_only_event(event_name="dify.workflow.run", attributes={})
mock_logger.info.assert_not_called()

View File

@@ -0,0 +1,206 @@
from unittest.mock import MagicMock, patch
import pytest
@pytest.fixture
def mock_db():
with patch("services.app_service.db") as mock_db:
mock_db.session = MagicMock()
yield mock_db
@pytest.fixture
def _mock_deps():
with (
patch("services.app_service.BillingService"),
patch("services.app_service.FeatureService"),
patch("services.app_service.EnterpriseService"),
patch("services.app_service.remove_app_and_related_data_task"),
):
yield
@pytest.fixture
def app_model():
app = MagicMock()
app.id = "app-123"
app.tenant_id = "tenant-456"
app.name = "Old Name"
app.icon_type = "emoji"
app.icon = "🤖"
app.icon_background = "#fff"
app.enable_site = False
app.enable_api = False
return app
def _make_collector(target: list):
def handler(sender, **kw):
target.append(sender)
return handler
@pytest.mark.usefixtures("mock_db", "_mock_deps")
class TestAppWasDeletedSignal:
def test_sends_signal(self, app_model):
from events.app_event import app_was_deleted
from services.app_service import AppService
received = []
handler = _make_collector(received)
app_was_deleted.connect(handler)
try:
AppService().delete_app(app_model)
finally:
app_was_deleted.disconnect(handler)
assert received == [app_model]
def test_signal_fires_before_db_delete(self, app_model, mock_db):
from events.app_event import app_was_deleted
from services.app_service import AppService
call_order: list[str] = []
def handler(sender, **kw):
call_order.append("signal")
app_was_deleted.connect(handler)
mock_db.session.delete.side_effect = lambda _: call_order.append("db_delete")
try:
AppService().delete_app(app_model)
finally:
app_was_deleted.disconnect(handler)
assert call_order.index("signal") < call_order.index("db_delete")
@pytest.mark.usefixtures("mock_db")
class TestAppWasUpdatedSignal:
def test_update_app(self, app_model):
from events.app_event import app_was_updated
from services.app_service import AppService
received = []
handler = _make_collector(received)
app_was_updated.connect(handler)
with patch("services.app_service.current_user", MagicMock(id="user-1")):
try:
AppService().update_app(
app_model,
{
"name": "New",
"description": "Desc",
"icon_type": "emoji",
"icon": "🤖",
"icon_background": "#fff",
"use_icon_as_answer_icon": False,
"max_active_requests": 0,
},
)
finally:
app_was_updated.disconnect(handler)
assert received == [app_model]
def test_update_app_name(self, app_model):
from events.app_event import app_was_updated
from services.app_service import AppService
received = []
handler = _make_collector(received)
app_was_updated.connect(handler)
with patch("services.app_service.current_user", MagicMock(id="user-1")):
try:
AppService().update_app_name(app_model, "New Name")
finally:
app_was_updated.disconnect(handler)
assert received == [app_model]
def test_update_app_icon(self, app_model):
from events.app_event import app_was_updated
from services.app_service import AppService
received = []
handler = _make_collector(received)
app_was_updated.connect(handler)
with patch("services.app_service.current_user", MagicMock(id="user-1")):
try:
AppService().update_app_icon(app_model, "🎉", "#000")
finally:
app_was_updated.disconnect(handler)
assert received == [app_model]
def test_update_app_site_status_sends_when_changed(self, app_model):
from events.app_event import app_was_updated
from services.app_service import AppService
received = []
handler = _make_collector(received)
app_was_updated.connect(handler)
with patch("services.app_service.current_user", MagicMock(id="user-1")):
try:
app_model.enable_site = False
AppService().update_app_site_status(app_model, True)
finally:
app_was_updated.disconnect(handler)
assert received == [app_model]
def test_update_app_site_status_skips_when_unchanged(self, app_model):
from events.app_event import app_was_updated
from services.app_service import AppService
received = []
handler = _make_collector(received)
app_was_updated.connect(handler)
try:
app_model.enable_site = True
AppService().update_app_site_status(app_model, True)
finally:
app_was_updated.disconnect(handler)
assert received == []
def test_update_app_api_status_sends_when_changed(self, app_model):
from events.app_event import app_was_updated
from services.app_service import AppService
received = []
handler = _make_collector(received)
app_was_updated.connect(handler)
with patch("services.app_service.current_user", MagicMock(id="user-1")):
try:
app_model.enable_api = False
AppService().update_app_api_status(app_model, True)
finally:
app_was_updated.disconnect(handler)
assert received == [app_model]
def test_update_app_api_status_skips_when_unchanged(self, app_model):
from events.app_event import app_was_updated
from services.app_service import AppService
received = []
handler = _make_collector(received)
app_was_updated.connect(handler)
try:
app_model.enable_api = True
AppService().update_app_api_status(app_model, True)
finally:
app_was_updated.disconnect(handler)
assert received == []

View File

@@ -0,0 +1,69 @@
"""Unit tests for enterprise telemetry Celery task."""
import json
from unittest.mock import MagicMock, patch
import pytest
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
@pytest.fixture
def sample_envelope_json():
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_CREATED,
tenant_id="test-tenant",
event_id="test-event-123",
payload={"app_id": "app-123"},
)
return envelope.model_dump_json()
def test_process_enterprise_telemetry_success(sample_envelope_json):
with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class:
mock_handler = MagicMock()
mock_handler_class.return_value = mock_handler
process_enterprise_telemetry(sample_envelope_json)
mock_handler.handle.assert_called_once()
call_args = mock_handler.handle.call_args[0][0]
assert isinstance(call_args, TelemetryEnvelope)
assert call_args.case == TelemetryCase.APP_CREATED
assert call_args.tenant_id == "test-tenant"
assert call_args.event_id == "test-event-123"
def test_process_enterprise_telemetry_invalid_json(caplog):
invalid_json = "not valid json"
process_enterprise_telemetry(invalid_json)
assert "Failed to process enterprise telemetry envelope" in caplog.text
def test_process_enterprise_telemetry_handler_exception(sample_envelope_json, caplog):
with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class:
mock_handler = MagicMock()
mock_handler.handle.side_effect = Exception("Handler error")
mock_handler_class.return_value = mock_handler
process_enterprise_telemetry(sample_envelope_json)
assert "Failed to process enterprise telemetry envelope" in caplog.text
def test_process_enterprise_telemetry_validation_error(caplog):
invalid_envelope = json.dumps(
{
"case": "INVALID_CASE",
"tenant_id": "test-tenant",
"event_id": "test-event",
"payload": {},
}
)
process_enterprise_telemetry(invalid_envelope)
assert "Failed to process enterprise telemetry envelope" in caplog.text

View File

@@ -3,13 +3,10 @@ import { LexicalComposer } from '@lexical/react/LexicalComposer'
import { act, render, waitFor } from '@testing-library/react'
import {
BLUR_COMMAND,
COMMAND_PRIORITY_EDITOR,
FOCUS_COMMAND,
KEY_ESCAPE_COMMAND,
} from 'lexical'
import OnBlurBlock from '../on-blur-or-focus-block'
import { CaptureEditorPlugin } from '../test-utils'
import { CLEAR_HIDE_MENU_TIMEOUT } from '../workflow-variable-block'
const renderOnBlurBlock = (props?: {
onBlur?: () => void
@@ -75,7 +72,7 @@ describe('OnBlurBlock', () => {
expect(onFocus).toHaveBeenCalledTimes(1)
})
it('should call onBlur and dispatch escape after delay when blur target is not var-search-input', async () => {
it('should call onBlur when blur target is not var-search-input', async () => {
const onBlur = vi.fn()
const { getEditor } = renderOnBlurBlock({ onBlur })
@@ -85,14 +82,6 @@ describe('OnBlurBlock', () => {
const editor = getEditor()
expect(editor).not.toBeNull()
vi.useFakeTimers()
const onEscape = vi.fn(() => true)
const unregister = editor!.registerCommand(
KEY_ESCAPE_COMMAND,
onEscape,
COMMAND_PRIORITY_EDITOR,
)
let handled = false
act(() => {
@@ -101,18 +90,9 @@ describe('OnBlurBlock', () => {
expect(handled).toBe(true)
expect(onBlur).toHaveBeenCalledTimes(1)
expect(onEscape).not.toHaveBeenCalled()
act(() => {
vi.advanceTimersByTime(200)
})
expect(onEscape).toHaveBeenCalledTimes(1)
unregister()
vi.useRealTimers()
})
it('should dispatch delayed escape when onBlur callback is not provided', async () => {
it('should handle blur when onBlur callback is not provided', async () => {
const { getEditor } = renderOnBlurBlock()
await waitFor(() => {
@@ -121,28 +101,16 @@ describe('OnBlurBlock', () => {
const editor = getEditor()
expect(editor).not.toBeNull()
vi.useFakeTimers()
const onEscape = vi.fn(() => true)
const unregister = editor!.registerCommand(
KEY_ESCAPE_COMMAND,
onEscape,
COMMAND_PRIORITY_EDITOR,
)
let handled = false
act(() => {
editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div')))
})
act(() => {
vi.advanceTimersByTime(200)
handled = editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div')))
})
expect(onEscape).toHaveBeenCalledTimes(1)
unregister()
vi.useRealTimers()
expect(handled).toBe(true)
})
it('should skip onBlur and delayed escape when blur target is var-search-input', async () => {
it('should skip onBlur when blur target is var-search-input', async () => {
const onBlur = vi.fn()
const { getEditor } = renderOnBlurBlock({ onBlur })
@@ -152,31 +120,17 @@ describe('OnBlurBlock', () => {
const editor = getEditor()
expect(editor).not.toBeNull()
vi.useFakeTimers()
const target = document.createElement('input')
target.classList.add('var-search-input')
const onEscape = vi.fn(() => true)
const unregister = editor!.registerCommand(
KEY_ESCAPE_COMMAND,
onEscape,
COMMAND_PRIORITY_EDITOR,
)
let handled = false
act(() => {
handled = editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(target))
})
act(() => {
vi.advanceTimersByTime(200)
})
expect(handled).toBe(true)
expect(onBlur).not.toHaveBeenCalled()
expect(onEscape).not.toHaveBeenCalled()
unregister()
vi.useRealTimers()
})
it('should handle focus command when onFocus callback is not provided', async () => {
@@ -198,59 +152,6 @@ describe('OnBlurBlock', () => {
})
})
describe('Clear timeout command', () => {
it('should clear scheduled escape timeout when clear command is dispatched', async () => {
const { getEditor } = renderOnBlurBlock({ onBlur: vi.fn() })
await waitFor(() => {
expect(getEditor()).not.toBeNull()
})
const editor = getEditor()
expect(editor).not.toBeNull()
vi.useFakeTimers()
const onEscape = vi.fn(() => true)
const unregister = editor!.registerCommand(
KEY_ESCAPE_COMMAND,
onEscape,
COMMAND_PRIORITY_EDITOR,
)
act(() => {
editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div')))
})
act(() => {
editor!.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined)
})
act(() => {
vi.advanceTimersByTime(200)
})
expect(onEscape).not.toHaveBeenCalled()
unregister()
vi.useRealTimers()
})
it('should handle clear command when no timeout is scheduled', async () => {
const { getEditor } = renderOnBlurBlock()
await waitFor(() => {
expect(getEditor()).not.toBeNull()
})
const editor = getEditor()
expect(editor).not.toBeNull()
let handled = false
act(() => {
handled = editor!.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined)
})
expect(handled).toBe(true)
})
})
describe('Lifecycle cleanup', () => {
it('should unregister commands when component unmounts', async () => {
const { getEditor, unmount } = renderOnBlurBlock()
@@ -266,16 +167,13 @@ describe('OnBlurBlock', () => {
let blurHandled = true
let focusHandled = true
let clearHandled = true
act(() => {
blurHandled = editor!.dispatchCommand(BLUR_COMMAND, createBlurEvent(document.createElement('div')))
focusHandled = editor!.dispatchCommand(FOCUS_COMMAND, createFocusEvent())
clearHandled = editor!.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined)
})
expect(blurHandled).toBe(false)
expect(focusHandled).toBe(false)
expect(clearHandled).toBe(false)
})
})
})

View File

@@ -1,14 +1,13 @@
import type { LexicalEditor } from 'lexical'
import { LexicalComposer } from '@lexical/react/LexicalComposer'
import { act, render, waitFor } from '@testing-library/react'
import { $getRoot, COMMAND_PRIORITY_EDITOR } from 'lexical'
import { $getRoot } from 'lexical'
import { CustomTextNode } from '../custom-text/node'
import { CaptureEditorPlugin } from '../test-utils'
import UpdateBlock, {
PROMPT_EDITOR_INSERT_QUICKLY,
PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER,
} from '../update-block'
import { CLEAR_HIDE_MENU_TIMEOUT } from '../workflow-variable-block'
const { mockUseEventEmitterContextContext } = vi.hoisted(() => ({
mockUseEventEmitterContextContext: vi.fn(),
@@ -157,7 +156,7 @@ describe('UpdateBlock', () => {
})
describe('Quick insert event', () => {
it('should insert slash and dispatch clear command when quick insert event matches instance id', async () => {
it('should insert slash when quick insert event matches instance id', async () => {
const { emit, getEditor } = setup({ instanceId: 'instance-1' })
await waitFor(() => {
@@ -168,13 +167,6 @@ describe('UpdateBlock', () => {
selectRootEnd(editor!)
const clearCommandHandler = vi.fn(() => true)
const unregister = editor!.registerCommand(
CLEAR_HIDE_MENU_TIMEOUT,
clearCommandHandler,
COMMAND_PRIORITY_EDITOR,
)
emit({
type: PROMPT_EDITOR_INSERT_QUICKLY,
instanceId: 'instance-1',
@@ -183,9 +175,6 @@ describe('UpdateBlock', () => {
await waitFor(() => {
expect(readEditorText(editor!)).toBe('/')
})
expect(clearCommandHandler).toHaveBeenCalledTimes(1)
unregister()
})
it('should ignore quick insert event when instance id does not match', async () => {

View File

@@ -23,6 +23,8 @@ import {
$createTextNode,
$getRoot,
$setSelection,
BLUR_COMMAND,
FOCUS_COMMAND,
KEY_ESCAPE_COMMAND,
} from 'lexical'
import * as React from 'react'
@@ -631,4 +633,180 @@ describe('ComponentPicker (component-picker-block/index.tsx)', () => {
// With a single option group, the only divider should be the workflow-var/options separator.
expect(document.querySelectorAll('.bg-divider-subtle')).toHaveLength(1)
})
describe('blur/focus menu visibility', () => {
it('hides the menu after a 200ms delay when blur command is dispatched', async () => {
const captures: Captures = { editor: null, eventEmitter: null }
render((
<MinimalEditor
triggerString="{"
contextBlock={makeContextBlock()}
captures={captures}
/>
))
const editor = await waitForEditor(captures)
await setEditorText(editor, '{', true)
expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument()
vi.useFakeTimers()
act(() => {
editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') }))
})
expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument()
act(() => {
vi.advanceTimersByTime(200)
})
expect(screen.queryByText('common.promptEditor.context.item.title')).not.toBeInTheDocument()
vi.useRealTimers()
})
it('restores menu visibility when focus command is dispatched after blur hides it', async () => {
const captures: Captures = { editor: null, eventEmitter: null }
render((
<MinimalEditor
triggerString="{"
contextBlock={makeContextBlock()}
captures={captures}
/>
))
const editor = await waitForEditor(captures)
await setEditorText(editor, '{', true)
expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument()
vi.useFakeTimers()
act(() => {
editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') }))
})
act(() => {
vi.advanceTimersByTime(200)
})
expect(screen.queryByText('common.promptEditor.context.item.title')).not.toBeInTheDocument()
act(() => {
editor.dispatchCommand(FOCUS_COMMAND, new FocusEvent('focus'))
})
vi.useRealTimers()
await setEditorText(editor, '{', true)
await waitFor(() => {
expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument()
})
})
it('cancels the blur timer when focus arrives before the 200ms timeout', async () => {
const captures: Captures = { editor: null, eventEmitter: null }
render((
<MinimalEditor
triggerString="{"
contextBlock={makeContextBlock()}
captures={captures}
/>
))
const editor = await waitForEditor(captures)
await setEditorText(editor, '{', true)
expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument()
vi.useFakeTimers()
act(() => {
editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') }))
})
act(() => {
editor.dispatchCommand(FOCUS_COMMAND, new FocusEvent('focus'))
})
act(() => {
vi.advanceTimersByTime(200)
})
expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument()
vi.useRealTimers()
})
it('cancels a pending blur timer when a subsequent blur targets var-search-input', async () => {
const captures: Captures = { editor: null, eventEmitter: null }
render((
<MinimalEditor
triggerString="{"
contextBlock={makeContextBlock()}
captures={captures}
/>
))
const editor = await waitForEditor(captures)
await setEditorText(editor, '{', true)
expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument()
vi.useFakeTimers()
act(() => {
editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: document.createElement('button') }))
})
const varInput = document.createElement('input')
varInput.classList.add('var-search-input')
act(() => {
editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: varInput }))
})
act(() => {
vi.advanceTimersByTime(200)
})
expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument()
vi.useRealTimers()
})
it('does not hide the menu when blur target is var-search-input', async () => {
const captures: Captures = { editor: null, eventEmitter: null }
render((
<MinimalEditor
triggerString="{"
contextBlock={makeContextBlock()}
captures={captures}
/>
))
const editor = await waitForEditor(captures)
await setEditorText(editor, '{', true)
expect(await screen.findByText('common.promptEditor.context.item.title')).toBeInTheDocument()
vi.useFakeTimers()
const target = document.createElement('input')
target.classList.add('var-search-input')
act(() => {
editor.dispatchCommand(BLUR_COMMAND, new FocusEvent('blur', { relatedTarget: target }))
})
act(() => {
vi.advanceTimersByTime(200)
})
expect(screen.queryByText('common.promptEditor.context.item.title')).toBeInTheDocument()
vi.useRealTimers()
})
})
})

View File

@@ -21,11 +21,19 @@ import {
} from '@floating-ui/react'
import { useLexicalComposerContext } from '@lexical/react/LexicalComposerContext'
import { LexicalTypeaheadMenuPlugin } from '@lexical/react/LexicalTypeaheadMenuPlugin'
import { KEY_ESCAPE_COMMAND } from 'lexical'
import { mergeRegister } from '@lexical/utils'
import {
BLUR_COMMAND,
COMMAND_PRIORITY_EDITOR,
FOCUS_COMMAND,
KEY_ESCAPE_COMMAND,
} from 'lexical'
import {
Fragment,
memo,
useCallback,
useEffect,
useRef,
useState,
} from 'react'
import ReactDOM from 'react-dom'
@@ -87,6 +95,46 @@ const ComponentPicker = ({
})
const [queryString, setQueryString] = useState<string | null>(null)
const [blurHidden, setBlurHidden] = useState(false)
const blurTimerRef = useRef<ReturnType<typeof setTimeout> | null>(null)
const clearBlurTimer = useCallback(() => {
if (blurTimerRef.current) {
clearTimeout(blurTimerRef.current)
blurTimerRef.current = null
}
}, [])
useEffect(() => {
const unregister = mergeRegister(
editor.registerCommand(
BLUR_COMMAND,
(event) => {
clearBlurTimer()
const target = event?.relatedTarget as HTMLElement
if (!target?.classList?.contains('var-search-input'))
blurTimerRef.current = setTimeout(() => setBlurHidden(true), 200)
return false
},
COMMAND_PRIORITY_EDITOR,
),
editor.registerCommand(
FOCUS_COMMAND,
() => {
clearBlurTimer()
setBlurHidden(false)
return false
},
COMMAND_PRIORITY_EDITOR,
),
)
return () => {
if (blurTimerRef.current)
clearTimeout(blurTimerRef.current)
unregister()
}
}, [editor, clearBlurTimer])
eventEmitter?.useSubscription((v: any) => {
if (v.type === INSERT_VARIABLE_VALUE_BLOCK_COMMAND)
@@ -159,6 +207,8 @@ const ComponentPicker = ({
anchorElementRef,
{ options, selectedIndex, selectOptionAndCleanUp, setHighlightedIndex },
) => {
if (blurHidden)
return null
if (!(anchorElementRef.current && (allFlattenOptions.length || workflowVariableBlock?.show)))
return null
@@ -240,7 +290,7 @@ const ComponentPicker = ({
}
</>
)
}, [allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, workflowVariableOptions, isSupportFileVar, handleClose, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField])
}, [blurHidden, allFlattenOptions.length, workflowVariableBlock?.show, floatingStyles, isPositioned, refs, workflowVariableOptions, isSupportFileVar, handleClose, currentBlock?.generatorType, handleSelectWorkflowVariable, queryString, workflowVariableBlock?.showManageInputField, workflowVariableBlock?.onManageInputField])
return (
<LexicalTypeaheadMenuPlugin

View File

@@ -5,10 +5,8 @@ import {
BLUR_COMMAND,
COMMAND_PRIORITY_EDITOR,
FOCUS_COMMAND,
KEY_ESCAPE_COMMAND,
} from 'lexical'
import { useEffect, useRef } from 'react'
import { CLEAR_HIDE_MENU_TIMEOUT } from './workflow-variable-block'
import { useEffect } from 'react'
type OnBlurBlockProps = {
onBlur?: () => void
@@ -20,35 +18,13 @@ const OnBlurBlock: FC<OnBlurBlockProps> = ({
}) => {
const [editor] = useLexicalComposerContext()
const ref = useRef<ReturnType<typeof setTimeout> | null>(null)
useEffect(() => {
const clearHideMenuTimeout = () => {
if (ref.current) {
clearTimeout(ref.current)
ref.current = null
}
}
const unregister = mergeRegister(
editor.registerCommand(
CLEAR_HIDE_MENU_TIMEOUT,
() => {
clearHideMenuTimeout()
return true
},
COMMAND_PRIORITY_EDITOR,
),
return mergeRegister(
editor.registerCommand(
BLUR_COMMAND,
(event) => {
// Check if the clicked target element is var-search-input
const target = event?.relatedTarget as HTMLElement
if (!target?.classList?.contains('var-search-input')) {
clearHideMenuTimeout()
ref.current = setTimeout(() => {
editor.dispatchCommand(KEY_ESCAPE_COMMAND, new KeyboardEvent('keydown', { key: 'Escape' }))
}, 200)
if (onBlur)
onBlur()
}
@@ -66,11 +42,6 @@ const OnBlurBlock: FC<OnBlurBlockProps> = ({
COMMAND_PRIORITY_EDITOR,
),
)
return () => {
clearHideMenuTimeout()
unregister()
}
}, [editor, onBlur, onFocus])
return null

View File

@@ -3,7 +3,6 @@ import { $insertNodes } from 'lexical'
import { useEventEmitterContextContext } from '@/context/event-emitter'
import { textToEditorState } from '../utils'
import { CustomTextNode } from './custom-text/node'
import { CLEAR_HIDE_MENU_TIMEOUT } from './workflow-variable-block'
export const PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER = 'PROMPT_EDITOR_UPDATE_VALUE_BY_EVENT_EMITTER'
export const PROMPT_EDITOR_INSERT_QUICKLY = 'PROMPT_EDITOR_INSERT_QUICKLY'
@@ -30,8 +29,6 @@ const UpdateBlock = ({
editor.update(() => {
const textNode = new CustomTextNode('/')
$insertNodes([textNode])
editor.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined)
})
}
})

View File

@@ -9,7 +9,6 @@ import { $insertNodes, COMMAND_PRIORITY_EDITOR } from 'lexical'
import { Type } from '@/app/components/workflow/nodes/llm/types'
import { BlockEnum } from '@/app/components/workflow/types'
import {
CLEAR_HIDE_MENU_TIMEOUT,
DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND,
INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND,
UPDATE_WORKFLOW_NODES_MAP,
@@ -134,7 +133,6 @@ describe('WorkflowVariableBlock', () => {
const insertHandler = mockRegisterCommand.mock.calls[0][1] as (variables: string[]) => boolean
const result = insertHandler(['node-1', 'answer'])
expect(mockDispatchCommand).toHaveBeenCalledWith(CLEAR_HIDE_MENU_TIMEOUT, undefined)
expect($createWorkflowVariableBlockNode).toHaveBeenCalledWith(
['node-1', 'answer'],
workflowNodesMap,

View File

@@ -18,7 +18,6 @@ import {
export const INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND = createCommand('INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND')
export const DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND = createCommand('DELETE_WORKFLOW_VARIABLE_BLOCK_COMMAND')
export const CLEAR_HIDE_MENU_TIMEOUT = createCommand('CLEAR_HIDE_MENU_TIMEOUT')
export const UPDATE_WORKFLOW_NODES_MAP = createCommand('UPDATE_WORKFLOW_NODES_MAP')
export type WorkflowVariableBlockProps = {
@@ -49,7 +48,6 @@ const WorkflowVariableBlock = memo(({
editor.registerCommand(
INSERT_WORKFLOW_VARIABLE_BLOCK_COMMAND,
(variables: string[]) => {
editor.dispatchCommand(CLEAR_HIDE_MENU_TIMEOUT, undefined)
const workflowVariableBlockNode = $createWorkflowVariableBlockNode(variables, workflowNodesMap, getVarType)
$insertNodes([workflowVariableBlockNode])

View File

@@ -2768,7 +2768,7 @@
},
"app/components/base/prompt-editor/plugins/workflow-variable-block/index.tsx": {
"react-refresh/only-export-components": {
"count": 4
"count": 3
}
},
"app/components/base/prompt-editor/plugins/workflow-variable-block/node.tsx": {