Compare commits

...

43 Commits

Author SHA1 Message Date
GareArc
576eca2113 Merge branch '1.12.1-otel-ee' into deploy/enterprise
Some checks are pending
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Waiting to run
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Waiting to run
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Waiting to run
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Blocked by required conditions
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Blocked by required conditions
2026-02-05 23:07:48 -08:00
GareArc
8ded2d73f0 fix(telemetry): move EE guard to gateway routing level
Prevents CE users from enqueueing EE-only events (all METRIC_LOG cases)
to non-existent enterprise_telemetry Celery queue.

- Add _should_drop_ee_only_event() check in emit() before routing
- Remove redundant check from _emit_trace()
- Single guard at gateway level protects both trace and metric/log paths
2026-02-05 22:58:40 -08:00
GareArc
4a9b74f86b refactor(telemetry): simplify by eliminating TelemetryFacade
**Problem:**
The telemetry system had unnecessary abstraction layers and bad practices
from the last 3 commits introducing the gateway implementation:
- TelemetryFacade class wrapper around emit() function
- String literals instead of SignalType enum
- Dictionary mapping enum → string instead of enum → enum
- Unnecessary ENTERPRISE_TELEMETRY_GATEWAY_ENABLED feature flag
- Duplicate guard checks scattered across files
- Non-thread-safe TelemetryGateway singleton pattern
- Missing guard in ops_trace_task.py causing RuntimeError spam

**Solution:**
1. Deleted TelemetryFacade - replaced with thin emit() function in core/telemetry/__init__.py
2. Added SignalType enum ('trace' | 'metric_log') to enterprise/telemetry/contracts.py
3. Replaced CASE_TO_TRACE_TASK_NAME dict with CASE_TO_TRACE_TASK: dict[TelemetryCase, TraceTaskName]
4. Deleted is_gateway_enabled() and _emit_legacy() - using existing ENTERPRISE_ENABLED + ENTERPRISE_TELEMETRY_ENABLED instead
5. Extracted _should_drop_ee_only_event() helper to eliminate duplicate checks
6. Moved TelemetryGateway singleton to ext_enterprise_telemetry.py:
   - Init once in init_app() for thread-safety
   - Access via get_gateway() function
7. Re-added guard to ops_trace_task.py to prevent RuntimeError when EE=OFF but CE tracing enabled
8. Updated 11 caller files to import 'emit as telemetry_emit' instead of 'TelemetryFacade'

**Result:**
- 322 net lines deleted (533 removed, 211 added)
- All 91 tests pass
- Thread-safe singleton pattern
- Cleaner API surface: from TelemetryFacade.emit() to telemetry_emit()
- Proper enum usage throughout
- No RuntimeError spam in EE=OFF + CE=ON scenario
2026-02-05 22:41:09 -08:00
Xiyuan Chen
d7c3ae50dc Update api/services/tools/builtin_tools_manage_service.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-06 13:37:37 +08:00
NFish
b921711e9e fix: hide invite button if current user is not workspace manager (#31742) 2026-02-06 13:37:37 +08:00
yunlu.wen
fb38ad84e1 chore: upgrade deps, see pull #30976 2026-02-06 13:37:33 +08:00
Yunlu Wen
91c854b5be chore: sync enterprise release (#31626)
Co-authored-by: zhsama <torvalds@linux.do>
2026-02-06 13:35:28 +08:00
NFish
d35b231941 fix: enterprise CVE 2026 23864 (#31599) 2026-02-06 13:35:22 +08:00
GareArc
849b4b8c40 fix: add TYPE_CHECKING import for Account type annotation 2026-02-06 13:32:20 +08:00
GareArc
990e8feee8 security: fix IDOR and privilege escalation in set_default_provider
- Add tenant_id verification to prevent IDOR attacks
- Add admin check for enterprise tenant-wide default changes
- Preserve non-enterprise behavior (users can set own defaults)
2026-02-06 13:32:18 +08:00
GareArc
53641019b1 fix: remove user_id filter when clearing default provider (enterprise only)
When setting a new default credential in enterprise mode, the code was
only clearing is_default for credentials matching the current user_id.
This caused issues when:
1. Enterprise credential A (synced with system user_id) was default
2. User sets local credential B as default
3. A still had is_default=true (different user_id)
4. Both A and B were considered defaults

The fix removes user_id from the filter only for enterprise deployments,
since enterprise credentials may have different user_id than local ones.
Non-enterprise behavior is unchanged to avoid breaking existing setups.

Fixes EE-1511
2026-02-06 13:31:50 +08:00
GareArc
d1f10ff301 feat: add redis mq for account deletion cleanup 2026-02-06 13:31:50 +08:00
Xiyuan Chen
c8027e168b feat: implement workspace permission checks for member invitations an… (#31202) 2026-02-06 13:31:46 +08:00
NFish
aae3f76999 feat: ee workspace permission control (#30841) 2026-02-06 13:31:26 +08:00
NFish
2860c72b03 feat: ee workspace permission control (#30841) 2026-02-06 13:13:06 +08:00
GareArc
4d47339ce6 feat: Add parent trace context propagation for workflow-as-tool hierarchy
Enables distributed tracing for nested workflows across all trace providers
(Langfuse, LangSmith, community providers). When a workflow invokes another
workflow via workflow-as-tool, the child workflow now includes parent context
attributes that allow trace systems to reconstruct the full execution tree.

Changes:
- Add parent_trace_context field to WorkflowTool
- Set parent context in tool node when invoking workflow-as-tool
- Extract and pass parent context through app generator

This is a community enhancement (ungated) that improves distributed tracing
for all users. Parent context includes: trace_id, node_execution_id,
workflow_run_id, and app_id.
2026-02-05 20:19:29 -08:00
GareArc
6e47e163b8 fix(telemetry): use atomic Redis SET NX for idempotency and register Celery queue 2026-02-05 20:15:34 -08:00
GareArc
1663a7ab4c feat(telemetry): add gateway diagnostics and verify integration
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-Claude)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-02-05 20:15:13 -08:00
GareArc
51b0c5c89c feat(telemetry): implement gateway routing and enqueue logic
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-Claude)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-02-05 20:15:13 -08:00
GareArc
752b01ae91 refactor(telemetry): migrate event handlers to gateway-only producers
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-Claude)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-02-05 20:15:12 -08:00
GareArc
3d3e8d75d8 feat(telemetry): add gateway envelope contracts and routing table 2026-02-05 20:15:12 -08:00
GareArc
55c0fe503d fix(telemetry): correct enterprise-only trace filtering logic
The logic was inverted - we were blocking all CE traces and only allowing
enterprise traces. The correct logic should be:
- Allow all CE traces (workflow, message, tool, etc.)
- Only block enterprise-only traces when enterprise telemetry is disabled

Before: if event.name not in _ENTERPRISE_ONLY_TRACES: return
After: if event.name in _ENTERPRISE_ONLY_TRACES and not is_enterprise_telemetry_enabled(): return
2026-02-05 20:15:12 -08:00
GareArc
adadf1ec5f refactor(telemetry): migrate to type-safe enum-based event routing with centralized enterprise filtering
Changes:
- Change TelemetryEvent.name from str to TraceTaskName enum for type safety
- Remove hardcoded trace_task_name_map from facade (no mapping needed)
- Add centralized enterprise-only filter in TelemetryFacade.emit()
- Rename is_telemetry_enabled() to is_enterprise_telemetry_enabled()
- Update all 11 call sites to pass TraceTaskName enum values
- Remove redundant enterprise guard from draft_trace.py
- Add unit tests for TelemetryFacade.emit() routing (6 tests)
- Add unit tests for TraceQueueManager telemetry guard (5 tests)
- Fix test fixture scoping issue for full test suite compatibility
- Fix tenant_id handling in agent tool callback handler

Benefits:
- 100% type-safe: basedpyright catches errors at compile time
- No string literals: eliminates entire class of typo bugs
- Single point of control: centralized filtering in facade
- All guards removed except facade
- Zero regressions: 4887 tests passing

Verification:
- make lint: PASS
- make type-check: PASS (0 errors, 0 warnings)
- pytest: 4887 passed, 8 skipped
2026-02-05 20:15:12 -08:00
GareArc
ed222945aa refactor(telemetry): introduce TelemetryFacade to centralize event emission
Migrate from direct TraceQueueManager.add_trace_task calls to TelemetryFacade.emit
with TelemetryEvent abstraction. This reduces CE code invasion by consolidating
telemetry logic in core/telemetry/ with a single guard in ops_trace_manager.py.
2026-02-05 20:15:11 -08:00
GareArc
2d60be311d fix: extract model_provider from model_config in prompt generation trace
The model_provider field in prompt generation traces was being incorrectly
extracted by parsing the model name (e.g., 'deepseek-chat'), which resulted
in an empty string when the model name didn't contain a '/' character.

Now extracts the provider directly from the model_config parameter, with
a fallback to the old parsing logic for backward compatibility.

Changes:
- Update _emit_prompt_generation_trace to accept model_config parameter
- Extract provider from model_config.get('provider') when available
- Update all 6 call sites to pass model_config
- Maintain backward compatibility with fallback logic
2026-02-05 20:15:11 -08:00
GareArc
80ee2e982e fix(telemetry): prevent UUID validation error for tenant-prefixed storage IDs
- get_ops_trace_instance was trying to query App table with storage_id format "tenant-{uuid}"
- This caused psycopg2.errors.InvalidTextRepresentation when app_id is None
- Added early return for tenant-prefixed storage identifiers to skip App lookup
- Enterprise telemetry still works correctly with these storage IDs
2026-02-05 20:15:11 -08:00
GareArc
5bbc938a0d fix(telemetry): add prompt generation trace emission for no_variable=false path
- The no_variable=false code path in generate_rule_config was missing trace emission
- Added timing wrapper and _emit_prompt_generation_trace call to ensure metrics/logs are captured
- Trace now emitted on both success and failure cases for consistency with no_variable=true path
2026-02-05 20:15:10 -08:00
GareArc
052f50805f feat(telemetry): add node_execution_id and app_id support to trace metadata
- Forward kwargs to message_trace to preserve node_execution_id
- Add node_execution_id extraction to all trace methods
- Add app_id parameter to prompt generation API endpoints
- Enable app_id tracing for rule_generate, code_generate, and structured_output operations
2026-02-05 20:15:10 -08:00
GareArc
f5043a8ac8 fix(telemetry): enable metrics and logs for standalone prompt generation
Remove app_id parameter from three endpoints and update trace manager to use
tenant_id as storage identifier when app_id is unavailable. This allows
standalone prompt generation utilities to emit telemetry.

Changes:
- controllers/console/app/generator.py: Remove app_id=None from 3 endpoints
  (RuleGenerateApi, RuleCodeGenerateApi, RuleStructuredOutputGenerateApi)
- core/ops/ops_trace_manager.py: Use tenant_id fallback in send_to_celery
  - Extract tenant_id from task.kwargs when app_id is None
  - Use 'tenant-{tenant_id}' format as storage identifier
  - Skip traces only if neither app_id nor tenant_id available

The trace metadata still contains the actual tenant_id, so enterprise
telemetry correctly emits metrics and logs grouped by tenant.
2026-02-05 20:15:10 -08:00
GareArc
a4bebbb5b5 fix(telemetry): remove app_id parameter from standalone prompt generation endpoints
Remove app_id=None from three prompt generation endpoints that lack proper
app context. These standalone utilities only have tenant_id available, so
we don't pass app_id at all rather than passing incomplete information.

Affected endpoints:
- /rule-generate (RuleGenerateApi)
- /code-generate (RuleCodeGenerateApi)
- /structured-output-generate (RuleStructuredOutputGenerateApi)
2026-02-05 20:15:10 -08:00
GareArc
22c8d8d772 feat(telemetry): add prompt generation telemetry to Enterprise OTEL
- Add PromptGenerationTraceInfo trace entity with operation_type field
- Implement telemetry for rule-generate, code-generate, structured-output, instruction-modify operations
- Emit metrics: tokens (total/input/output), duration histogram, requests counter, errors counter
- Emit structured logs with model info and operation context
- Content redaction controlled by ENTERPRISE_INCLUDE_CONTENT env var
- Fix user_id propagation in TraceTask kwargs
- Fix latency calculation when llm_result is None

No spans exported - metrics and logs only for lightweight observability.
2026-02-05 20:14:49 -08:00
GareArc
e67afa7a5b feat(telemetry): add input/output token metrics and fix trace cleanup
- Add dify.tokens.input and dify.tokens.output OTEL metrics
- Remove token split from trace log attributes (keep metrics only)
- Emit split token metrics for workflows and node executions
- Gracefully handle trace file deletion failures to prevent task crashes

BREAKING: None
MIGRATION: None
2026-02-05 20:12:30 -08:00
GareArc
8ceb1ed96f feat(telemetry): add input/output token split to enterprise OTEL traces
- Add PROMPT_TOKENS and COMPLETION_TOKENS to WorkflowNodeExecutionMetadataKey
- Store prompt/completion tokens in node execution metadata JSON (no schema change)
- Calculate workflow-level token split by summing node executions on-the-fly
- Export gen_ai.usage.input_tokens and output_tokens to enterprise telemetry
- Add semantic convention constants for token attributes
- Maintain backward compatibility (historical data shows null)

BREAKING: None
MIGRATION: None (uses JSON metadata, no schema changes)
2026-02-05 20:12:30 -08:00
GareArc
701f02f853 feat(telemetry): add invoked_by user tracking to enterprise OTEL 2026-02-05 20:12:29 -08:00
GareArc
639fb304ca fix(enterprise): Remove OTEL log export 2026-02-05 20:12:29 -08:00
GareArc
df44e79599 feat(enterprise): Add independent metrics export with dedicated MeterProvider
- Create dedicated MeterProvider instance (independent from ext_otel.py)
- Add create_metric_exporter() to _ExporterFactory with HTTP/gRPC support
- Enterprise metrics now work without requiring standard OTEL to be enabled
- Add MeterProvider shutdown to cleanup lifecycle
- Update module docstring to reflect full independence (Tracer, Logger, Meter)
2026-02-05 20:12:29 -08:00
GareArc
0497fd7469 fix(enterprise): Scope log handler to telemetry logger only
Only export structured telemetry logs, not all application logs. The attach_log_handler method now attaches to the 'dify.telemetry' logger instead of the root logger.
2026-02-05 20:12:29 -08:00
GareArc
bb3fcbfd5c feat(enterprise): Add gRPC protocol support for OTLP telemetry
- Add ENTERPRISE_OTLP_PROTOCOL config (http/grpc, default: http)
- Introduce _ExporterFactory class for protocol-agnostic exporter creation
- Support both HTTP and gRPC OTLP endpoints for traces and logs
- Refactor endpoint path handling into factory methods
2026-02-05 20:12:28 -08:00
GareArc
4d7ab24eb1 feat(enterprise): Add OTEL logs export with span_id correlation
- Add ENTERPRISE_OTEL_LOGS_ENABLED and ENTERPRISE_OTLP_LOGS_ENDPOINT config options
- Implement EnterpriseLoggingHandler for log record translation with trace/span ID parsing
- Add LoggerProvider and BatchLogRecordProcessor for OTLP log export
- Correlate telemetry logs with spans via span_id_source parameter
- Attach log handler during enterprise telemetry initialization
2026-02-05 20:12:28 -08:00
GareArc
3461c3a8ef feat(enterprise): Add OTEL telemetry with slim traces, metrics, and structured logs
- Add EnterpriseOtelTrace handler with span emission for workflows and nodes
- Implement minimal-span strategy: slim spans + detailed companion logs
- Add deterministic span/trace IDs for cross-workflow trace correlation
- Add metric collection at 100% accuracy (counters & histograms)
- Add event handlers for app lifecycle and feedback telemetry
- Add cross-workflow trace linking with parent context propagation
- Add OTEL exporter with configurable sampling and privacy controls
- Wire enterprise telemetry into workflow execution pipeline
- Add telemetry configuration in enterprise configs
2026-02-05 20:12:28 -08:00
wangxiaolei
cd03e0a9ef fix: fix delete_draft_variables_batch cycle forever (#31934)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-04 19:42:50 +08:00
zxhlyh
df2421d187 fix: auto summary env (#31930) 2026-02-04 19:42:26 +08:00
QuantumGhost
0ba321d840 chore: bump version in docker-compose and package manager to 1.12.1 (#31947) 2026-02-04 19:41:50 +08:00
78 changed files with 7234 additions and 1491 deletions

View File

@@ -106,10 +106,10 @@ ignore = [
"N803", # invalid-argument-name
]
"tests/*" = [
"F811", # redefined-while-unused
"T201", # allow print in tests,
"S110", # allow ignoring exceptions in tests code (currently)
"F811", # redefined-while-unused
"T201", # allow print in tests,
"S110", # allow ignoring exceptions in tests code (currently)
"PT019", # @patch-injected params look like unused fixtures
]
"controllers/console/explore/trial.py" = ["TID251"]
"controllers/console/human_input_form.py" = ["TID251"]

View File

@@ -122,7 +122,7 @@ These commands assume you start from the repository root.
```bash
cd api
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,enterprise_telemetry
```
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).

View File

@@ -81,6 +81,7 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_compress,
ext_database,
ext_enterprise_telemetry,
ext_fastopenapi,
ext_forward_refs,
ext_hosting_provider,
@@ -131,6 +132,7 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_fastopenapi,
ext_otel,
ext_enterprise_telemetry,
ext_request_logging,
ext_session_factory,
]

View File

@@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, Settings
from libs.file_utils import search_file_upwards
from .deploy import DeploymentConfig
from .enterprise import EnterpriseFeatureConfig
from .enterprise import EnterpriseFeatureConfig, EnterpriseTelemetryConfig
from .extra import ExtraServiceConfig
from .feature import FeatureConfig
from .middleware import MiddlewareConfig
@@ -73,6 +73,8 @@ class DifyConfig(
# Enterprise feature configs
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
EnterpriseFeatureConfig,
# Enterprise telemetry configs
EnterpriseTelemetryConfig,
):
model_config = SettingsConfigDict(
# read from dotenv format config file

View File

@@ -18,3 +18,44 @@ class EnterpriseFeatureConfig(BaseSettings):
description="Allow customization of the enterprise logo.",
default=False,
)
class EnterpriseTelemetryConfig(BaseSettings):
"""
Configuration for enterprise telemetry.
"""
ENTERPRISE_TELEMETRY_ENABLED: bool = Field(
description="Enable enterprise telemetry collection (also requires ENTERPRISE_ENABLED=true).",
default=False,
)
ENTERPRISE_OTLP_ENDPOINT: str = Field(
description="Enterprise OTEL collector endpoint.",
default="",
)
ENTERPRISE_OTLP_HEADERS: str = Field(
description="Auth headers for OTLP export (key=value,key2=value2).",
default="",
)
ENTERPRISE_OTLP_PROTOCOL: str = Field(
description="OTLP protocol: 'http' or 'grpc' (default: http).",
default="http",
)
ENTERPRISE_INCLUDE_CONTENT: bool = Field(
description="Include input/output content in traces (privacy toggle).",
default=True,
)
ENTERPRISE_SERVICE_NAME: str = Field(
description="Service name for OTEL resource.",
default="dify",
)
ENTERPRISE_OTEL_SAMPLING_RATE: float = Field(
description="Sampling rate for enterprise traces (0.0 to 1.0, default 1.0 = 100%).",
default=1.0,
)

View File

@@ -1,4 +1,5 @@
from collections.abc import Sequence
from typing import Any
from flask_restx import Resource
from pydantic import BaseModel, Field
@@ -11,12 +12,10 @@ from controllers.console.app.error import (
ProviderQuotaExceededError,
)
from controllers.console.wraps import account_initialization_required, setup_required
from core.app.app_config.entities import ModelConfig
from core.errors.error import ModelCurrentlyNotSupportError, ProviderTokenNotInitError, QuotaExceededError
from core.helper.code_executor.code_node_provider import CodeNodeProvider
from core.helper.code_executor.javascript.javascript_code_provider import JavascriptCodeProvider
from core.helper.code_executor.python3.python3_code_provider import Python3CodeProvider
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.llm_generator import LLMGenerator
from core.model_runtime.errors.invoke import InvokeError
from extensions.ext_database import db
@@ -27,14 +26,32 @@ from services.workflow_service import WorkflowService
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
class RuleGeneratePayload(BaseModel):
instruction: str = Field(..., description="Rule generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
no_variable: bool = Field(default=False, description="Whether to exclude variables")
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")
class RuleCodeGeneratePayload(RuleGeneratePayload):
code_language: str = Field(default="javascript", description="Programming language for code generation")
class RuleStructuredOutputPayload(BaseModel):
instruction: str = Field(..., description="Structured output generation instruction")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")
class InstructionGeneratePayload(BaseModel):
flow_id: str = Field(..., description="Workflow/Flow ID")
node_id: str = Field(default="", description="Node ID for workflow context")
current: str = Field(default="", description="Current instruction text")
language: str = Field(default="javascript", description="Programming language (javascript/python)")
instruction: str = Field(..., description="Instruction for generation")
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
model_config_data: dict[str, Any] = Field(..., alias="model_config", description="Model configuration")
ideal_output: str = Field(default="", description="Expected ideal output")
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")
class InstructionTemplatePayload(BaseModel):
@@ -50,7 +67,6 @@ reg(RuleCodeGeneratePayload)
reg(RuleStructuredOutputPayload)
reg(InstructionGeneratePayload)
reg(InstructionTemplatePayload)
reg(ModelConfig)
@console_ns.route("/rule-generate")
@@ -66,10 +82,17 @@ class RuleGenerateApi(Resource):
@account_initialization_required
def post(self):
args = RuleGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
account, current_tenant_id = current_account_with_tenant()
try:
rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args)
rules = LLMGenerator.generate_rule_config(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=args.no_variable,
user_id=account.id,
app_id=args.app_id,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
@@ -95,12 +118,16 @@ class RuleCodeGenerateApi(Resource):
@account_initialization_required
def post(self):
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
account, current_tenant_id = current_account_with_tenant()
try:
code_result = LLMGenerator.generate_code(
tenant_id=current_tenant_id,
args=args,
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.code_language,
user_id=account.id,
app_id=args.app_id,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -127,12 +154,15 @@ class RuleStructuredOutputGenerateApi(Resource):
@account_initialization_required
def post(self):
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
account, current_tenant_id = current_account_with_tenant()
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=current_tenant_id,
args=args,
instruction=args.instruction,
model_config=args.model_config_data,
user_id=account.id,
app_id=args.app_id,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -159,14 +189,14 @@ class InstructionGenerateApi(Resource):
@account_initialization_required
def post(self):
args = InstructionGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
account, current_tenant_id = current_account_with_tenant()
app_id = args.app_id or args.flow_id
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next(
(p for p in providers if p.is_accept_language(args.language)), None
)
code_template = code_provider.get_default_code() if code_provider else ""
try:
# Generate from nothing for a workflow node
if (args.current in (code_template, "")) and args.node_id != "":
app = db.session.query(App).where(App.id == args.flow_id).first()
if not app:
@@ -183,33 +213,33 @@ class InstructionGenerateApi(Resource):
case "llm":
return LLMGenerator.generate_rule_config(
current_tenant_id,
args=RuleGeneratePayload(
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
),
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
user_id=account.id,
app_id=app_id,
)
case "agent":
return LLMGenerator.generate_rule_config(
current_tenant_id,
args=RuleGeneratePayload(
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
),
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
user_id=account.id,
app_id=app_id,
)
case "code":
return LLMGenerator.generate_code(
tenant_id=current_tenant_id,
args=RuleCodeGeneratePayload(
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.language,
),
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.language,
user_id=account.id,
app_id=app_id,
)
case _:
return {"error": f"invalid node type: {node_type}"}
if args.node_id == "" and args.current != "": # For legacy app without a workflow
if args.node_id == "" and args.current != "":
return LLMGenerator.instruction_modify_legacy(
tenant_id=current_tenant_id,
flow_id=args.flow_id,
@@ -217,8 +247,10 @@ class InstructionGenerateApi(Resource):
instruction=args.instruction,
model_config=args.model_config_data,
ideal_output=args.ideal_output,
user_id=account.id,
app_id=app_id,
)
if args.node_id != "" and args.current != "": # For workflow node
if args.node_id != "" and args.current != "":
return LLMGenerator.instruction_modify_workflow(
tenant_id=current_tenant_id,
flow_id=args.flow_id,
@@ -228,6 +260,8 @@ class InstructionGenerateApi(Resource):
model_config=args.model_config_data,
ideal_output=args.ideal_output,
workflow_service=WorkflowService(),
user_id=account.id,
app_id=app_id,
)
return {"error": "incompatible parameters"}, 400
except ProviderTokenNotInitError as ex:

View File

@@ -1,6 +1,7 @@
from typing import Any
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest
@@ -77,7 +78,10 @@ class TraceAppConfigApi(Resource):
try:
result = OpsService.create_tracing_app_config(
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
app_id=app_id,
tracing_provider=args.tracing_provider,
tracing_config=args.tracing_config,
account_id=current_user.id,
)
if not result:
raise TracingConfigIsExist()
@@ -102,7 +106,10 @@ class TraceAppConfigApi(Resource):
try:
result = OpsService.update_tracing_app_config(
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
app_id=app_id,
tracing_provider=args.tracing_provider,
tracing_config=args.tracing_config,
account_id=current_user.id,
)
if not result:
raise TracingConfigNotExist()
@@ -124,7 +131,9 @@ class TraceAppConfigApi(Resource):
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
result = OpsService.delete_tracing_app_config(
app_id=app_id, tracing_provider=args.tracing_provider, account_id=current_user.id
)
if not result:
raise TracingConfigNotExist()
return {"result": "success"}, 204

View File

@@ -878,7 +878,11 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
current_user, current_tenant_id = current_account_with_tenant()
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
return BuiltinToolManageService.set_default_provider(
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=payload.id
tenant_id=current_tenant_id,
user_id=current_user.id,
provider=provider,
id=args["id"],
account=current_user,
)

View File

@@ -79,7 +79,7 @@ class BaseAgentRunner(AppRunner):
self.model_instance = model_instance
# init callback
self.agent_callback = DifyAgentCallbackHandler()
self.agent_callback = DifyAgentCallbackHandler(tenant_id=tenant_id)
# init dataset tools
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager=queue_manager,

View File

@@ -63,6 +63,8 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
@@ -564,7 +566,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle stop events."""
_ = trace_manager
resolved_state = None
if self._workflow_run_id:
resolved_state = self._resolve_graph_runtime_state(graph_runtime_state)
@@ -579,8 +580,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
with self._database_session() as session:
# Save message
self._save_message(session=session, graph_runtime_state=resolved_state)
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
yield workflow_finish_resp
elif event.stopped_by in (
@@ -589,8 +589,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
):
# When hitting input-moderation or annotation-reply, the workflow will not start
with self._database_session() as session:
# Save message
self._save_message(session=session)
self._save_message(session=session, trace_manager=trace_manager)
yield self._message_end_to_stream_response()
@@ -599,6 +598,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
event: QueueAdvancedChatMessageEndEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle advanced chat message end events."""
@@ -616,7 +616,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
# Save message
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
yield self._message_end_to_stream_response()
@@ -770,7 +770,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if self._conversation_name_generate_thread:
logger.debug("Conversation name generation running as daemon thread")
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
def _save_message(
self,
*,
session: Session,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
):
message = self._get_message(session=session)
# If there are assistant files, remove markdown image links from answer
@@ -826,6 +832,22 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
]
session.add_all(message_files)
if trace_manager:
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.MESSAGE_TRACE,
context=TelemetryContext(
tenant_id=self._application_generate_entity.app_config.tenant_id,
app_id=self._application_generate_entity.app_config.app_id,
),
payload={
"conversation_id": str(message.conversation_id),
"message_id": str(message.id),
},
),
trace_manager=trace_manager,
)
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
"""Bootstrap the cached runtime state from the queue manager when present."""
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state

View File

@@ -147,9 +147,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
inputs: Mapping[str, Any] = args["inputs"]
extras = {
extras: dict[str, Any] = {
**extract_external_trace_id_from_args(args),
}
parent_trace_context = args.get("_parent_trace_context")
if parent_trace_context:
extras["parent_trace_context"] = parent_trace_context
workflow_run_id = str(uuid.uuid4())
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
# trigger shouldn't prepare user inputs

View File

@@ -52,10 +52,11 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
from events.message_event import message_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
@@ -409,10 +410,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.message_metadata = self._task_state.metadata.model_dump_json()
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
)
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.MESSAGE_TRACE,
context=TelemetryContext(
tenant_id=self._application_generate_entity.app_config.tenant_id,
app_id=self._application_generate_entity.app_config.app_id,
),
payload={
"conversation_id": self._conversation_id,
"message_id": self._message_id,
},
),
trace_manager=trace_manager,
)
message_was_created.send(

View File

@@ -15,8 +15,7 @@ from datetime import datetime
from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
from core.workflow.enums import (
@@ -373,6 +372,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_execution_repository.save_execution_data(domain_execution)
self._enqueue_node_trace_task(domain_execution)
def _fail_running_node_executions(self, *, error_message: str) -> None:
now = naive_utc_now()
@@ -390,17 +390,131 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
external_trace_id = None
parent_trace_context = None
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
external_trace_id = self._application_generate_entity.extras.get("external_trace_id")
parent_trace_context = self._application_generate_entity.extras.get("parent_trace_context")
trace_task = TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=execution,
conversation_id=conversation_id,
user_id=self._trace_manager.user_id,
external_trace_id=external_trace_id,
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.WORKFLOW_TRACE,
context=TelemetryContext(
tenant_id=self._application_generate_entity.app_config.tenant_id,
user_id=self._trace_manager.user_id,
app_id=self._application_generate_entity.app_config.app_id,
),
payload={
"workflow_execution": execution,
"conversation_id": conversation_id,
"user_id": self._trace_manager.user_id,
"external_trace_id": external_trace_id,
"parent_trace_context": parent_trace_context,
},
),
trace_manager=self._trace_manager,
)
def _enqueue_node_trace_task(self, domain_execution: WorkflowNodeExecution) -> None:
if not self._trace_manager:
return
execution = self._get_workflow_execution()
meta = domain_execution.metadata or {}
parent_trace_context = None
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
parent_trace_context = self._application_generate_entity.extras.get("parent_trace_context")
node_data: dict[str, Any] = {
"workflow_id": domain_execution.workflow_id,
"workflow_execution_id": execution.id_,
"tenant_id": self._application_generate_entity.app_config.tenant_id,
"app_id": self._application_generate_entity.app_config.app_id,
"node_execution_id": domain_execution.id,
"node_id": domain_execution.node_id,
"node_type": str(domain_execution.node_type.value),
"title": domain_execution.title,
"status": str(domain_execution.status.value),
"error": domain_execution.error,
"elapsed_time": domain_execution.elapsed_time,
"index": domain_execution.index,
"predecessor_node_id": domain_execution.predecessor_node_id,
"created_at": domain_execution.created_at,
"finished_at": domain_execution.finished_at,
"total_tokens": meta.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0),
"prompt_tokens": meta.get(WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS),
"completion_tokens": meta.get(WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS),
"total_price": meta.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0),
"currency": meta.get(WorkflowNodeExecutionMetadataKey.CURRENCY),
"tool_name": (meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name")
if isinstance(meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict)
else None,
"iteration_id": meta.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID),
"iteration_index": meta.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX),
"loop_id": meta.get(WorkflowNodeExecutionMetadataKey.LOOP_ID),
"loop_index": meta.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX),
"parallel_id": meta.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID),
"node_inputs": dict(domain_execution.inputs) if domain_execution.inputs else None,
"node_outputs": dict(domain_execution.outputs) if domain_execution.outputs else None,
"process_data": dict(domain_execution.process_data) if domain_execution.process_data else None,
}
node_data["invoke_from"] = self._application_generate_entity.invoke_from.value
node_data["user_id"] = self._system_variables().get(SystemVariableKey.USER_ID.value)
if domain_execution.node_type.value == "knowledge-retrieval" and domain_execution.outputs:
results = domain_execution.outputs.get("result") or []
dataset_ids: list[str] = []
dataset_names: list[str] = []
for doc in results:
if not isinstance(doc, dict):
continue
doc_meta = doc.get("metadata") or {}
did = doc_meta.get("dataset_id")
dname = doc_meta.get("dataset_name")
if did and did not in dataset_ids:
dataset_ids.append(did)
if dname and dname not in dataset_names:
dataset_names.append(dname)
if dataset_ids:
node_data["dataset_ids"] = dataset_ids
if dataset_names:
node_data["dataset_names"] = dataset_names
tool_info = meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO)
if isinstance(tool_info, dict):
plugin_id = tool_info.get("plugin_unique_identifier")
if plugin_id:
node_data["plugin_name"] = plugin_id
credential_id = tool_info.get("credential_id")
if credential_id:
node_data["credential_id"] = credential_id
node_data["credential_provider_type"] = tool_info.get("provider_type")
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
if conversation_id:
node_data["conversation_id"] = conversation_id
if parent_trace_context:
node_data["parent_trace_context"] = parent_trace_context
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id=node_data.get("tenant_id"),
user_id=node_data.get("user_id"),
app_id=node_data.get("app_id"),
),
payload={"node_execution_data": node_data},
),
trace_manager=self._trace_manager,
)
self._trace_manager.add_trace_task(trace_task)
def _system_variables(self) -> Mapping[str, Any]:
runtime_state = self.graph_runtime_state

View File

@@ -4,8 +4,9 @@ from typing import Any, TextIO, Union
from pydantic import BaseModel
from configs import dify_config
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.ops_trace_manager import TraceQueueManager
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
from core.tools.entities.tool_entities import ToolInvokeMessage
_TEXT_COLOR_MAPPING = {
@@ -36,13 +37,15 @@ class DifyAgentCallbackHandler(BaseModel):
color: str | None = ""
current_loop: int = 1
tenant_id: str | None = None
def __init__(self, color: str | None = None):
def __init__(self, color: str | None = None, tenant_id: str | None = None):
super().__init__()
"""Initialize callback handler."""
# use a specific color is not specified
self.color = color or "green"
self.current_loop = 1
self.tenant_id = tenant_id
def on_tool_start(
self,
@@ -71,15 +74,23 @@ class DifyAgentCallbackHandler(BaseModel):
print_text("\n")
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.TOOL_TRACE,
message_id=message_id,
tool_name=tool_name,
tool_inputs=tool_inputs,
tool_outputs=tool_outputs,
timer=timer,
)
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.TOOL_TRACE,
context=TelemetryContext(
tenant_id=self.tenant_id,
app_id=trace_manager.app_id,
user_id=trace_manager.user_id,
),
payload={
"message_id": message_id,
"tool_name": tool_name,
"tool_inputs": tool_inputs,
"tool_outputs": tool_outputs,
"timer": timer,
},
),
trace_manager=trace_manager,
)
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any):

View File

@@ -6,8 +6,6 @@ from typing import Protocol, cast
import json_repair
from core.app.app_config.entities import ModelConfig
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.llm_generator.prompts import (
@@ -27,10 +25,10 @@ from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from extensions.ext_storage import storage
@@ -73,8 +71,8 @@ class LLMGenerator:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
)
answer = response.message.get_text_content()
if answer == "":
answer = cast(str, response.message.content)
if answer is None:
return ""
try:
result_dict = json.loads(answer)
@@ -96,15 +94,17 @@ class LLMGenerator:
name = name[:75] + "..."
# get tracing instance
trace_manager = TraceQueueManager(app_id=app_id)
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.GENERATE_NAME_TRACE,
conversation_id=conversation_id,
generate_conversation_name=name,
inputs=prompt,
timer=timer,
tenant_id=tenant_id,
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.GENERATE_NAME_TRACE,
context=TelemetryContext(tenant_id=tenant_id, app_id=app_id),
payload={
"conversation_id": conversation_id,
"generate_conversation_name": name,
"inputs": prompt,
"timer": timer,
"tenant_id": tenant_id,
},
)
)
@@ -153,19 +153,27 @@ class LLMGenerator:
return questions
@classmethod
def generate_rule_config(cls, tenant_id: str, args: RuleGeneratePayload):
def generate_rule_config(
cls,
tenant_id: str,
instruction: str,
model_config: dict,
no_variable: bool,
user_id: str | None = None,
app_id: str | None = None,
):
output_parser = RuleConfigGeneratorOutputParser()
error = ""
error_step = ""
rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""}
model_parameters = args.model_config_data.completion_params
if args.no_variable:
model_parameters = model_config.get("completion_params", {})
if no_variable:
prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE)
prompt_generate = prompt_template.format(
inputs={
"TASK_DESCRIPTION": args.instruction,
"TASK_DESCRIPTION": instruction,
},
remove_template_variables=False,
)
@@ -177,26 +185,45 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=args.model_config_data.provider,
model=args.model_config_data.name,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
llm_result = None
with measure_time() as timer:
try:
llm_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
rule_config["prompt"] = response.message.get_text_content()
rule_config["prompt"] = cast(str, llm_result.message.content)
except InvokeError as e:
error = str(e)
error_step = "generate rule config"
except Exception as e:
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
rule_config["error"] = str(e)
except InvokeError as e:
error = str(e)
error_step = "generate rule config"
except Exception as e:
logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
rule_config["error"] = str(e)
error = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
if user_id:
prompt_value = rule_config.get("prompt", "")
generated_output = str(prompt_value) if prompt_value else ""
cls._emit_prompt_generation_trace(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
operation_type="rule_generate",
instruction=instruction,
generated_output=generated_output,
llm_result=llm_result,
model_config=model_config,
timer=timer,
error=error or None,
)
return rule_config
# get rule config prompt, parameter and statement
@@ -211,7 +238,7 @@ class LLMGenerator:
# format the prompt_generate_prompt
prompt_generate_prompt = prompt_template.format(
inputs={
"TASK_DESCRIPTION": args.instruction,
"TASK_DESCRIPTION": instruction,
},
remove_template_variables=False,
)
@@ -222,84 +249,125 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=args.model_config_data.provider,
model=args.model_config_data.name,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
try:
llm_result = None
with measure_time() as timer:
try:
# the first step to generate the task prompt
prompt_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
try:
# the first step to generate the task prompt
prompt_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
llm_result = prompt_content
except InvokeError as e:
error = str(e)
error_step = "generate prefix prompt"
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
if user_id:
cls._emit_prompt_generation_trace(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
operation_type="rule_generate",
instruction=instruction,
generated_output="",
llm_result=llm_result,
model_config=model_config,
timer=timer,
error=error,
)
return rule_config
rule_config["prompt"] = cast(str, prompt_content.message.content)
if not isinstance(prompt_content.message.content, str):
raise NotImplementedError("prompt content is not a string")
parameter_generate_prompt = parameter_template.format(
inputs={
"INPUT_TEXT": prompt_content.message.content,
},
remove_template_variables=False,
)
except InvokeError as e:
error = str(e)
error_step = "generate prefix prompt"
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)]
return rule_config
rule_config["prompt"] = prompt_content.message.get_text_content()
parameter_generate_prompt = parameter_template.format(
inputs={
"INPUT_TEXT": prompt_content.message.get_text_content(),
},
remove_template_variables=False,
)
parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)]
# the second step to generate the task_parameter and task_statement
statement_generate_prompt = statement_template.format(
inputs={
"TASK_DESCRIPTION": args.instruction,
"INPUT_TEXT": prompt_content.message.get_text_content(),
},
remove_template_variables=False,
)
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
try:
parameter_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
# the second step to generate the task_parameter and task_statement
statement_generate_prompt = statement_template.format(
inputs={
"TASK_DESCRIPTION": instruction,
"INPUT_TEXT": prompt_content.message.content,
},
remove_template_variables=False,
)
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.get_text_content())
except InvokeError as e:
error = str(e)
error_step = "generate variables"
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
try:
statement_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
)
rule_config["opening_statement"] = statement_content.message.get_text_content()
except InvokeError as e:
error = str(e)
error_step = "generate conversation opener"
try:
parameter_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
)
rule_config["variables"] = re.findall(
r'"\s*([^"]+)\s*"', cast(str, parameter_content.message.content)
)
except InvokeError as e:
error = str(e)
error_step = "generate variables"
except Exception as e:
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
rule_config["error"] = str(e)
try:
statement_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
)
rule_config["opening_statement"] = cast(str, statement_content.message.content)
except InvokeError as e:
error = str(e)
error_step = "generate conversation opener"
except Exception as e:
logger.exception("Failed to generate rule config, model: %s", model_config.get("name"))
rule_config["error"] = str(e)
error = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
if user_id:
generated_output = rule_config.get("prompt", "")
cls._emit_prompt_generation_trace(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
operation_type="rule_generate",
instruction=instruction,
generated_output=str(generated_output) if generated_output else "",
llm_result=llm_result,
model_config=model_config,
timer=timer,
error=error or None,
)
return rule_config
@classmethod
def generate_code(
cls,
tenant_id: str,
args: RuleCodeGeneratePayload,
instruction: str,
model_config: dict,
code_language: str = "javascript",
user_id: str | None = None,
app_id: str | None = None,
):
if args.code_language == "python":
if code_language == "python":
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
else:
prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE)
prompt = prompt_template.format(
inputs={
"INSTRUCTION": args.instruction,
"CODE_LANGUAGE": args.code_language,
"INSTRUCTION": instruction,
"CODE_LANGUAGE": code_language,
},
remove_template_variables=False,
)
@@ -308,28 +376,49 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=args.model_config_data.provider,
model=args.model_config_data.name,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
prompt_messages = [UserPromptMessage(content=prompt)]
model_parameters = args.model_config_data.completion_params
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
model_parameters = model_config.get("completion_params", {})
llm_result = None
error = None
with measure_time() as timer:
try:
llm_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
generated_code = cast(str, llm_result.message.content)
result = {"code": generated_code, "language": code_language, "error": ""}
except InvokeError as e:
error = str(e)
result = {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logger.exception(
"Failed to invoke LLM model, model: %s, language: %s", model_config.get("name"), code_language
)
error = str(e)
result = {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
if user_id:
cls._emit_prompt_generation_trace(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
operation_type="code_generate",
instruction=instruction,
generated_output=result.get("code", ""),
llm_result=llm_result,
model_config=model_config,
timer=timer,
error=error,
)
generated_code = response.message.get_text_content()
return {"code": generated_code, "language": args.code_language, "error": ""}
except InvokeError as e:
error = str(e)
return {"code": "", "language": args.code_language, "error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logger.exception(
"Failed to invoke LLM model, model: %s, language: %s", args.model_config_data.name, args.code_language
)
return {"code": "", "language": args.code_language, "error": f"An unexpected error occurred: {str(e)}"}
return result
@classmethod
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
@@ -355,49 +444,76 @@ class LLMGenerator:
raise TypeError("Expected LLMResult when stream=False")
response = result
answer = response.message.get_text_content()
answer = cast(str, response.message.content)
return answer.strip()
@classmethod
def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload):
def generate_structured_output(
cls, tenant_id: str, instruction: str, model_config: dict, user_id: str | None = None, app_id: str | None = None
):
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=args.model_config_data.provider,
model=args.model_config_data.name,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
prompt_messages = [
SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
UserPromptMessage(content=args.instruction),
UserPromptMessage(content=instruction),
]
model_parameters = args.model_config_data.completion_params
model_parameters = model_config.get("model_parameters", {})
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
llm_result = None
error = None
result = {"output": "", "error": ""}
with measure_time() as timer:
try:
llm_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
raw_content = llm_result.message.content
if not isinstance(raw_content, str):
raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}")
try:
parsed_content = json.loads(raw_content)
except json.JSONDecodeError:
parsed_content = json_repair.loads(raw_content)
if not isinstance(parsed_content, dict | list):
raise ValueError(f"Failed to parse structured output from llm: {raw_content}")
generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False)
result = {"output": generated_json_schema, "error": ""}
except InvokeError as e:
error = str(e)
result = {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
except Exception as e:
logger.exception("Failed to invoke LLM model, model: %s", model_config.get("name"))
error = str(e)
result = {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
if user_id:
cls._emit_prompt_generation_trace(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
operation_type="structured_output",
instruction=instruction,
generated_output=result.get("output", ""),
llm_result=llm_result,
model_config=model_config,
timer=timer,
error=error,
)
raw_content = response.message.get_text_content()
try:
parsed_content = json.loads(raw_content)
except json.JSONDecodeError:
parsed_content = json_repair.loads(raw_content)
if not isinstance(parsed_content, dict | list):
raise ValueError(f"Failed to parse structured output from llm: {raw_content}")
generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False)
return {"output": generated_json_schema, "error": ""}
except InvokeError as e:
error = str(e)
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
except Exception as e:
logger.exception("Failed to invoke LLM model, model: %s", args.model_config_data.name)
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
return result
@staticmethod
def instruction_modify_legacy(
@@ -405,14 +521,16 @@ class LLMGenerator:
flow_id: str,
current: str,
instruction: str,
model_config: ModelConfig,
model_config: dict,
ideal_output: str | None,
user_id: str | None = None,
app_id: str | None = None,
):
last_run: Message | None = (
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
)
if not last_run:
return LLMGenerator.__instruction_modify_common(
result = LLMGenerator.__instruction_modify_common(
tenant_id=tenant_id,
model_config=model_config,
last_run=None,
@@ -421,22 +539,28 @@ class LLMGenerator:
instruction=instruction,
node_type="llm",
ideal_output=ideal_output,
user_id=user_id,
app_id=app_id,
)
last_run_dict = {
"query": last_run.query,
"answer": last_run.answer,
"error": last_run.error,
}
return LLMGenerator.__instruction_modify_common(
tenant_id=tenant_id,
model_config=model_config,
last_run=last_run_dict,
current=current,
error_message=str(last_run.error),
instruction=instruction,
node_type="llm",
ideal_output=ideal_output,
)
else:
last_run_dict = {
"query": last_run.query,
"answer": last_run.answer,
"error": last_run.error,
}
result = LLMGenerator.__instruction_modify_common(
tenant_id=tenant_id,
model_config=model_config,
last_run=last_run_dict,
current=current,
error_message=str(last_run.error),
instruction=instruction,
node_type="llm",
ideal_output=ideal_output,
user_id=user_id,
app_id=app_id,
)
return result
@staticmethod
def instruction_modify_workflow(
@@ -445,9 +569,11 @@ class LLMGenerator:
node_id: str,
current: str,
instruction: str,
model_config: ModelConfig,
model_config: dict,
ideal_output: str | None,
workflow_service: WorkflowServiceInterface,
user_id: str | None = None,
app_id: str | None = None,
):
session = db.session()
@@ -478,6 +604,8 @@ class LLMGenerator:
instruction=instruction,
node_type=node_type,
ideal_output=ideal_output,
user_id=user_id,
app_id=app_id,
)
def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence:
@@ -511,18 +639,22 @@ class LLMGenerator:
instruction=instruction,
node_type=last_run.node_type,
ideal_output=ideal_output,
user_id=user_id,
app_id=app_id,
)
@staticmethod
def __instruction_modify_common(
tenant_id: str,
model_config: ModelConfig,
model_config: dict,
last_run: dict | None,
current: str | None,
error_message: str | None,
instruction: str,
node_type: str,
ideal_output: str | None,
user_id: str | None = None,
app_id: str | None = None,
):
LAST_RUN = "{{#last_run#}}"
CURRENT = "{{#current#}}"
@@ -537,8 +669,8 @@ class LLMGenerator:
model_instance = ModelManager().get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=model_config.provider,
model=model_config.name,
provider=model_config.get("provider", ""),
model=model_config.get("name", ""),
)
match node_type:
case "llm" | "agent":
@@ -562,24 +694,122 @@ class LLMGenerator:
]
model_parameters = {"temperature": 0.4}
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
llm_result = None
error = None
result = {}
with measure_time() as timer:
try:
llm_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
generated_raw = llm_result.message.get_text_content()
first_brace = generated_raw.find("{")
last_brace = generated_raw.rfind("}")
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
json_str = generated_raw[first_brace : last_brace + 1]
data = json_repair.loads(json_str)
if not isinstance(data, dict):
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
result = data
except InvokeError as e:
error = str(e)
result = {"error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logger.exception(
"Failed to invoke LLM model, model: %s", json.dumps(model_config.get("name")), exc_info=True
)
error = str(e)
result = {"error": f"An unexpected error occurred: {str(e)}"}
if user_id:
generated_output = ""
if isinstance(result, dict):
for key in ["prompt", "code", "output", "modified"]:
if result.get(key):
generated_output = str(result[key])
break
LLMGenerator._emit_prompt_generation_trace(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
operation_type="instruction_modify",
instruction=instruction,
generated_output=generated_output,
llm_result=llm_result,
model_config=model_config,
timer=timer,
error=error,
)
generated_raw = response.message.get_text_content()
first_brace = generated_raw.find("{")
last_brace = generated_raw.rfind("}")
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
json_str = generated_raw[first_brace : last_brace + 1]
data = json_repair.loads(json_str)
if not isinstance(data, dict):
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
return data
except InvokeError as e:
error = str(e)
return {"error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True)
return {"error": f"An unexpected error occurred: {str(e)}"}
return result
@classmethod
def _emit_prompt_generation_trace(
cls,
tenant_id: str,
user_id: str,
app_id: str | None,
operation_type: str,
instruction: str,
generated_output: str,
llm_result: LLMResult | None,
model_config: dict | None = None,
timer=None,
error: str | None = None,
):
if llm_result:
prompt_tokens = llm_result.usage.prompt_tokens
completion_tokens = llm_result.usage.completion_tokens
total_tokens = llm_result.usage.total_tokens
model_name = llm_result.model
# Extract provider from model_config if available, otherwise fall back to parsing model name
if model_config and model_config.get("provider"):
model_provider = model_config.get("provider", "")
else:
model_provider = model_name.split("/")[0] if "/" in model_name else ""
latency = llm_result.usage.latency
total_price = float(llm_result.usage.total_price) if llm_result.usage.total_price else None
currency = llm_result.usage.currency
else:
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
model_provider = model_config.get("provider", "") if model_config else ""
model_name = model_config.get("name", "") if model_config else ""
latency = 0.0
if timer:
start_time = timer.get("start")
end_time = timer.get("end")
if start_time and end_time:
latency = (end_time - start_time).total_seconds()
total_price = None
currency = None
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.PROMPT_GENERATION_TRACE,
context=TelemetryContext(tenant_id=tenant_id, user_id=user_id, app_id=app_id),
payload={
"tenant_id": tenant_id,
"user_id": user_id,
"app_id": app_id,
"operation_type": operation_type,
"instruction": instruction,
"generated_output": generated_output,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"model_provider": model_provider,
"model_name": model_name,
"latency": latency,
"total_price": total_price,
"currency": currency,
"timer": timer,
"error": error,
},
)
)

View File

@@ -15,16 +15,23 @@ class TraceContextFilter(logging.Filter):
"""
def filter(self, record: logging.LogRecord) -> bool:
# Get trace context from OpenTelemetry
trace_id, span_id = self._get_otel_context()
# Preserve explicit trace_id set by the caller (e.g. emit_metric_only_event)
existing_trace_id = getattr(record, "trace_id", "")
if not existing_trace_id:
# Get trace context from OpenTelemetry
trace_id, span_id = self._get_otel_context()
# Set trace_id (fallback to ContextVar if no OTEL context)
if trace_id:
record.trace_id = trace_id
# Set trace_id (fallback to ContextVar if no OTEL context)
if trace_id:
record.trace_id = trace_id
else:
record.trace_id = get_trace_id()
record.span_id = span_id or ""
else:
record.trace_id = get_trace_id()
record.span_id = span_id or ""
# Keep existing trace_id; only fill span_id if missing
if not getattr(record, "span_id", ""):
record.span_id = ""
# For backward compatibility, also set req_id
record.req_id = get_request_id()
@@ -55,9 +62,12 @@ class IdentityContextFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
identity = self._extract_identity()
record.tenant_id = identity.get("tenant_id", "")
record.user_id = identity.get("user_id", "")
record.user_type = identity.get("user_type", "")
if not getattr(record, "tenant_id", ""):
record.tenant_id = identity.get("tenant_id", "")
if not getattr(record, "user_id", ""):
record.user_id = identity.get("user_id", "")
if not getattr(record, "user_type", ""):
record.user_type = identity.get("user_type", "")
return True
def _extract_identity(self) -> dict[str, str]:

View File

@@ -5,9 +5,10 @@ from typing import Any
from core.app.app_config.entities import AppConfig
from core.moderation.base import ModerationAction, ModerationError
from core.moderation.factory import ModerationFactory
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.ops_trace_manager import TraceQueueManager
from core.ops.utils import measure_time
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
logger = logging.getLogger(__name__)
@@ -49,14 +50,18 @@ class InputModeration:
moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.MODERATION_TRACE,
message_id=message_id,
moderation_result=moderation_result,
inputs=inputs,
timer=timer,
)
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.MODERATION_TRACE,
context=TelemetryContext(tenant_id=tenant_id, app_id=app_id),
payload={
"message_id": message_id,
"moderation_result": moderation_result,
"inputs": inputs,
"timer": timer,
},
),
trace_manager=trace_manager,
)
if not moderation_result.flagged:

View File

@@ -9,8 +9,8 @@ from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
class BaseTraceInfo(BaseModel):
message_id: str | None = None
message_data: Any | None = None
inputs: Union[str, dict[str, Any], list] | None = None
outputs: Union[str, dict[str, Any], list] | None = None
inputs: Union[str, dict[str, Any], list[Any]] | None = None
outputs: Union[str, dict[str, Any], list[Any]] | None = None
start_time: datetime | None = None
end_time: datetime | None = None
metadata: dict[str, Any]
@@ -18,7 +18,7 @@ class BaseTraceInfo(BaseModel):
@field_validator("inputs", "outputs")
@classmethod
def ensure_type(cls, v):
def ensure_type(cls, v: str | dict[str, Any] | list[Any] | None) -> str | dict[str, Any] | list[Any] | None:
if v is None:
return None
if isinstance(v, str | dict | list):
@@ -48,10 +48,14 @@ class WorkflowTraceInfo(BaseTraceInfo):
workflow_run_version: str
error: str | None = None
total_tokens: int
prompt_tokens: int | None = None
completion_tokens: int | None = None
file_list: list[str]
query: str
metadata: dict[str, Any]
invoked_by: str | None = None
class MessageTraceInfo(BaseTraceInfo):
conversation_model: str
@@ -59,7 +63,7 @@ class MessageTraceInfo(BaseTraceInfo):
answer_tokens: int
total_tokens: int
error: str | None = None
file_list: Union[str, dict[str, Any], list] | None = None
file_list: Union[str, dict[str, Any], list[Any]] | None = None
message_file_data: Any | None = None
conversation_mode: str
gen_ai_server_time_to_first_token: float | None = None
@@ -106,7 +110,7 @@ class ToolTraceInfo(BaseTraceInfo):
tool_config: dict[str, Any]
time_cost: Union[int, float]
tool_parameters: dict[str, Any]
file_url: Union[str, None, list] = None
file_url: Union[str, None, list[str]] = None
class GenerateNameTraceInfo(BaseTraceInfo):
@@ -114,6 +118,79 @@ class GenerateNameTraceInfo(BaseTraceInfo):
tenant_id: str
class PromptGenerationTraceInfo(BaseTraceInfo):
"""Trace information for prompt generation operations (rule-generate, code-generate, etc.)."""
tenant_id: str
user_id: str
app_id: str | None = None
operation_type: str
instruction: str
prompt_tokens: int
completion_tokens: int
total_tokens: int
model_provider: str
model_name: str
latency: float
total_price: float | None = None
currency: str | None = None
error: str | None = None
model_config = ConfigDict(protected_namespaces=())
class WorkflowNodeTraceInfo(BaseTraceInfo):
workflow_id: str
workflow_run_id: str
tenant_id: str
node_execution_id: str
node_id: str
node_type: str
title: str
status: str
error: str | None = None
elapsed_time: float
index: int
predecessor_node_id: str | None = None
total_tokens: int = 0
total_price: float = 0.0
currency: str | None = None
model_provider: str | None = None
model_name: str | None = None
prompt_tokens: int | None = None
completion_tokens: int | None = None
tool_name: str | None = None
iteration_id: str | None = None
iteration_index: int | None = None
loop_id: str | None = None
loop_index: int | None = None
parallel_id: str | None = None
node_inputs: Mapping[str, Any] | None = None
node_outputs: Mapping[str, Any] | None = None
process_data: Mapping[str, Any] | None = None
invoked_by: str | None = None
model_config = ConfigDict(protected_namespaces=())
class DraftNodeExecutionTrace(WorkflowNodeTraceInfo):
pass
class TaskData(BaseModel):
app_id: str
trace_info_type: str
@@ -128,16 +205,22 @@ trace_info_info_map = {
"DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo,
"ToolTraceInfo": ToolTraceInfo,
"GenerateNameTraceInfo": GenerateNameTraceInfo,
"PromptGenerationTraceInfo": PromptGenerationTraceInfo,
"WorkflowNodeTraceInfo": WorkflowNodeTraceInfo,
"DraftNodeExecutionTrace": DraftNodeExecutionTrace,
}
class TraceTaskName(StrEnum):
CONVERSATION_TRACE = "conversation"
WORKFLOW_TRACE = "workflow"
DRAFT_NODE_EXECUTION_TRACE = "draft_node_execution"
MESSAGE_TRACE = "message"
MODERATION_TRACE = "moderation"
SUGGESTED_QUESTION_TRACE = "suggested_question"
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
TOOL_TRACE = "tool"
GENERATE_NAME_TRACE = "generate_conversation_name"
PROMPT_GENERATION_TRACE = "prompt_generation"
DATASOURCE_TRACE = "datasource"
NODE_EXECUTION_TRACE = "node_execution"

View File

@@ -3,6 +3,7 @@ import os
from datetime import datetime, timedelta
from langfuse import Langfuse
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
@@ -30,7 +31,7 @@ from core.ops.utils import filter_none_values
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.enums import NodeType
from extensions.ext_database import db
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
from models import EndUser, Message, WorkflowNodeExecutionTriggeredFrom
from models.enums import MessageStatus
logger = logging.getLogger(__name__)
@@ -71,7 +72,50 @@ class LangFuseDataTrace(BaseTraceInstance):
metadata = trace_info.metadata
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
if trace_info.message_id:
# Check for parent_trace_context to detect nested workflow
parent_trace_context = trace_info.metadata.get("parent_trace_context")
if parent_trace_context:
# Nested workflow: create span under outer trace
outer_trace_id = parent_trace_context.get("trace_id")
parent_node_execution_id = parent_trace_context.get("parent_node_execution_id")
parent_conversation_id = parent_trace_context.get("parent_conversation_id")
parent_workflow_run_id = parent_trace_context.get("parent_workflow_run_id")
# Resolve outer trace_id: try message_id lookup first, fallback to workflow_run_id
if parent_conversation_id:
session_factory = sessionmaker(bind=db.engine)
with session_factory() as session:
message_data_stmt = select(Message.id).where(
Message.conversation_id == parent_conversation_id,
Message.workflow_run_id == parent_workflow_run_id,
)
resolved_message_id = session.scalar(message_data_stmt)
if resolved_message_id:
outer_trace_id = resolved_message_id
else:
outer_trace_id = parent_workflow_run_id
else:
outer_trace_id = parent_workflow_run_id
# Create inner workflow span under outer trace
workflow_span_data = LangfuseSpan(
id=trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE,
input=dict(trace_info.workflow_run_inputs),
output=dict(trace_info.workflow_run_outputs),
trace_id=outer_trace_id,
parent_observation_id=parent_node_execution_id,
start_time=trace_info.start_time,
end_time=trace_info.end_time,
metadata=metadata,
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
status_message=trace_info.error or "",
)
self.add_span(langfuse_span_data=workflow_span_data)
# Use outer_trace_id for all node spans/generations
trace_id = outer_trace_id
elif trace_info.message_id:
trace_id = trace_info.trace_id or trace_info.message_id
name = TraceTaskName.MESSAGE_TRACE
trace_data = LangfuseTrace(
@@ -174,6 +218,11 @@ class LangFuseDataTrace(BaseTraceInstance):
}
)
# Determine parent_observation_id for nested workflows
node_parent_observation_id = None
if parent_trace_context or trace_info.message_id:
node_parent_observation_id = trace_info.workflow_run_id
# add generation span
if process_data and process_data.get("model_mode") == "chat":
total_token = metadata.get("total_tokens", 0)
@@ -206,7 +255,7 @@ class LangFuseDataTrace(BaseTraceInstance):
metadata=metadata,
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
status_message=trace_info.error or "",
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
parent_observation_id=node_parent_observation_id,
usage=generation_usage,
)
@@ -225,7 +274,7 @@ class LangFuseDataTrace(BaseTraceInstance):
metadata=metadata,
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
status_message=trace_info.error or "",
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
parent_observation_id=node_parent_observation_id,
)
self.add_span(langfuse_span_data=span_data)

View File

@@ -6,6 +6,7 @@ from typing import cast
from langsmith import Client
from langsmith.schemas import RunBase
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
@@ -30,7 +31,7 @@ from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
from models import EndUser, Message, MessageFile, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
@@ -64,7 +65,35 @@ class LangSmithDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
# Check for parent_trace_context for cross-workflow linking
parent_trace_context = trace_info.metadata.get("parent_trace_context")
if parent_trace_context:
# Inner workflow: resolve outer trace_id and link to parent node
outer_trace_id = parent_trace_context.get("parent_workflow_run_id")
# Try to resolve message_id from conversation_id if available
if parent_trace_context.get("parent_conversation_id"):
try:
session_factory = sessionmaker(bind=db.engine)
with session_factory() as session:
message_data_stmt = select(Message.id).where(
Message.conversation_id == parent_trace_context["parent_conversation_id"],
Message.workflow_run_id == parent_trace_context["parent_workflow_run_id"],
)
resolved_message_id = session.scalar(message_data_stmt)
if resolved_message_id:
outer_trace_id = resolved_message_id
except Exception as e:
logger.debug("Failed to resolve message_id from conversation_id: %s", str(e))
trace_id = outer_trace_id
parent_run_id = parent_trace_context.get("parent_node_execution_id")
else:
# Outer workflow: existing behavior
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
parent_run_id = trace_info.message_id or None
if trace_info.start_time is None:
trace_info.start_time = datetime.now()
message_dotted_order = (
@@ -78,7 +107,8 @@ class LangSmithDataTrace(BaseTraceInstance):
metadata = trace_info.metadata
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
if trace_info.message_id:
# Only create message_run for outer workflows (no parent_trace_context)
if trace_info.message_id and not parent_trace_context:
message_run = LangSmithRunModel(
id=trace_info.message_id,
name=TraceTaskName.MESSAGE_TRACE,
@@ -121,9 +151,9 @@ class LangSmithDataTrace(BaseTraceInstance):
},
error=trace_info.error,
tags=["workflow"],
parent_run_id=trace_info.message_id or None,
parent_run_id=parent_run_id,
trace_id=trace_id,
dotted_order=workflow_dotted_order,
dotted_order=None if parent_trace_context else workflow_dotted_order,
serialized=None,
events=[],
session_id=None,

View File

@@ -21,19 +21,25 @@ from core.ops.entities.config_entity import (
)
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
DraftNodeExecutionTrace,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
PromptGenerationTraceInfo,
SuggestedQuestionTraceInfo,
TaskData,
ToolTraceInfo,
TraceTaskName,
WorkflowNodeTraceInfo,
WorkflowTraceInfo,
)
from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Tenant
from models.dataset import Dataset
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from models.workflow import WorkflowAppLog
from tasks.ops_trace_task import process_trace_tasks
@@ -43,6 +49,44 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]:
"""Return (app_name, workspace_name) for the given IDs. Falls back to empty strings."""
app_name = ""
workspace_name = ""
if not app_id and not tenant_id:
return app_name, workspace_name
with Session(db.engine) as session:
if app_id:
name = session.scalar(select(App.name).where(App.id == app_id))
if name:
app_name = name
if tenant_id:
name = session.scalar(select(Tenant.name).where(Tenant.id == tenant_id))
if name:
workspace_name = name
return app_name, workspace_name
_PROVIDER_TYPE_TO_MODEL: dict[str, type] = {
"builtin": BuiltinToolProvider,
"plugin": BuiltinToolProvider,
"api": ApiToolProvider,
"workflow": WorkflowToolProvider,
"mcp": MCPToolProvider,
}
def _lookup_credential_name(credential_id: str | None, provider_type: str | None) -> str:
if not credential_id:
return ""
model_cls = _PROVIDER_TYPE_TO_MODEL.get(provider_type or "")
if not model_cls:
return ""
with Session(db.engine) as session:
name = session.scalar(select(model_cls.name).where(model_cls.id == credential_id))
return str(name) if name else ""
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
def __getitem__(self, provider: str) -> dict[str, Any]:
match provider:
@@ -317,6 +361,10 @@ class OpsTraceManager:
if app_id is None:
return None
# Handle storage_id format (tenant-{uuid}) - not a real app_id
if isinstance(app_id, str) and app_id.startswith("tenant-"):
return None
app: App | None = db.session.query(App).where(App.id == app_id).first()
if app is None:
@@ -479,6 +527,56 @@ class TraceTask:
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
return cls._workflow_run_repo
@classmethod
def _get_user_id_from_metadata(cls, metadata: dict[str, Any]) -> str:
"""Extract user ID from metadata, prioritizing end_user over account.
Returns the actual user ID (end_user or account) who invoked the workflow,
regardless of invoke_from context.
"""
# Priority 1: End user (external users via API/WebApp)
if user_id := metadata.get("from_end_user_id"):
return f"end_user:{user_id}"
# Priority 2: Account user (internal users via console/debugger)
if user_id := metadata.get("from_account_id"):
return f"account:{user_id}"
# Priority 3: User (internal users via console/debugger)
if user_id := metadata.get("user_id"):
return f"user:{user_id}"
return "anonymous"
@classmethod
def _calculate_workflow_token_split(cls, workflow_run_id: str, tenant_id: str) -> tuple[int, int]:
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from models.workflow import WorkflowNodeExecutionModel
with Session(db.engine) as session:
node_executions = session.scalars(
select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
)
).all()
total_prompt = 0
total_completion = 0
for node_exec in node_executions:
metadata = node_exec.execution_metadata_dict
prompt = metadata.get(WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS)
if prompt is not None:
total_prompt += prompt
completion = metadata.get(WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS)
if completion is not None:
total_completion += completion
return (total_prompt, total_completion)
def __init__(
self,
trace_type: Any,
@@ -499,6 +597,8 @@ class TraceTask:
self.app_id = None
self.trace_id = None
self.kwargs = kwargs
if user_id is not None and "user_id" not in self.kwargs:
self.kwargs["user_id"] = user_id
external_trace_id = kwargs.get("external_trace_id")
if external_trace_id:
self.trace_id = external_trace_id
@@ -512,7 +612,7 @@ class TraceTask:
TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id, **self.kwargs),
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
message_id=self.message_id, timer=self.timer, **self.kwargs
),
@@ -528,6 +628,9 @@ class TraceTask:
TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
),
TraceTaskName.PROMPT_GENERATION_TRACE: lambda: self.prompt_generation_trace(**self.kwargs),
TraceTaskName.NODE_EXECUTION_TRACE: lambda: self.node_execution_trace(**self.kwargs),
TraceTaskName.DRAFT_NODE_EXECUTION_TRACE: lambda: self.draft_node_execution_trace(**self.kwargs),
}
return preprocess_map.get(self.trace_type, lambda: None)()
@@ -563,6 +666,10 @@ class TraceTask:
total_tokens = workflow_run.total_tokens
prompt_tokens, completion_tokens = self._calculate_workflow_token_split(
workflow_run_id=workflow_run_id, tenant_id=tenant_id
)
file_list = workflow_run_inputs.get("sys.file") or []
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
@@ -583,7 +690,9 @@ class TraceTask:
)
message_id = session.scalar(message_data_stmt)
metadata = {
app_name, workspace_name = _lookup_app_and_workspace_names(workflow_run.app_id, tenant_id)
metadata: dict[str, Any] = {
"workflow_id": workflow_id,
"conversation_id": conversation_id,
"workflow_run_id": workflow_run_id,
@@ -596,8 +705,14 @@ class TraceTask:
"triggered_from": workflow_run.triggered_from,
"user_id": user_id,
"app_id": workflow_run.app_id,
"app_name": app_name,
"workspace_name": workspace_name,
}
parent_trace_context = self.kwargs.get("parent_trace_context")
if parent_trace_context:
metadata["parent_trace_context"] = parent_trace_context
workflow_trace_info = WorkflowTraceInfo(
trace_id=self.trace_id,
workflow_data=workflow_run.to_dict(),
@@ -612,6 +727,8 @@ class TraceTask:
workflow_run_version=workflow_run_version,
error=error,
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
file_list=file_list,
query=query,
metadata=metadata,
@@ -619,10 +736,11 @@ class TraceTask:
message_id=message_id,
start_time=workflow_run.created_at,
end_time=workflow_run.finished_at,
invoked_by=self._get_user_id_from_metadata(metadata),
)
return workflow_trace_info
def message_trace(self, message_id: str | None):
def message_trace(self, message_id: str | None, **kwargs):
if not message_id:
return {}
message_data = get_message_data(message_id)
@@ -645,6 +763,14 @@ class TraceTask:
streaming_metrics = self._extract_streaming_metrics(message_data)
tenant_id = ""
with Session(db.engine) as session:
tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id))
if tid:
tenant_id = str(tid)
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
metadata = {
"conversation_id": message_data.conversation_id,
"ls_provider": message_data.model_provider,
@@ -656,7 +782,14 @@ class TraceTask:
"workflow_run_id": message_data.workflow_run_id,
"from_source": message_data.from_source,
"message_id": message_id,
"tenant_id": tenant_id,
"app_id": message_data.app_id,
"user_id": message_data.from_end_user_id or message_data.from_account_id,
"app_name": app_name,
"workspace_name": workspace_name,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
message_tokens = message_data.message_tokens
@@ -698,6 +831,8 @@ class TraceTask:
"preset_response": moderation_result.preset_response,
"query": moderation_result.query,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
# get workflow_app_log_id
workflow_app_log_id = None
@@ -739,6 +874,8 @@ class TraceTask:
"workflow_run_id": message_data.workflow_run_id,
"from_source": message_data.from_source,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
# get workflow_app_log_id
workflow_app_log_id = None
@@ -778,6 +915,36 @@ class TraceTask:
if not message_data:
return {}
tenant_id = ""
with Session(db.engine) as session:
tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id))
if tid:
tenant_id = str(tid)
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
doc_list = [doc.model_dump() for doc in documents] if documents else []
dataset_ids: set[str] = set()
for doc in doc_list:
doc_meta = doc.get("metadata") or {}
did = doc_meta.get("dataset_id")
if did:
dataset_ids.add(did)
embedding_models: dict[str, dict[str, str]] = {}
if dataset_ids:
with Session(db.engine) as session:
rows = session.execute(
select(Dataset.id, Dataset.embedding_model, Dataset.embedding_model_provider).where(
Dataset.id.in_(list(dataset_ids))
)
).all()
for row in rows:
embedding_models[str(row[0])] = {
"embedding_model": row[1] or "",
"embedding_model_provider": row[2] or "",
}
metadata = {
"message_id": message_id,
"ls_provider": message_data.model_provider,
@@ -788,13 +955,21 @@ class TraceTask:
"agent_based": message_data.agent_based,
"workflow_run_id": message_data.workflow_run_id,
"from_source": message_data.from_source,
"tenant_id": tenant_id,
"app_id": message_data.app_id,
"user_id": message_data.from_end_user_id or message_data.from_account_id,
"app_name": app_name,
"workspace_name": workspace_name,
"embedding_models": embedding_models,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
trace_id=self.trace_id,
message_id=message_id,
inputs=message_data.query or message_data.inputs,
documents=[doc.model_dump() for doc in documents] if documents else [],
documents=doc_list,
start_time=timer.get("start"),
end_time=timer.get("end"),
metadata=metadata,
@@ -837,6 +1012,10 @@ class TraceTask:
"error": error,
"tool_parameters": tool_parameters,
}
if message_data.workflow_run_id:
metadata["workflow_run_id"] = message_data.workflow_run_id
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
file_url = ""
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
@@ -891,6 +1070,8 @@ class TraceTask:
"conversation_id": conversation_id,
"tenant_id": tenant_id,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
generate_name_trace_info = GenerateNameTraceInfo(
trace_id=self.trace_id,
@@ -905,6 +1086,158 @@ class TraceTask:
return generate_name_trace_info
def prompt_generation_trace(self, **kwargs) -> PromptGenerationTraceInfo | dict:
tenant_id = kwargs.get("tenant_id", "")
user_id = kwargs.get("user_id", "")
app_id = kwargs.get("app_id")
operation_type = kwargs.get("operation_type", "")
instruction = kwargs.get("instruction", "")
generated_output = kwargs.get("generated_output", "")
prompt_tokens = kwargs.get("prompt_tokens", 0)
completion_tokens = kwargs.get("completion_tokens", 0)
total_tokens = kwargs.get("total_tokens", 0)
model_provider = kwargs.get("model_provider", "")
model_name = kwargs.get("model_name", "")
latency = kwargs.get("latency", 0.0)
timer = kwargs.get("timer")
start_time = timer.get("start") if timer else None
end_time = timer.get("end") if timer else None
total_price = kwargs.get("total_price")
currency = kwargs.get("currency")
error = kwargs.get("error")
app_name = None
workspace_name = None
if app_id:
app_name, workspace_name = _lookup_app_and_workspace_names(app_id, tenant_id)
metadata = {
"tenant_id": tenant_id,
"user_id": user_id,
"app_id": app_id or "",
"app_name": app_name,
"workspace_name": workspace_name,
"operation_type": operation_type,
"model_provider": model_provider,
"model_name": model_name,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
return PromptGenerationTraceInfo(
trace_id=self.trace_id,
inputs=instruction,
outputs=generated_output,
start_time=start_time,
end_time=end_time,
metadata=metadata,
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
operation_type=operation_type,
instruction=instruction,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model_provider=model_provider,
model_name=model_name,
latency=latency,
total_price=total_price,
currency=currency,
error=error,
)
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict:
node_data: dict = kwargs.get("node_execution_data", {})
if not node_data:
return {}
app_name, workspace_name = _lookup_app_and_workspace_names(node_data.get("app_id"), node_data.get("tenant_id"))
credential_name = _lookup_credential_name(
node_data.get("credential_id"), node_data.get("credential_provider_type")
)
metadata: dict[str, Any] = {
"tenant_id": node_data.get("tenant_id"),
"app_id": node_data.get("app_id"),
"app_name": app_name,
"workspace_name": workspace_name,
"user_id": node_data.get("user_id"),
"dataset_ids": node_data.get("dataset_ids"),
"dataset_names": node_data.get("dataset_names"),
"plugin_name": node_data.get("plugin_name"),
"credential_name": credential_name,
}
parent_trace_context = node_data.get("parent_trace_context")
if parent_trace_context:
metadata["parent_trace_context"] = parent_trace_context
message_id: str | None = None
conversation_id = node_data.get("conversation_id")
workflow_execution_id = node_data.get("workflow_execution_id")
if conversation_id and workflow_execution_id and not parent_trace_context:
with Session(db.engine) as session:
msg_id = session.scalar(
select(Message.id).where(
Message.conversation_id == conversation_id,
Message.workflow_run_id == workflow_execution_id,
)
)
if msg_id:
message_id = str(msg_id)
metadata["message_id"] = message_id
return WorkflowNodeTraceInfo(
trace_id=self.trace_id,
message_id=message_id,
start_time=node_data.get("created_at"),
end_time=node_data.get("finished_at"),
metadata=metadata,
workflow_id=node_data.get("workflow_id", ""),
workflow_run_id=node_data.get("workflow_execution_id", ""),
tenant_id=node_data.get("tenant_id", ""),
node_execution_id=node_data.get("node_execution_id", ""),
node_id=node_data.get("node_id", ""),
node_type=node_data.get("node_type", ""),
title=node_data.get("title", ""),
status=node_data.get("status", ""),
error=node_data.get("error"),
elapsed_time=node_data.get("elapsed_time", 0.0),
index=node_data.get("index", 0),
predecessor_node_id=node_data.get("predecessor_node_id"),
total_tokens=node_data.get("total_tokens", 0),
total_price=node_data.get("total_price", 0.0),
currency=node_data.get("currency"),
model_provider=node_data.get("model_provider"),
model_name=node_data.get("model_name"),
prompt_tokens=node_data.get("prompt_tokens"),
completion_tokens=node_data.get("completion_tokens"),
tool_name=node_data.get("tool_name"),
iteration_id=node_data.get("iteration_id"),
iteration_index=node_data.get("iteration_index"),
loop_id=node_data.get("loop_id"),
loop_index=node_data.get("loop_index"),
parallel_id=node_data.get("parallel_id"),
node_inputs=node_data.get("node_inputs"),
node_outputs=node_data.get("node_outputs"),
process_data=node_data.get("process_data"),
invoked_by=self._get_user_id_from_metadata(metadata),
)
def draft_node_execution_trace(self, **kwargs) -> DraftNodeExecutionTrace | dict:
node_trace = self.node_execution_trace(**kwargs)
if not node_trace or not isinstance(node_trace, WorkflowNodeTraceInfo):
return node_trace
return DraftNodeExecutionTrace(**node_trace.model_dump())
def _extract_streaming_metrics(self, message_data) -> dict:
if not message_data.message_metadata:
return {}
@@ -938,13 +1271,17 @@ class TraceQueueManager:
self.user_id = user_id
self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
self.flask_app = current_app._get_current_object() # type: ignore
from core.telemetry import is_enterprise_telemetry_enabled
self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled()
if trace_manager_timer is None:
self.start_timer()
def add_trace_task(self, trace_task: TraceTask):
global trace_manager_timer, trace_manager_queue
try:
if self.trace_instance:
if self._enterprise_telemetry_enabled or self.trace_instance:
trace_task.app_id = self.app_id
trace_manager_queue.put(trace_task)
except Exception:
@@ -980,20 +1317,27 @@ class TraceQueueManager:
def send_to_celery(self, tasks: list[TraceTask]):
with self.flask_app.app_context():
for task in tasks:
if task.app_id is None:
continue
storage_id = task.app_id
if storage_id is None:
tenant_id = task.kwargs.get("tenant_id")
if tenant_id:
storage_id = f"tenant-{tenant_id}"
else:
logger.warning("Skipping trace without app_id or tenant_id, trace_type: %s", task.trace_type)
continue
file_id = uuid4().hex
trace_info = task.execute()
task_data = TaskData(
app_id=task.app_id,
app_id=storage_id,
trace_info_type=type(trace_info).__name__,
trace_info=trace_info.model_dump() if trace_info else None,
)
file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
file_path = f"{OPS_FILE_PATH}{storage_id}/{file_id}.json"
storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
file_info = {
"file_id": file_id,
"app_id": task.app_id,
"app_id": storage_id,
}
process_trace_tasks.delay(file_info) # type: ignore

View File

@@ -27,8 +27,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.ops_trace_manager import TraceQueueManager
from core.ops.utils import measure_time
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
@@ -56,6 +55,8 @@ from core.rag.retrieval.template_prompts import (
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
from core.tools.signature import sign_upload_file
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
@@ -728,10 +729,21 @@ class DatasetRetrieval:
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
)
app_config = self.application_generate_entity.app_config if self.application_generate_entity else None
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE,
context=TelemetryContext(
tenant_id=app_config.tenant_id if app_config else None,
app_id=app_config.app_id if app_config else None,
),
payload={
"message_id": message_id,
"documents": documents,
"timer": timer,
},
),
trace_manager=trace_manager,
)
def _on_query(

View File

@@ -0,0 +1,60 @@
"""Community telemetry helpers.
Provides ``emit()`` which enqueues trace events into the CE trace pipeline
(``TraceQueueManager`` → ``ops_trace`` Celery queue → Langfuse / LangSmith / etc.).
Enterprise-only traces (node execution, draft node execution, prompt generation)
are silently dropped when enterprise telemetry is disabled.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from core.ops.entities.trace_entity import TraceTaskName
from core.telemetry.events import TelemetryContext, TelemetryEvent
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
_ENTERPRISE_ONLY_TRACES: frozenset[TraceTaskName] = frozenset(
{
TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
TraceTaskName.NODE_EXECUTION_TRACE,
TraceTaskName.PROMPT_GENERATION_TRACE,
}
)
def _is_enterprise_telemetry_enabled() -> bool:
try:
from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled
return is_enterprise_telemetry_enabled()
except Exception:
return False
def emit(event: TelemetryEvent, trace_manager: TraceQueueManager | None = None) -> None:
from core.ops.ops_trace_manager import TraceQueueManager as LocalTraceQueueManager
from core.ops.ops_trace_manager import TraceTask
if event.name in _ENTERPRISE_ONLY_TRACES and not _is_enterprise_telemetry_enabled():
return
queue_manager = trace_manager or LocalTraceQueueManager(
app_id=event.context.app_id,
user_id=event.context.user_id,
)
queue_manager.add_trace_task(TraceTask(event.name, **event.payload))
is_enterprise_telemetry_enabled = _is_enterprise_telemetry_enabled
__all__ = [
"TelemetryContext",
"TelemetryEvent",
"TraceTaskName",
"emit",
"is_enterprise_telemetry_enabled",
]

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from core.ops.entities.trace_entity import TraceTaskName
@dataclass(frozen=True)
class TelemetryContext:
tenant_id: str | None = None
user_id: str | None = None
app_id: str | None = None
@dataclass(frozen=True)
class TelemetryEvent:
name: TraceTaskName
context: TelemetryContext
payload: dict[str, Any]

View File

@@ -50,6 +50,7 @@ class WorkflowTool(Tool):
self.workflow_call_depth = workflow_call_depth
self.label = label
self._latest_usage = LLMUsage.empty_usage()
self.parent_trace_context: dict[str, str] | None = None
super().__init__(entity=entity, runtime=runtime)
@@ -90,11 +91,15 @@ class WorkflowTool(Tool):
self._latest_usage = LLMUsage.empty_usage()
args: dict[str, Any] = {"inputs": tool_parameters, "files": files}
if self.parent_trace_context:
args["_parent_trace_context"] = self.parent_trace_context
result = generator.generate(
app_model=app,
workflow=workflow,
user=user,
args={"inputs": tool_parameters, "files": files},
args=args,
invoke_from=self.runtime.invoke_from,
streaming=False,
call_depth=self.workflow_call_depth + 1,

View File

@@ -232,6 +232,8 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
"""
TOTAL_TOKENS = "total_tokens"
PROMPT_TOKENS = "prompt_tokens"
COMPLETION_TOKENS = "completion_tokens"
TOTAL_PRICE = "total_price"
CURRENCY = "currency"
TOOL_INFO = "tool_info"

View File

@@ -322,6 +322,8 @@ class LLMNode(Node[LLMNodeData]):
outputs=outputs,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS: usage.prompt_tokens,
WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS: usage.completion_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},

View File

@@ -61,6 +61,7 @@ class ToolNode(Node[ToolNodeData]):
"provider_type": self.node_data.provider_type.value,
"provider_id": self.node_data.provider_id,
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
"credential_id": self.node_data.credential_id,
}
# get tool runtime
@@ -105,6 +106,20 @@ class ToolNode(Node[ToolNodeData]):
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
from core.tools.workflow_as_tool.tool import WorkflowTool
if isinstance(tool_runtime, WorkflowTool):
workflow_run_id_var = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.WORKFLOW_EXECUTION_ID]
)
tool_runtime.parent_trace_context = {
"trace_id": str(workflow_run_id_var.text) if workflow_run_id_var else "",
"parent_node_execution_id": self.execution_id,
"parent_workflow_run_id": str(workflow_run_id_var.text) if workflow_run_id_var else "",
"parent_app_id": self.app_id,
"parent_conversation_id": conversation_id.text if conversation_id else None,
}
try:
message_stream = ToolEngine.generic_invoke(
tool=tool_runtime,
@@ -431,6 +446,8 @@ class ToolNode(Node[ToolNodeData]):
}
if isinstance(usage.total_tokens, int) and usage.total_tokens > 0:
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
metadata[WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS] = usage.prompt_tokens
metadata[WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS] = usage.completion_tokens
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency

View File

View File

View File

@@ -0,0 +1,83 @@
"""Telemetry gateway contracts and data structures.
This module defines the envelope format for telemetry events and the routing
configuration that determines how each event type is processed.
"""
from __future__ import annotations
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, field_validator
class TelemetryCase(StrEnum):
"""Enumeration of all known telemetry event cases."""
WORKFLOW_RUN = "workflow_run"
NODE_EXECUTION = "node_execution"
DRAFT_NODE_EXECUTION = "draft_node_execution"
MESSAGE_RUN = "message_run"
TOOL_EXECUTION = "tool_execution"
MODERATION_CHECK = "moderation_check"
SUGGESTED_QUESTION = "suggested_question"
DATASET_RETRIEVAL = "dataset_retrieval"
GENERATE_NAME = "generate_name"
PROMPT_GENERATION = "prompt_generation"
APP_CREATED = "app_created"
APP_UPDATED = "app_updated"
APP_DELETED = "app_deleted"
FEEDBACK_CREATED = "feedback_created"
class SignalType(StrEnum):
"""Signal routing type for telemetry cases."""
TRACE = "trace"
METRIC_LOG = "metric_log"
class CaseRoute(BaseModel):
"""Routing configuration for a telemetry case.
Attributes:
signal_type: The type of signal (trace or metric_log).
ce_eligible: Whether this case is eligible for community edition tracing.
"""
signal_type: SignalType
ce_eligible: bool
class TelemetryEnvelope(BaseModel):
"""Envelope for telemetry events.
Attributes:
case: The telemetry case type.
tenant_id: The tenant identifier.
event_id: Unique event identifier for deduplication.
payload: The main event payload.
payload_fallback: Fallback payload (max 64KB).
metadata: Optional metadata dictionary.
"""
case: TelemetryCase
tenant_id: str
event_id: str
payload: dict[str, Any]
payload_fallback: bytes | None = None
metadata: dict[str, Any] | None = None
@field_validator("payload_fallback")
@classmethod
def validate_payload_fallback_size(cls, v: bytes | None) -> bytes | None:
"""Validate that payload_fallback does not exceed 64KB."""
if v is not None and len(v) > 65536: # 64 * 1024
raise ValueError("payload_fallback must not exceed 64KB")
return v
class Config:
"""Pydantic configuration."""
use_enum_values = False

View File

@@ -0,0 +1,77 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from models.workflow import WorkflowNodeExecutionModel
def enqueue_draft_node_execution_trace(
*,
execution: WorkflowNodeExecutionModel,
outputs: Mapping[str, Any] | None,
workflow_execution_id: str | None,
user_id: str,
) -> None:
node_data = _build_node_execution_data(
execution=execution,
outputs=outputs,
workflow_execution_id=workflow_execution_id,
)
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id=execution.tenant_id,
user_id=user_id,
app_id=execution.app_id,
),
payload={"node_execution_data": node_data},
)
)
def _build_node_execution_data(
*,
execution: WorkflowNodeExecutionModel,
outputs: Mapping[str, Any] | None,
workflow_execution_id: str | None,
) -> dict[str, Any]:
metadata = execution.execution_metadata_dict
node_outputs = outputs if outputs is not None else execution.outputs_dict
execution_id = workflow_execution_id or execution.workflow_run_id or execution.id
return {
"workflow_id": execution.workflow_id,
"workflow_execution_id": execution_id,
"tenant_id": execution.tenant_id,
"app_id": execution.app_id,
"node_execution_id": execution.id,
"node_id": execution.node_id,
"node_type": execution.node_type,
"title": execution.title,
"status": execution.status,
"error": execution.error,
"elapsed_time": execution.elapsed_time,
"index": execution.index,
"predecessor_node_id": execution.predecessor_node_id,
"created_at": execution.created_at,
"finished_at": execution.finished_at,
"total_tokens": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0),
"total_price": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0),
"currency": metadata.get(WorkflowNodeExecutionMetadataKey.CURRENCY),
"tool_name": (metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name")
if isinstance(metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict)
else None,
"iteration_id": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID),
"iteration_index": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX),
"loop_id": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID),
"loop_index": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX),
"parallel_id": metadata.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID),
"node_inputs": execution.inputs_dict,
"node_outputs": node_outputs,
"process_data": execution.process_data_dict,
}

View File

@@ -0,0 +1,844 @@
"""Enterprise trace handler — duck-typed, NOT a BaseTraceInstance subclass.
Invoked directly in the Celery task, not through OpsTraceManager dispatch.
Only requires a matching ``trace(trace_info)`` method signature.
Signal strategy:
- **Traces (spans)**: workflow run, node execution, draft node execution only.
- **Metrics + structured logs**: all other event types.
"""
from __future__ import annotations
import json
import logging
from typing import Any, cast
from opentelemetry.util.types import AttributeValue
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
DraftNodeExecutionTrace,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
PromptGenerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowNodeTraceInfo,
WorkflowTraceInfo,
)
from enterprise.telemetry.entities import (
EnterpriseTelemetryCounter,
EnterpriseTelemetryHistogram,
EnterpriseTelemetrySpan,
)
from enterprise.telemetry.telemetry_log import emit_metric_only_event, emit_telemetry_log
logger = logging.getLogger(__name__)
class EnterpriseOtelTrace:
"""Duck-typed enterprise trace handler.
``*_trace`` methods emit spans (workflow/node only) or structured logs
(all other events), plus metrics at 100 % accuracy.
"""
def __init__(self) -> None:
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if exporter is None:
raise RuntimeError("EnterpriseOtelTrace instantiated but exporter is not initialized")
self._exporter = exporter
def trace(self, trace_info: BaseTraceInfo) -> None:
if isinstance(trace_info, WorkflowTraceInfo):
self._workflow_trace(trace_info)
elif isinstance(trace_info, MessageTraceInfo):
self._message_trace(trace_info)
elif isinstance(trace_info, ToolTraceInfo):
self._tool_trace(trace_info)
elif isinstance(trace_info, DraftNodeExecutionTrace):
self._draft_node_execution_trace(trace_info)
elif isinstance(trace_info, WorkflowNodeTraceInfo):
self._node_execution_trace(trace_info)
elif isinstance(trace_info, ModerationTraceInfo):
self._moderation_trace(trace_info)
elif isinstance(trace_info, SuggestedQuestionTraceInfo):
self._suggested_question_trace(trace_info)
elif isinstance(trace_info, DatasetRetrievalTraceInfo):
self._dataset_retrieval_trace(trace_info)
elif isinstance(trace_info, GenerateNameTraceInfo):
self._generate_name_trace(trace_info)
elif isinstance(trace_info, PromptGenerationTraceInfo):
self._prompt_generation_trace(trace_info)
def _common_attrs(self, trace_info: BaseTraceInfo) -> dict[str, Any]:
metadata = self._metadata(trace_info)
tenant_id, app_id, user_id = self._context_ids(trace_info, metadata)
return {
"dify.trace_id": trace_info.trace_id,
"dify.tenant_id": tenant_id,
"dify.app_id": app_id,
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"gen_ai.user.id": user_id,
"dify.message.id": trace_info.message_id,
}
def _metadata(self, trace_info: BaseTraceInfo) -> dict[str, Any]:
return trace_info.metadata
def _context_ids(
self,
trace_info: BaseTraceInfo,
metadata: dict[str, Any],
) -> tuple[str | None, str | None, str | None]:
tenant_id = getattr(trace_info, "tenant_id", None) or metadata.get("tenant_id")
app_id = getattr(trace_info, "app_id", None) or metadata.get("app_id")
user_id = getattr(trace_info, "user_id", None) or metadata.get("user_id")
return tenant_id, app_id, user_id
def _labels(self, **values: AttributeValue) -> dict[str, AttributeValue]:
return dict(values)
def _safe_payload_value(self, value: Any) -> str | dict[str, Any] | list[object] | None:
if isinstance(value, str):
return value
if isinstance(value, dict):
return cast(dict[str, Any], value)
if isinstance(value, list):
items: list[object] = []
for item in cast(list[object], value):
items.append(item)
return items
return None
def _content_or_ref(self, value: Any, ref: str) -> Any:
if self._exporter.include_content:
return self._maybe_json(value)
return ref
def _maybe_json(self, value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
try:
return json.dumps(value, default=str)
except (TypeError, ValueError):
return str(value)
# ------------------------------------------------------------------
# SPAN-emitting handlers (workflow, node execution, draft node)
# ------------------------------------------------------------------
def _workflow_trace(self, info: WorkflowTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
# -- Slim span attrs: identity + structure + status + timing only --
span_attrs: dict[str, Any] = {
"dify.trace_id": info.trace_id,
"dify.tenant_id": tenant_id,
"dify.app_id": app_id,
"dify.workflow.id": info.workflow_id,
"dify.workflow.run_id": info.workflow_run_id,
"dify.workflow.status": info.workflow_run_status,
"dify.workflow.error": info.error,
"dify.workflow.elapsed_time": info.workflow_run_elapsed_time,
"dify.invoke_from": metadata.get("triggered_from"),
"dify.conversation.id": info.conversation_id,
"dify.message.id": info.message_id,
"dify.invoked_by": info.invoked_by,
}
trace_correlation_override: str | None = None
parent_span_id_source: str | None = None
parent_ctx = metadata.get("parent_trace_context")
if isinstance(parent_ctx, dict):
parent_ctx_dict = cast(dict[str, Any], parent_ctx)
span_attrs["dify.parent.trace_id"] = parent_ctx_dict.get("trace_id")
span_attrs["dify.parent.node.execution_id"] = parent_ctx_dict.get("parent_node_execution_id")
span_attrs["dify.parent.workflow.run_id"] = parent_ctx_dict.get("parent_workflow_run_id")
span_attrs["dify.parent.app.id"] = parent_ctx_dict.get("parent_app_id")
trace_override_value = parent_ctx_dict.get("parent_workflow_run_id")
if isinstance(trace_override_value, str):
trace_correlation_override = trace_override_value
parent_span_value = parent_ctx_dict.get("parent_node_execution_id")
if isinstance(parent_span_value, str):
parent_span_id_source = parent_span_value
self._exporter.export_span(
EnterpriseTelemetrySpan.WORKFLOW_RUN,
span_attrs,
correlation_id=info.workflow_run_id,
span_id_source=info.workflow_run_id,
start_time=info.start_time,
end_time=info.end_time,
trace_correlation_override=trace_correlation_override,
parent_span_id_source=parent_span_id_source,
)
# -- Companion log: ALL attrs (span + detail) for full picture --
log_attrs: dict[str, Any] = {**span_attrs}
log_attrs.update(
{
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"gen_ai.user.id": user_id,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.workflow.version": info.workflow_run_version,
}
)
ref = f"ref:workflow_run_id={info.workflow_run_id}"
log_attrs["dify.workflow.inputs"] = self._content_or_ref(info.workflow_run_inputs, ref)
log_attrs["dify.workflow.outputs"] = self._content_or_ref(info.workflow_run_outputs, ref)
log_attrs["dify.workflow.query"] = self._content_or_ref(info.query, ref)
emit_telemetry_log(
event_name="dify.workflow.run",
attributes=log_attrs,
signal="span_detail",
trace_id_source=info.workflow_run_id,
span_id_source=info.workflow_run_id,
tenant_id=tenant_id,
user_id=user_id,
)
# -- Metrics --
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, labels)
if info.prompt_tokens is not None and info.prompt_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, labels)
if info.completion_tokens is not None and info.completion_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, labels)
invoke_from = metadata.get("triggered_from", "")
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="workflow",
status=info.workflow_run_status,
invoke_from=invoke_from,
),
)
self._exporter.record_histogram(
EnterpriseTelemetryHistogram.WORKFLOW_DURATION,
float(info.workflow_run_elapsed_time),
self._labels(
**labels,
status=info.workflow_run_status,
),
)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="workflow",
),
)
def _node_execution_trace(self, info: WorkflowNodeTraceInfo) -> None:
self._emit_node_execution_trace(info, EnterpriseTelemetrySpan.NODE_EXECUTION, "node")
def _draft_node_execution_trace(self, info: DraftNodeExecutionTrace) -> None:
self._emit_node_execution_trace(
info,
EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION,
"draft_node",
correlation_id_override=info.node_execution_id,
trace_correlation_override_param=info.workflow_run_id,
)
def _emit_node_execution_trace(
self,
info: WorkflowNodeTraceInfo,
span_name: EnterpriseTelemetrySpan,
request_type: str,
correlation_id_override: str | None = None,
trace_correlation_override_param: str | None = None,
) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
# -- Slim span attrs: identity + structure + status + timing --
span_attrs: dict[str, Any] = {
"dify.trace_id": info.trace_id,
"dify.tenant_id": tenant_id,
"dify.app_id": app_id,
"dify.workflow.id": info.workflow_id,
"dify.workflow.run_id": info.workflow_run_id,
"dify.message.id": info.message_id,
"dify.conversation.id": metadata.get("conversation_id"),
"dify.node.execution_id": info.node_execution_id,
"dify.node.id": info.node_id,
"dify.node.type": info.node_type,
"dify.node.title": info.title,
"dify.node.status": info.status,
"dify.node.error": info.error,
"dify.node.elapsed_time": info.elapsed_time,
"dify.node.index": info.index,
"dify.node.predecessor_node_id": info.predecessor_node_id,
"dify.node.iteration_id": info.iteration_id,
"dify.node.loop_id": info.loop_id,
"dify.node.parallel_id": info.parallel_id,
"dify.node.invoked_by": info.invoked_by,
}
trace_correlation_override = trace_correlation_override_param
parent_ctx = metadata.get("parent_trace_context")
if isinstance(parent_ctx, dict):
parent_ctx_dict = cast(dict[str, Any], parent_ctx)
override_value = parent_ctx_dict.get("parent_workflow_run_id")
if isinstance(override_value, str):
trace_correlation_override = override_value
effective_correlation_id = correlation_id_override or info.workflow_run_id
self._exporter.export_span(
span_name,
span_attrs,
correlation_id=effective_correlation_id,
span_id_source=info.node_execution_id,
start_time=info.start_time,
end_time=info.end_time,
trace_correlation_override=trace_correlation_override,
)
# -- Companion log: ALL attrs (span + detail) --
log_attrs: dict[str, Any] = {**span_attrs}
log_attrs.update(
{
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"dify.invoke_from": metadata.get("invoke_from"),
"gen_ai.user.id": user_id,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.node.total_price": info.total_price,
"dify.node.currency": info.currency,
"gen_ai.provider.name": info.model_provider,
"gen_ai.request.model": info.model_name,
"gen_ai.tool.name": info.tool_name,
"dify.node.iteration_index": info.iteration_index,
"dify.node.loop_index": info.loop_index,
"dify.plugin.name": metadata.get("plugin_name"),
"dify.credential.name": metadata.get("credential_name"),
"dify.dataset.ids": self._maybe_json(metadata.get("dataset_ids")),
"dify.dataset.names": self._maybe_json(metadata.get("dataset_names")),
}
)
ref = f"ref:node_execution_id={info.node_execution_id}"
log_attrs["dify.node.inputs"] = self._content_or_ref(info.node_inputs, ref)
log_attrs["dify.node.outputs"] = self._content_or_ref(info.node_outputs, ref)
log_attrs["dify.node.process_data"] = self._content_or_ref(info.process_data, ref)
emit_telemetry_log(
event_name=span_name.value,
attributes=log_attrs,
signal="span_detail",
trace_id_source=info.workflow_run_id,
span_id_source=info.node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
# -- Metrics --
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
node_type=info.node_type,
model_provider=info.model_provider or "",
)
if info.total_tokens:
token_labels = self._labels(
**labels,
model_name=info.model_name or "",
)
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
if info.prompt_tokens is not None and info.prompt_tokens > 0:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels
)
if info.completion_tokens is not None and info.completion_tokens > 0:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type=request_type,
status=info.status,
),
)
duration_labels = dict(labels)
plugin_name = metadata.get("plugin_name")
if plugin_name and info.node_type in {"tool", "knowledge-retrieval"}:
duration_labels["plugin_name"] = plugin_name
self._exporter.record_histogram(EnterpriseTelemetryHistogram.NODE_DURATION, info.elapsed_time, duration_labels)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type=request_type,
),
)
# ------------------------------------------------------------------
# METRIC-ONLY handlers (structured log + counters/histograms)
# ------------------------------------------------------------------
def _message_trace(self, info: MessageTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs.update(
{
"dify.invoke_from": metadata.get("from_source"),
"dify.conversation.id": metadata.get("conversation_id"),
"dify.conversation.mode": info.conversation_mode,
"gen_ai.provider.name": metadata.get("ls_provider"),
"gen_ai.request.model": metadata.get("ls_model_name"),
"gen_ai.usage.input_tokens": info.message_tokens,
"gen_ai.usage.output_tokens": info.answer_tokens,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.message.status": metadata.get("status"),
"dify.message.error": info.error,
"dify.message.from_source": metadata.get("from_source"),
"dify.message.from_end_user_id": metadata.get("from_end_user_id"),
"dify.message.from_account_id": metadata.get("from_account_id"),
"dify.streaming": info.is_streaming_request,
"dify.message.time_to_first_token": info.gen_ai_server_time_to_first_token,
"dify.message.streaming_duration": info.llm_streaming_time_to_generate,
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
ref = f"ref:message_id={info.message_id}"
inputs = self._safe_payload_value(info.inputs)
outputs = self._safe_payload_value(info.outputs)
attrs["dify.message.inputs"] = self._content_or_ref(inputs, ref)
attrs["dify.message.outputs"] = self._content_or_ref(outputs, ref)
emit_metric_only_event(
event_name="dify.message.run",
attributes=attrs,
trace_id_source=metadata.get("workflow_run_id") or str(info.message_id) if info.message_id else None,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
model_provider=metadata.get("ls_provider", ""),
model_name=metadata.get("ls_model_name", ""),
)
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, labels)
invoke_from = metadata.get("from_source", "")
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="message",
status=metadata.get("status", ""),
invoke_from=invoke_from,
),
)
if info.start_time and info.end_time:
duration = (info.end_time - info.start_time).total_seconds()
self._exporter.record_histogram(EnterpriseTelemetryHistogram.MESSAGE_DURATION, duration, labels)
if info.gen_ai_server_time_to_first_token is not None:
self._exporter.record_histogram(
EnterpriseTelemetryHistogram.MESSAGE_TTFT, info.gen_ai_server_time_to_first_token, labels
)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="message",
),
)
def _tool_trace(self, info: ToolTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs.update(
{
"gen_ai.tool.name": info.tool_name,
"dify.tool.time_cost": info.time_cost,
"dify.tool.error": info.error,
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
ref = f"ref:message_id={info.message_id}"
attrs["dify.tool.inputs"] = self._content_or_ref(info.tool_inputs, ref)
attrs["dify.tool.outputs"] = self._content_or_ref(info.tool_outputs, ref)
attrs["dify.tool.parameters"] = self._content_or_ref(info.tool_parameters, ref)
attrs["dify.tool.config"] = self._content_or_ref(info.tool_config, ref)
emit_metric_only_event(
event_name="dify.tool.execution",
attributes=attrs,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
tool_name=info.tool_name,
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="tool",
),
)
self._exporter.record_histogram(EnterpriseTelemetryHistogram.TOOL_DURATION, float(info.time_cost), labels)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="tool",
),
)
def _moderation_trace(self, info: ModerationTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs.update(
{
"dify.moderation.flagged": info.flagged,
"dify.moderation.action": info.action,
"dify.moderation.preset_response": info.preset_response,
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
attrs["dify.moderation.query"] = self._content_or_ref(
info.query,
f"ref:message_id={info.message_id}",
)
emit_metric_only_event(
event_name="dify.moderation.check",
attributes=attrs,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="moderation",
),
)
def _suggested_question_trace(self, info: SuggestedQuestionTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs.update(
{
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.suggested_question.status": info.status,
"dify.suggested_question.error": info.error,
"gen_ai.provider.name": info.model_provider,
"gen_ai.request.model": info.model_id,
"dify.suggested_question.count": len(info.suggested_question),
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
attrs["dify.suggested_question.questions"] = self._content_or_ref(
info.suggested_question,
f"ref:message_id={info.message_id}",
)
emit_metric_only_event(
event_name="dify.suggested_question.generation",
attributes=attrs,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="suggested_question",
),
)
def _dataset_retrieval_trace(self, info: DatasetRetrievalTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs["dify.dataset.error"] = info.error
attrs["dify.workflow.run_id"] = metadata.get("workflow_run_id")
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
docs: list[dict[str, Any]] = []
documents_any: Any = info.documents
documents_list: list[Any] = cast(list[Any], documents_any) if isinstance(documents_any, list) else []
for entry in documents_list:
if isinstance(entry, dict):
entry_dict: dict[str, Any] = cast(dict[str, Any], entry)
docs.append(entry_dict)
dataset_ids: list[str] = []
dataset_names: list[str] = []
structured_docs: list[dict[str, Any]] = []
for doc in docs:
meta_raw = doc.get("metadata")
meta: dict[str, Any] = cast(dict[str, Any], meta_raw) if isinstance(meta_raw, dict) else {}
did = meta.get("dataset_id")
dname = meta.get("dataset_name")
if did and did not in dataset_ids:
dataset_ids.append(did)
if dname and dname not in dataset_names:
dataset_names.append(dname)
structured_docs.append(
{
"dataset_id": did,
"document_id": meta.get("document_id"),
"segment_id": meta.get("segment_id"),
"score": meta.get("score"),
}
)
attrs["dify.dataset.ids"] = self._maybe_json(dataset_ids)
attrs["dify.dataset.names"] = self._maybe_json(dataset_names)
attrs["dify.retrieval.document_count"] = len(docs)
embedding_models_raw: Any = metadata.get("embedding_models")
embedding_models: dict[str, Any] = (
cast(dict[str, Any], embedding_models_raw) if isinstance(embedding_models_raw, dict) else {}
)
if embedding_models:
providers: list[str] = []
models: list[str] = []
for ds_info in embedding_models.values():
if isinstance(ds_info, dict):
ds_info_dict: dict[str, Any] = cast(dict[str, Any], ds_info)
p = ds_info_dict.get("embedding_model_provider", "")
m = ds_info_dict.get("embedding_model", "")
if p and p not in providers:
providers.append(p)
if m and m not in models:
models.append(m)
attrs["dify.dataset.embedding_providers"] = self._maybe_json(providers)
attrs["dify.dataset.embedding_models"] = self._maybe_json(models)
ref = f"ref:message_id={info.message_id}"
retrieval_inputs = self._safe_payload_value(info.inputs)
attrs["dify.retrieval.query"] = self._content_or_ref(retrieval_inputs, ref)
attrs["dify.dataset.documents"] = self._content_or_ref(structured_docs, ref)
emit_metric_only_event(
event_name="dify.dataset.retrieval",
attributes=attrs,
trace_id_source=metadata.get("workflow_run_id") or str(info.message_id) if info.message_id else None,
span_id_source=node_execution_id or (str(info.message_id) if info.message_id else None),
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="dataset_retrieval",
),
)
for did in dataset_ids:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.DATASET_RETRIEVALS,
1,
self._labels(
**labels,
dataset_id=did,
),
)
def _generate_name_trace(self, info: GenerateNameTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs["dify.conversation.id"] = info.conversation_id
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
ref = f"ref:conversation_id={info.conversation_id}"
inputs = self._safe_payload_value(info.inputs)
outputs = self._safe_payload_value(info.outputs)
attrs["dify.generate_name.inputs"] = self._content_or_ref(inputs, ref)
attrs["dify.generate_name.outputs"] = self._content_or_ref(outputs, ref)
emit_metric_only_event(
event_name="dify.generate_name.execution",
attributes=attrs,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="generate_name",
),
)
def _prompt_generation_trace(self, info: PromptGenerationTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = {
"dify.trace_id": info.trace_id,
"dify.tenant_id": tenant_id,
"dify.user.id": user_id,
"dify.app.id": app_id or "",
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"dify.operation.type": info.operation_type,
"gen_ai.provider.name": info.model_provider,
"gen_ai.request.model": info.model_name,
"gen_ai.usage.input_tokens": info.prompt_tokens,
"gen_ai.usage.output_tokens": info.completion_tokens,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.prompt_generation.latency": info.latency,
"dify.prompt_generation.error": info.error,
}
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
if info.total_price is not None:
attrs["dify.prompt_generation.total_price"] = info.total_price
attrs["dify.prompt_generation.currency"] = info.currency
ref = f"ref:trace_id={info.trace_id}"
outputs = self._safe_payload_value(info.outputs)
attrs["dify.prompt_generation.instruction"] = self._content_or_ref(info.instruction, ref)
attrs["dify.prompt_generation.output"] = self._content_or_ref(outputs, ref)
emit_metric_only_event(
event_name="dify.prompt_generation.execution",
attributes=attrs,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=info.operation_type,
model_provider=info.model_provider,
model_name=info.model_name,
)
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, labels)
if info.prompt_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, labels)
if info.completion_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, labels)
status = "failed" if info.error else "success"
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="prompt_generation",
status=status,
),
)
self._exporter.record_histogram(
EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION,
info.latency,
labels,
)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="prompt_generation",
),
)

View File

@@ -0,0 +1,33 @@
from enum import StrEnum
class EnterpriseTelemetrySpan(StrEnum):
WORKFLOW_RUN = "dify.workflow.run"
NODE_EXECUTION = "dify.node.execution"
DRAFT_NODE_EXECUTION = "dify.node.execution.draft"
class EnterpriseTelemetryCounter(StrEnum):
TOKENS = "tokens"
INPUT_TOKENS = "input_tokens"
OUTPUT_TOKENS = "output_tokens"
REQUESTS = "requests"
ERRORS = "errors"
FEEDBACK = "feedback"
DATASET_RETRIEVALS = "dataset_retrievals"
class EnterpriseTelemetryHistogram(StrEnum):
WORKFLOW_DURATION = "workflow_duration"
NODE_DURATION = "node_duration"
MESSAGE_DURATION = "message_duration"
MESSAGE_TTFT = "message_ttft"
TOOL_DURATION = "tool_duration"
PROMPT_GENERATION_DURATION = "prompt_generation_duration"
__all__ = [
"EnterpriseTelemetryCounter",
"EnterpriseTelemetryHistogram",
"EnterpriseTelemetrySpan",
]

View File

@@ -0,0 +1,130 @@
"""Blinker signal handlers for enterprise telemetry.
Registered at import time via ``@signal.connect`` decorators.
Import must happen during ``ext_enterprise_telemetry.init_app()`` to ensure handlers fire.
"""
from __future__ import annotations
import logging
import uuid
from events.app_event import app_was_created, app_was_deleted, app_was_updated
from events.feedback_event import feedback_was_created
logger = logging.getLogger(__name__)
__all__ = [
"_handle_app_created",
"_handle_app_deleted",
"_handle_app_updated",
"_handle_feedback_created",
]
@app_was_created.connect
def _handle_app_created(sender: object, **kwargs: object) -> None:
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
exporter = get_enterprise_exporter()
if not exporter:
return
tenant_id = str(getattr(sender, "tenant_id", "") or "")
payload = {
"app_id": getattr(sender, "id", None),
"mode": getattr(sender, "mode", None),
}
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_CREATED,
tenant_id=tenant_id,
event_id=str(uuid.uuid4()),
payload=payload,
)
process_enterprise_telemetry.delay(envelope.model_dump_json())
@app_was_deleted.connect
def _handle_app_deleted(sender: object, **kwargs: object) -> None:
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
exporter = get_enterprise_exporter()
if not exporter:
return
tenant_id = str(getattr(sender, "tenant_id", "") or "")
payload = {
"app_id": getattr(sender, "id", None),
}
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_DELETED,
tenant_id=tenant_id,
event_id=str(uuid.uuid4()),
payload=payload,
)
process_enterprise_telemetry.delay(envelope.model_dump_json())
@app_was_updated.connect
def _handle_app_updated(sender: object, **kwargs: object) -> None:
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
exporter = get_enterprise_exporter()
if not exporter:
return
tenant_id = str(getattr(sender, "tenant_id", "") or "")
payload = {
"app_id": getattr(sender, "id", None),
}
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_UPDATED,
tenant_id=tenant_id,
event_id=str(uuid.uuid4()),
payload=payload,
)
process_enterprise_telemetry.delay(envelope.model_dump_json())
@feedback_was_created.connect
def _handle_feedback_created(sender: object, **kwargs: object) -> None:
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
exporter = get_enterprise_exporter()
if not exporter:
return
tenant_id = str(kwargs.get("tenant_id", "") or "")
payload = {
"message_id": getattr(sender, "message_id", None),
"app_id": getattr(sender, "app_id", None),
"conversation_id": getattr(sender, "conversation_id", None),
"from_end_user_id": getattr(sender, "from_end_user_id", None),
"from_account_id": getattr(sender, "from_account_id", None),
"rating": getattr(sender, "rating", None),
"from_source": getattr(sender, "from_source", None),
"content": getattr(sender, "content", None),
}
envelope = TelemetryEnvelope(
case=TelemetryCase.FEEDBACK_CREATED,
tenant_id=tenant_id,
event_id=str(uuid.uuid4()),
payload=payload,
)
process_enterprise_telemetry.delay(envelope.model_dump_json())

View File

@@ -0,0 +1,252 @@
"""Enterprise OTEL exporter — shared by EnterpriseOtelTrace, event handlers, and direct instrumentation.
Uses dedicated TracerProvider and MeterProvider instances (configurable sampling,
independent from ext_otel.py infrastructure).
Initialized once during Flask extension init (single-threaded via ext_enterprise_telemetry.py).
Accessed via ``ext_enterprise_telemetry.get_enterprise_exporter()`` from any thread/process.
"""
import logging
import socket
import uuid
from datetime import datetime
from typing import Any, cast
from opentelemetry import trace
from opentelemetry.context import Context
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace import SpanContext, TraceFlags
from opentelemetry.util.types import Attributes, AttributeValue
from configs import dify_config
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram
from enterprise.telemetry.id_generator import (
CorrelationIdGenerator,
compute_deterministic_span_id,
set_correlation_id,
set_span_id_source,
)
logger = logging.getLogger(__name__)
def is_enterprise_telemetry_enabled() -> bool:
return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED)
def _parse_otlp_headers(raw: str) -> dict[str, str]:
"""Parse ``key=value,key2=value2`` into a dict."""
if not raw:
return {}
headers: dict[str, str] = {}
for pair in raw.split(","):
if "=" not in pair:
continue
k, v = pair.split("=", 1)
headers[k.strip()] = v.strip()
return headers
def _datetime_to_ns(dt: datetime) -> int:
"""Convert a datetime to nanoseconds since epoch (OTEL convention)."""
return int(dt.timestamp() * 1_000_000_000)
class _ExporterFactory:
def __init__(self, protocol: str, endpoint: str, headers: dict[str, str]):
self._protocol = protocol
self._endpoint = endpoint
self._headers = headers
self._grpc_headers = tuple(headers.items()) if headers else None
self._http_headers = headers or None
def create_trace_exporter(self) -> HTTPSpanExporter | GRPCSpanExporter:
if self._protocol == "grpc":
return GRPCSpanExporter(
endpoint=self._endpoint or None,
headers=self._grpc_headers,
insecure=True,
)
trace_endpoint = f"{self._endpoint}/v1/traces" if self._endpoint else ""
return HTTPSpanExporter(endpoint=trace_endpoint or None, headers=self._http_headers)
def create_metric_exporter(self) -> HTTPMetricExporter | GRPCMetricExporter:
if self._protocol == "grpc":
return GRPCMetricExporter(
endpoint=self._endpoint or None,
headers=self._grpc_headers,
insecure=True,
)
metric_endpoint = f"{self._endpoint}/v1/metrics" if self._endpoint else ""
return HTTPMetricExporter(endpoint=metric_endpoint or None, headers=self._http_headers)
class EnterpriseExporter:
"""Shared OTEL exporter for all enterprise telemetry.
``export_span`` creates spans with optional real timestamps, deterministic
span/trace IDs, and cross-workflow parent linking.
``increment_counter`` / ``record_histogram`` emit OTEL metrics at 100% accuracy.
"""
def __init__(self, config: object) -> None:
endpoint: str = getattr(config, "ENTERPRISE_OTLP_ENDPOINT", "")
headers_raw: str = getattr(config, "ENTERPRISE_OTLP_HEADERS", "")
protocol: str = (getattr(config, "ENTERPRISE_OTLP_PROTOCOL", "http") or "http").lower()
service_name: str = getattr(config, "ENTERPRISE_SERVICE_NAME", "dify")
sampling_rate: float = getattr(config, "ENTERPRISE_OTEL_SAMPLING_RATE", 1.0)
self.include_content: bool = getattr(config, "ENTERPRISE_INCLUDE_CONTENT", True)
resource = Resource(
attributes={
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.HOST_NAME: socket.gethostname(),
}
)
sampler = ParentBasedTraceIdRatio(sampling_rate)
id_generator = CorrelationIdGenerator()
self._tracer_provider = TracerProvider(resource=resource, sampler=sampler, id_generator=id_generator)
headers = _parse_otlp_headers(headers_raw)
factory = _ExporterFactory(protocol, endpoint, headers)
trace_exporter = factory.create_trace_exporter()
self._tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
self._tracer = self._tracer_provider.get_tracer("dify.enterprise")
metric_exporter = factory.create_metric_exporter()
self._meter_provider = MeterProvider(
resource=resource,
metric_readers=[PeriodicExportingMetricReader(metric_exporter)],
)
meter = self._meter_provider.get_meter("dify.enterprise")
self._counters = {
EnterpriseTelemetryCounter.TOKENS: meter.create_counter("dify.tokens.total", unit="{token}"),
EnterpriseTelemetryCounter.INPUT_TOKENS: meter.create_counter("dify.tokens.input", unit="{token}"),
EnterpriseTelemetryCounter.OUTPUT_TOKENS: meter.create_counter("dify.tokens.output", unit="{token}"),
EnterpriseTelemetryCounter.REQUESTS: meter.create_counter("dify.requests.total", unit="{request}"),
EnterpriseTelemetryCounter.ERRORS: meter.create_counter("dify.errors.total", unit="{error}"),
EnterpriseTelemetryCounter.FEEDBACK: meter.create_counter("dify.feedback.total", unit="{feedback}"),
EnterpriseTelemetryCounter.DATASET_RETRIEVALS: meter.create_counter(
"dify.dataset.retrievals.total", unit="{retrieval}"
),
}
self._histograms = {
EnterpriseTelemetryHistogram.WORKFLOW_DURATION: meter.create_histogram("dify.workflow.duration", unit="s"),
EnterpriseTelemetryHistogram.NODE_DURATION: meter.create_histogram("dify.node.duration", unit="s"),
EnterpriseTelemetryHistogram.MESSAGE_DURATION: meter.create_histogram("dify.message.duration", unit="s"),
EnterpriseTelemetryHistogram.MESSAGE_TTFT: meter.create_histogram(
"dify.message.time_to_first_token", unit="s"
),
EnterpriseTelemetryHistogram.TOOL_DURATION: meter.create_histogram("dify.tool.duration", unit="s"),
EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION: meter.create_histogram(
"dify.prompt_generation.duration", unit="s"
),
}
def export_span(
self,
name: str,
attributes: dict[str, Any],
correlation_id: str | None = None,
span_id_source: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
trace_correlation_override: str | None = None,
parent_span_id_source: str | None = None,
) -> None:
"""Export an OTEL span with optional deterministic IDs and real timestamps.
Args:
name: Span operation name.
attributes: Span attributes dict.
correlation_id: Source for trace_id derivation (groups spans in one trace).
span_id_source: Source for deterministic span_id (e.g. workflow_run_id or node_execution_id).
start_time: Real span start time. When None, uses current time.
end_time: Real span end time. When None, span ends immediately.
trace_correlation_override: Override trace_id source (for cross-workflow linking).
When set, trace_id is derived from this instead of ``correlation_id``.
parent_span_id_source: Override parent span_id source (for cross-workflow linking).
When set, parent span_id is derived from this value. When None and
``correlation_id`` is set, parent is the workflow root span.
"""
effective_trace_correlation = trace_correlation_override or correlation_id
set_correlation_id(effective_trace_correlation)
set_span_id_source(span_id_source)
try:
parent_context: Context | None = None
# A span is the "root" of its correlation group when span_id_source == correlation_id
# (i.e. a workflow root span). All other spans are children.
if parent_span_id_source:
# Cross-workflow linking: parent is an explicit span (e.g. tool node in outer workflow)
parent_span_id = compute_deterministic_span_id(parent_span_id_source)
parent_trace_id = int(uuid.UUID(effective_trace_correlation)) if effective_trace_correlation else 0
if parent_trace_id:
parent_span_context = SpanContext(
trace_id=parent_trace_id,
span_id=parent_span_id,
is_remote=True,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context))
elif correlation_id and correlation_id != span_id_source:
# Child span: parent is the correlation-group root (workflow root span)
parent_span_id = compute_deterministic_span_id(correlation_id)
parent_trace_id = int(uuid.UUID(effective_trace_correlation or correlation_id))
parent_span_context = SpanContext(
trace_id=parent_trace_id,
span_id=parent_span_id,
is_remote=True,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context))
span_start_time = _datetime_to_ns(start_time) if start_time is not None else None
span_end_on_exit = end_time is None
with self._tracer.start_as_current_span(
name,
context=parent_context,
start_time=span_start_time,
end_on_exit=span_end_on_exit,
) as span:
for key, value in attributes.items():
if value is not None:
span.set_attribute(key, value)
if end_time is not None:
span.end(end_time=_datetime_to_ns(end_time))
except Exception:
logger.exception("Failed to export span %s", name)
finally:
set_correlation_id(None)
set_span_id_source(None)
def increment_counter(
self, name: EnterpriseTelemetryCounter, value: int, labels: dict[str, AttributeValue]
) -> None:
counter = self._counters.get(name)
if counter:
counter.add(value, cast(Attributes, labels))
def record_histogram(
self, name: EnterpriseTelemetryHistogram, value: float, labels: dict[str, AttributeValue]
) -> None:
histogram = self._histograms.get(name)
if histogram:
histogram.record(value, cast(Attributes, labels))
def shutdown(self) -> None:
self._tracer_provider.shutdown()
self._meter_provider.shutdown()

View File

@@ -0,0 +1,199 @@
"""Telemetry gateway routing and dispatch.
Maps ``TelemetryCase`` → ``CaseRoute`` (signal type + CE eligibility)
and dispatches events to either the trace pipeline or the metric/log
Celery queue.
Singleton lifecycle is managed by ``ext_enterprise_telemetry.init_app()``
which creates the instance during single-threaded Flask app startup.
Access via ``ext_enterprise_telemetry.get_gateway()``.
"""
from __future__ import annotations
import json
import logging
import uuid
from typing import TYPE_CHECKING, Any
from core.ops.entities.trace_entity import TraceTaskName
from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase, TelemetryEnvelope
from extensions.ext_storage import storage
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
logger = logging.getLogger(__name__)
PAYLOAD_SIZE_THRESHOLD_BYTES = 1 * 1024 * 1024
CASE_TO_TRACE_TASK: dict[TelemetryCase, TraceTaskName] = {
TelemetryCase.WORKFLOW_RUN: TraceTaskName.WORKFLOW_TRACE,
TelemetryCase.MESSAGE_RUN: TraceTaskName.MESSAGE_TRACE,
TelemetryCase.NODE_EXECUTION: TraceTaskName.NODE_EXECUTION_TRACE,
TelemetryCase.DRAFT_NODE_EXECUTION: TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
TelemetryCase.PROMPT_GENERATION: TraceTaskName.PROMPT_GENERATION_TRACE,
}
CASE_ROUTING: dict[TelemetryCase, CaseRoute] = {
TelemetryCase.WORKFLOW_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.MESSAGE_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
TelemetryCase.DRAFT_NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
TelemetryCase.PROMPT_GENERATION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
TelemetryCase.APP_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.APP_UPDATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.APP_DELETED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.FEEDBACK_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.TOOL_EXECUTION: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.MODERATION_CHECK: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.SUGGESTED_QUESTION: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.DATASET_RETRIEVAL: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.GENERATE_NAME: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
}
def _is_enterprise_telemetry_enabled() -> bool:
try:
from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled
return is_enterprise_telemetry_enabled()
except Exception:
return False
def _should_drop_ee_only_event(route: CaseRoute) -> bool:
"""Return True when the event is enterprise-only and EE telemetry is disabled."""
return not route.ce_eligible and not _is_enterprise_telemetry_enabled()
class TelemetryGateway:
"""Routes telemetry events to the trace pipeline or the metric/log Celery queue.
Stateless — instantiated once during ``ext_enterprise_telemetry.init_app()``
and shared for the lifetime of the process.
"""
def emit(
self,
case: TelemetryCase,
context: dict[str, Any],
payload: dict[str, Any],
trace_manager: TraceQueueManager | None = None,
) -> None:
route = CASE_ROUTING.get(case)
if route is None:
logger.warning("Unknown telemetry case: %s, dropping event", case)
return
if _should_drop_ee_only_event(route):
logger.debug("Dropping EE-only event: case=%s (EE disabled)", case)
return
logger.debug(
"Gateway routing: case=%s, signal_type=%s, ce_eligible=%s",
case,
route.signal_type,
route.ce_eligible,
)
if route.signal_type is SignalType.TRACE:
self._emit_trace(case, context, payload, route, trace_manager)
else:
self._emit_metric_log(case, context, payload)
def _emit_trace(
self,
case: TelemetryCase,
context: dict[str, Any],
payload: dict[str, Any],
route: CaseRoute,
trace_manager: TraceQueueManager | None,
) -> None:
from core.ops.ops_trace_manager import TraceQueueManager as LocalTraceQueueManager
from core.ops.ops_trace_manager import TraceTask
trace_task_name = CASE_TO_TRACE_TASK.get(case)
if trace_task_name is None:
logger.warning("No TraceTaskName mapping for case: %s", case)
return
queue_manager = trace_manager or LocalTraceQueueManager(
app_id=context.get("app_id"),
user_id=context.get("user_id"),
)
queue_manager.add_trace_task(TraceTask(trace_task_name, **payload))
logger.debug("Enqueued trace task: case=%s, app_id=%s", case, context.get("app_id"))
def _emit_metric_log(
self,
case: TelemetryCase,
context: dict[str, Any],
payload: dict[str, Any],
) -> None:
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
tenant_id = context.get("tenant_id", "")
event_id = str(uuid.uuid4())
payload_for_envelope, payload_ref = self._handle_payload_sizing(payload, tenant_id, event_id)
envelope = TelemetryEnvelope(
case=case,
tenant_id=tenant_id,
event_id=event_id,
payload=payload_for_envelope,
metadata={"payload_ref": payload_ref} if payload_ref else None,
)
process_enterprise_telemetry.delay(envelope.model_dump_json())
logger.debug(
"Enqueued metric/log event: case=%s, tenant_id=%s, event_id=%s",
case,
tenant_id,
event_id,
)
def _handle_payload_sizing(
self,
payload: dict[str, Any],
tenant_id: str,
event_id: str,
) -> tuple[dict[str, Any], str | None]:
try:
payload_json = json.dumps(payload)
payload_size = len(payload_json.encode("utf-8"))
except (TypeError, ValueError):
logger.warning("Failed to serialize payload for sizing: event_id=%s", event_id)
return payload, None
if payload_size <= PAYLOAD_SIZE_THRESHOLD_BYTES:
return payload, None
storage_key = f"telemetry/{tenant_id}/{event_id}.json"
try:
storage.save(storage_key, payload_json.encode("utf-8"))
logger.debug("Stored large payload to storage: key=%s, size=%d", storage_key, payload_size)
return {}, storage_key
except Exception:
logger.warning("Failed to store large payload, inlining instead: event_id=%s", event_id, exc_info=True)
return payload, None
def emit(
case: TelemetryCase,
context: dict[str, Any],
payload: dict[str, Any],
trace_manager: TraceQueueManager | None = None,
) -> None:
"""Module-level convenience wrapper.
Fetches the gateway singleton from the extension; no-ops when
enterprise telemetry is disabled (gateway is ``None``).
"""
from extensions.ext_enterprise_telemetry import get_gateway
gateway = get_gateway()
if gateway is not None:
gateway.emit(case, context, payload, trace_manager)

View File

@@ -0,0 +1,76 @@
"""Custom OTEL ID Generator for correlation-based trace/span ID derivation.
Uses contextvars for thread-safe correlation_id -> trace_id mapping.
When a span_id_source is set, the span_id is derived deterministically
from that value, enabling any span to reference another as parent
without depending on span creation order.
"""
import random
import uuid
from contextvars import ContextVar
from typing import cast
from opentelemetry.sdk.trace.id_generator import IdGenerator
_correlation_id_context: ContextVar[str | None] = ContextVar("correlation_id", default=None)
_span_id_source_context: ContextVar[str | None] = ContextVar("span_id_source", default=None)
def set_correlation_id(correlation_id: str | None) -> None:
_correlation_id_context.set(correlation_id)
def get_correlation_id() -> str | None:
return _correlation_id_context.get()
def set_span_id_source(source_id: str | None) -> None:
"""Set the source for deterministic span_id generation.
When set, ``generate_span_id()`` derives the span_id from this value
(lower 64 bits of the UUID). Pass the ``workflow_run_id`` for workflow
root spans or ``node_execution_id`` for node spans.
"""
_span_id_source_context.set(source_id)
def compute_deterministic_span_id(source_id: str) -> int:
"""Derive a deterministic span_id from any UUID string.
Uses the lower 64 bits of the UUID, guaranteeing non-zero output
(OTEL requires span_id != 0).
"""
span_id = cast(int, uuid.UUID(source_id).int) & ((1 << 64) - 1)
return span_id if span_id != 0 else 1
class CorrelationIdGenerator(IdGenerator):
"""ID generator that derives trace_id and optionally span_id from context.
- trace_id: always derived from correlation_id (groups all spans in one trace)
- span_id: derived from span_id_source when set (enables deterministic
parent-child linking), otherwise random
"""
def generate_trace_id(self) -> int:
correlation_id = _correlation_id_context.get()
if correlation_id:
try:
return cast(int, uuid.UUID(correlation_id).int)
except (ValueError, AttributeError):
pass
return random.getrandbits(128)
def generate_span_id(self) -> int:
source = _span_id_source_context.get()
if source:
try:
return compute_deterministic_span_id(source)
except (ValueError, AttributeError):
pass
span_id = random.getrandbits(64)
while span_id == 0:
span_id = random.getrandbits(64)
return span_id

View File

@@ -0,0 +1,371 @@
"""Enterprise metric/log event handler.
This module processes metric and log telemetry events after they've been
dequeued from the enterprise_telemetry Celery queue. It handles case routing,
idempotency checking, and payload rehydration.
"""
from __future__ import annotations
import logging
from typing import Any
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
from extensions.ext_redis import redis_client
logger = logging.getLogger(__name__)
class EnterpriseMetricHandler:
"""Handler for enterprise metric and log telemetry events.
Processes envelopes from the enterprise_telemetry queue, routing each
case to the appropriate handler method. Implements idempotency checking
and payload rehydration with fallback.
"""
def _increment_diagnostic_counter(self, counter_name: str, labels: dict[str, str] | None = None) -> None:
"""Increment a diagnostic counter for operational monitoring.
Args:
counter_name: Name of the counter (e.g., 'processed_total', 'deduped_total').
labels: Optional labels for the counter.
"""
try:
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
return
full_counter_name = f"enterprise_telemetry.handler.{counter_name}"
logger.debug(
"Diagnostic counter: %s, labels=%s",
full_counter_name,
labels or {},
)
except Exception:
logger.debug("Failed to increment diagnostic counter: %s", counter_name, exc_info=True)
def handle(self, envelope: TelemetryEnvelope) -> None:
"""Main entry point for processing telemetry envelopes.
Args:
envelope: The telemetry envelope to process.
"""
# Check for duplicate events
if self._is_duplicate(envelope):
logger.debug(
"Skipping duplicate event: tenant_id=%s, event_id=%s",
envelope.tenant_id,
envelope.event_id,
)
self._increment_diagnostic_counter("deduped_total")
return
# Route to appropriate handler based on case
case = envelope.case
if case == TelemetryCase.APP_CREATED:
self._on_app_created(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_created"})
elif case == TelemetryCase.APP_UPDATED:
self._on_app_updated(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_updated"})
elif case == TelemetryCase.APP_DELETED:
self._on_app_deleted(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"})
elif case == TelemetryCase.FEEDBACK_CREATED:
self._on_feedback_created(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"})
elif case == TelemetryCase.MESSAGE_RUN:
self._on_message_run(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "message_run"})
elif case == TelemetryCase.TOOL_EXECUTION:
self._on_tool_execution(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"})
elif case == TelemetryCase.MODERATION_CHECK:
self._on_moderation_check(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"})
elif case == TelemetryCase.SUGGESTED_QUESTION:
self._on_suggested_question(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"})
elif case == TelemetryCase.DATASET_RETRIEVAL:
self._on_dataset_retrieval(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"})
elif case == TelemetryCase.GENERATE_NAME:
self._on_generate_name(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "generate_name"})
elif case == TelemetryCase.PROMPT_GENERATION:
self._on_prompt_generation(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"})
else:
logger.warning(
"Unknown telemetry case: %s (tenant_id=%s, event_id=%s)",
case,
envelope.tenant_id,
envelope.event_id,
)
def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool:
"""Check if this event has already been processed.
Uses Redis with TTL for deduplication. Returns True if duplicate,
False if first time seeing this event.
Args:
envelope: The telemetry envelope to check.
Returns:
True if this event_id has been seen before, False otherwise.
"""
dedup_key = f"telemetry:dedup:{envelope.tenant_id}:{envelope.event_id}"
try:
# Atomic set-if-not-exists with 1h TTL
# Returns True if key was set (first time), None if already exists (duplicate)
was_set = redis_client.set(dedup_key, b"1", nx=True, ex=3600)
return was_set is None
except Exception:
# Fail open: if Redis is unavailable, process the event
# (prefer occasional duplicate over lost data)
logger.warning(
"Redis unavailable for deduplication check, processing event anyway: %s",
envelope.event_id,
exc_info=True,
)
return False
def _rehydrate(self, envelope: TelemetryEnvelope) -> dict[str, Any]:
"""Rehydrate payload from reference or fallback.
Attempts to resolve payload_ref to full data. If that fails,
falls back to payload_fallback. If both fail, emits a degraded
event marker.
Args:
envelope: The telemetry envelope containing payload data.
Returns:
The rehydrated payload dictionary.
"""
# For now, payload is directly in the envelope
# Future: implement payload_ref resolution from storage
payload = envelope.payload
if not payload and envelope.payload_fallback:
import pickle
try:
payload = pickle.loads(envelope.payload_fallback) # noqa: S301
logger.debug("Used payload_fallback for event_id=%s", envelope.event_id)
except Exception:
logger.warning(
"Failed to deserialize payload_fallback for event_id=%s",
envelope.event_id,
exc_info=True,
)
if not payload:
# Both ref and fallback failed - emit degraded event
logger.error(
"Payload rehydration failed for event_id=%s, tenant_id=%s, case=%s",
envelope.event_id,
envelope.tenant_id,
envelope.case,
)
# Emit degraded event marker
from enterprise.telemetry.telemetry_log import emit_metric_only_event
emit_metric_only_event(
event_name="dify.telemetry.rehydration_failed",
attributes={
"dify.tenant_id": envelope.tenant_id,
"dify.event_id": envelope.event_id,
"dify.case": envelope.case,
"rehydration_failed": True,
},
tenant_id=envelope.tenant_id,
)
self._increment_diagnostic_counter("rehydration_failed_total")
return {}
return payload
# Stub methods for each metric/log case
# These will be implemented in later tasks with actual emission logic
def _on_app_created(self, envelope: TelemetryEnvelope) -> None:
"""Handle app created event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for APP_CREATED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
attrs = {
"dify.app.id": payload.get("app_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.app.mode": payload.get("mode"),
}
emit_metric_only_event(
event_name="dify.app.created",
attributes=attrs,
tenant_id=envelope.tenant_id,
)
exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
{
"type": "app.created",
"tenant_id": envelope.tenant_id,
},
)
def _on_app_updated(self, envelope: TelemetryEnvelope) -> None:
"""Handle app updated event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for APP_UPDATED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
attrs = {
"dify.app.id": payload.get("app_id"),
"dify.tenant_id": envelope.tenant_id,
}
emit_metric_only_event(
event_name="dify.app.updated",
attributes=attrs,
tenant_id=envelope.tenant_id,
)
exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
{
"type": "app.updated",
"tenant_id": envelope.tenant_id,
},
)
def _on_app_deleted(self, envelope: TelemetryEnvelope) -> None:
"""Handle app deleted event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for APP_DELETED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
attrs = {
"dify.app.id": payload.get("app_id"),
"dify.tenant_id": envelope.tenant_id,
}
emit_metric_only_event(
event_name="dify.app.deleted",
attributes=attrs,
tenant_id=envelope.tenant_id,
)
exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
{
"type": "app.deleted",
"tenant_id": envelope.tenant_id,
},
)
def _on_feedback_created(self, envelope: TelemetryEnvelope) -> None:
"""Handle feedback created event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for FEEDBACK_CREATED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
include_content = exporter.include_content
attrs: dict = {
"dify.message.id": payload.get("message_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.app_id": payload.get("app_id"),
"dify.conversation.id": payload.get("conversation_id"),
"gen_ai.user.id": payload.get("from_end_user_id") or payload.get("from_account_id"),
"dify.feedback.rating": payload.get("rating"),
"dify.feedback.from_source": payload.get("from_source"),
}
if include_content:
attrs["dify.feedback.content"] = payload.get("content")
user_id = payload.get("from_end_user_id") or payload.get("from_account_id")
emit_metric_only_event(
event_name="dify.feedback.created",
attributes=attrs,
tenant_id=envelope.tenant_id,
user_id=str(user_id or ""),
)
exporter.increment_counter(
EnterpriseTelemetryCounter.FEEDBACK,
1,
{
"tenant_id": envelope.tenant_id,
"app_id": str(payload.get("app_id", "")),
"rating": str(payload.get("rating", "")),
},
)
def _on_message_run(self, envelope: TelemetryEnvelope) -> None:
"""Handle message run event (stub)."""
logger.debug("Processing MESSAGE_RUN: event_id=%s", envelope.event_id)
def _on_tool_execution(self, envelope: TelemetryEnvelope) -> None:
"""Handle tool execution event (stub)."""
logger.debug("Processing TOOL_EXECUTION: event_id=%s", envelope.event_id)
def _on_moderation_check(self, envelope: TelemetryEnvelope) -> None:
"""Handle moderation check event (stub)."""
logger.debug("Processing MODERATION_CHECK: event_id=%s", envelope.event_id)
def _on_suggested_question(self, envelope: TelemetryEnvelope) -> None:
"""Handle suggested question event (stub)."""
logger.debug("Processing SUGGESTED_QUESTION: event_id=%s", envelope.event_id)
def _on_dataset_retrieval(self, envelope: TelemetryEnvelope) -> None:
"""Handle dataset retrieval event (stub)."""
logger.debug("Processing DATASET_RETRIEVAL: event_id=%s", envelope.event_id)
def _on_generate_name(self, envelope: TelemetryEnvelope) -> None:
"""Handle generate name event (stub)."""
logger.debug("Processing GENERATE_NAME: event_id=%s", envelope.event_id)
def _on_prompt_generation(self, envelope: TelemetryEnvelope) -> None:
"""Handle prompt generation event (stub)."""
logger.debug("Processing PROMPT_GENERATION: event_id=%s", envelope.event_id)

View File

@@ -0,0 +1,119 @@
"""Structured-log emitter for enterprise telemetry events.
Emits structured JSON log lines correlated with OTEL traces via trace_id.
Picked up by ``StructuredJSONFormatter`` → stdout/Loki/Elastic.
"""
from __future__ import annotations
import logging
import uuid
from functools import lru_cache
from typing import Any
logger = logging.getLogger("dify.telemetry")
@lru_cache(maxsize=4096)
def compute_trace_id_hex(uuid_str: str | None) -> str:
"""Convert a business UUID string to a 32-hex OTEL-compatible trace_id.
Returns empty string when *uuid_str* is ``None`` or invalid.
"""
if not uuid_str:
return ""
normalized = uuid_str.strip().lower()
if len(normalized) == 32 and all(ch in "0123456789abcdef" for ch in normalized):
return normalized
try:
return f"{uuid.UUID(normalized).int:032x}"
except (ValueError, AttributeError):
return ""
@lru_cache(maxsize=4096)
def compute_span_id_hex(uuid_str: str | None) -> str:
if not uuid_str:
return ""
normalized = uuid_str.strip().lower()
if len(normalized) == 16 and all(ch in "0123456789abcdef" for ch in normalized):
return normalized
try:
from enterprise.telemetry.id_generator import compute_deterministic_span_id
return f"{compute_deterministic_span_id(normalized):016x}"
except (ValueError, AttributeError):
return ""
def emit_telemetry_log(
*,
event_name: str,
attributes: dict[str, Any],
signal: str = "metric_only",
trace_id_source: str | None = None,
span_id_source: str | None = None,
tenant_id: str | None = None,
user_id: str | None = None,
) -> None:
"""Emit a structured log line for a telemetry event.
Parameters
----------
event_name:
Canonical event name, e.g. ``"dify.workflow.run"``.
attributes:
All event-specific attributes (already built by the caller).
signal:
``"metric_only"`` for events with no span, ``"span_detail"``
for detail logs accompanying a slim span.
trace_id_source:
A UUID string (e.g. ``workflow_run_id``) used to derive a 32-hex
trace_id for cross-signal correlation.
tenant_id:
Tenant identifier (for the ``IdentityContextFilter``).
user_id:
User identifier (for the ``IdentityContextFilter``).
"""
if not logger.isEnabledFor(logging.INFO):
return
attrs = {
"dify.event.name": event_name,
"dify.event.signal": signal,
**attributes,
}
extra: dict[str, Any] = {"attributes": attrs}
trace_id_hex = compute_trace_id_hex(trace_id_source)
if trace_id_hex:
extra["trace_id"] = trace_id_hex
span_id_hex = compute_span_id_hex(span_id_source)
if span_id_hex:
extra["span_id"] = span_id_hex
if tenant_id:
extra["tenant_id"] = tenant_id
if user_id:
extra["user_id"] = user_id
logger.info("telemetry.%s", signal, extra=extra)
def emit_metric_only_event(
*,
event_name: str,
attributes: dict[str, Any],
trace_id_source: str | None = None,
span_id_source: str | None = None,
tenant_id: str | None = None,
user_id: str | None = None,
) -> None:
emit_telemetry_log(
event_name=event_name,
attributes=attributes,
signal="metric_only",
trace_id_source=trace_id_source,
span_id_source=span_id_source,
tenant_id=tenant_id,
user_id=user_id,
)

View File

@@ -3,6 +3,12 @@ from blinker import signal
# sender: app
app_was_created = signal("app-was-created")
# sender: app
app_was_deleted = signal("app-was-deleted")
# sender: app
app_was_updated = signal("app-was-updated")
# sender: app, kwargs: app_model_config
app_model_config_was_updated = signal("app-model-config-was-updated")

View File

@@ -0,0 +1,4 @@
from blinker import signal
# sender: MessageFeedback, kwargs: tenant_id
feedback_was_created = signal("feedback-was-created")

View File

@@ -0,0 +1,58 @@
"""Flask extension for enterprise telemetry lifecycle management.
Initializes the EnterpriseExporter and TelemetryGateway singletons during
``create_app()`` (single-threaded), registers blinker event handlers,
and hooks atexit for graceful shutdown.
Skipped entirely when ``ENTERPRISE_ENABLED`` and ``ENTERPRISE_TELEMETRY_ENABLED``
are false (``is_enabled()`` gate).
"""
from __future__ import annotations
import atexit
import logging
from typing import TYPE_CHECKING
from configs import dify_config
if TYPE_CHECKING:
from dify_app import DifyApp
from enterprise.telemetry.exporter import EnterpriseExporter
from enterprise.telemetry.gateway import TelemetryGateway
logger = logging.getLogger(__name__)
_exporter: EnterpriseExporter | None = None
_gateway: TelemetryGateway | None = None
def is_enabled() -> bool:
return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED)
def init_app(app: DifyApp) -> None:
global _exporter, _gateway
if not is_enabled():
return
from enterprise.telemetry.exporter import EnterpriseExporter
from enterprise.telemetry.gateway import TelemetryGateway
_exporter = EnterpriseExporter(dify_config)
_gateway = TelemetryGateway()
atexit.register(_exporter.shutdown)
# Import to trigger @signal.connect decorator registration
import enterprise.telemetry.event_handlers # noqa: F401 # type: ignore[reportUnusedImport]
logger.info("Enterprise telemetry initialized")
def get_enterprise_exporter() -> EnterpriseExporter | None:
return _exporter
def get_gateway() -> TelemetryGateway | None:
return _gateway

View File

@@ -21,3 +21,15 @@ class DifySpanAttributes:
INVOKE_FROM = "dify.invoke_from"
"""Invocation source, e.g. SERVICE_API, WEB_APP, DEBUGGER."""
INVOKED_BY = "dify.invoked_by"
"""Invoked by, e.g. end_user, account, user."""
USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
"""Number of input tokens (prompt tokens) used."""
USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
"""Number of output tokens (completion tokens) generated."""
USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
"""Total number of tokens used."""

View File

@@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.12.0"
version = "1.12.1"
requires-python = ">=3.11,<3.13"
dependencies = [

View File

@@ -327,6 +327,12 @@ class AccountService:
@staticmethod
def delete_account(account: Account):
"""Delete account. This method only adds a task to the queue for deletion."""
# Queue account deletion sync tasks for all workspaces BEFORE account deletion (enterprise only)
from services.enterprise.account_deletion_sync import sync_account_deletion
sync_account_deletion(account_id=account.id, source="account_deleted")
# Now proceed with async account deletion
delete_account_task.delay(account.id)
@staticmethod
@@ -1230,6 +1236,11 @@ class TenantService:
if dify_config.BILLING_ENABLED:
BillingService.clean_billing_info_cache(tenant.id)
# Queue account deletion sync task for enterprise backend to reassign resources (enterprise only)
from services.enterprise.account_deletion_sync import sync_workspace_member_removal
sync_workspace_member_removal(workspace_id=tenant.id, member_id=account.id, source="workspace_member_removed")
@staticmethod
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account):
"""Update member role"""

View File

@@ -14,7 +14,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelTy
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created
from events.app_event import app_was_created, app_was_deleted
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
@@ -340,6 +340,8 @@ class AppService:
db.session.delete(app)
db.session.commit()
app_was_deleted.send(app)
# clean up web app settings
if FeatureService.get_system_features().webapp_auth.enabled:
EnterpriseService.WebAppAuth.cleanup_webapp(app.id)

View File

@@ -0,0 +1,115 @@
import json
import logging
import uuid
from datetime import UTC, datetime
from redis import RedisError
from configs import dify_config
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import TenantAccountJoin
logger = logging.getLogger(__name__)
ACCOUNT_DELETION_SYNC_QUEUE = "enterprise:member:sync:queue"
ACCOUNT_DELETION_SYNC_TASK_TYPE = "sync_member_deletion_from_workspace"
def _queue_task(workspace_id: str, member_id: str, *, source: str) -> bool:
"""
Queue an account deletion sync task to Redis.
Internal helper function. Do not call directly - use the public functions instead.
Args:
workspace_id: The workspace/tenant ID to sync
member_id: The member/account ID that was removed
source: Source of the sync request (for debugging/tracking)
Returns:
bool: True if task was queued successfully, False otherwise
"""
try:
task = {
"task_id": str(uuid.uuid4()),
"workspace_id": workspace_id,
"member_id": member_id,
"retry_count": 0,
"created_at": datetime.now(UTC).isoformat(),
"source": source,
"type": ACCOUNT_DELETION_SYNC_TASK_TYPE,
}
# Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP
redis_client.lpush(ACCOUNT_DELETION_SYNC_QUEUE, json.dumps(task))
logger.info(
"Queued account deletion sync task for workspace %s, member %s, task_id: %s, source: %s",
workspace_id,
member_id,
task["task_id"],
source,
)
return True
except (RedisError, TypeError) as e:
logger.error(
"Failed to queue account deletion sync for workspace %s, member %s: %s",
workspace_id,
member_id,
str(e),
exc_info=True,
)
# Don't raise - we don't want to fail member deletion if queueing fails
return False
def sync_workspace_member_removal(workspace_id: str, member_id: str, *, source: str) -> bool:
"""
Sync a single workspace member removal (enterprise only).
Queues a task for the enterprise backend to reassign resources from the removed member.
Handles enterprise edition check internally. Safe to call in community edition (no-op).
Args:
workspace_id: The workspace/tenant ID
member_id: The member/account ID that was removed
source: Source of the sync request (e.g., "workspace_member_removed")
Returns:
bool: True if task was queued (or skipped in community), False if queueing failed
"""
if not dify_config.ENTERPRISE_ENABLED:
return True
return _queue_task(workspace_id=workspace_id, member_id=member_id, source=source)
def sync_account_deletion(account_id: str, *, source: str) -> bool:
"""
Sync full account deletion across all workspaces (enterprise only).
Fetches all workspace memberships for the account and queues a sync task for each.
Handles enterprise edition check internally. Safe to call in community edition (no-op).
Args:
account_id: The account ID being deleted
source: Source of the sync request (e.g., "account_deleted")
Returns:
bool: True if all tasks were queued (or skipped in community), False if any queueing failed
"""
if not dify_config.ENTERPRISE_ENABLED:
return True
# Fetch all workspaces the account belongs to
workspace_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).all()
# Queue sync task for each workspace
success = True
for join in workspace_joins:
if not _queue_task(workspace_id=join.tenant_id, member_id=account_id, source=source):
success = False
return success

View File

@@ -7,9 +7,10 @@ from core.llm_generator.llm_generator import LLMGenerator
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
from events.feedback_event import feedback_was_created
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
@@ -179,6 +180,9 @@ class MessageService:
db.session.commit()
if feedback and rating:
feedback_was_created.send(feedback, tenant_id=app_model.tenant_id)
return feedback
@classmethod
@@ -294,10 +298,15 @@ class MessageService:
questions: list[str] = list(questions_sequence)
# get tracing instance
trace_manager = TraceQueueManager(app_id=app_model.id)
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE,
context=TelemetryContext(tenant_id=app_model.tenant_id, app_id=app_model.id),
payload={
"message_id": message_id,
"suggested_question": questions,
"timer": timer,
},
)
)

View File

@@ -1,3 +1,4 @@
import logging
from typing import Any
from core.ops.entities.config_entity import BaseTracingConfig
@@ -5,6 +6,8 @@ from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
from extensions.ext_database import db
from models.model import App, TraceAppConfig
logger = logging.getLogger(__name__)
class OpsService:
@classmethod
@@ -135,12 +138,13 @@ class OpsService:
return trace_config_data.to_dict()
@classmethod
def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict):
def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict, account_id: str):
"""
Create tracing app config
:param app_id: app id
:param tracing_provider: tracing provider
:param tracing_config: tracing config
:param account_id: account id of the user creating the config
:return:
"""
try:
@@ -207,15 +211,19 @@ class OpsService:
db.session.add(trace_config_data)
db.session.commit()
# Log the creation with modifier information
logger.info("Trace config created: app_id=%s, provider=%s, created_by=%s", app_id, tracing_provider, account_id)
return {"result": "success"}
@classmethod
def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict):
def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict, account_id: str):
"""
Update tracing app config
:param app_id: app id
:param tracing_provider: tracing provider
:param tracing_config: tracing config
:param account_id: account id of the user updating the config
:return:
"""
try:
@@ -251,14 +259,18 @@ class OpsService:
current_trace_config.tracing_config = tracing_config
db.session.commit()
# Log the update with modifier information
logger.info("Trace config updated: app_id=%s, provider=%s, updated_by=%s", app_id, tracing_provider, account_id)
return current_trace_config.to_dict()
@classmethod
def delete_tracing_app_config(cls, app_id: str, tracing_provider: str):
def delete_tracing_app_config(cls, app_id: str, tracing_provider: str, account_id: str):
"""
Delete tracing app config
:param app_id: app id
:param tracing_provider: tracing provider
:param account_id: account id of the user deleting the config
:return:
"""
trace_config = (
@@ -270,6 +282,9 @@ class OpsService:
if not trace_config:
return None
# Log the deletion with modifier information
logger.info("Trace config deleted: app_id=%s, provider=%s, deleted_by=%s", app_id, tracing_provider, account_id)
db.session.delete(trace_config)
db.session.commit()

View File

@@ -2,7 +2,10 @@ import json
import logging
from collections.abc import Mapping
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from models.account import Account
from sqlalchemy import exists, select
from sqlalchemy.orm import Session
@@ -406,20 +409,37 @@ class BuiltinToolManageService:
return {"result": "success"}
@staticmethod
def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str):
def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str, account: "Account | None" = None):
"""
set default provider
"""
with Session(db.engine) as session:
# get provider
target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first()
# get provider (verify tenant ownership to prevent IDOR)
target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first()
if target_provider is None:
raise ValueError("provider not found")
# clear default provider
session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
).update({"is_default": False})
if dify_config.ENTERPRISE_ENABLED:
# Enterprise: verify admin permission for tenant-wide operation
from models.account import TenantAccountRole
if account is None:
# In enterprise mode, an account context is required to perform permission checks
raise ValueError("Account is required to set default credentials in enterprise mode")
if not TenantAccountRole.is_privileged_role(account.current_role):
raise ValueError("Only workspace admins/owners can set default credentials in enterprise mode")
# Enterprise: clear ALL defaults for this provider in the tenant
# (regardless of user_id, since enterprise credentials may have different user_id)
session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id, provider=provider, is_default=True
).update({"is_default": False})
else:
# Non-enterprise: only clear defaults for the current user
session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
).update({"is_default": False})
# set new default provider
target_provider.is_default = True

View File

@@ -27,6 +27,7 @@ from core.workflow.nodes.start.entities import StartNodeData
from core.workflow.runtime import VariablePool
from core.workflow.system_variable import SystemVariable
from core.workflow.workflow_entry import WorkflowEntry
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
from enums.cloud_plan import CloudPlan
from events.app_event import app_draft_workflow_was_synced, app_published_workflow_was_updated
from extensions.ext_database import db
@@ -647,6 +648,7 @@ class WorkflowService:
node_config = draft_workflow.get_node_config_by_id(node_id)
node_type = Workflow.get_node_type_from_node_config(node_config)
node_data = node_config.get("data", {})
workflow_execution_id: str | None = None
if node_type.is_start_node:
with Session(bind=db.engine) as session, session.begin():
draft_var_srv = WorkflowDraftVariableService(session)
@@ -672,10 +674,13 @@ class WorkflowService:
node_type=node_type,
conversation_id=conversation_id,
)
workflow_execution_id = variable_pool.system_variables.workflow_execution_id
else:
workflow_execution_id = str(uuid.uuid4())
system_variable = SystemVariable(workflow_execution_id=workflow_execution_id)
variable_pool = VariablePool(
system_variables=SystemVariable.default(),
system_variables=system_variable,
user_inputs=user_inputs,
environment_variables=draft_workflow.environment_variables,
conversation_variables=[],
@@ -729,6 +734,13 @@ class WorkflowService:
with Session(db.engine) as session:
outputs = workflow_node_execution.load_full_outputs(session, storage)
enqueue_draft_node_execution_trace(
execution=workflow_node_execution,
outputs=outputs,
workflow_execution_id=workflow_execution_id,
user_id=account.id,
)
with Session(bind=db.engine) as session, session.begin():
draft_var_saver = DraftVariableSaver(
session=session,
@@ -784,19 +796,20 @@ class WorkflowService:
Returns:
WorkflowNodeExecution: The execution result
"""
created_at = naive_utc_now()
node, node_run_result, run_succeeded, error = self._execute_node_safely(invoke_node_fn)
finished_at = naive_utc_now()
# Create base node execution
node_execution = WorkflowNodeExecution(
id=str(uuid.uuid4()),
id=node.execution_id or str(uuid.uuid4()),
workflow_id="", # Single-step execution has no workflow ID
index=1,
node_id=node_id,
node_type=node.node_type,
title=node.title,
elapsed_time=time.perf_counter() - start_at,
created_at=naive_utc_now(),
finished_at=naive_utc_now(),
created_at=created_at,
finished_at=finished_at,
)
# Populate execution result data

View File

@@ -0,0 +1,52 @@
"""Celery worker for enterprise metric/log telemetry events.
This module defines the Celery task that processes telemetry envelopes
from the enterprise_telemetry queue. It deserializes envelopes and
dispatches them to the EnterpriseMetricHandler.
"""
import json
import logging
from celery import shared_task
from enterprise.telemetry.contracts import TelemetryEnvelope
from enterprise.telemetry.metric_handler import EnterpriseMetricHandler
logger = logging.getLogger(__name__)
@shared_task(queue="enterprise_telemetry")
def process_enterprise_telemetry(envelope_json: str) -> None:
"""Process enterprise metric/log telemetry envelope.
This task is enqueued by the TelemetryGateway for metric/log-only
events. It deserializes the envelope and dispatches to the handler.
Best-effort processing: logs errors but never raises, to avoid
failing user requests due to telemetry issues.
Args:
envelope_json: JSON-serialized TelemetryEnvelope.
"""
try:
# Deserialize envelope
envelope_dict = json.loads(envelope_json)
envelope = TelemetryEnvelope.model_validate(envelope_dict)
# Process through handler
handler = EnterpriseMetricHandler()
handler.handle(envelope)
logger.debug(
"Successfully processed telemetry envelope: tenant_id=%s, event_id=%s, case=%s",
envelope.tenant_id,
envelope.event_id,
envelope.case,
)
except Exception:
# Best-effort: log and drop on error, never fail user request
logger.warning(
"Failed to process enterprise telemetry envelope, dropping event",
exc_info=True,
)

View File

@@ -39,12 +39,24 @@ def process_trace_tasks(file_info):
trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]]
try:
trace_type = trace_info_info_map.get(trace_info_type)
if trace_type:
trace_info = trace_type(**trace_info)
from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled
if is_ee_telemetry_enabled():
from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace
try:
EnterpriseOtelTrace().trace(trace_info)
except Exception:
logger.warning("Enterprise trace failed for app_id: %s", app_id, exc_info=True)
if trace_instance:
with current_app.app_context():
trace_type = trace_info_info_map.get(trace_info_type)
if trace_type:
trace_info = trace_type(**trace_info)
trace_instance.trace(trace_info)
logger.info("Processing trace tasks success, app_id: %s", app_id)
except Exception as e:
logger.info("error:\n\n\n%s\n\n\n\n", e)
@@ -52,4 +64,12 @@ def process_trace_tasks(file_info):
redis_client.incr(failed_key)
logger.info("Processing trace tasks failed, app_id: %s", app_id)
finally:
storage.delete(file_path)
try:
storage.delete(file_path)
except Exception as e:
logger.warning(
"Failed to delete trace file %s for app_id %s: %s",
file_path,
app_id,
e,
)

View File

@@ -259,8 +259,8 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
def del_workflow_archive_log(workflow_archive_log_id: str):
db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
def del_workflow_archive_log(session, workflow_archive_log_id: str):
session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
synchronize_session=False
)
@@ -420,7 +420,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
total_files_deleted = 0
while True:
with session_factory.create_session() as session:
with session_factory.create_session() as session, session.begin():
# Get a batch of draft variable IDs along with their file_ids
query_sql = """
SELECT id, file_id FROM workflow_draft_variables

View File

@@ -10,7 +10,10 @@ from models import Tenant
from models.enums import CreatorUserRole
from models.model import App, UploadFile
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch
from tasks.remove_app_and_related_data_task import (
_delete_draft_variables,
delete_draft_variables_batch,
)
@pytest.fixture
@@ -297,12 +300,18 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
data = setup_offload_test_data
app_id = data["app"].id
upload_file_ids = [uf.id for uf in data["upload_files"]]
variable_file_ids = [vf.id for vf in data["variable_files"]]
mock_storage.delete.return_value = None
with session_factory.create_session() as session:
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_before = session.query(WorkflowDraftVariableFile).count()
upload_files_before = session.query(UploadFile).count()
var_files_before = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
.count()
)
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert draft_vars_before == 3
assert var_files_before == 2
assert upload_files_before == 2
@@ -315,8 +324,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
assert draft_vars_after == 0
with session_factory.create_session() as session:
var_files_after = session.query(WorkflowDraftVariableFile).count()
upload_files_after = session.query(UploadFile).count()
var_files_after = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
.count()
)
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert var_files_after == 0
assert upload_files_after == 0
@@ -329,6 +342,8 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
data = setup_offload_test_data
app_id = data["app"].id
upload_file_ids = [uf.id for uf in data["upload_files"]]
variable_file_ids = [vf.id for vf in data["variable_files"]]
mock_storage.delete.side_effect = [Exception("Storage error"), None]
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
@@ -339,8 +354,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
assert draft_vars_after == 0
with session_factory.create_session() as session:
var_files_after = session.query(WorkflowDraftVariableFile).count()
upload_files_after = session.query(UploadFile).count()
var_files_after = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
.count()
)
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert var_files_after == 0
assert upload_files_after == 0
@@ -395,3 +414,275 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
if app2_obj:
session.delete(app2_obj)
session.commit()
class TestDeleteDraftVariablesSessionCommit:
"""Test suite to verify session commit behavior in delete_draft_variables_batch."""
@pytest.fixture
def setup_offload_test_data(self, app_and_tenant):
"""Create test data with offload files for session commit tests."""
from core.variables.types import SegmentType
from libs.datetime_utils import naive_utc_now
tenant, app = app_and_tenant
with session_factory.create_session() as session:
upload_file1 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file1.json",
name="file1.json",
size=1024,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
upload_file2 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file2.json",
name="file2.json",
size=2048,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
session.add(upload_file1)
session.add(upload_file2)
session.flush()
var_file1 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file1.id,
size=1024,
length=10,
value_type=SegmentType.STRING,
)
var_file2 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file2.id,
size=2048,
length=20,
value_type=SegmentType.OBJECT,
)
session.add(var_file1)
session.add(var_file2)
session.flush()
draft_var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_1",
name="large_var_1",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file1.id,
)
draft_var2 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_2",
name="large_var_2",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file2.id,
)
draft_var3 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_3",
name="regular_var",
value=StringSegment(value="regular_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(draft_var1)
session.add(draft_var2)
session.add(draft_var3)
session.commit()
data = {
"app": app,
"tenant": tenant,
"upload_files": [upload_file1, upload_file2],
"variable_files": [var_file1, var_file2],
"draft_variables": [draft_var1, draft_var2, draft_var3],
}
yield data
with session_factory.create_session() as session:
for table, ids in [
(WorkflowDraftVariable, [v.id for v in data["draft_variables"]]),
(WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]),
(UploadFile, [uf.id for uf in data["upload_files"]]),
]:
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
session.execute(cleanup_query)
session.commit()
@pytest.fixture
def setup_commit_test_data(self, app_and_tenant):
"""Create test data for session commit tests."""
tenant, app = app_and_tenant
variable_ids: list[str] = []
with session_factory.create_session() as session:
variables = []
for i in range(10):
var = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(var)
variables.append(var)
session.commit()
variable_ids = [v.id for v in variables]
yield {
"app": app,
"tenant": tenant,
"variable_ids": variable_ids,
}
with session_factory.create_session() as session:
cleanup_query = (
delete(WorkflowDraftVariable)
.where(WorkflowDraftVariable.id.in_(variable_ids))
.execution_options(synchronize_session=False)
)
session.execute(cleanup_query)
session.commit()
def test_session_commit_is_called_after_each_batch(self, setup_commit_test_data):
"""Test that session.begin() is used for automatic transaction management."""
data = setup_commit_test_data
app_id = data["app"].id
# Since session.begin() is used, the transaction is automatically committed
# when the with block exits successfully. We verify this by checking that
# data is actually persisted.
deleted_count = delete_draft_variables_batch(app_id, batch_size=3)
# Verify all data was deleted (proves transaction was committed)
with session_factory.create_session() as session:
remaining_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert deleted_count == 10
assert remaining_count == 0
def test_data_persisted_after_batch_deletion(self, setup_commit_test_data):
"""Test that data is actually persisted to database after batch deletion with commits."""
data = setup_commit_test_data
app_id = data["app"].id
variable_ids = data["variable_ids"]
# Verify initial state
with session_factory.create_session() as session:
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert initial_count == 10
# Perform deletion with small batch size to force multiple commits
deleted_count = delete_draft_variables_batch(app_id, batch_size=3)
assert deleted_count == 10
# Verify all data is deleted in a new session (proves commits worked)
with session_factory.create_session() as session:
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert final_count == 0
# Verify specific IDs are deleted
with session_factory.create_session() as session:
remaining_vars = (
session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id.in_(variable_ids)).count()
)
assert remaining_vars == 0
def test_session_commit_with_empty_dataset(self, setup_commit_test_data):
"""Test session behavior when deleting from an empty dataset."""
nonexistent_app_id = str(uuid.uuid4())
# Should not raise any errors and should return 0
deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=10)
assert deleted_count == 0
def test_session_commit_with_single_batch(self, setup_commit_test_data):
"""Test that commit happens correctly when all data fits in a single batch."""
data = setup_commit_test_data
app_id = data["app"].id
with session_factory.create_session() as session:
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert initial_count == 10
# Delete all in a single batch
deleted_count = delete_draft_variables_batch(app_id, batch_size=100)
assert deleted_count == 10
# Verify data is persisted
with session_factory.create_session() as session:
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert final_count == 0
def test_invalid_batch_size_raises_error(self, setup_commit_test_data):
"""Test that invalid batch size raises ValueError."""
data = setup_commit_test_data
app_id = data["app"].id
with pytest.raises(ValueError, match="batch_size must be positive"):
delete_draft_variables_batch(app_id, batch_size=0)
with pytest.raises(ValueError, match="batch_size must be positive"):
delete_draft_variables_batch(app_id, batch_size=-1)
@patch("extensions.ext_storage.storage")
def test_session_commit_with_offload_data_cleanup(self, mock_storage, setup_offload_test_data):
"""Test that session commits correctly when cleaning up offload data."""
data = setup_offload_test_data
app_id = data["app"].id
upload_file_ids = [uf.id for uf in data["upload_files"]]
mock_storage.delete.return_value = None
# Verify initial state
with session_factory.create_session() as session:
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_before = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
.count()
)
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert draft_vars_before == 3
assert var_files_before == 2
assert upload_files_before == 2
# Delete variables with offload data
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
assert deleted_count == 3
# Verify all data is persisted (deleted) in new session
with session_factory.create_session() as session:
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_after = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
.count()
)
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert draft_vars_after == 0
assert var_files_after == 0
assert upload_files_after == 0
# Verify storage cleanup was called
assert mock_storage.delete.call_count == 2

View File

@@ -0,0 +1,200 @@
"""Unit tests for TraceQueueManager telemetry guard.
This test suite verifies that TraceQueueManager correctly drops trace tasks
when telemetry is disabled, proving Bug 1 from code review is a false positive.
The guard logic moved from persistence.py to TraceQueueManager.add_trace_task()
at line 1282 of ops_trace_manager.py:
if self._enterprise_telemetry_enabled or self.trace_instance:
trace_task.app_id = self.app_id
trace_manager_queue.put(trace_task)
Tasks are only enqueued if EITHER:
- Enterprise telemetry is enabled (_enterprise_telemetry_enabled=True), OR
- A third-party trace instance (Langfuse, etc.) is configured
When BOTH are false, tasks are silently dropped (correct behavior).
"""
import queue
import sys
import types
from unittest.mock import MagicMock, patch
import pytest
@pytest.fixture
def trace_queue_manager_and_task(monkeypatch):
"""Fixture to provide TraceQueueManager and TraceTask with delayed imports."""
module_name = "core.ops.ops_trace_manager"
if module_name not in sys.modules:
ops_stub = types.ModuleType(module_name)
class StubTraceTask:
def __init__(self, trace_type):
self.trace_type = trace_type
self.app_id = None
class StubTraceQueueManager:
def __init__(self, app_id=None):
self.app_id = app_id
from core.telemetry import is_enterprise_telemetry_enabled
self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled()
self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id)
def add_trace_task(self, trace_task):
if self._enterprise_telemetry_enabled or self.trace_instance:
trace_task.app_id = self.app_id
from core.ops.ops_trace_manager import trace_manager_queue
trace_manager_queue.put(trace_task)
class StubOpsTraceManager:
@staticmethod
def get_ops_trace_instance(app_id):
return None
ops_stub.TraceQueueManager = StubTraceQueueManager
ops_stub.TraceTask = StubTraceTask
ops_stub.OpsTraceManager = StubOpsTraceManager
ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue)
monkeypatch.setitem(sys.modules, module_name, ops_stub)
from core.ops.entities.trace_entity import TraceTaskName
ops_module = __import__(module_name, fromlist=["TraceQueueManager", "TraceTask"])
TraceQueueManager = ops_module.TraceQueueManager
TraceTask = ops_module.TraceTask
return TraceQueueManager, TraceTask, TraceTaskName
class TestTraceQueueManagerTelemetryGuard:
"""Test TraceQueueManager's telemetry guard in add_trace_task()."""
def test_task_not_enqueued_when_telemetry_disabled_and_no_trace_instance(self, trace_queue_manager_and_task):
"""Verify task is NOT enqueued when telemetry disabled and no trace instance.
This is the core guard: when _enterprise_telemetry_enabled=False AND
trace_instance=None, the task should be silently dropped.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.is_enterprise_telemetry_enabled", return_value=False),
patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="test-app-id")
manager.add_trace_task(trace_task)
mock_queue.put.assert_not_called()
def test_task_enqueued_when_telemetry_enabled(self, trace_queue_manager_and_task):
"""Verify task IS enqueued when enterprise telemetry is enabled.
When _enterprise_telemetry_enabled=True, the task should be enqueued
regardless of trace_instance state.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.is_enterprise_telemetry_enabled", return_value=True),
patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="test-app-id")
manager.add_trace_task(trace_task)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.app_id == "test-app-id"
def test_task_enqueued_when_trace_instance_configured(self, trace_queue_manager_and_task):
"""Verify task IS enqueued when third-party trace instance is configured.
When trace_instance is not None (e.g., Langfuse configured), the task
should be enqueued even if enterprise telemetry is disabled.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
mock_trace_instance = MagicMock()
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.is_enterprise_telemetry_enabled", return_value=False),
patch(
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance
),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="test-app-id")
manager.add_trace_task(trace_task)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.app_id == "test-app-id"
def test_task_enqueued_when_both_telemetry_and_trace_instance_enabled(self, trace_queue_manager_and_task):
"""Verify task IS enqueued when both telemetry and trace instance are enabled.
When both _enterprise_telemetry_enabled=True AND trace_instance is set,
the task should definitely be enqueued.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
mock_trace_instance = MagicMock()
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.is_enterprise_telemetry_enabled", return_value=True),
patch(
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance
),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="test-app-id")
manager.add_trace_task(trace_task)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.app_id == "test-app-id"
def test_app_id_set_before_enqueue(self, trace_queue_manager_and_task):
"""Verify app_id is set on the task before enqueuing.
The guard logic sets trace_task.app_id = self.app_id before calling
trace_manager_queue.put(trace_task). This test verifies that behavior.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.is_enterprise_telemetry_enabled", return_value=True),
patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="expected-app-id")
manager.add_trace_task(trace_task)
called_task = mock_queue.put.call_args[0][0]
assert called_task.app_id == "expected-app-id"

View File

@@ -0,0 +1,181 @@
"""Unit tests for core.telemetry.emit() routing and enterprise-only filtering."""
from __future__ import annotations
import queue
import sys
import types
from unittest.mock import MagicMock, patch
import pytest
from core.ops.entities.trace_entity import TraceTaskName
from core.telemetry.events import TelemetryContext, TelemetryEvent
@pytest.fixture
def telemetry_test_setup(monkeypatch):
module_name = "core.ops.ops_trace_manager"
ops_stub = types.ModuleType(module_name)
class StubTraceTask:
def __init__(self, trace_type, **kwargs):
self.trace_type = trace_type
self.app_id = None
self.kwargs = kwargs
class StubTraceQueueManager:
def __init__(self, app_id=None, user_id=None):
self.app_id = app_id
self.user_id = user_id
self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id)
def add_trace_task(self, trace_task):
trace_task.app_id = self.app_id
from core.ops.ops_trace_manager import trace_manager_queue
trace_manager_queue.put(trace_task)
class StubOpsTraceManager:
@staticmethod
def get_ops_trace_instance(app_id):
return None
ops_stub.TraceQueueManager = StubTraceQueueManager
ops_stub.TraceTask = StubTraceTask
ops_stub.OpsTraceManager = StubOpsTraceManager
ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue)
monkeypatch.setitem(sys.modules, module_name, ops_stub)
from core.telemetry import emit
return emit, ops_stub.trace_manager_queue
class TestTelemetryEmit:
@patch("core.telemetry._is_enterprise_telemetry_enabled", return_value=True)
def test_emit_enterprise_trace_creates_trace_task(self, _mock_ee, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
event = TelemetryEvent(
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={"key": "value"},
)
emit_fn(event)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
def test_emit_community_trace_enqueued(self, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
event = TelemetryEvent(
name=TraceTaskName.WORKFLOW_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={},
)
emit_fn(event)
mock_queue.put.assert_called_once()
def test_emit_enterprise_only_trace_dropped_when_ee_disabled(self, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
event = TelemetryEvent(
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={},
)
emit_fn(event)
mock_queue.put.assert_not_called()
@patch("core.telemetry._is_enterprise_telemetry_enabled", return_value=True)
def test_emit_all_enterprise_only_traces_allowed_when_ee_enabled(self, _mock_ee, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
enterprise_only_traces = [
TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
TraceTaskName.NODE_EXECUTION_TRACE,
TraceTaskName.PROMPT_GENERATION_TRACE,
]
for trace_name in enterprise_only_traces:
mock_queue.reset_mock()
event = TelemetryEvent(
name=trace_name,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={},
)
emit_fn(event)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.trace_type == trace_name
@patch("core.telemetry._is_enterprise_telemetry_enabled", return_value=True)
def test_emit_passes_name_directly_to_trace_task(self, _mock_ee, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
event = TelemetryEvent(
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={"extra": "data"},
)
emit_fn(event)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
assert isinstance(called_task.trace_type, TraceTaskName)
@patch("core.telemetry._is_enterprise_telemetry_enabled", return_value=True)
def test_emit_with_provided_trace_manager(self, _mock_ee, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
mock_trace_manager = MagicMock()
mock_trace_manager.add_trace_task = MagicMock()
event = TelemetryEvent(
name=TraceTaskName.NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={},
)
emit_fn(event, trace_manager=mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
called_task = mock_trace_manager.add_trace_task.call_args[0][0]
assert called_task.trace_type == TraceTaskName.NODE_EXECUTION_TRACE

View File

@@ -0,0 +1,252 @@
from __future__ import annotations
import sys
from unittest.mock import MagicMock, patch
import pytest
from core.telemetry import is_enterprise_telemetry_enabled
from enterprise.telemetry.contracts import TelemetryCase
from enterprise.telemetry.gateway import TelemetryGateway
class TestTelemetryCoreExports:
def test_is_enterprise_telemetry_enabled_exported(self) -> None:
from core.telemetry import is_enterprise_telemetry_enabled as exported_func
assert callable(exported_func)
@pytest.fixture
def mock_ops_trace_manager():
mock_module = MagicMock()
mock_trace_task_class = MagicMock()
mock_trace_task_class.return_value = MagicMock()
mock_module.TraceTask = mock_trace_task_class
mock_module.TraceQueueManager = MagicMock()
mock_trace_entity = MagicMock()
mock_trace_task_name = MagicMock()
mock_trace_task_name.return_value = "workflow"
mock_trace_entity.TraceTaskName = mock_trace_task_name
with (
patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}),
patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}),
):
yield mock_module, mock_trace_entity
class TestGatewayIntegrationTraceRouting:
@pytest.fixture
def gateway(self) -> TelemetryGateway:
return TelemetryGateway()
@pytest.fixture
def mock_trace_manager(self) -> MagicMock:
return MagicMock()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_ce_eligible_trace_routed_to_trace_manager(
self,
gateway: TelemetryGateway,
mock_trace_manager: MagicMock,
) -> None:
with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True):
context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"}
payload = {"workflow_run_id": "run-abc"}
gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_ce_eligible_trace_routed_when_ee_disabled(
self,
gateway: TelemetryGateway,
mock_trace_manager: MagicMock,
) -> None:
with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"workflow_run_id": "run-abc"}
gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_enterprise_only_trace_dropped_when_ee_disabled(
self,
gateway: TelemetryGateway,
mock_trace_manager: MagicMock,
) -> None:
with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_id": "node-abc"}
gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_not_called()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_enterprise_only_trace_routed_when_ee_enabled(
self,
gateway: TelemetryGateway,
mock_trace_manager: MagicMock,
) -> None:
with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_id": "node-abc"}
gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
class TestGatewayIntegrationMetricRouting:
@pytest.fixture
def gateway(self) -> TelemetryGateway:
return TelemetryGateway()
def test_metric_case_routes_to_celery_task(
self,
gateway: TelemetryGateway,
) -> None:
from enterprise.telemetry.contracts import TelemetryEnvelope
with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") as mock_delay:
context = {"tenant_id": "tenant-123"}
payload = {"app_id": "app-abc", "name": "My App"}
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
mock_delay.assert_called_once()
envelope_json = mock_delay.call_args[0][0]
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
assert envelope.case == TelemetryCase.APP_CREATED
assert envelope.tenant_id == "tenant-123"
assert envelope.payload["app_id"] == "app-abc"
def test_tool_execution_metric_routed(
self,
gateway: TelemetryGateway,
) -> None:
from enterprise.telemetry.contracts import TelemetryEnvelope
with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") as mock_delay:
context = {"tenant_id": "tenant-123", "app_id": "app-123"}
payload = {"tool_name": "test_tool", "tool_inputs": {}, "tool_outputs": "result"}
gateway.emit(TelemetryCase.TOOL_EXECUTION, context, payload)
mock_delay.assert_called_once()
envelope_json = mock_delay.call_args[0][0]
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
assert envelope.case == TelemetryCase.TOOL_EXECUTION
def test_moderation_check_metric_routed(
self,
gateway: TelemetryGateway,
) -> None:
from enterprise.telemetry.contracts import TelemetryEnvelope
with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") as mock_delay:
context = {"tenant_id": "tenant-123", "app_id": "app-123"}
payload = {"message_id": "msg-123", "moderation_result": {"flagged": False}}
gateway.emit(TelemetryCase.MODERATION_CHECK, context, payload)
mock_delay.assert_called_once()
envelope_json = mock_delay.call_args[0][0]
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
assert envelope.case == TelemetryCase.MODERATION_CHECK
class TestGatewayIntegrationCEEligibility:
@pytest.fixture
def gateway(self) -> TelemetryGateway:
return TelemetryGateway()
@pytest.fixture
def mock_trace_manager(self) -> MagicMock:
return MagicMock()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_workflow_run_is_ce_eligible(
self,
gateway: TelemetryGateway,
mock_trace_manager: MagicMock,
) -> None:
with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"workflow_run_id": "run-abc"}
gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_message_run_is_ce_eligible(
self,
gateway: TelemetryGateway,
mock_trace_manager: MagicMock,
) -> None:
with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"message_id": "msg-abc", "conversation_id": "conv-123"}
gateway.emit(TelemetryCase.MESSAGE_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_node_execution_not_ce_eligible(
self,
gateway: TelemetryGateway,
mock_trace_manager: MagicMock,
) -> None:
with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_id": "node-abc"}
gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_not_called()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_draft_node_execution_not_ce_eligible(
self,
gateway: TelemetryGateway,
mock_trace_manager: MagicMock,
) -> None:
with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_execution_data": {}}
gateway.emit(TelemetryCase.DRAFT_NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_not_called()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_prompt_generation_not_ce_eligible(
self,
gateway: TelemetryGateway,
mock_trace_manager: MagicMock,
) -> None:
with patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"}
payload = {"operation_type": "generate", "instruction": "test"}
gateway.emit(TelemetryCase.PROMPT_GENERATION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_not_called()
class TestIsEnterpriseTelemetryEnabled:
def test_returns_false_when_exporter_import_fails(self) -> None:
with patch.dict(sys.modules, {"enterprise.telemetry.exporter": None}):
result = is_enterprise_telemetry_enabled()
assert result is False
def test_function_is_callable(self) -> None:
assert callable(is_enterprise_telemetry_enabled)

View File

@@ -0,0 +1,264 @@
"""Unit tests for telemetry gateway contracts."""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase, TelemetryEnvelope
from enterprise.telemetry.gateway import CASE_ROUTING
class TestTelemetryCase:
"""Tests for TelemetryCase enum."""
def test_all_cases_defined(self) -> None:
"""Verify all 14 telemetry cases are defined."""
expected_cases = {
"WORKFLOW_RUN",
"NODE_EXECUTION",
"DRAFT_NODE_EXECUTION",
"MESSAGE_RUN",
"TOOL_EXECUTION",
"MODERATION_CHECK",
"SUGGESTED_QUESTION",
"DATASET_RETRIEVAL",
"GENERATE_NAME",
"PROMPT_GENERATION",
"APP_CREATED",
"APP_UPDATED",
"APP_DELETED",
"FEEDBACK_CREATED",
}
actual_cases = {case.name for case in TelemetryCase}
assert actual_cases == expected_cases
def test_case_values(self) -> None:
"""Verify case enum values are correct."""
assert TelemetryCase.WORKFLOW_RUN.value == "workflow_run"
assert TelemetryCase.NODE_EXECUTION.value == "node_execution"
assert TelemetryCase.DRAFT_NODE_EXECUTION.value == "draft_node_execution"
assert TelemetryCase.MESSAGE_RUN.value == "message_run"
assert TelemetryCase.TOOL_EXECUTION.value == "tool_execution"
assert TelemetryCase.MODERATION_CHECK.value == "moderation_check"
assert TelemetryCase.SUGGESTED_QUESTION.value == "suggested_question"
assert TelemetryCase.DATASET_RETRIEVAL.value == "dataset_retrieval"
assert TelemetryCase.GENERATE_NAME.value == "generate_name"
assert TelemetryCase.PROMPT_GENERATION.value == "prompt_generation"
assert TelemetryCase.APP_CREATED.value == "app_created"
assert TelemetryCase.APP_UPDATED.value == "app_updated"
assert TelemetryCase.APP_DELETED.value == "app_deleted"
assert TelemetryCase.FEEDBACK_CREATED.value == "feedback_created"
class TestCaseRoute:
"""Tests for CaseRoute model."""
def test_valid_trace_route(self) -> None:
"""Verify valid trace route creation."""
route = CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True)
assert route.signal_type == SignalType.TRACE
assert route.ce_eligible is True
def test_valid_metric_log_route(self) -> None:
"""Verify valid metric_log route creation."""
route = CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False)
assert route.signal_type == SignalType.METRIC_LOG
assert route.ce_eligible is False
def test_invalid_signal_type(self) -> None:
"""Verify invalid signal_type is rejected."""
with pytest.raises(ValidationError):
CaseRoute(signal_type="invalid", ce_eligible=True)
class TestTelemetryEnvelope:
"""Tests for TelemetryEnvelope model."""
def test_valid_envelope_minimal(self) -> None:
"""Verify valid minimal envelope creation."""
envelope = TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
tenant_id="tenant-123",
event_id="event-456",
payload={"key": "value"},
)
assert envelope.case == TelemetryCase.WORKFLOW_RUN
assert envelope.tenant_id == "tenant-123"
assert envelope.event_id == "event-456"
assert envelope.payload == {"key": "value"}
assert envelope.payload_fallback is None
assert envelope.metadata is None
def test_valid_envelope_full(self) -> None:
"""Verify valid envelope with all fields."""
metadata = {"source": "api"}
fallback = b"fallback data"
envelope = TelemetryEnvelope(
case=TelemetryCase.MESSAGE_RUN,
tenant_id="tenant-789",
event_id="event-012",
payload={"message": "hello"},
payload_fallback=fallback,
metadata=metadata,
)
assert envelope.case == TelemetryCase.MESSAGE_RUN
assert envelope.tenant_id == "tenant-789"
assert envelope.event_id == "event-012"
assert envelope.payload == {"message": "hello"}
assert envelope.payload_fallback == fallback
assert envelope.metadata == metadata
def test_missing_required_case(self) -> None:
"""Verify missing case field is rejected."""
with pytest.raises(ValidationError):
TelemetryEnvelope(
tenant_id="tenant-123",
event_id="event-456",
payload={"key": "value"},
)
def test_missing_required_tenant_id(self) -> None:
"""Verify missing tenant_id field is rejected."""
with pytest.raises(ValidationError):
TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
event_id="event-456",
payload={"key": "value"},
)
def test_missing_required_event_id(self) -> None:
"""Verify missing event_id field is rejected."""
with pytest.raises(ValidationError):
TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
tenant_id="tenant-123",
payload={"key": "value"},
)
def test_missing_required_payload(self) -> None:
"""Verify missing payload field is rejected."""
with pytest.raises(ValidationError):
TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
tenant_id="tenant-123",
event_id="event-456",
)
def test_payload_fallback_within_limit(self) -> None:
"""Verify payload_fallback within 64KB limit is accepted."""
fallback = b"x" * 65536
envelope = TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
tenant_id="tenant-123",
event_id="event-456",
payload={"key": "value"},
payload_fallback=fallback,
)
assert envelope.payload_fallback == fallback
def test_payload_fallback_exceeds_limit(self) -> None:
"""Verify payload_fallback exceeding 64KB is rejected."""
fallback = b"x" * 65537
with pytest.raises(ValidationError) as exc_info:
TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
tenant_id="tenant-123",
event_id="event-456",
payload={"key": "value"},
payload_fallback=fallback,
)
assert "64KB" in str(exc_info.value)
def test_payload_fallback_none(self) -> None:
"""Verify payload_fallback can be None."""
envelope = TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
tenant_id="tenant-123",
event_id="event-456",
payload={"key": "value"},
payload_fallback=None,
)
assert envelope.payload_fallback is None
class TestCaseRouting:
"""Tests for CASE_ROUTING table."""
def test_all_cases_routed(self) -> None:
"""Verify all 14 cases have routing entries."""
assert len(CASE_ROUTING) == 14
for case in TelemetryCase:
assert case in CASE_ROUTING
def test_trace_ce_eligible_cases(self) -> None:
"""Verify trace cases with CE eligibility."""
ce_eligible_trace_cases = {
TelemetryCase.WORKFLOW_RUN,
TelemetryCase.MESSAGE_RUN,
}
for case in ce_eligible_trace_cases:
route = CASE_ROUTING[case]
assert route.signal_type == SignalType.TRACE
assert route.ce_eligible is True
def test_trace_enterprise_only_cases(self) -> None:
"""Verify trace cases that are enterprise-only."""
enterprise_only_trace_cases = {
TelemetryCase.NODE_EXECUTION,
TelemetryCase.DRAFT_NODE_EXECUTION,
TelemetryCase.PROMPT_GENERATION,
}
for case in enterprise_only_trace_cases:
route = CASE_ROUTING[case]
assert route.signal_type == SignalType.TRACE
assert route.ce_eligible is False
def test_metric_log_cases(self) -> None:
"""Verify metric/log-only cases."""
metric_log_cases = {
TelemetryCase.APP_CREATED,
TelemetryCase.APP_UPDATED,
TelemetryCase.APP_DELETED,
TelemetryCase.FEEDBACK_CREATED,
TelemetryCase.TOOL_EXECUTION,
TelemetryCase.MODERATION_CHECK,
TelemetryCase.SUGGESTED_QUESTION,
TelemetryCase.DATASET_RETRIEVAL,
TelemetryCase.GENERATE_NAME,
}
for case in metric_log_cases:
route = CASE_ROUTING[case]
assert route.signal_type == SignalType.METRIC_LOG
assert route.ce_eligible is False
def test_routing_table_completeness(self) -> None:
"""Verify routing table covers all cases with correct types."""
trace_cases = {
TelemetryCase.WORKFLOW_RUN,
TelemetryCase.MESSAGE_RUN,
TelemetryCase.NODE_EXECUTION,
TelemetryCase.DRAFT_NODE_EXECUTION,
TelemetryCase.PROMPT_GENERATION,
}
metric_log_cases = {
TelemetryCase.APP_CREATED,
TelemetryCase.APP_UPDATED,
TelemetryCase.APP_DELETED,
TelemetryCase.FEEDBACK_CREATED,
TelemetryCase.TOOL_EXECUTION,
TelemetryCase.MODERATION_CHECK,
TelemetryCase.SUGGESTED_QUESTION,
TelemetryCase.DATASET_RETRIEVAL,
TelemetryCase.GENERATE_NAME,
}
all_cases = trace_cases | metric_log_cases
assert len(all_cases) == 14
assert all_cases == set(TelemetryCase)
for case in trace_cases:
assert CASE_ROUTING[case].signal_type == SignalType.TRACE
for case in metric_log_cases:
assert CASE_ROUTING[case].signal_type == SignalType.METRIC_LOG

View File

@@ -0,0 +1,134 @@
from unittest.mock import MagicMock, patch
import pytest
from enterprise.telemetry import event_handlers
from enterprise.telemetry.contracts import TelemetryCase
@pytest.fixture
def mock_exporter():
with patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock:
exporter = MagicMock()
mock.return_value = exporter
yield exporter
@pytest.fixture
def mock_task():
with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry") as mock:
yield mock
def test_handle_app_created_calls_task(mock_exporter, mock_task):
sender = MagicMock()
sender.id = "app-123"
sender.tenant_id = "tenant-456"
sender.mode = "chat"
event_handlers._handle_app_created(sender)
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[0][0]
assert "app_created" in call_args
assert "tenant-456" in call_args
assert "app-123" in call_args
assert "chat" in call_args
def test_handle_app_created_no_exporter(mock_task):
with patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter", return_value=None):
sender = MagicMock()
sender.id = "app-123"
sender.tenant_id = "tenant-456"
event_handlers._handle_app_created(sender)
mock_task.delay.assert_not_called()
def test_handle_app_updated_calls_task(mock_exporter, mock_task):
sender = MagicMock()
sender.id = "app-123"
sender.tenant_id = "tenant-456"
event_handlers._handle_app_updated(sender)
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[0][0]
assert "app_updated" in call_args
assert "tenant-456" in call_args
assert "app-123" in call_args
def test_handle_app_deleted_calls_task(mock_exporter, mock_task):
sender = MagicMock()
sender.id = "app-123"
sender.tenant_id = "tenant-456"
event_handlers._handle_app_deleted(sender)
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[0][0]
assert "app_deleted" in call_args
assert "tenant-456" in call_args
assert "app-123" in call_args
def test_handle_feedback_created_calls_task(mock_exporter, mock_task):
sender = MagicMock()
sender.message_id = "msg-123"
sender.app_id = "app-456"
sender.conversation_id = "conv-789"
sender.from_end_user_id = "user-001"
sender.from_account_id = None
sender.rating = "like"
sender.from_source = "api"
sender.content = "Great response!"
event_handlers._handle_feedback_created(sender, tenant_id="tenant-456")
mock_task.delay.assert_called_once()
call_args = mock_task.delay.call_args[0][0]
assert "feedback_created" in call_args
assert "tenant-456" in call_args
assert "msg-123" in call_args
assert "app-456" in call_args
assert "conv-789" in call_args
assert "user-001" in call_args
assert "like" in call_args
assert "api" in call_args
assert "Great response!" in call_args
def test_handle_feedback_created_no_exporter(mock_task):
with patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter", return_value=None):
sender = MagicMock()
sender.message_id = "msg-123"
event_handlers._handle_feedback_created(sender, tenant_id="tenant-456")
mock_task.delay.assert_not_called()
def test_handlers_create_valid_envelopes(mock_exporter, mock_task):
import json
from enterprise.telemetry.contracts import TelemetryEnvelope
sender = MagicMock()
sender.id = "app-123"
sender.tenant_id = "tenant-456"
sender.mode = "chat"
event_handlers._handle_app_created(sender)
call_args = mock_task.delay.call_args[0][0]
envelope_dict = json.loads(call_args)
envelope = TelemetryEnvelope(**envelope_dict)
assert envelope.case == TelemetryCase.APP_CREATED
assert envelope.tenant_id == "tenant-456"
assert envelope.event_id
assert envelope.payload["app_id"] == "app-123"
assert envelope.payload["mode"] == "chat"

View File

@@ -0,0 +1,301 @@
from __future__ import annotations
import sys
from unittest.mock import MagicMock, patch
import pytest
from core.ops.entities.trace_entity import TraceTaskName
from enterprise.telemetry.contracts import SignalType, TelemetryCase, TelemetryEnvelope
from enterprise.telemetry.gateway import (
CASE_ROUTING,
CASE_TO_TRACE_TASK,
PAYLOAD_SIZE_THRESHOLD_BYTES,
TelemetryGateway,
emit,
)
class TestCaseRoutingTable:
def test_all_cases_have_routing(self) -> None:
for case in TelemetryCase:
assert case in CASE_ROUTING, f"Missing routing for {case}"
def test_trace_cases(self) -> None:
trace_cases = [
TelemetryCase.WORKFLOW_RUN,
TelemetryCase.MESSAGE_RUN,
TelemetryCase.NODE_EXECUTION,
TelemetryCase.DRAFT_NODE_EXECUTION,
TelemetryCase.PROMPT_GENERATION,
]
for case in trace_cases:
assert CASE_ROUTING[case].signal_type is SignalType.TRACE, f"{case} should be trace"
def test_metric_log_cases(self) -> None:
metric_log_cases = [
TelemetryCase.APP_CREATED,
TelemetryCase.APP_UPDATED,
TelemetryCase.APP_DELETED,
TelemetryCase.FEEDBACK_CREATED,
TelemetryCase.TOOL_EXECUTION,
TelemetryCase.MODERATION_CHECK,
TelemetryCase.SUGGESTED_QUESTION,
TelemetryCase.DATASET_RETRIEVAL,
TelemetryCase.GENERATE_NAME,
]
for case in metric_log_cases:
assert CASE_ROUTING[case].signal_type is SignalType.METRIC_LOG, f"{case} should be metric_log"
def test_ce_eligible_cases(self) -> None:
ce_eligible_cases = [TelemetryCase.WORKFLOW_RUN, TelemetryCase.MESSAGE_RUN]
for case in ce_eligible_cases:
assert CASE_ROUTING[case].ce_eligible is True, f"{case} should be CE eligible"
def test_enterprise_only_cases(self) -> None:
enterprise_only_cases = [
TelemetryCase.NODE_EXECUTION,
TelemetryCase.DRAFT_NODE_EXECUTION,
TelemetryCase.PROMPT_GENERATION,
]
for case in enterprise_only_cases:
assert CASE_ROUTING[case].ce_eligible is False, f"{case} should be enterprise-only"
def test_trace_cases_have_task_name_mapping(self) -> None:
trace_cases = [c for c in TelemetryCase if CASE_ROUTING[c].signal_type is SignalType.TRACE]
for case in trace_cases:
assert case in CASE_TO_TRACE_TASK, f"Missing TraceTaskName mapping for {case}"
@pytest.fixture
def mock_ops_trace_manager():
mock_module = MagicMock()
mock_trace_task_class = MagicMock()
mock_trace_task_class.return_value = MagicMock()
mock_module.TraceTask = mock_trace_task_class
mock_module.TraceQueueManager = MagicMock()
mock_trace_entity = MagicMock()
mock_trace_task_name = MagicMock()
mock_trace_task_name.return_value = "workflow"
mock_trace_entity.TraceTaskName = mock_trace_task_name
with (
patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}),
patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}),
):
yield mock_module, mock_trace_entity
class TestTelemetryGatewayTraceRouting:
@pytest.fixture
def gateway(self) -> TelemetryGateway:
return TelemetryGateway()
@pytest.fixture
def mock_trace_manager(self) -> MagicMock:
return MagicMock()
@patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True)
def test_trace_case_routes_to_trace_manager(
self,
_mock_ee_enabled: MagicMock,
gateway: TelemetryGateway,
mock_trace_manager: MagicMock,
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
) -> None:
context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"}
payload = {"workflow_run_id": "run-abc"}
gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False)
def test_ce_eligible_trace_enqueued_when_ee_disabled(
self,
_mock_ee_enabled: MagicMock,
gateway: TelemetryGateway,
mock_trace_manager: MagicMock,
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
) -> None:
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"workflow_run_id": "run-abc"}
gateway.emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=False)
def test_enterprise_only_trace_dropped_when_ee_disabled(
self,
_mock_ee_enabled: MagicMock,
gateway: TelemetryGateway,
mock_trace_manager: MagicMock,
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
) -> None:
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_id": "node-abc"}
gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_not_called()
@patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True)
def test_enterprise_only_trace_enqueued_when_ee_enabled(
self,
_mock_ee_enabled: MagicMock,
gateway: TelemetryGateway,
mock_trace_manager: MagicMock,
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
) -> None:
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_id": "node-abc"}
gateway.emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
class TestTelemetryGatewayMetricLogRouting:
@pytest.fixture
def gateway(self) -> TelemetryGateway:
return TelemetryGateway()
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
def test_metric_case_routes_to_celery_task(
self,
mock_delay: MagicMock,
gateway: TelemetryGateway,
) -> None:
context = {"tenant_id": "tenant-123"}
payload = {"app_id": "app-abc", "name": "My App"}
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
mock_delay.assert_called_once()
envelope_json = mock_delay.call_args[0][0]
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
assert envelope.case == TelemetryCase.APP_CREATED
assert envelope.tenant_id == "tenant-123"
assert envelope.payload["app_id"] == "app-abc"
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
def test_envelope_has_unique_event_id(
self,
mock_delay: MagicMock,
gateway: TelemetryGateway,
) -> None:
context = {"tenant_id": "tenant-123"}
payload = {"app_id": "app-abc"}
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
assert mock_delay.call_count == 2
envelope1 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[0][0][0])
envelope2 = TelemetryEnvelope.model_validate_json(mock_delay.call_args_list[1][0][0])
assert envelope1.event_id != envelope2.event_id
class TestTelemetryGatewayPayloadSizing:
@pytest.fixture
def gateway(self) -> TelemetryGateway:
return TelemetryGateway()
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
def test_small_payload_inlined(
self,
mock_delay: MagicMock,
gateway: TelemetryGateway,
) -> None:
context = {"tenant_id": "tenant-123"}
payload = {"key": "small_value"}
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
envelope_json = mock_delay.call_args[0][0]
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
assert envelope.payload == payload
assert envelope.metadata is None
@patch("enterprise.telemetry.gateway.storage")
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
def test_large_payload_stored(
self,
mock_delay: MagicMock,
mock_storage: MagicMock,
gateway: TelemetryGateway,
) -> None:
context = {"tenant_id": "tenant-123"}
large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000)
payload = {"key": large_value}
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
mock_storage.save.assert_called_once()
storage_key = mock_storage.save.call_args[0][0]
assert storage_key.startswith("telemetry/tenant-123/")
envelope_json = mock_delay.call_args[0][0]
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
assert envelope.payload == {}
assert envelope.metadata is not None
assert envelope.metadata["payload_ref"] == storage_key
@patch("enterprise.telemetry.gateway.storage")
@patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay")
def test_large_payload_fallback_on_storage_error(
self,
mock_delay: MagicMock,
mock_storage: MagicMock,
gateway: TelemetryGateway,
) -> None:
mock_storage.save.side_effect = Exception("Storage failure")
context = {"tenant_id": "tenant-123"}
large_value = "x" * (PAYLOAD_SIZE_THRESHOLD_BYTES + 1000)
payload = {"key": large_value}
gateway.emit(TelemetryCase.APP_CREATED, context, payload)
envelope_json = mock_delay.call_args[0][0]
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
assert envelope.payload == payload
assert envelope.metadata is None
class TestModuleLevelFunctions:
@patch("extensions.ext_enterprise_telemetry.get_gateway")
@patch("enterprise.telemetry.gateway._is_enterprise_telemetry_enabled", return_value=True)
def test_emit_function_uses_gateway(
self,
_mock_ee_enabled: MagicMock,
mock_get_gateway: MagicMock,
mock_ops_trace_manager: tuple[MagicMock, MagicMock],
) -> None:
mock_gateway = TelemetryGateway()
mock_get_gateway.return_value = mock_gateway
mock_trace_manager = MagicMock()
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"workflow_run_id": "run-abc"}
with patch.object(mock_gateway, "emit") as mock_emit:
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
mock_emit.assert_called_once_with(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
class TestTraceTaskNameMapping:
def test_workflow_run_mapping(self) -> None:
assert CASE_TO_TRACE_TASK[TelemetryCase.WORKFLOW_RUN] is TraceTaskName.WORKFLOW_TRACE
def test_message_run_mapping(self) -> None:
assert CASE_TO_TRACE_TASK[TelemetryCase.MESSAGE_RUN] is TraceTaskName.MESSAGE_TRACE
def test_node_execution_mapping(self) -> None:
assert CASE_TO_TRACE_TASK[TelemetryCase.NODE_EXECUTION] is TraceTaskName.NODE_EXECUTION_TRACE
def test_draft_node_execution_mapping(self) -> None:
assert CASE_TO_TRACE_TASK[TelemetryCase.DRAFT_NODE_EXECUTION] is TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
def test_prompt_generation_mapping(self) -> None:
assert CASE_TO_TRACE_TASK[TelemetryCase.PROMPT_GENERATION] is TraceTaskName.PROMPT_GENERATION_TRACE

View File

@@ -0,0 +1,452 @@
"""Unit tests for EnterpriseMetricHandler."""
from unittest.mock import MagicMock, patch
import pytest
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
from enterprise.telemetry.metric_handler import EnterpriseMetricHandler
@pytest.fixture
def mock_redis():
with patch("enterprise.telemetry.metric_handler.redis_client") as mock:
yield mock
@pytest.fixture
def sample_envelope():
return TelemetryEnvelope(
case=TelemetryCase.APP_CREATED,
tenant_id="test-tenant",
event_id="test-event-123",
payload={"app_id": "app-123", "name": "Test App"},
)
def test_dispatch_app_created(sample_envelope, mock_redis):
mock_redis.set.return_value = True
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_app_created") as mock_handler:
handler.handle(sample_envelope)
mock_handler.assert_called_once_with(sample_envelope)
def test_dispatch_app_updated(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_UPDATED,
tenant_id="test-tenant",
event_id="test-event-456",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_app_updated") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_app_deleted(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_DELETED,
tenant_id="test-tenant",
event_id="test-event-789",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_app_deleted") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_feedback_created(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.FEEDBACK_CREATED,
tenant_id="test-tenant",
event_id="test-event-abc",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_feedback_created") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_message_run(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.MESSAGE_RUN,
tenant_id="test-tenant",
event_id="test-event-msg",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_message_run") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_tool_execution(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.TOOL_EXECUTION,
tenant_id="test-tenant",
event_id="test-event-tool",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_tool_execution") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_moderation_check(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.MODERATION_CHECK,
tenant_id="test-tenant",
event_id="test-event-mod",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_moderation_check") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_suggested_question(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.SUGGESTED_QUESTION,
tenant_id="test-tenant",
event_id="test-event-sq",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_suggested_question") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_dataset_retrieval(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.DATASET_RETRIEVAL,
tenant_id="test-tenant",
event_id="test-event-ds",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_dataset_retrieval") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_generate_name(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.GENERATE_NAME,
tenant_id="test-tenant",
event_id="test-event-gn",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_generate_name") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_dispatch_prompt_generation(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.PROMPT_GENERATION,
tenant_id="test-tenant",
event_id="test-event-pg",
payload={},
)
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_prompt_generation") as mock_handler:
handler.handle(envelope)
mock_handler.assert_called_once_with(envelope)
def test_all_known_cases_have_handlers(mock_redis):
mock_redis.set.return_value = True
handler = EnterpriseMetricHandler()
for case in TelemetryCase:
envelope = TelemetryEnvelope(
case=case,
tenant_id="test-tenant",
event_id=f"test-{case.value}",
payload={},
)
handler.handle(envelope)
def test_idempotency_duplicate(sample_envelope, mock_redis):
mock_redis.set.return_value = None
handler = EnterpriseMetricHandler()
with patch.object(handler, "_on_app_created") as mock_handler:
handler.handle(sample_envelope)
mock_handler.assert_not_called()
def test_idempotency_first_seen(sample_envelope, mock_redis):
mock_redis.set.return_value = True
handler = EnterpriseMetricHandler()
is_dup = handler._is_duplicate(sample_envelope)
assert is_dup is False
mock_redis.set.assert_called_once_with(
"telemetry:dedup:test-tenant:test-event-123",
b"1",
nx=True,
ex=3600,
)
def test_idempotency_redis_failure_fails_open(sample_envelope, mock_redis, caplog):
mock_redis.set.side_effect = Exception("Redis unavailable")
handler = EnterpriseMetricHandler()
is_dup = handler._is_duplicate(sample_envelope)
assert is_dup is False
assert "Redis unavailable for deduplication check" in caplog.text
def test_rehydration_uses_payload(sample_envelope):
handler = EnterpriseMetricHandler()
payload = handler._rehydrate(sample_envelope)
assert payload == {"app_id": "app-123", "name": "Test App"}
def test_rehydration_fallback():
import pickle
fallback_data = {"fallback": "data"}
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_CREATED,
tenant_id="test-tenant",
event_id="test-event-fb",
payload={},
payload_fallback=pickle.dumps(fallback_data),
)
handler = EnterpriseMetricHandler()
payload = handler._rehydrate(envelope)
assert payload == fallback_data
def test_rehydration_emits_degraded_event_on_failure():
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_CREATED,
tenant_id="test-tenant",
event_id="test-event-fail",
payload={},
payload_fallback=None,
)
handler = EnterpriseMetricHandler()
with patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit:
payload = handler._rehydrate(envelope)
assert payload == {}
mock_emit.assert_called_once()
call_args = mock_emit.call_args
assert call_args[1]["event_name"] == "dify.telemetry.rehydration_failed"
assert call_args[1]["attributes"]["rehydration_failed"] is True
def test_on_app_created_emits_correct_event(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_CREATED,
tenant_id="tenant-123",
event_id="event-456",
payload={"app_id": "app-789", "mode": "chat"},
)
handler = EnterpriseMetricHandler()
with (
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
):
mock_exporter = MagicMock()
mock_get_exporter.return_value = mock_exporter
handler._on_app_created(envelope)
mock_emit.assert_called_once_with(
event_name="dify.app.created",
attributes={
"dify.app.id": "app-789",
"dify.tenant_id": "tenant-123",
"dify.app.mode": "chat",
},
tenant_id="tenant-123",
)
mock_exporter.increment_counter.assert_called_once()
call_args = mock_exporter.increment_counter.call_args
assert call_args[0][1] == 1
assert call_args[0][2]["type"] == "app.created"
assert call_args[0][2]["tenant_id"] == "tenant-123"
def test_on_app_updated_emits_correct_event(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_UPDATED,
tenant_id="tenant-123",
event_id="event-456",
payload={"app_id": "app-789"},
)
handler = EnterpriseMetricHandler()
with (
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
):
mock_exporter = MagicMock()
mock_get_exporter.return_value = mock_exporter
handler._on_app_updated(envelope)
mock_emit.assert_called_once_with(
event_name="dify.app.updated",
attributes={
"dify.app.id": "app-789",
"dify.tenant_id": "tenant-123",
},
tenant_id="tenant-123",
)
mock_exporter.increment_counter.assert_called_once()
call_args = mock_exporter.increment_counter.call_args
assert call_args[0][2]["type"] == "app.updated"
def test_on_app_deleted_emits_correct_event(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_DELETED,
tenant_id="tenant-123",
event_id="event-456",
payload={"app_id": "app-789"},
)
handler = EnterpriseMetricHandler()
with (
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
):
mock_exporter = MagicMock()
mock_get_exporter.return_value = mock_exporter
handler._on_app_deleted(envelope)
mock_emit.assert_called_once_with(
event_name="dify.app.deleted",
attributes={
"dify.app.id": "app-789",
"dify.tenant_id": "tenant-123",
},
tenant_id="tenant-123",
)
mock_exporter.increment_counter.assert_called_once()
call_args = mock_exporter.increment_counter.call_args
assert call_args[0][2]["type"] == "app.deleted"
def test_on_feedback_created_emits_correct_event(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.FEEDBACK_CREATED,
tenant_id="tenant-123",
event_id="event-456",
payload={
"message_id": "msg-001",
"app_id": "app-789",
"conversation_id": "conv-123",
"from_end_user_id": "user-456",
"from_account_id": None,
"rating": "like",
"from_source": "api",
"content": "Great!",
},
)
handler = EnterpriseMetricHandler()
with (
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
):
mock_exporter = MagicMock()
mock_exporter.include_content = True
mock_get_exporter.return_value = mock_exporter
handler._on_feedback_created(envelope)
mock_emit.assert_called_once()
call_args = mock_emit.call_args
assert call_args[1]["event_name"] == "dify.feedback.created"
assert call_args[1]["attributes"]["dify.message.id"] == "msg-001"
assert call_args[1]["attributes"]["dify.feedback.content"] == "Great!"
assert call_args[1]["tenant_id"] == "tenant-123"
assert call_args[1]["user_id"] == "user-456"
mock_exporter.increment_counter.assert_called_once()
counter_args = mock_exporter.increment_counter.call_args
assert counter_args[0][2]["app_id"] == "app-789"
assert counter_args[0][2]["rating"] == "like"
def test_on_feedback_created_without_content(mock_redis):
mock_redis.set.return_value = True
envelope = TelemetryEnvelope(
case=TelemetryCase.FEEDBACK_CREATED,
tenant_id="tenant-123",
event_id="event-456",
payload={
"message_id": "msg-001",
"app_id": "app-789",
"conversation_id": "conv-123",
"from_end_user_id": "user-456",
"from_account_id": None,
"rating": "like",
"from_source": "api",
"content": "Great!",
},
)
handler = EnterpriseMetricHandler()
with (
patch("extensions.ext_enterprise_telemetry.get_enterprise_exporter") as mock_get_exporter,
patch("enterprise.telemetry.telemetry_log.emit_metric_only_event") as mock_emit,
):
mock_exporter = MagicMock()
mock_exporter.include_content = False
mock_get_exporter.return_value = mock_exporter
handler._on_feedback_created(envelope)
mock_emit.assert_called_once()
call_args = mock_emit.call_args
assert "dify.feedback.content" not in call_args[1]["attributes"]

View File

@@ -0,0 +1,69 @@
"""Unit tests for enterprise telemetry Celery task."""
import json
from unittest.mock import MagicMock, patch
import pytest
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
@pytest.fixture
def sample_envelope_json():
envelope = TelemetryEnvelope(
case=TelemetryCase.APP_CREATED,
tenant_id="test-tenant",
event_id="test-event-123",
payload={"app_id": "app-123"},
)
return envelope.model_dump_json()
def test_process_enterprise_telemetry_success(sample_envelope_json):
with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class:
mock_handler = MagicMock()
mock_handler_class.return_value = mock_handler
process_enterprise_telemetry(sample_envelope_json)
mock_handler.handle.assert_called_once()
call_args = mock_handler.handle.call_args[0][0]
assert isinstance(call_args, TelemetryEnvelope)
assert call_args.case == TelemetryCase.APP_CREATED
assert call_args.tenant_id == "test-tenant"
assert call_args.event_id == "test-event-123"
def test_process_enterprise_telemetry_invalid_json(caplog):
invalid_json = "not valid json"
process_enterprise_telemetry(invalid_json)
assert "Failed to process enterprise telemetry envelope" in caplog.text
def test_process_enterprise_telemetry_handler_exception(sample_envelope_json, caplog):
with patch("tasks.enterprise_telemetry_task.EnterpriseMetricHandler") as mock_handler_class:
mock_handler = MagicMock()
mock_handler.handle.side_effect = Exception("Handler error")
mock_handler_class.return_value = mock_handler
process_enterprise_telemetry(sample_envelope_json)
assert "Failed to process enterprise telemetry envelope" in caplog.text
def test_process_enterprise_telemetry_validation_error(caplog):
invalid_envelope = json.dumps(
{
"case": "INVALID_CASE",
"tenant_id": "test-tenant",
"event_id": "test-event",
"payload": {},
}
)
process_enterprise_telemetry(invalid_envelope)
assert "Failed to process enterprise telemetry envelope" in caplog.text

View File

@@ -350,7 +350,7 @@ class TestDeleteWorkflowArchiveLogs:
mock_query.where.return_value = mock_delete_query
mock_db.session.query.return_value = mock_query
delete_func("log-1")
delete_func(mock_db.session, "log-1")
mock_db.session.query.assert_called_once_with(WorkflowArchiveLog)
mock_query.where.assert_called_once()

8
api/uv.lock generated
View File

@@ -1368,7 +1368,7 @@ wheels = [
[[package]]
name = "dify-api"
version = "1.12.0"
version = "1.12.1"
source = { virtual = "." }
dependencies = [
{ name = "aliyun-log-python-sdk" },
@@ -4433,15 +4433,15 @@ wheels = [
[[package]]
name = "pdfminer-six"
version = "20251230"
version = "20260107"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "charset-normalizer" },
{ name = "cryptography" },
]
sdist = { url = "https://files.pythonhosted.org/packages/46/9a/d79d8fa6d47a0338846bb558b39b9963b8eb2dfedec61867c138c1b17eeb/pdfminer_six-20251230.tar.gz", hash = "sha256:e8f68a14c57e00c2d7276d26519ea64be1b48f91db1cdc776faa80528ca06c1e", size = 8511285, upload-time = "2025-12-30T15:49:13.104Z" }
sdist = { url = "https://files.pythonhosted.org/packages/34/a4/5cec1112009f0439a5ca6afa8ace321f0ab2f48da3255b7a1c8953014670/pdfminer_six-20260107.tar.gz", hash = "sha256:96bfd431e3577a55a0efd25676968ca4ce8fd5b53f14565f85716ff363889602", size = 8512094, upload-time = "2026-01-07T13:29:12.937Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/65/d7/b288ea32deb752a09aab73c75e1e7572ab2a2b56c3124a5d1eb24c62ceb3/pdfminer_six-20251230-py3-none-any.whl", hash = "sha256:9ff2e3466a7dfc6de6fd779478850b6b7c2d9e9405aa2a5869376a822771f485", size = 6591909, upload-time = "2025-12-30T15:49:10.76Z" },
{ url = "https://files.pythonhosted.org/packages/20/8b/28c4eaec9d6b036a52cb44720408f26b1a143ca9bce76cc19e8f5de00ab4/pdfminer_six-20260107-py3-none-any.whl", hash = "sha256:366585ba97e80dffa8f00cebe303d2f381884d8637af4ce422f1df3ef38111a9", size = 6592252, upload-time = "2026-01-07T13:29:10.742Z" },
]
[[package]]

View File

@@ -21,7 +21,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.12.0
image: langgenius/dify-api:1.12.1
restart: always
environment:
# Use the shared environment variables.
@@ -63,7 +63,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.12.0
image: langgenius/dify-api:1.12.1
restart: always
environment:
# Use the shared environment variables.
@@ -102,7 +102,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.12.0
image: langgenius/dify-api:1.12.1
restart: always
environment:
# Use the shared environment variables.
@@ -132,7 +132,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.12.0
image: langgenius/dify-web:1.12.1
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@@ -707,7 +707,7 @@ services:
# API service
api:
image: langgenius/dify-api:1.12.0
image: langgenius/dify-api:1.12.1
restart: always
environment:
# Use the shared environment variables.
@@ -749,7 +749,7 @@ services:
# worker service
# The Celery worker for processing all queues (dataset, workflow, mail, etc.)
worker:
image: langgenius/dify-api:1.12.0
image: langgenius/dify-api:1.12.1
restart: always
environment:
# Use the shared environment variables.
@@ -788,7 +788,7 @@ services:
# worker_beat service
# Celery beat for scheduling periodic tasks.
worker_beat:
image: langgenius/dify-api:1.12.0
image: langgenius/dify-api:1.12.1
restart: always
environment:
# Use the shared environment variables.
@@ -818,7 +818,7 @@ services:
# Frontend web application.
web:
image: langgenius/dify-web:1.12.0
image: langgenius/dify-web:1.12.1
restart: always
environment:
CONSOLE_API_URL: ${CONSOLE_API_URL:-}

View File

@@ -154,7 +154,7 @@ export const GeneralChunkingOptions: FC<GeneralChunkingOptionsProps> = ({
</div>
))}
{
showSummaryIndexSetting && (
showSummaryIndexSetting && IS_CE_EDITION && (
<div className="mt-3">
<SummaryIndexSetting
entry="create-document"

View File

@@ -12,6 +12,7 @@ import Divider from '@/app/components/base/divider'
import { ParentChildChunk } from '@/app/components/base/icons/src/vender/knowledge'
import RadioCard from '@/app/components/base/radio-card'
import SummaryIndexSetting from '@/app/components/datasets/settings/summary-index-setting'
import { IS_CE_EDITION } from '@/config'
import { ChunkingMode } from '@/models/datasets'
import FileList from '../../assets/file-list-3-fill.svg'
import Note from '../../assets/note-mod.svg'
@@ -191,7 +192,7 @@ export const ParentChildOptions: FC<ParentChildOptionsProps> = ({
</div>
))}
{
showSummaryIndexSetting && (
showSummaryIndexSetting && IS_CE_EDITION && (
<div className="mt-3">
<SummaryIndexSetting
entry="create-document"

View File

@@ -26,6 +26,7 @@ import CustomPopover from '@/app/components/base/popover'
import Switch from '@/app/components/base/switch'
import { ToastContext } from '@/app/components/base/toast'
import Tooltip from '@/app/components/base/tooltip'
import { IS_CE_EDITION } from '@/config'
import { DataSourceType, DocumentActionType } from '@/models/datasets'
import {
useDocumentArchive,
@@ -263,10 +264,14 @@ const Operations = ({
<span className={s.actionName}>{t('list.action.sync', { ns: 'datasetDocuments' })}</span>
</div>
)}
<div className={s.actionItem} onClick={() => onOperate('summary')}>
<SearchLinesSparkle className="h-4 w-4 text-text-tertiary" />
<span className={s.actionName}>{t('list.action.summary', { ns: 'datasetDocuments' })}</span>
</div>
{
IS_CE_EDITION && (
<div className={s.actionItem} onClick={() => onOperate('summary')}>
<SearchLinesSparkle className="h-4 w-4 text-text-tertiary" />
<span className={s.actionName}>{t('list.action.summary', { ns: 'datasetDocuments' })}</span>
</div>
)
}
<Divider className="my-1" />
</>
)}

View File

@@ -7,6 +7,7 @@ import Button from '@/app/components/base/button'
import Confirm from '@/app/components/base/confirm'
import Divider from '@/app/components/base/divider'
import { SearchLinesSparkle } from '@/app/components/base/icons/src/vender/knowledge'
import { IS_CE_EDITION } from '@/config'
import { cn } from '@/utils/classnames'
const i18nPrefix = 'batchAction'
@@ -87,7 +88,7 @@ const BatchAction: FC<IBatchActionProps> = ({
<span className="px-0.5">{t('metadata.metadata', { ns: 'dataset' })}</span>
</Button>
)}
{onBatchSummary && (
{onBatchSummary && IS_CE_EDITION && (
<Button
variant="ghost"
className="gap-x-0.5 px-3"

View File

@@ -21,6 +21,7 @@ import RetrievalMethodConfig from '@/app/components/datasets/common/retrieval-me
import { ModelTypeEnum } from '@/app/components/header/account-setting/model-provider-page/declarations'
import { useModelList } from '@/app/components/header/account-setting/model-provider-page/hooks'
import ModelSelector from '@/app/components/header/account-setting/model-provider-page/model-selector'
import { IS_CE_EDITION } from '@/config'
import { useSelector as useAppContextWithSelector } from '@/context/app-context'
import { useDatasetDetailContextWithSelector } from '@/context/dataset-detail'
import { useDocLink } from '@/context/i18n'
@@ -359,7 +360,7 @@ const Form = () => {
{
indexMethod === IndexingType.QUALIFIED
&& [ChunkingMode.text, ChunkingMode.parentChild].includes(currentDataset?.doc_form as ChunkingMode)
&& (
&& IS_CE_EDITION && (
<>
<Divider
type="horizontal"

View File

@@ -104,7 +104,7 @@ const MembersPage = () => {
<UpgradeBtn className="mr-2" loc="member-invite" />
)}
<div className="shrink-0">
<InviteButton disabled={!isCurrentWorkspaceManager || isMemberFull} onClick={() => setInviteModalVisible(true)} />
{isCurrentWorkspaceManager && <InviteButton disabled={isMemberFull} onClick={() => setInviteModalVisible(true)} />}
</div>
</div>
<div className="overflow-visible lg:overflow-visible">

View File

@@ -18,6 +18,7 @@ import {
Group,
} from '@/app/components/workflow/nodes/_base/components/layout'
import VarReferencePicker from '@/app/components/workflow/nodes/_base/components/variable/var-reference-picker'
import { IS_CE_EDITION } from '@/config'
import Split from '../_base/components/split'
import ChunkStructure from './components/chunk-structure'
import EmbeddingModel from './components/embedding-model'
@@ -172,7 +173,7 @@ const Panel: FC<NodePanelProps<KnowledgeBaseNodeType>> = ({
{
data.indexing_technique === IndexMethodEnum.QUALIFIED
&& [ChunkStructureEnum.general, ChunkStructureEnum.parent_child].includes(data.chunk_structure)
&& (
&& IS_CE_EDITION && (
<>
<SummaryIndexSetting
summaryIndexSetting={data.summary_index_setting}

View File

@@ -1,7 +1,7 @@
{
"name": "dify-web",
"type": "module",
"version": "1.12.0",
"version": "1.12.1",
"private": true,
"packageManager": "pnpm@10.27.0+sha512.72d699da16b1179c14ba9e64dc71c9a40988cbdc65c264cb0e489db7de917f20dcf4d64d8723625f2969ba52d4b7e2a1170682d9ac2a5dcaeaab732b7e16f04a",
"imports": {

2205
web/pnpm-lock.yaml generated

File diff suppressed because it is too large Load Diff

View File

@@ -9,7 +9,6 @@ import type {
} from '@/types/workflow'
import { get, post } from './base'
import { getFlowPrefix } from './utils'
import { sanitizeWorkflowDraftPayload } from './workflow-payload'
export const fetchWorkflowDraft = (url: string) => {
return get(url, {}, { silent: true }) as Promise<FetchWorkflowDraftResponse>
@@ -19,8 +18,7 @@ export const syncWorkflowDraft = ({ url, params }: {
url: string
params: Pick<FetchWorkflowDraftResponse, 'graph' | 'features' | 'environment_variables' | 'conversation_variables'>
}) => {
const sanitized = sanitizeWorkflowDraftPayload(params)
return post<CommonResponse & { updated_at: number, hash: string }>(url, { body: sanitized }, { silent: true })
return post<CommonResponse & { updated_at: number, hash: string }>(url, { body: params }, { silent: true })
}
export const fetchNodesDefaultConfigs = (url: string) => {