Compare commits

...

81 Commits

Author SHA1 Message Date
GareArc
a677009b23 fix(api): re-add reg(ModelConfig) after merge from release branch 2026-03-04 17:33:29 -08:00
GareArc
43e0565985 Merge remote-tracking branch 'origin/release/e-1.12.1' into 1.12.1-otel-ee 2026-03-04 17:33:05 -08:00
GareArc
ee13650e3d fix(api): restore missing reg(ModelConfig) from 1.12.1 refactor 2026-03-04 17:31:19 -08:00
GareArc
7ef139cadd Squash merge 1.12.1-otel-ee into release/e-1.12.1 2026-03-04 16:59:37 -08:00
GareArc
9fa8f6235e Merge branch 'release/e-1.12.1' into 1.12.1-otel-ee 2026-03-04 16:59:21 -08:00
L1nSn0w
bf5a327156 fix(api): ensure enterprise workspace join occurs on account registration failure
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
2026-03-04 14:56:21 +08:00
L1nSn0w
d94af41f07 fix(api): ensure default workspace join occurs even if personal workspace creation fails 2026-03-04 14:56:21 +08:00
GareArc
6536489195 fix(telemetry): restore TRACE_TASK_TO_CASE lookup broken by CE safety refactor
The CE safety commit (8a3485454a) converted module-level dicts to lazy
functions but forgot to update __init__.py, which still imported the
now-deleted TRACE_TASK_TO_CASE constant causing an ImportError at startup.

Add get_trace_task_to_case() to gateway.py as a lazy public wrapper
(inverse of _get_case_to_trace_task) and update __init__.py to call it.
2026-03-02 19:59:20 -08:00
GareArc
8a3485454a fix(telemetry): ensure CE safety for enterprise-only imports and DB lookups
- Move enqueue_draft_node_execution_trace import inside call site in workflow_service.py
- Make gateway.py enterprise type imports lazy (routing dicts built on first call)
- Restore typed ModelConfig in llm_generator method signatures (revert dict regression)
- Fix generate_structured_output using wrong key model_parameters -> completion_params
- Replace unsafe cast(str, msg.content) with get_text_content() across llm_generator
- Remove duplicated payload classes from generator.py, import from core.llm_generator.entities
- Gate _lookup_app_and_workspace_names and credential lookups in ops_trace_manager behind is_enterprise_telemetry_enabled()
2026-03-02 18:45:33 -08:00
GareArc
8d8552cbb9 Merge branch 'fix/otel-upgrade-e-1.12.1' into release/e-1.12.1
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
2026-03-02 17:21:39 -08:00
GareArc
d6de27a25a feat(telemetry): promote gen_ai scalar fields from log-only to span attributes
Move gen_ai.usage.*, gen_ai.request.model, gen_ai.provider.name, and
gen_ai.user.id from companion-log-only to span attributes on workflow
and node execution spans.

These are small scalars with no size risk. Having them on spans enables
filtering and grouping in trace UIs (Tempo, Jaeger, Datadog) without
requiring a cross-signal join to companion logs.

Data dictionary updated: span tables gain the new fields; companion log
'additional attributes' tables trimmed to only list fields not already
covered by 'All span attributes'.
2026-03-02 15:55:10 -08:00
GareArc
fe741140d5 fix(telemetry): fix zero-value message and workflow duration histograms
Workflow RT: replace float(info.workflow_run_elapsed_time) with
(end_time - start_time).total_seconds() using workflow_run.created_at and
workflow_run.finished_at. The elapsed_time DB field defaults to 0 and can
be stale if the workflow_storage Celery task has not committed yet when the
trace fires. Wall-clock timestamps are more reliable; elapsed_time is kept
as fallback.

Message RT: change end_time from created_at + provider_response_latency to
message.updated_at when updated_at > created_at. The pipeline explicitly
sets message.updated_at = naive_utc_now() at the moment the LLM response
is complete, making it the canonical response-complete timestamp.
Falls back to the latency-based calculation for error/aborted messages.
2026-03-02 04:14:57 -08:00
GareArc
9b5b355a4e fix(telemetry): gate ObservabilityLayer content attrs behind ENTERPRISE_INCLUDE_CONTENT
Add should_include_content() helper to extensions/otel/parser/base.py that
returns True in CE (no behaviour change) and respects ENTERPRISE_INCLUDE_CONTENT
in EE. Gate all content-bearing span attributes in LLM, retrieval, tool, and
default node parsers so that gen_ai.completion, gen_ai.prompt, retrieval.document,
tool call arguments/results, and node input/output values are suppressed when
ENTERPRISE_ENABLED=True and ENTERPRISE_INCLUDE_CONTENT=False.
2026-03-02 04:04:26 -08:00
GareArc
3364003f90 fix(telemetry): add credential_name lookup with async-safe fallback 2026-03-02 02:27:31 -08:00
GareArc
6df00c83ae fix(telemetry): populate LLM credential info in node execution traces
- Add _lookup_llm_credential_info() to query Provider/ProviderModel tables
- Lookup LLM credentials when tool credential_id is null
- Fall back to provider-level credential if no model-specific credential
2026-03-02 01:47:39 -08:00
GareArc
a2a5b02a53 docs(telemetry): add token consumption query patterns to data dictionary
Add token hierarchy diagram, common PromQL queries (totals, drill-down,
rates), and app name lookup via trace query.
2026-03-02 01:07:18 -08:00
GareArc
1fcb05432d fix(telemetry): populate missing fields in node execution trace
- Extract model_provider/model_name from process_data (LLM nodes store
  model info there, not in execution_metadata)
- Add invoke_from to node execution trace metadata dict
- Add credential_id to node execution trace metadata dict
- Add conversation_id to metadata after message_id lookup
- Add tool_name to tool_info dict in tool node
2026-03-02 01:07:10 -08:00
L1nSn0w
58524fd7fd feat(enterprise): auto-join newly registered accounts to the default workspace (#32308)
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Co-authored-by: Yunlu Wen <yunlu.wen@dify.ai>
2026-03-02 16:38:43 +08:00
GareArc
9d4b2715e8 fix(celery): register enterprise_telemetry_task in worker imports
Fixes Celery worker error where process_enterprise_telemetry task
was unregistered despite being dispatched from the app.

Added conditional import when ENTERPRISE_TELEMETRY_ENABLED=true
to ensure the task is available in the worker process.

Resolves: KeyError 'tasks.enterprise_telemetry_task.process_enterprise_telemetry'
2026-03-01 22:27:44 -08:00
GareArc
2d7bffcc11 fix: upgrade OpenTelemetry packages from 0.48b0 to 0.49b0
Fixes "Failed to detach context" error in production by upgrading to OTEL 0.49b0,
which includes None token guards in Celery instrumentor (PR opentelemetry-python-contrib#2927).

Package Updates:
- OTEL instrumentation: 0.48b0 → 0.49b0
- OTEL SDK/API: 1.27.0 → 1.28.0
- protobuf: 4.25.8 → 5.29.6 (required by opentelemetry-proto 1.28.0)
- Google Cloud packages upgraded for protobuf 5.x compatibility:
  - google-api-core: 2.18.0 → 2.19.1+
  - google-auth: 2.29.0 → 2.47.0+
  - google-cloud-aiplatform: 1.49.0 → 1.123.0+
  - googleapis-common-protos: 1.63.0 → 1.65.0+
  - google-cloud-storage: 2.16.0 → 3.0.0+
- httpx: 0.27.0 → 0.28.0 (required by google-genai 1.37+)

Also removed duplicate opentelemetry-instrumentation-httpx entry in pyproject.toml.
2026-03-01 21:47:51 -08:00
GareArc
83f5850d0a refactor(telemetry): add resolved_parent_context property and fix edge cases
- Add resolved_parent_context property to BaseTraceInfo for reusable parent context extraction
- Refactor enterprise_trace.py to use property instead of duplicated dict plucking (~19 lines eliminated)
- Fix UUID validation in exporter.py with specific error logging for invalid trace correlation IDs
- Add error isolation in event_handlers.py to prevent telemetry failures from breaking user operations
- Replace pickle-based payload_fallback with JSON storage rehydration for security
- Update TelemetryEnvelope to use Pydantic v2 ConfigDict with extra='forbid'
- Update tests to reflect contract changes and new error handling behavior
2026-03-01 19:33:59 -08:00
yunlu.wen
7a92c1764f fix token label 2026-03-02 10:10:01 +08:00
GareArc
9952a17fed fix(telemetry): use URL scheme instead of API key for gRPC TLS detection
- Change insecure parameter from API key-based to URL scheme-based detection
- https:// endpoints now correctly use TLS (insecure=False)
- All other endpoints (http://, no scheme) use insecure=True
- Update tests to reflect URL scheme-based logic
- Remove incorrect documentation claiming API key controls TLS
2026-03-01 02:24:25 -08:00
GareArc
36ff9b447d Merge origin/release/e-1.12.1 into 1.12.1-otel-ee
Sync enterprise 1.12.1 changes:
- feat: implement heartbeat mechanism for database migration lock
- refactor: replace AutoRenewRedisLock with DbMigrationAutoRenewLock
- fix: improve logging for database migration lock release
- fix: make flask upgrade-db fail on error
- fix: include sso_verified in access_mode validation
- fix: inherit web app permission from original app
- fix: make e-1.12.1 enterprise migrations database-agnostic
- fix: get_message_event_type return wrong message type
- refactor: document_indexing_sync_task split db session
- fix: trigger output schema miss
- test: remove unrelated enterprise service test

Conflict resolution:
- Combined OTEL telemetry imports with tool signature import in easy_ui_based_generate_task_pipeline.py
2026-03-01 00:18:46 -08:00
GareArc
ff877ee39c fix(telemetry): add resolved_trace_id property to eliminate trace_id inconsistencies
Add computed property to BaseTraceInfo that provides intelligent fallback:
1. External trace_id (from X-Trace-Id header)
2. workflow_run_id (for workflow-related traces)
3. message_id (as final fallback)

This ensures attribute dify.trace_id always matches log-level trace_id,
eliminating inconsistencies where attribute was null but log-level had value.

Changes:
- Add resolved_trace_id property to BaseTraceInfo (trace_entity.py)
- Replace 4 direct trace_id attribute assignments with resolved_trace_id
- Add trace_id_source parameter to 5 emit_metric_only_event calls

Fixes trace_id inconsistency found in MESSAGE_RUN, TOOL_EXECUTION,
MODERATION_CHECK, SUGGESTED_QUESTION_GENERATION, GENERATE_NAME_EXECUTION,
DATASET_RETRIEVAL, and PROMPT_GENERATION_EXECUTION events.

All 78 telemetry tests passing.
2026-02-28 20:32:15 -08:00
GareArc
abcf14a571 refactor(telemetry): move gateway to core as stateless module-level functions
Move routing table, emit(), and is_enterprise_telemetry_enabled() from
enterprise/telemetry/gateway.py into core/telemetry/gateway.py so both
CE and EE share one code path. The ce_eligible flag in CASE_ROUTING
controls which events flow in CE — flipping it is the only change needed
to enable an event in community edition.

- Delete enterprise/telemetry/gateway.py (class-based singleton)
- Create core/telemetry/gateway.py (stateless functions, no shared state)
- Simplify core/telemetry/__init__.py to thin facade over gateway
- Remove TelemetryGateway class and get_gateway() from ext_enterprise_telemetry
- Single-source is_enterprise_telemetry_enabled in core.telemetry.gateway
- Fix pre-existing test bugs (missing dify.event.id in metric handler tests)
- Update all imports and mock paths across 7 test files
2026-02-28 19:27:24 -08:00
GareArc
5e57f73598 feat(telemetry): add model provider and name tags to all trace metrics
Add comprehensive model tracking across all OTEL metrics and logs:
- Node execution metrics now include model_name for LLM operations
- Suggested question metrics include model_provider and model_name
- Dataset retrieval captures both embedding and rerank model info
- Updated DATA_DICTIONARY.md with complete metric label documentation

This enables granular cost tracking, performance analysis, and usage monitoring per model across all operation types.
2026-02-28 00:06:44 -08:00
GareArc
62592be60b docs(enterprise): split telemetry docs into README and data dictionary
Separate background/configuration instructions from the data dictionary:
- README.md: Overview, configuration, correlation model, content gating
- DATA_DICTIONARY.md: Pure reference format with signals and attributes

The data dictionary is now concise (465 lines vs 911) and focuses on
attribute types and relationships without verbose explanations.
2026-02-27 12:32:48 -08:00
L1nSn0w
5025e29220 test: remove unrelated enterprise service test
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-14 16:34:49 +08:00
L1nSn0w
3cdc9c119e refactor(api): enhance DbMigrationAutoRenewLock acquisition logic
- Added a check to prevent double acquisition of the DB migration lock, raising an error if an attempt is made to acquire it while already held.
- Implemented logic to reuse the lock object if it has already been created, improving efficiency and clarity in lock management.
- Reset the lock object to None upon release to ensure proper state management.

(cherry picked from commit d4b102d3c8a473c4fd6409dba7c198289bb5f921)
2026-02-14 16:28:38 +08:00
L1nSn0w
18ba367b11 refactor(api): improve DbMigrationAutoRenewLock configuration and logging
- Introduced constants for minimum and maximum join timeout values, enhancing clarity and maintainability.
- Updated the renewal interval calculation to use defined constants for better readability.
- Improved logging messages to include context information, making it easier to trace issues during lock operations.

(cherry picked from commit 1471b77bf5156a95417bde148753702d44221929)
2026-02-14 16:28:38 +08:00
autofix-ci[bot]
d0bd74fccb [autofix.ci] apply automated fixes
(cherry picked from commit 907e63cdc57f8006017837a74c2da2fbe274dcfb)
2026-02-14 16:28:38 +08:00
L1nSn0w
5ccbc00eb9 refactor(api): replace AutoRenewRedisLock with DbMigrationAutoRenewLock
- Updated the database migration locking mechanism to use DbMigrationAutoRenewLock for improved clarity and functionality.
- Removed the AutoRenewRedisLock implementation and its associated tests.
- Adjusted integration and unit tests to reflect the new locking class and its usage in the upgrade_db command.

(cherry picked from commit c812ad9ff26bed3eb59862bd7a5179b7ee83f11f)
2026-02-14 16:28:38 +08:00
L1nSn0w
94603b5408 refactor(api): replace heartbeat mechanism with AutoRenewRedisLock for database migration
- Removed the manual heartbeat function for renewing the Redis lock during database migrations.
- Integrated AutoRenewRedisLock to handle lock renewal automatically, simplifying the upgrade_db command.
- Updated unit tests to reflect changes in lock handling and error management during migrations.

(cherry picked from commit 8814256eb5fa20b29e554264f3b659b027bc4c9a)
2026-02-14 16:28:38 +08:00
L1nSn0w
8d4bd5636b refactor(tests): replace hardcoded wait time with constant for clarity
- Introduced HEARTBEAT_WAIT_TIMEOUT_SECONDS constant to improve readability and maintainability of test code.
- Updated test assertions to use the new constant instead of a hardcoded value.

(cherry picked from commit 0d53743d83b03ae0e68fad143711ffa5f6354093)
2026-02-14 16:28:38 +08:00
autofix-ci[bot]
ee0c4a8852 [autofix.ci] apply automated fixes
(cherry picked from commit 326cffa553ffac1bcd39a051c899c35b0ebe997d)
2026-02-14 16:28:38 +08:00
L1nSn0w
6032c598b0 fix(api): improve logging for database migration lock release
- Added a migration_succeeded flag to track the success of database migrations.
- Enhanced logging messages to indicate the status of the migration when releasing the lock, providing clearer context for potential issues.

(cherry picked from commit e74be0392995d16d288eed2175c51148c9e5b9c0)
2026-02-14 16:28:38 +08:00
L1nSn0w
afdd5b6c86 feat(api): implement heartbeat mechanism for database migration lock
- Added a heartbeat function to renew the Redis lock during database migrations, preventing long blockages from crashed processes.
- Updated the upgrade_db command to utilize the new locking mechanism with a configurable TTL.
- Removed the deprecated MIGRATION_LOCK_TTL from DeploymentConfig and related files.
- Enhanced unit tests to cover the new lock renewal behavior and error handling during migrations.

(cherry picked from commit a3331c622435f9f215b95f6b0261f43ae56a9d9c)
2026-02-14 16:28:38 +08:00
L1nSn0w
9acdfbde2f feat(api): enhance database migration locking mechanism and configuration
- Introduced a configurable Redis lock TTL for database migrations in DeploymentConfig.
- Updated the upgrade_db command to handle lock release errors gracefully.
- Added documentation for the new MIGRATION_LOCK_TTL environment variable in the .env.example file and docker-compose.yaml.

(cherry picked from commit 4a05fb120622908bc109a3715686706aab3d3b59)
2026-02-14 16:28:38 +08:00
longbingljw
1977e68b2d fix: make flask upgrade-db fail on error (#32024)
(cherry picked from commit d9530f7bb7)
2026-02-14 16:28:38 +08:00
Xiyuan Chen
e9a7e8f77f fix: include sso_verified in access_mode validation (#32325) 2026-02-13 23:40:37 -08:00
Xiyuan Chen
9e2b28c950 fix(app-copy): inherit web app permission from original app (#32322) 2026-02-13 22:33:51 -08:00
L1nSn0w
affd07ae94 fix: make e-1.12.1 enterprise migrations database-agnostic for MySQL/TiDB (#32267)
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
Co-authored-by: Cursor <cursoragent@cursor.com>
2026-02-12 15:45:24 +08:00
NFish
111c76b71f Merge remote-tracking branch 'origin/hotfix/1.12.1-fix.6' into release/e-1.12.1 2026-02-12 13:26:12 +08:00
GareArc
262b7d4d08 docs(enterprise): add telemetry data dictionary for OTEL signals
- Comprehensive reference for all enterprise telemetry signals
- Documents 3 span types, 10 counters, 6 histograms, 13 log events
- Includes trace correlation model with ASCII diagrams
- Configuration reference for all 8 ENTERPRISE_* variables
- Per-emission-site label tables for metrics
- Full JSON schemas for structured log events
- Content gating behavior and token double-counting warnings
2026-02-10 19:51:14 -08:00
wangxiaolei
793d22754e fix: fix get_message_event_type return wrong message type (#32019)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-11 11:00:40 +08:00
GareArc
b5dbabf5d0 feat(telemetry): add missing ID fields for name attributes
- Add dify.credential.id to node execution events
- Add dify.event.id to all telemetry events (APP_CREATED, APP_UPDATED, APP_DELETED, FEEDBACK_CREATED)

This ensures all .name fields have corresponding .id fields for reliable aggregation and deduplication.
2026-02-10 00:09:41 -08:00
GareArc
aa34ec0d25 test(enterprise-telemetry): add unit tests for OTEL bearer auth and insecure flag 2026-02-09 01:44:21 -08:00
GareArc
ffa8aedc48 feat(enterprise-telemetry): wire bearer token auth and configurable insecure flag into OTEL exporter 2026-02-09 01:44:21 -08:00
GareArc
f78b0f1f36 feat(enterprise-telemetry): add ENTERPRISE_OTLP_API_KEY config field 2026-02-09 01:44:21 -08:00
wangxiaolei
b62965034e refactor: document_indexing_sync_task split db session (#32129)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-09 17:16:17 +08:00
wangxiaolei
016d72a8c6 fix: fix trigger output schema miss (#32116) 2026-02-09 17:16:08 +08:00
NFish
08b8eff933 Merge remote-tracking branch 'origin/hotfix/1.12.1-fix.4' into release/e-1.12.1
Some checks failed
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/amd64, build-api-amd64) (push) Has been cancelled
Build and Push API & Web / build (api, DIFY_API_IMAGE_NAME, linux/arm64, build-api-arm64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/amd64, build-web-amd64) (push) Has been cancelled
Build and Push API & Web / build (web, DIFY_WEB_IMAGE_NAME, linux/arm64, build-web-arm64) (push) Has been cancelled
Build and Push API & Web / create-manifest (api, DIFY_API_IMAGE_NAME, merge-api-images) (push) Has been cancelled
Build and Push API & Web / create-manifest (web, DIFY_WEB_IMAGE_NAME, merge-web-images) (push) Has been cancelled
2026-02-09 15:54:32 +08:00
NFish
579cdea820 fix: include app id in automatic generation requests (#32138) 2026-02-09 15:52:22 +08:00
wangxiaolei
125f7e3ab4 refactor: document_indexing_update_task split database session (#32105)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-09 10:51:45 +08:00
wangxiaolei
400ed2fd72 refactor: partition Celery task sessions into smaller, discrete execu… (#32085)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-08 21:05:03 +08:00
GareArc
1b3a21e6f8 feat(telemetry): unify token metric label structure with Pydantic enforcement
- Add TokenMetricLabels BaseModel to enforce consistent label structure
- All dify.token.* metrics now use identical 6-label structure:
  * tenant_id, app_id, operation_type, model_provider, model_name, node_type
- Pydantic validation ensures runtime enforcement (extra='forbid', frozen=True)
- Enables filtering by operation_type to avoid double-counting:
  * workflow: aggregated workflow-level tokens
  * node_execution: individual node-level tokens
  * message: direct message tokens
  * rule_generate/code_generate: prompt generation tokens

Previously, inconsistent label cardinality made aggregation impossible:
- WORKFLOW: 3 labels
- NODE_EXECUTION: 6 labels
- MESSAGE: 5 labels
- PROMPT_GENERATION: 5 labels

Now all use the same 6-label structure for consistent querying.
2026-02-06 03:10:20 -08:00
GareArc
11c74d741a feat: add dedicated app event counters and convert event names to StrEnum
- Add APP_CREATED, APP_UPDATED, APP_DELETED counters to EnterpriseTelemetryCounter
- Create EnterpriseTelemetryEvent StrEnum for type-safe event names
- Update metric_handler to use new app-specific counters with labels (tenant_id, app_id, mode)
- Convert all event_name strings to EnterpriseTelemetryEvent enum values
- Update exporter to create OTEL meters for new app counters (dify.app.created.total, etc.)
- Update tests to verify new counter behavior and enum usage
2026-02-06 02:38:19 -08:00
GareArc
ea9081f22d feat(telemetry): add operation_type labels for token metrics
Ultraworked with [Sisyphus](https://github.com/code-yeongyu/oh-my-opencode)

Co-authored-by: Sisyphus <clio-agent@sisyphuslabs.ai>
2026-02-06 01:06:07 -08:00
QuantumGhost
840a8f3fc2 perf: use batch delete method instead of single delete (#32036)
Co-authored-by: fatelei <fatelei@gmail.com>
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
Co-authored-by: FFXN <lizy@dify.ai>
2026-02-06 15:13:17 +08:00
GareArc
ac8e96bd9d docs(telemetry): clarify enterprise_telemetry queue is EE-only 2026-02-05 23:10:37 -08:00
GareArc
91a6fe25d1 feat(telemetry): add enterprise OTEL telemetry with gateway, traces, metrics, and logs 2026-02-05 23:10:30 -08:00
wangxiaolei
b4a5296fd1 fix: fix tool type is miss (#32042) 2026-02-06 14:38:54 +08:00
Xiyuan Chen
d7c3ae50dc Update api/services/tools/builtin_tools_manage_service.py
Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
2026-02-06 13:37:37 +08:00
NFish
b921711e9e fix: hide invite button if current user is not workspace manager (#31742) 2026-02-06 13:37:37 +08:00
yunlu.wen
fb38ad84e1 chore: upgrade deps, see pull #30976 2026-02-06 13:37:33 +08:00
Yunlu Wen
91c854b5be chore: sync enterprise release (#31626)
Co-authored-by: zhsama <torvalds@linux.do>
2026-02-06 13:35:28 +08:00
NFish
d35b231941 fix: enterprise CVE 2026 23864 (#31599) 2026-02-06 13:35:22 +08:00
GareArc
849b4b8c40 fix: add TYPE_CHECKING import for Account type annotation 2026-02-06 13:32:20 +08:00
GareArc
990e8feee8 security: fix IDOR and privilege escalation in set_default_provider
- Add tenant_id verification to prevent IDOR attacks
- Add admin check for enterprise tenant-wide default changes
- Preserve non-enterprise behavior (users can set own defaults)
2026-02-06 13:32:18 +08:00
GareArc
53641019b1 fix: remove user_id filter when clearing default provider (enterprise only)
When setting a new default credential in enterprise mode, the code was
only clearing is_default for credentials matching the current user_id.
This caused issues when:
1. Enterprise credential A (synced with system user_id) was default
2. User sets local credential B as default
3. A still had is_default=true (different user_id)
4. Both A and B were considered defaults

The fix removes user_id from the filter only for enterprise deployments,
since enterprise credentials may have different user_id than local ones.
Non-enterprise behavior is unchanged to avoid breaking existing setups.

Fixes EE-1511
2026-02-06 13:31:50 +08:00
GareArc
d1f10ff301 feat: add redis mq for account deletion cleanup 2026-02-06 13:31:50 +08:00
Xiyuan Chen
c8027e168b feat: implement workspace permission checks for member invitations an… (#31202) 2026-02-06 13:31:46 +08:00
NFish
aae3f76999 feat: ee workspace permission control (#30841) 2026-02-06 13:31:26 +08:00
NFish
2860c72b03 feat: ee workspace permission control (#30841) 2026-02-06 13:13:06 +08:00
wangxiaolei
fcb53383df fix: fix agent node tool type is not right (#32008)
Infer real tool type via querying relevant database tables.

The root cause for incorrect `type` field is still not clear.
2026-02-06 11:25:29 +08:00
QuantumGhost
540e1db83c perf(api): Optimize the response time of AppListApi endpoint (#31999) 2026-02-06 10:46:25 +08:00
wangxiaolei
2f75e38c08 fix: fix miss use db.session (#31971) 2026-02-05 15:59:37 +08:00
wangxiaolei
cd03e0a9ef fix: fix delete_draft_variables_batch cycle forever (#31934)
Co-authored-by: autofix-ci[bot] <114827586+autofix-ci[bot]@users.noreply.github.com>
2026-02-04 19:42:50 +08:00
zxhlyh
df2421d187 fix: auto summary env (#31930) 2026-02-04 19:42:26 +08:00
QuantumGhost
0ba321d840 chore: bump version in docker-compose and package manager to 1.12.1 (#31947) 2026-02-04 19:41:50 +08:00
128 changed files with 10964 additions and 2534 deletions

View File

@@ -106,10 +106,10 @@ ignore = [
"N803", # invalid-argument-name
]
"tests/*" = [
"F811", # redefined-while-unused
"T201", # allow print in tests,
"S110", # allow ignoring exceptions in tests code (currently)
"F811", # redefined-while-unused
"T201", # allow print in tests,
"S110", # allow ignoring exceptions in tests code (currently)
"PT019", # @patch-injected params look like unused fixtures
]
"controllers/console/explore/trial.py" = ["TID251"]
"controllers/console/human_input_form.py" = ["TID251"]

View File

@@ -122,7 +122,8 @@ These commands assume you start from the repository root.
```bash
cd api
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
# Note: enterprise_telemetry queue is only used in Enterprise Edition
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,enterprise_telemetry
```
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).

View File

@@ -81,6 +81,7 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_compress,
ext_database,
ext_enterprise_telemetry,
ext_fastopenapi,
ext_forward_refs,
ext_hosting_provider,
@@ -131,6 +132,7 @@ def initialize_extensions(app: DifyApp):
ext_commands,
ext_fastopenapi,
ext_otel,
ext_enterprise_telemetry,
ext_request_logging,
ext_session_factory,
]

View File

@@ -30,6 +30,7 @@ from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
from extensions.storage.opendal_storage import OpenDALStorage
from extensions.storage.storage_type import StorageType
from libs.db_migration_lock import DbMigrationAutoRenewLock
from libs.helper import email as email_validate
from libs.password import hash_password, password_pattern, valid_password
from libs.rsa import generate_key_pair
@@ -54,6 +55,8 @@ from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
logger = logging.getLogger(__name__)
DB_UPGRADE_LOCK_TTL_SECONDS = 60
@click.command("reset-password", help="Reset the account password.")
@click.option("--email", prompt=True, help="Account email to reset password for")
@@ -727,8 +730,15 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
@click.command("upgrade-db", help="Upgrade the database")
def upgrade_db():
click.echo("Preparing database migration...")
lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
lock = DbMigrationAutoRenewLock(
redis_client=redis_client,
name="db_upgrade_lock",
ttl_seconds=DB_UPGRADE_LOCK_TTL_SECONDS,
logger=logger,
log_context="db_migration",
)
if lock.acquire(blocking=False):
migration_succeeded = False
try:
click.echo(click.style("Starting database migration.", fg="green"))
@@ -737,12 +747,16 @@ def upgrade_db():
flask_migrate.upgrade()
migration_succeeded = True
click.echo(click.style("Database migration successful!", fg="green"))
except Exception:
except Exception as e:
logger.exception("Failed to execute database migration")
click.echo(click.style(f"Database migration failed: {e}", fg="red"))
raise SystemExit(1)
finally:
lock.release()
status = "successful" if migration_succeeded else "failed"
lock.release_safely(status=status)
else:
click.echo("Database migration skipped")

View File

@@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, Settings
from libs.file_utils import search_file_upwards
from .deploy import DeploymentConfig
from .enterprise import EnterpriseFeatureConfig
from .enterprise import EnterpriseFeatureConfig, EnterpriseTelemetryConfig
from .extra import ExtraServiceConfig
from .feature import FeatureConfig
from .middleware import MiddlewareConfig
@@ -73,6 +73,8 @@ class DifyConfig(
# Enterprise feature configs
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
EnterpriseFeatureConfig,
# Enterprise telemetry configs
EnterpriseTelemetryConfig,
):
model_config = SettingsConfigDict(
# read from dotenv format config file

View File

@@ -18,3 +18,49 @@ class EnterpriseFeatureConfig(BaseSettings):
description="Allow customization of the enterprise logo.",
default=False,
)
class EnterpriseTelemetryConfig(BaseSettings):
"""
Configuration for enterprise telemetry.
"""
ENTERPRISE_TELEMETRY_ENABLED: bool = Field(
description="Enable enterprise telemetry collection (also requires ENTERPRISE_ENABLED=true).",
default=False,
)
ENTERPRISE_OTLP_ENDPOINT: str = Field(
description="Enterprise OTEL collector endpoint.",
default="",
)
ENTERPRISE_OTLP_HEADERS: str = Field(
description="Auth headers for OTLP export (key=value,key2=value2).",
default="",
)
ENTERPRISE_OTLP_PROTOCOL: str = Field(
description="OTLP protocol: 'http' or 'grpc' (default: http).",
default="http",
)
ENTERPRISE_OTLP_API_KEY: str = Field(
description="Bearer token for enterprise OTLP export authentication.",
default="",
)
ENTERPRISE_INCLUDE_CONTENT: bool = Field(
description="Include input/output content in traces (privacy toggle).",
default=True,
)
ENTERPRISE_SERVICE_NAME: str = Field(
description="Service name for OTEL resource.",
default="dify",
)
ENTERPRISE_OTEL_SAMPLING_RATE: float = Field(
description="Sampling rate for enterprise traces (0.0 to 1.0, default 1.0 = 100%).",
default=1.0,
)

View File

@@ -1,3 +1,4 @@
import logging
import uuid
from datetime import datetime
from typing import Any, Literal, TypeAlias
@@ -54,6 +55,8 @@ ALLOW_CREATE_APP_MODES = ["chat", "agent-chat", "advanced-chat", "workflow", "co
register_enum_models(console_ns, IconType)
_logger = logging.getLogger(__name__)
class AppListQuery(BaseModel):
page: int = Field(default=1, ge=1, le=99999, description="Page number (1-99999)")
@@ -499,6 +502,7 @@ class AppListApi(Resource):
select(Workflow).where(
Workflow.version == Workflow.VERSION_DRAFT,
Workflow.app_id.in_(workflow_capable_app_ids),
Workflow.tenant_id == current_tenant_id,
)
)
.scalars()
@@ -510,12 +514,14 @@ class AppListApi(Resource):
NodeType.TRIGGER_PLUGIN,
}
for workflow in draft_workflows:
node_id = None
try:
for _, node_data in workflow.walk_nodes():
for node_id, node_data in workflow.walk_nodes():
if node_data.get("type") in trigger_node_types:
draft_trigger_app_ids.add(str(workflow.app_id))
break
except Exception:
_logger.exception("error while walking nodes, workflow_id=%s, node_id=%s", workflow.id, node_id)
continue
for app in app_pagination.items:
@@ -654,6 +660,19 @@ class AppCopyApi(Resource):
)
session.commit()
# Inherit web app permission from original app
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
try:
# Get the original app's access mode
original_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_model.id)
access_mode = original_settings.access_mode
except Exception:
# If original app has no settings (old app), default to public to match fallback behavior
access_mode = "public"
# Apply the same access mode to the copied app
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, access_mode)
stmt = select(App).where(App.id == result.app_id)
app = session.scalar(stmt)

View File

@@ -35,6 +35,7 @@ class InstructionGeneratePayload(BaseModel):
instruction: str = Field(..., description="Instruction for generation")
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
ideal_output: str = Field(default="", description="Expected ideal output")
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")
class InstructionTemplatePayload(BaseModel):
@@ -66,10 +67,17 @@ class RuleGenerateApi(Resource):
@account_initialization_required
def post(self):
args = RuleGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
account, current_tenant_id = current_account_with_tenant()
try:
rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args)
rules = LLMGenerator.generate_rule_config(
tenant_id=current_tenant_id,
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=args.no_variable,
user_id=account.id,
app_id=args.app_id,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
except QuotaExceededError:
@@ -95,12 +103,16 @@ class RuleCodeGenerateApi(Resource):
@account_initialization_required
def post(self):
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
account, current_tenant_id = current_account_with_tenant()
try:
code_result = LLMGenerator.generate_code(
tenant_id=current_tenant_id,
args=args,
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.code_language,
user_id=account.id,
app_id=args.app_id,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -127,12 +139,15 @@ class RuleStructuredOutputGenerateApi(Resource):
@account_initialization_required
def post(self):
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
account, current_tenant_id = current_account_with_tenant()
try:
structured_output = LLMGenerator.generate_structured_output(
tenant_id=current_tenant_id,
args=args,
instruction=args.instruction,
model_config=args.model_config_data,
user_id=account.id,
app_id=args.app_id,
)
except ProviderTokenNotInitError as ex:
raise ProviderNotInitializeError(ex.description)
@@ -159,14 +174,14 @@ class InstructionGenerateApi(Resource):
@account_initialization_required
def post(self):
args = InstructionGeneratePayload.model_validate(console_ns.payload)
_, current_tenant_id = current_account_with_tenant()
account, current_tenant_id = current_account_with_tenant()
app_id = args.app_id or args.flow_id
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
code_provider: type[CodeNodeProvider] | None = next(
(p for p in providers if p.is_accept_language(args.language)), None
)
code_template = code_provider.get_default_code() if code_provider else ""
try:
# Generate from nothing for a workflow node
if (args.current in (code_template, "")) and args.node_id != "":
app = db.session.query(App).where(App.id == args.flow_id).first()
if not app:
@@ -183,33 +198,33 @@ class InstructionGenerateApi(Resource):
case "llm":
return LLMGenerator.generate_rule_config(
current_tenant_id,
args=RuleGeneratePayload(
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
),
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
user_id=account.id,
app_id=app_id,
)
case "agent":
return LLMGenerator.generate_rule_config(
current_tenant_id,
args=RuleGeneratePayload(
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
),
instruction=args.instruction,
model_config=args.model_config_data,
no_variable=True,
user_id=account.id,
app_id=app_id,
)
case "code":
return LLMGenerator.generate_code(
tenant_id=current_tenant_id,
args=RuleCodeGeneratePayload(
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.language,
),
instruction=args.instruction,
model_config=args.model_config_data,
code_language=args.language,
user_id=account.id,
app_id=app_id,
)
case _:
return {"error": f"invalid node type: {node_type}"}
if args.node_id == "" and args.current != "": # For legacy app without a workflow
if args.node_id == "" and args.current != "":
return LLMGenerator.instruction_modify_legacy(
tenant_id=current_tenant_id,
flow_id=args.flow_id,
@@ -217,8 +232,10 @@ class InstructionGenerateApi(Resource):
instruction=args.instruction,
model_config=args.model_config_data,
ideal_output=args.ideal_output,
user_id=account.id,
app_id=app_id,
)
if args.node_id != "" and args.current != "": # For workflow node
if args.node_id != "" and args.current != "":
return LLMGenerator.instruction_modify_workflow(
tenant_id=current_tenant_id,
flow_id=args.flow_id,
@@ -228,6 +245,8 @@ class InstructionGenerateApi(Resource):
model_config=args.model_config_data,
ideal_output=args.ideal_output,
workflow_service=WorkflowService(),
user_id=account.id,
app_id=app_id,
)
return {"error": "incompatible parameters"}, 400
except ProviderTokenNotInitError as ex:

View File

@@ -1,6 +1,7 @@
from typing import Any
from flask import request
from flask_login import current_user
from flask_restx import Resource, fields
from pydantic import BaseModel, Field
from werkzeug.exceptions import BadRequest
@@ -77,7 +78,10 @@ class TraceAppConfigApi(Resource):
try:
result = OpsService.create_tracing_app_config(
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
app_id=app_id,
tracing_provider=args.tracing_provider,
tracing_config=args.tracing_config,
account_id=current_user.id,
)
if not result:
raise TracingConfigIsExist()
@@ -102,7 +106,10 @@ class TraceAppConfigApi(Resource):
try:
result = OpsService.update_tracing_app_config(
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
app_id=app_id,
tracing_provider=args.tracing_provider,
tracing_config=args.tracing_config,
account_id=current_user.id,
)
if not result:
raise TracingConfigNotExist()
@@ -124,7 +131,9 @@ class TraceAppConfigApi(Resource):
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
try:
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
result = OpsService.delete_tracing_app_config(
app_id=app_id, tracing_provider=args.tracing_provider, account_id=current_user.id
)
if not result:
raise TracingConfigNotExist()
return {"result": "success"}, 204

View File

@@ -878,7 +878,11 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
current_user, current_tenant_id = current_account_with_tenant()
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
return BuiltinToolManageService.set_default_provider(
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=payload.id
tenant_id=current_tenant_id,
user_id=current_user.id,
provider=provider,
id=args["id"],
account=current_user,
)

View File

@@ -79,7 +79,7 @@ class BaseAgentRunner(AppRunner):
self.model_instance = model_instance
# init callback
self.agent_callback = DifyAgentCallbackHandler()
self.agent_callback = DifyAgentCallbackHandler(tenant_id=tenant_id)
# init dataset tools
hit_callback = DatasetIndexToolCallbackHandler(
queue_manager=queue_manager,

View File

@@ -63,6 +63,8 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.model_runtime.entities.llm_entities import LLMUsage
from core.model_runtime.utils.encoders import jsonable_encoder
from core.ops.ops_trace_manager import TraceQueueManager
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
from core.workflow.enums import WorkflowExecutionStatus
from core.workflow.nodes import NodeType
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
@@ -564,7 +566,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle stop events."""
_ = trace_manager
resolved_state = None
if self._workflow_run_id:
resolved_state = self._resolve_graph_runtime_state(graph_runtime_state)
@@ -579,8 +580,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
)
with self._database_session() as session:
# Save message
self._save_message(session=session, graph_runtime_state=resolved_state)
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
yield workflow_finish_resp
elif event.stopped_by in (
@@ -589,8 +589,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
):
# When hitting input-moderation or annotation-reply, the workflow will not start
with self._database_session() as session:
# Save message
self._save_message(session=session)
self._save_message(session=session, trace_manager=trace_manager)
yield self._message_end_to_stream_response()
@@ -599,6 +598,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
event: QueueAdvancedChatMessageEndEvent,
*,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
**kwargs,
) -> Generator[StreamResponse, None, None]:
"""Handle advanced chat message end events."""
@@ -616,7 +616,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
# Save message
with self._database_session() as session:
self._save_message(session=session, graph_runtime_state=resolved_state)
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
yield self._message_end_to_stream_response()
@@ -770,7 +770,13 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
if self._conversation_name_generate_thread:
logger.debug("Conversation name generation running as daemon thread")
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
def _save_message(
self,
*,
session: Session,
graph_runtime_state: GraphRuntimeState | None = None,
trace_manager: TraceQueueManager | None = None,
):
message = self._get_message(session=session)
# If there are assistant files, remove markdown image links from answer
@@ -826,6 +832,22 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
]
session.add_all(message_files)
if trace_manager:
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.MESSAGE_TRACE,
context=TelemetryContext(
tenant_id=self._application_generate_entity.app_config.tenant_id,
app_id=self._application_generate_entity.app_config.app_id,
),
payload={
"conversation_id": str(message.conversation_id),
"message_id": str(message.id),
},
),
trace_manager=trace_manager,
)
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
"""Bootstrap the cached runtime state from the queue manager when present."""
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state

View File

@@ -147,9 +147,12 @@ class WorkflowAppGenerator(BaseAppGenerator):
inputs: Mapping[str, Any] = args["inputs"]
extras = {
extras: dict[str, Any] = {
**extract_external_trace_id_from_args(args),
}
parent_trace_context = args.get("_parent_trace_context")
if parent_trace_context:
extras["parent_trace_context"] = parent_trace_context
workflow_run_id = str(uuid.uuid4())
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
# trigger shouldn't prepare user inputs

View File

@@ -45,6 +45,8 @@ from core.app.entities.task_entities import (
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
from core.app.task_pipeline.message_cycle_manager import MessageCycleManager
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
from core.file import helpers as file_helpers
from core.file.enums import FileTransferMethod
from core.model_manager import ModelInstance
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk, LLMResultChunkDelta, LLMUsage
from core.model_runtime.entities.message_entities import (
@@ -52,14 +54,16 @@ from core.model_runtime.entities.message_entities import (
TextPromptMessageContent,
)
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.ops_trace_manager import TraceQueueManager
from core.prompt.utils.prompt_message_util import PromptMessageUtil
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
from core.tools.signature import sign_tool_file
from events.message_event import message_was_created
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.model import AppMode, Conversation, Message, MessageAgentThought
from models.model import AppMode, Conversation, Message, MessageAgentThought, MessageFile, UploadFile
logger = logging.getLogger(__name__)
@@ -409,10 +413,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
message.message_metadata = self._task_state.metadata.model_dump_json()
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
)
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.MESSAGE_TRACE,
context=TelemetryContext(
tenant_id=self._application_generate_entity.app_config.tenant_id,
app_id=self._application_generate_entity.app_config.app_id,
),
payload={
"conversation_id": self._conversation_id,
"message_id": self._message_id,
},
),
trace_manager=trace_manager,
)
message_was_created.send(
@@ -463,6 +476,85 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
metadata=metadata_dict,
)
def _record_files(self):
with Session(db.engine, expire_on_commit=False) as session:
message_files = session.scalars(select(MessageFile).where(MessageFile.message_id == self._message_id)).all()
if not message_files:
return None
files_list = []
upload_file_ids = [
mf.upload_file_id
for mf in message_files
if mf.transfer_method == FileTransferMethod.LOCAL_FILE and mf.upload_file_id
]
upload_files_map = {}
if upload_file_ids:
upload_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(upload_file_ids))).all()
upload_files_map = {uf.id: uf for uf in upload_files}
for message_file in message_files:
upload_file = None
if message_file.transfer_method == FileTransferMethod.LOCAL_FILE and message_file.upload_file_id:
upload_file = upload_files_map.get(message_file.upload_file_id)
url = None
filename = "file"
mime_type = "application/octet-stream"
size = 0
extension = ""
if message_file.transfer_method == FileTransferMethod.REMOTE_URL:
url = message_file.url
if message_file.url:
filename = message_file.url.split("/")[-1].split("?")[0] # Remove query params
elif message_file.transfer_method == FileTransferMethod.LOCAL_FILE:
if upload_file:
url = file_helpers.get_signed_file_url(upload_file_id=str(upload_file.id))
filename = upload_file.name
mime_type = upload_file.mime_type or "application/octet-stream"
size = upload_file.size or 0
extension = f".{upload_file.extension}" if upload_file.extension else ""
elif message_file.upload_file_id:
# Fallback: generate URL even if upload_file not found
url = file_helpers.get_signed_file_url(upload_file_id=str(message_file.upload_file_id))
elif message_file.transfer_method == FileTransferMethod.TOOL_FILE and message_file.url:
# For tool files, use URL directly if it's HTTP, otherwise sign it
if message_file.url.startswith("http"):
url = message_file.url
filename = message_file.url.split("/")[-1].split("?")[0]
else:
# Extract tool file id and extension from URL
url_parts = message_file.url.split("/")
if url_parts:
file_part = url_parts[-1].split("?")[0] # Remove query params first
# Use rsplit to correctly handle filenames with multiple dots
if "." in file_part:
tool_file_id, ext = file_part.rsplit(".", 1)
extension = f".{ext}"
else:
tool_file_id = file_part
extension = ".bin"
url = sign_tool_file(tool_file_id=tool_file_id, extension=extension)
filename = file_part
transfer_method_value = message_file.transfer_method
remote_url = message_file.url if message_file.transfer_method == FileTransferMethod.REMOTE_URL else ""
file_dict = {
"related_id": message_file.id,
"extension": extension,
"filename": filename,
"size": size,
"mime_type": mime_type,
"transfer_method": transfer_method_value,
"type": message_file.type,
"url": url or "",
"upload_file_id": message_file.upload_file_id or message_file.id,
"remote_url": remote_url,
}
files_list.append(file_dict)
return files_list or None
def _agent_message_to_stream_response(self, answer: str, message_id: str) -> AgentMessageStreamResponse:
"""
Agent message to stream response.

View File

@@ -64,7 +64,13 @@ class MessageCycleManager:
# Use SQLAlchemy 2.x style session.scalar(select(...))
with session_factory.create_session() as session:
message_file = session.scalar(select(MessageFile).where(MessageFile.message_id == message_id))
message_file = session.scalar(
select(MessageFile)
.where(
MessageFile.message_id == message_id,
)
.where(MessageFile.belongs_to == "assistant")
)
if message_file:
self._message_has_file.add(message_id)

View File

@@ -15,8 +15,7 @@ from datetime import datetime
from typing import Any, Union
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.ops_trace_manager import TraceQueueManager
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
from core.workflow.enums import (
@@ -373,6 +372,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
self._workflow_node_execution_repository.save(domain_execution)
self._workflow_node_execution_repository.save_execution_data(domain_execution)
self._enqueue_node_trace_task(domain_execution)
def _fail_running_node_executions(self, *, error_message: str) -> None:
now = naive_utc_now()
@@ -390,17 +390,138 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
external_trace_id = None
parent_trace_context = None
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
external_trace_id = self._application_generate_entity.extras.get("external_trace_id")
parent_trace_context = self._application_generate_entity.extras.get("parent_trace_context")
trace_task = TraceTask(
TraceTaskName.WORKFLOW_TRACE,
workflow_execution=execution,
conversation_id=conversation_id,
user_id=self._trace_manager.user_id,
external_trace_id=external_trace_id,
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.WORKFLOW_TRACE,
context=TelemetryContext(
tenant_id=self._application_generate_entity.app_config.tenant_id,
user_id=self._trace_manager.user_id,
app_id=self._application_generate_entity.app_config.app_id,
),
payload={
"workflow_execution": execution,
"conversation_id": conversation_id,
"user_id": self._trace_manager.user_id,
"external_trace_id": external_trace_id,
"parent_trace_context": parent_trace_context,
},
),
trace_manager=self._trace_manager,
)
def _enqueue_node_trace_task(self, domain_execution: WorkflowNodeExecution) -> None:
if not self._trace_manager:
return
execution = self._get_workflow_execution()
meta = domain_execution.metadata or {}
parent_trace_context = None
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
parent_trace_context = self._application_generate_entity.extras.get("parent_trace_context")
node_data: dict[str, Any] = {
"workflow_id": domain_execution.workflow_id,
"workflow_execution_id": execution.id_,
"tenant_id": self._application_generate_entity.app_config.tenant_id,
"app_id": self._application_generate_entity.app_config.app_id,
"node_execution_id": domain_execution.id,
"node_id": domain_execution.node_id,
"node_type": str(domain_execution.node_type.value),
"title": domain_execution.title,
"status": str(domain_execution.status.value),
"error": domain_execution.error,
"elapsed_time": domain_execution.elapsed_time,
"index": domain_execution.index,
"predecessor_node_id": domain_execution.predecessor_node_id,
"created_at": domain_execution.created_at,
"finished_at": domain_execution.finished_at,
"total_tokens": meta.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0),
"prompt_tokens": meta.get(WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS),
"completion_tokens": meta.get(WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS),
"total_price": meta.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0),
"currency": meta.get(WorkflowNodeExecutionMetadataKey.CURRENCY),
"tool_name": (meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name")
if isinstance(meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict)
else None,
"iteration_id": meta.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID),
"iteration_index": meta.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX),
"loop_id": meta.get(WorkflowNodeExecutionMetadataKey.LOOP_ID),
"loop_index": meta.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX),
"parallel_id": meta.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID),
"node_inputs": dict(domain_execution.inputs) if domain_execution.inputs else None,
"node_outputs": dict(domain_execution.outputs) if domain_execution.outputs else None,
"process_data": dict(domain_execution.process_data) if domain_execution.process_data else None,
}
node_data["invoke_from"] = self._application_generate_entity.invoke_from.value
node_data["user_id"] = self._system_variables().get(SystemVariableKey.USER_ID.value)
# Extract model info from process_data — LLM nodes store provider/model there,
if domain_execution.process_data:
if mp := domain_execution.process_data.get("model_provider"):
node_data["model_provider"] = mp
if mn := domain_execution.process_data.get("model_name"):
node_data["model_name"] = mn
if domain_execution.node_type.value == "knowledge-retrieval" and domain_execution.outputs:
results = domain_execution.outputs.get("result") or []
dataset_ids: list[str] = []
dataset_names: list[str] = []
for doc in results:
if not isinstance(doc, dict):
continue
doc_meta = doc.get("metadata") or {}
did = doc_meta.get("dataset_id")
dname = doc_meta.get("dataset_name")
if did and did not in dataset_ids:
dataset_ids.append(did)
if dname and dname not in dataset_names:
dataset_names.append(dname)
if dataset_ids:
node_data["dataset_ids"] = dataset_ids
if dataset_names:
node_data["dataset_names"] = dataset_names
tool_info = meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO)
if isinstance(tool_info, dict):
plugin_id = tool_info.get("plugin_unique_identifier")
if plugin_id:
node_data["plugin_name"] = plugin_id
credential_id = tool_info.get("credential_id")
if credential_id:
node_data["credential_id"] = credential_id
node_data["credential_provider_type"] = tool_info.get("provider_type")
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
if conversation_id:
node_data["conversation_id"] = conversation_id
if parent_trace_context:
node_data["parent_trace_context"] = parent_trace_context
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id=node_data.get("tenant_id"),
user_id=node_data.get("user_id"),
app_id=node_data.get("app_id"),
),
payload={"node_execution_data": node_data},
),
trace_manager=self._trace_manager,
)
self._trace_manager.add_trace_task(trace_task)
def _system_variables(self) -> Mapping[str, Any]:
runtime_state = self.graph_runtime_state

View File

@@ -4,8 +4,9 @@ from typing import Any, TextIO, Union
from pydantic import BaseModel
from configs import dify_config
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.ops_trace_manager import TraceQueueManager
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
from core.tools.entities.tool_entities import ToolInvokeMessage
_TEXT_COLOR_MAPPING = {
@@ -36,13 +37,15 @@ class DifyAgentCallbackHandler(BaseModel):
color: str | None = ""
current_loop: int = 1
tenant_id: str | None = None
def __init__(self, color: str | None = None):
def __init__(self, color: str | None = None, tenant_id: str | None = None):
super().__init__()
"""Initialize callback handler."""
# use a specific color is not specified
self.color = color or "green"
self.current_loop = 1
self.tenant_id = tenant_id
def on_tool_start(
self,
@@ -71,15 +74,23 @@ class DifyAgentCallbackHandler(BaseModel):
print_text("\n")
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.TOOL_TRACE,
message_id=message_id,
tool_name=tool_name,
tool_inputs=tool_inputs,
tool_outputs=tool_outputs,
timer=timer,
)
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.TOOL_TRACE,
context=TelemetryContext(
tenant_id=self.tenant_id,
app_id=trace_manager.app_id,
user_id=trace_manager.user_id,
),
payload={
"message_id": message_id,
"tool_name": tool_name,
"tool_inputs": tool_inputs,
"tool_outputs": tool_outputs,
"timer": timer,
},
),
trace_manager=trace_manager,
)
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any):

View File

@@ -9,6 +9,7 @@ class RuleGeneratePayload(BaseModel):
instruction: str = Field(..., description="Rule generation instruction")
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
no_variable: bool = Field(default=False, description="Whether to exclude variables")
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")
class RuleCodeGeneratePayload(RuleGeneratePayload):
@@ -18,3 +19,4 @@ class RuleCodeGeneratePayload(RuleGeneratePayload):
class RuleStructuredOutputPayload(BaseModel):
instruction: str = Field(..., description="Structured output generation instruction")
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")

View File

@@ -7,7 +7,6 @@ from typing import Protocol, cast
import json_repair
from core.app.app_config.entities import ModelConfig
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
from core.llm_generator.prompts import (
@@ -27,10 +26,11 @@ from core.model_runtime.entities.llm_entities import LLMResult
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
from core.model_runtime.entities.model_entities import ModelType
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.entities.trace_entity import OperationType
from core.ops.utils import measure_time
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from extensions.ext_storage import storage
@@ -74,7 +74,7 @@ class LLMGenerator:
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
)
answer = response.message.get_text_content()
if answer == "":
if answer is None:
return ""
try:
result_dict = json.loads(answer)
@@ -96,15 +96,17 @@ class LLMGenerator:
name = name[:75] + "..."
# get tracing instance
trace_manager = TraceQueueManager(app_id=app_id)
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.GENERATE_NAME_TRACE,
conversation_id=conversation_id,
generate_conversation_name=name,
inputs=prompt,
timer=timer,
tenant_id=tenant_id,
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.GENERATE_NAME_TRACE,
context=TelemetryContext(tenant_id=tenant_id, app_id=app_id),
payload={
"conversation_id": conversation_id,
"generate_conversation_name": name,
"inputs": prompt,
"timer": timer,
"tenant_id": tenant_id,
},
)
)
@@ -153,19 +155,27 @@ class LLMGenerator:
return questions
@classmethod
def generate_rule_config(cls, tenant_id: str, args: RuleGeneratePayload):
def generate_rule_config(
cls,
tenant_id: str,
instruction: str,
model_config: ModelConfig,
no_variable: bool,
user_id: str | None = None,
app_id: str | None = None,
):
output_parser = RuleConfigGeneratorOutputParser()
error = ""
error_step = ""
rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""}
model_parameters = args.model_config_data.completion_params
if args.no_variable:
model_parameters = model_config.completion_params
if no_variable:
prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE)
prompt_generate = prompt_template.format(
inputs={
"TASK_DESCRIPTION": args.instruction,
"TASK_DESCRIPTION": instruction,
},
remove_template_variables=False,
)
@@ -177,26 +187,45 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=args.model_config_data.provider,
model=args.model_config_data.name,
provider=model_config.provider,
model=model_config.name,
)
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
llm_result = None
with measure_time() as timer:
try:
llm_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
rule_config["prompt"] = response.message.get_text_content()
rule_config["prompt"] = llm_result.message.get_text_content() or ""
except InvokeError as e:
error = str(e)
error_step = "generate rule config"
except Exception as e:
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
rule_config["error"] = str(e)
except InvokeError as e:
error = str(e)
error_step = "generate rule config"
except Exception as e:
logger.exception("Failed to generate rule config, model: %s", model_config.name)
rule_config["error"] = str(e)
error = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
if user_id:
prompt_value = rule_config.get("prompt", "")
generated_output = str(prompt_value) if prompt_value else ""
cls._emit_prompt_generation_trace(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
operation_type=OperationType.RULE_GENERATE,
instruction=instruction,
generated_output=generated_output,
llm_result=llm_result,
model_config=model_config,
timer=timer,
error=error or None,
)
return rule_config
# get rule config prompt, parameter and statement
@@ -211,7 +240,7 @@ class LLMGenerator:
# format the prompt_generate_prompt
prompt_generate_prompt = prompt_template.format(
inputs={
"TASK_DESCRIPTION": args.instruction,
"TASK_DESCRIPTION": instruction,
},
remove_template_variables=False,
)
@@ -222,84 +251,125 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=args.model_config_data.provider,
model=args.model_config_data.name,
provider=model_config.provider,
model=model_config.name,
)
try:
llm_result = None
with measure_time() as timer:
try:
# the first step to generate the task prompt
prompt_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
try:
# the first step to generate the task prompt
prompt_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
llm_result = prompt_content
except InvokeError as e:
error = str(e)
error_step = "generate prefix prompt"
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
if user_id:
cls._emit_prompt_generation_trace(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
operation_type=OperationType.RULE_GENERATE,
instruction=instruction,
generated_output="",
llm_result=llm_result,
model_config=model_config,
timer=timer,
error=error,
)
return rule_config
rule_config["prompt"] = prompt_content.message.get_text_content() or ""
if not isinstance(prompt_content.message.content, str):
raise NotImplementedError("prompt content is not a string")
parameter_generate_prompt = parameter_template.format(
inputs={
"INPUT_TEXT": prompt_content.message.content,
},
remove_template_variables=False,
)
except InvokeError as e:
error = str(e)
error_step = "generate prefix prompt"
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)]
return rule_config
rule_config["prompt"] = prompt_content.message.get_text_content()
parameter_generate_prompt = parameter_template.format(
inputs={
"INPUT_TEXT": prompt_content.message.get_text_content(),
},
remove_template_variables=False,
)
parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)]
# the second step to generate the task_parameter and task_statement
statement_generate_prompt = statement_template.format(
inputs={
"TASK_DESCRIPTION": args.instruction,
"INPUT_TEXT": prompt_content.message.get_text_content(),
},
remove_template_variables=False,
)
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
try:
parameter_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
# the second step to generate the task_parameter and task_statement
statement_generate_prompt = statement_template.format(
inputs={
"TASK_DESCRIPTION": instruction,
"INPUT_TEXT": prompt_content.message.content,
},
remove_template_variables=False,
)
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.get_text_content())
except InvokeError as e:
error = str(e)
error_step = "generate variables"
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
try:
statement_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
)
rule_config["opening_statement"] = statement_content.message.get_text_content()
except InvokeError as e:
error = str(e)
error_step = "generate conversation opener"
try:
parameter_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
)
rule_config["variables"] = re.findall(
r'"\s*([^"]+)\s*"', prompt_content.message.get_text_content() or ""
)
except InvokeError as e:
error = str(e)
error_step = "generate variables"
except Exception as e:
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
rule_config["error"] = str(e)
try:
statement_content: LLMResult = model_instance.invoke_llm(
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
)
rule_config["opening_statement"] = statement_content.message.get_text_content() or ""
except InvokeError as e:
error = str(e)
error_step = "generate conversation opener"
except Exception as e:
logger.exception("Failed to generate rule config, model: %s", model_config.name)
rule_config["error"] = str(e)
error = str(e)
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
if user_id:
generated_output = rule_config.get("prompt", "")
cls._emit_prompt_generation_trace(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
operation_type=OperationType.RULE_GENERATE,
instruction=instruction,
generated_output=str(generated_output) if generated_output else "",
llm_result=llm_result,
model_config=model_config,
timer=timer,
error=error or None,
)
return rule_config
@classmethod
def generate_code(
cls,
tenant_id: str,
args: RuleCodeGeneratePayload,
instruction: str,
model_config: ModelConfig,
code_language: str = "javascript",
user_id: str | None = None,
app_id: str | None = None,
):
if args.code_language == "python":
if code_language == "python":
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
else:
prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE)
prompt = prompt_template.format(
inputs={
"INSTRUCTION": args.instruction,
"CODE_LANGUAGE": args.code_language,
"INSTRUCTION": instruction,
"CODE_LANGUAGE": code_language,
},
remove_template_variables=False,
)
@@ -308,28 +378,49 @@ class LLMGenerator:
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=args.model_config_data.provider,
model=args.model_config_data.name,
provider=model_config.provider,
model=model_config.name,
)
prompt_messages = [UserPromptMessage(content=prompt)]
model_parameters = args.model_config_data.completion_params
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
model_parameters = model_config.completion_params
llm_result = None
error = None
with measure_time() as timer:
try:
llm_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
generated_code = llm_result.message.get_text_content() or ""
result = {"code": generated_code, "language": code_language, "error": ""}
except InvokeError as e:
error = str(e)
result = {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logger.exception(
"Failed to invoke LLM model, model: %s, language: %s", model_config.name, code_language
)
error = str(e)
result = {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
if user_id:
cls._emit_prompt_generation_trace(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
operation_type=OperationType.CODE_GENERATE,
instruction=instruction,
generated_output=result.get("code", ""),
llm_result=llm_result,
model_config=model_config,
timer=timer,
error=error,
)
generated_code = response.message.get_text_content()
return {"code": generated_code, "language": args.code_language, "error": ""}
except InvokeError as e:
error = str(e)
return {"code": "", "language": args.code_language, "error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logger.exception(
"Failed to invoke LLM model, model: %s, language: %s", args.model_config_data.name, args.code_language
)
return {"code": "", "language": args.code_language, "error": f"An unexpected error occurred: {str(e)}"}
return result
@classmethod
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
@@ -355,49 +446,81 @@ class LLMGenerator:
raise TypeError("Expected LLMResult when stream=False")
response = result
answer = response.message.get_text_content()
answer = response.message.get_text_content() or ""
return answer.strip()
@classmethod
def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload):
def generate_structured_output(
cls,
tenant_id: str,
instruction: str,
model_config: ModelConfig,
user_id: str | None = None,
app_id: str | None = None,
):
model_manager = ModelManager()
model_instance = model_manager.get_model_instance(
tenant_id=tenant_id,
model_type=ModelType.LLM,
provider=args.model_config_data.provider,
model=args.model_config_data.name,
provider=model_config.provider,
model=model_config.name,
)
prompt_messages = [
SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
UserPromptMessage(content=args.instruction),
UserPromptMessage(content=instruction),
]
model_parameters = args.model_config_data.completion_params
model_parameters = model_config.completion_params
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
llm_result = None
error = None
result = {"output": "", "error": ""}
with measure_time() as timer:
try:
llm_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
raw_content = llm_result.message.content
if not isinstance(raw_content, str):
raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}")
try:
parsed_content = json.loads(raw_content)
except json.JSONDecodeError:
parsed_content = json_repair.loads(raw_content)
if not isinstance(parsed_content, dict | list):
raise ValueError(f"Failed to parse structured output from llm: {raw_content}")
generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False)
result = {"output": generated_json_schema, "error": ""}
except InvokeError as e:
error = str(e)
result = {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
except Exception as e:
logger.exception("Failed to invoke LLM model, model: %s", model_config.name)
error = str(e)
result = {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
if user_id:
cls._emit_prompt_generation_trace(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
operation_type=OperationType.STRUCTURED_OUTPUT,
instruction=instruction,
generated_output=result.get("output", ""),
llm_result=llm_result,
model_config=model_config,
timer=timer,
error=error,
)
raw_content = response.message.get_text_content()
try:
parsed_content = json.loads(raw_content)
except json.JSONDecodeError:
parsed_content = json_repair.loads(raw_content)
if not isinstance(parsed_content, dict | list):
raise ValueError(f"Failed to parse structured output from llm: {raw_content}")
generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False)
return {"output": generated_json_schema, "error": ""}
except InvokeError as e:
error = str(e)
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
except Exception as e:
logger.exception("Failed to invoke LLM model, model: %s", args.model_config_data.name)
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
return result
@staticmethod
def instruction_modify_legacy(
@@ -407,12 +530,14 @@ class LLMGenerator:
instruction: str,
model_config: ModelConfig,
ideal_output: str | None,
user_id: str | None = None,
app_id: str | None = None,
):
last_run: Message | None = (
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
)
if not last_run:
return LLMGenerator.__instruction_modify_common(
result = LLMGenerator.__instruction_modify_common(
tenant_id=tenant_id,
model_config=model_config,
last_run=None,
@@ -421,22 +546,28 @@ class LLMGenerator:
instruction=instruction,
node_type="llm",
ideal_output=ideal_output,
user_id=user_id,
app_id=app_id,
)
last_run_dict = {
"query": last_run.query,
"answer": last_run.answer,
"error": last_run.error,
}
return LLMGenerator.__instruction_modify_common(
tenant_id=tenant_id,
model_config=model_config,
last_run=last_run_dict,
current=current,
error_message=str(last_run.error),
instruction=instruction,
node_type="llm",
ideal_output=ideal_output,
)
else:
last_run_dict = {
"query": last_run.query,
"answer": last_run.answer,
"error": last_run.error,
}
result = LLMGenerator.__instruction_modify_common(
tenant_id=tenant_id,
model_config=model_config,
last_run=last_run_dict,
current=current,
error_message=str(last_run.error),
instruction=instruction,
node_type="llm",
ideal_output=ideal_output,
user_id=user_id,
app_id=app_id,
)
return result
@staticmethod
def instruction_modify_workflow(
@@ -448,6 +579,8 @@ class LLMGenerator:
model_config: ModelConfig,
ideal_output: str | None,
workflow_service: WorkflowServiceInterface,
user_id: str | None = None,
app_id: str | None = None,
):
session = db.session()
@@ -478,6 +611,8 @@ class LLMGenerator:
instruction=instruction,
node_type=node_type,
ideal_output=ideal_output,
user_id=user_id,
app_id=app_id,
)
def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence:
@@ -511,6 +646,8 @@ class LLMGenerator:
instruction=instruction,
node_type=last_run.node_type,
ideal_output=ideal_output,
user_id=user_id,
app_id=app_id,
)
@staticmethod
@@ -523,6 +660,8 @@ class LLMGenerator:
instruction: str,
node_type: str,
ideal_output: str | None,
user_id: str | None = None,
app_id: str | None = None,
):
LAST_RUN = "{{#last_run#}}"
CURRENT = "{{#current#}}"
@@ -562,24 +701,120 @@ class LLMGenerator:
]
model_parameters = {"temperature": 0.4}
try:
response: LLMResult = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
llm_result = None
error = None
result = {}
with measure_time() as timer:
try:
llm_result = model_instance.invoke_llm(
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
)
generated_raw = llm_result.message.get_text_content()
first_brace = generated_raw.find("{")
last_brace = generated_raw.rfind("}")
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
json_str = generated_raw[first_brace : last_brace + 1]
data = json_repair.loads(json_str)
if not isinstance(data, dict):
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
result = data
except InvokeError as e:
error = str(e)
result = {"error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True)
error = str(e)
result = {"error": f"An unexpected error occurred: {str(e)}"}
if user_id:
generated_output = ""
if isinstance(result, dict):
for key in ["prompt", "code", "output", "modified"]:
if result.get(key):
generated_output = str(result[key])
break
LLMGenerator._emit_prompt_generation_trace(
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
operation_type=OperationType.INSTRUCTION_MODIFY,
instruction=instruction,
generated_output=generated_output,
llm_result=llm_result,
model_config=model_config,
timer=timer,
error=error,
)
generated_raw = response.message.get_text_content()
first_brace = generated_raw.find("{")
last_brace = generated_raw.rfind("}")
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
json_str = generated_raw[first_brace : last_brace + 1]
data = json_repair.loads(json_str)
if not isinstance(data, dict):
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
return data
except InvokeError as e:
error = str(e)
return {"error": f"Failed to generate code. Error: {error}"}
except Exception as e:
logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True)
return {"error": f"An unexpected error occurred: {str(e)}"}
return result
@classmethod
def _emit_prompt_generation_trace(
cls,
tenant_id: str,
user_id: str,
app_id: str | None,
operation_type: OperationType,
instruction: str,
generated_output: str,
llm_result: LLMResult | None,
model_config: ModelConfig | None = None,
timer=None,
error: str | None = None,
):
if llm_result:
prompt_tokens = llm_result.usage.prompt_tokens
completion_tokens = llm_result.usage.completion_tokens
total_tokens = llm_result.usage.total_tokens
model_name = llm_result.model
# Extract provider from model_config if available, otherwise fall back to parsing model name
if model_config and model_config.provider:
model_provider = model_config.provider
else:
model_provider = model_name.split("/")[0] if "/" in model_name else ""
latency = llm_result.usage.latency
total_price = float(llm_result.usage.total_price) if llm_result.usage.total_price else None
currency = llm_result.usage.currency
else:
prompt_tokens = 0
completion_tokens = 0
total_tokens = 0
model_provider = model_config.provider if model_config else ""
model_name = model_config.name if model_config else ""
latency = 0.0
if timer:
start_time = timer.get("start")
end_time = timer.get("end")
if start_time and end_time:
latency = (end_time - start_time).total_seconds()
total_price = None
currency = None
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.PROMPT_GENERATION_TRACE,
context=TelemetryContext(tenant_id=tenant_id, user_id=user_id, app_id=app_id),
payload={
"tenant_id": tenant_id,
"user_id": user_id,
"app_id": app_id,
"operation_type": operation_type,
"instruction": instruction,
"generated_output": generated_output,
"prompt_tokens": prompt_tokens,
"completion_tokens": completion_tokens,
"total_tokens": total_tokens,
"model_provider": model_provider,
"model_name": model_name,
"latency": latency,
"total_price": total_price,
"currency": currency,
"timer": timer,
"error": error,
},
)
)

View File

@@ -15,16 +15,23 @@ class TraceContextFilter(logging.Filter):
"""
def filter(self, record: logging.LogRecord) -> bool:
# Get trace context from OpenTelemetry
trace_id, span_id = self._get_otel_context()
# Preserve explicit trace_id set by the caller (e.g. emit_metric_only_event)
existing_trace_id = getattr(record, "trace_id", "")
if not existing_trace_id:
# Get trace context from OpenTelemetry
trace_id, span_id = self._get_otel_context()
# Set trace_id (fallback to ContextVar if no OTEL context)
if trace_id:
record.trace_id = trace_id
# Set trace_id (fallback to ContextVar if no OTEL context)
if trace_id:
record.trace_id = trace_id
else:
record.trace_id = get_trace_id()
record.span_id = span_id or ""
else:
record.trace_id = get_trace_id()
record.span_id = span_id or ""
# Keep existing trace_id; only fill span_id if missing
if not getattr(record, "span_id", ""):
record.span_id = ""
# For backward compatibility, also set req_id
record.req_id = get_request_id()
@@ -55,9 +62,12 @@ class IdentityContextFilter(logging.Filter):
def filter(self, record: logging.LogRecord) -> bool:
identity = self._extract_identity()
record.tenant_id = identity.get("tenant_id", "")
record.user_id = identity.get("user_id", "")
record.user_type = identity.get("user_type", "")
if not getattr(record, "tenant_id", ""):
record.tenant_id = identity.get("tenant_id", "")
if not getattr(record, "user_id", ""):
record.user_id = identity.get("user_id", "")
if not getattr(record, "user_type", ""):
record.user_type = identity.get("user_type", "")
return True
def _extract_identity(self) -> dict[str, str]:

View File

@@ -5,9 +5,10 @@ from typing import Any
from core.app.app_config.entities import AppConfig
from core.moderation.base import ModerationAction, ModerationError
from core.moderation.factory import ModerationFactory
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.ops_trace_manager import TraceQueueManager
from core.ops.utils import measure_time
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
logger = logging.getLogger(__name__)
@@ -49,14 +50,18 @@ class InputModeration:
moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.MODERATION_TRACE,
message_id=message_id,
moderation_result=moderation_result,
inputs=inputs,
timer=timer,
)
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.MODERATION_TRACE,
context=TelemetryContext(tenant_id=tenant_id, app_id=app_id),
payload={
"message_id": message_id,
"moderation_result": moderation_result,
"inputs": inputs,
"timer": timer,
},
),
trace_manager=trace_manager,
)
if not moderation_result.flagged:

View File

@@ -9,8 +9,8 @@ from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
class BaseTraceInfo(BaseModel):
message_id: str | None = None
message_data: Any | None = None
inputs: Union[str, dict[str, Any], list] | None = None
outputs: Union[str, dict[str, Any], list] | None = None
inputs: Union[str, dict[str, Any], list[Any]] | None = None
outputs: Union[str, dict[str, Any], list[Any]] | None = None
start_time: datetime | None = None
end_time: datetime | None = None
metadata: dict[str, Any]
@@ -18,7 +18,7 @@ class BaseTraceInfo(BaseModel):
@field_validator("inputs", "outputs")
@classmethod
def ensure_type(cls, v):
def ensure_type(cls, v: str | dict[str, Any] | list[Any] | None) -> str | dict[str, Any] | list[Any] | None:
if v is None:
return None
if isinstance(v, str | dict | list):
@@ -27,6 +27,48 @@ class BaseTraceInfo(BaseModel):
model_config = ConfigDict(protected_namespaces=())
@property
def resolved_trace_id(self) -> str | None:
"""Get trace_id with intelligent fallback.
Priority:
1. External trace_id (from X-Trace-Id header)
2. workflow_run_id (if this trace type has it)
3. message_id (as final fallback)
"""
if self.trace_id:
return self.trace_id
# Try workflow_run_id (only exists on workflow-related traces)
workflow_run_id = getattr(self, "workflow_run_id", None)
if workflow_run_id:
return workflow_run_id
# Final fallback to message_id
return str(self.message_id) if self.message_id else None
@property
def resolved_parent_context(self) -> tuple[str | None, str | None]:
"""Resolve cross-workflow parent linking from metadata.
Extracts typed parent IDs from the untyped ``parent_trace_context``
metadata dict (set by tool_node when invoking nested workflows).
Returns:
(trace_correlation_override, parent_span_id_source) where
trace_correlation_override is the outer workflow_run_id and
parent_span_id_source is the outer node_execution_id.
"""
parent_ctx = self.metadata.get("parent_trace_context")
if not isinstance(parent_ctx, dict):
return None, None
trace_override = parent_ctx.get("parent_workflow_run_id")
parent_span = parent_ctx.get("parent_node_execution_id")
return (
trace_override if isinstance(trace_override, str) else None,
parent_span if isinstance(parent_span, str) else None,
)
@field_serializer("start_time", "end_time")
def serialize_datetime(self, dt: datetime | None) -> str | None:
if dt is None:
@@ -48,10 +90,14 @@ class WorkflowTraceInfo(BaseTraceInfo):
workflow_run_version: str
error: str | None = None
total_tokens: int
prompt_tokens: int | None = None
completion_tokens: int | None = None
file_list: list[str]
query: str
metadata: dict[str, Any]
invoked_by: str | None = None
class MessageTraceInfo(BaseTraceInfo):
conversation_model: str
@@ -59,7 +105,7 @@ class MessageTraceInfo(BaseTraceInfo):
answer_tokens: int
total_tokens: int
error: str | None = None
file_list: Union[str, dict[str, Any], list] | None = None
file_list: Union[str, dict[str, Any], list[Any]] | None = None
message_file_data: Any | None = None
conversation_mode: str
gen_ai_server_time_to_first_token: float | None = None
@@ -106,7 +152,7 @@ class ToolTraceInfo(BaseTraceInfo):
tool_config: dict[str, Any]
time_cost: Union[int, float]
tool_parameters: dict[str, Any]
file_url: Union[str, None, list] = None
file_url: Union[str, None, list[str]] = None
class GenerateNameTraceInfo(BaseTraceInfo):
@@ -114,6 +160,79 @@ class GenerateNameTraceInfo(BaseTraceInfo):
tenant_id: str
class PromptGenerationTraceInfo(BaseTraceInfo):
"""Trace information for prompt generation operations (rule-generate, code-generate, etc.)."""
tenant_id: str
user_id: str
app_id: str | None = None
operation_type: str
instruction: str
prompt_tokens: int
completion_tokens: int
total_tokens: int
model_provider: str
model_name: str
latency: float
total_price: float | None = None
currency: str | None = None
error: str | None = None
model_config = ConfigDict(protected_namespaces=())
class WorkflowNodeTraceInfo(BaseTraceInfo):
workflow_id: str
workflow_run_id: str
tenant_id: str
node_execution_id: str
node_id: str
node_type: str
title: str
status: str
error: str | None = None
elapsed_time: float
index: int
predecessor_node_id: str | None = None
total_tokens: int = 0
total_price: float = 0.0
currency: str | None = None
model_provider: str | None = None
model_name: str | None = None
prompt_tokens: int | None = None
completion_tokens: int | None = None
tool_name: str | None = None
iteration_id: str | None = None
iteration_index: int | None = None
loop_id: str | None = None
loop_index: int | None = None
parallel_id: str | None = None
node_inputs: Mapping[str, Any] | None = None
node_outputs: Mapping[str, Any] | None = None
process_data: Mapping[str, Any] | None = None
invoked_by: str | None = None
model_config = ConfigDict(protected_namespaces=())
class DraftNodeExecutionTrace(WorkflowNodeTraceInfo):
pass
class TaskData(BaseModel):
app_id: str
trace_info_type: str
@@ -128,16 +247,38 @@ trace_info_info_map = {
"DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo,
"ToolTraceInfo": ToolTraceInfo,
"GenerateNameTraceInfo": GenerateNameTraceInfo,
"PromptGenerationTraceInfo": PromptGenerationTraceInfo,
"WorkflowNodeTraceInfo": WorkflowNodeTraceInfo,
"DraftNodeExecutionTrace": DraftNodeExecutionTrace,
}
class OperationType(StrEnum):
"""Operation type for token metric labels.
Used as a metric attribute on ``dify.tokens.input`` / ``dify.tokens.output``
counters so consumers can break down token usage by operation.
"""
WORKFLOW = "workflow"
NODE_EXECUTION = "node_execution"
MESSAGE = "message"
RULE_GENERATE = "rule_generate"
CODE_GENERATE = "code_generate"
STRUCTURED_OUTPUT = "structured_output"
INSTRUCTION_MODIFY = "instruction_modify"
class TraceTaskName(StrEnum):
CONVERSATION_TRACE = "conversation"
WORKFLOW_TRACE = "workflow"
DRAFT_NODE_EXECUTION_TRACE = "draft_node_execution"
MESSAGE_TRACE = "message"
MODERATION_TRACE = "moderation"
SUGGESTED_QUESTION_TRACE = "suggested_question"
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
TOOL_TRACE = "tool"
GENERATE_NAME_TRACE = "generate_conversation_name"
PROMPT_GENERATION_TRACE = "prompt_generation"
DATASOURCE_TRACE = "datasource"
NODE_EXECUTION_TRACE = "node_execution"

View File

@@ -3,6 +3,7 @@ import os
from datetime import datetime, timedelta
from langfuse import Langfuse
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
@@ -30,7 +31,7 @@ from core.ops.utils import filter_none_values
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.enums import NodeType
from extensions.ext_database import db
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
from models import EndUser, Message, WorkflowNodeExecutionTriggeredFrom
from models.enums import MessageStatus
logger = logging.getLogger(__name__)
@@ -71,7 +72,50 @@ class LangFuseDataTrace(BaseTraceInstance):
metadata = trace_info.metadata
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
if trace_info.message_id:
# Check for parent_trace_context to detect nested workflow
parent_trace_context = trace_info.metadata.get("parent_trace_context")
if parent_trace_context:
# Nested workflow: create span under outer trace
outer_trace_id = parent_trace_context.get("trace_id")
parent_node_execution_id = parent_trace_context.get("parent_node_execution_id")
parent_conversation_id = parent_trace_context.get("parent_conversation_id")
parent_workflow_run_id = parent_trace_context.get("parent_workflow_run_id")
# Resolve outer trace_id: try message_id lookup first, fallback to workflow_run_id
if parent_conversation_id:
session_factory = sessionmaker(bind=db.engine)
with session_factory() as session:
message_data_stmt = select(Message.id).where(
Message.conversation_id == parent_conversation_id,
Message.workflow_run_id == parent_workflow_run_id,
)
resolved_message_id = session.scalar(message_data_stmt)
if resolved_message_id:
outer_trace_id = resolved_message_id
else:
outer_trace_id = parent_workflow_run_id
else:
outer_trace_id = parent_workflow_run_id
# Create inner workflow span under outer trace
workflow_span_data = LangfuseSpan(
id=trace_info.workflow_run_id,
name=TraceTaskName.WORKFLOW_TRACE,
input=dict(trace_info.workflow_run_inputs),
output=dict(trace_info.workflow_run_outputs),
trace_id=outer_trace_id,
parent_observation_id=parent_node_execution_id,
start_time=trace_info.start_time,
end_time=trace_info.end_time,
metadata=metadata,
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
status_message=trace_info.error or "",
)
self.add_span(langfuse_span_data=workflow_span_data)
# Use outer_trace_id for all node spans/generations
trace_id = outer_trace_id
elif trace_info.message_id:
trace_id = trace_info.trace_id or trace_info.message_id
name = TraceTaskName.MESSAGE_TRACE
trace_data = LangfuseTrace(
@@ -174,6 +218,11 @@ class LangFuseDataTrace(BaseTraceInstance):
}
)
# Determine parent_observation_id for nested workflows
node_parent_observation_id = None
if parent_trace_context or trace_info.message_id:
node_parent_observation_id = trace_info.workflow_run_id
# add generation span
if process_data and process_data.get("model_mode") == "chat":
total_token = metadata.get("total_tokens", 0)
@@ -206,7 +255,7 @@ class LangFuseDataTrace(BaseTraceInstance):
metadata=metadata,
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
status_message=trace_info.error or "",
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
parent_observation_id=node_parent_observation_id,
usage=generation_usage,
)
@@ -225,7 +274,7 @@ class LangFuseDataTrace(BaseTraceInstance):
metadata=metadata,
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
status_message=trace_info.error or "",
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
parent_observation_id=node_parent_observation_id,
)
self.add_span(langfuse_span_data=span_data)

View File

@@ -6,6 +6,7 @@ from typing import cast
from langsmith import Client
from langsmith.schemas import RunBase
from sqlalchemy import select
from sqlalchemy.orm import sessionmaker
from core.ops.base_trace_instance import BaseTraceInstance
@@ -30,7 +31,7 @@ from core.ops.utils import filter_none_values, generate_dotted_order
from core.repositories import DifyCoreRepositoryFactory
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
from extensions.ext_database import db
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
from models import EndUser, Message, MessageFile, WorkflowNodeExecutionTriggeredFrom
logger = logging.getLogger(__name__)
@@ -64,7 +65,35 @@ class LangSmithDataTrace(BaseTraceInstance):
self.generate_name_trace(trace_info)
def workflow_trace(self, trace_info: WorkflowTraceInfo):
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
# Check for parent_trace_context for cross-workflow linking
parent_trace_context = trace_info.metadata.get("parent_trace_context")
if parent_trace_context:
# Inner workflow: resolve outer trace_id and link to parent node
outer_trace_id = parent_trace_context.get("parent_workflow_run_id")
# Try to resolve message_id from conversation_id if available
if parent_trace_context.get("parent_conversation_id"):
try:
session_factory = sessionmaker(bind=db.engine)
with session_factory() as session:
message_data_stmt = select(Message.id).where(
Message.conversation_id == parent_trace_context["parent_conversation_id"],
Message.workflow_run_id == parent_trace_context["parent_workflow_run_id"],
)
resolved_message_id = session.scalar(message_data_stmt)
if resolved_message_id:
outer_trace_id = resolved_message_id
except Exception as e:
logger.debug("Failed to resolve message_id from conversation_id: %s", str(e))
trace_id = outer_trace_id
parent_run_id = parent_trace_context.get("parent_node_execution_id")
else:
# Outer workflow: existing behavior
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
parent_run_id = trace_info.message_id or None
if trace_info.start_time is None:
trace_info.start_time = datetime.now()
message_dotted_order = (
@@ -78,7 +107,8 @@ class LangSmithDataTrace(BaseTraceInstance):
metadata = trace_info.metadata
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
if trace_info.message_id:
# Only create message_run for outer workflows (no parent_trace_context)
if trace_info.message_id and not parent_trace_context:
message_run = LangSmithRunModel(
id=trace_info.message_id,
name=TraceTaskName.MESSAGE_TRACE,
@@ -121,9 +151,9 @@ class LangSmithDataTrace(BaseTraceInstance):
},
error=trace_info.error,
tags=["workflow"],
parent_run_id=trace_info.message_id or None,
parent_run_id=parent_run_id,
trace_id=trace_id,
dotted_order=workflow_dotted_order,
dotted_order=None if parent_trace_context else workflow_dotted_order,
serialized=None,
events=[],
session_id=None,

View File

@@ -21,19 +21,26 @@ from core.ops.entities.config_entity import (
)
from core.ops.entities.trace_entity import (
DatasetRetrievalTraceInfo,
DraftNodeExecutionTrace,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
PromptGenerationTraceInfo,
SuggestedQuestionTraceInfo,
TaskData,
ToolTraceInfo,
TraceTaskName,
WorkflowNodeTraceInfo,
WorkflowTraceInfo,
)
from core.ops.utils import get_message_data
from extensions.ext_database import db
from extensions.ext_storage import storage
from models.account import Tenant
from models.dataset import Dataset
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
from models.provider import Provider, ProviderCredential, ProviderModel, ProviderModelCredential, ProviderType
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
from models.workflow import WorkflowAppLog
from tasks.ops_trace_task import process_trace_tasks
@@ -43,6 +50,139 @@ if TYPE_CHECKING:
logger = logging.getLogger(__name__)
def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]:
"""Return (app_name, workspace_name) for the given IDs. Falls back to empty strings."""
app_name = ""
workspace_name = ""
if not app_id and not tenant_id:
return app_name, workspace_name
with Session(db.engine) as session:
if app_id:
name = session.scalar(select(App.name).where(App.id == app_id))
if name:
app_name = name
if tenant_id:
name = session.scalar(select(Tenant.name).where(Tenant.id == tenant_id))
if name:
workspace_name = name
return app_name, workspace_name
_PROVIDER_TYPE_TO_MODEL: dict[str, type] = {
"builtin": BuiltinToolProvider,
"plugin": BuiltinToolProvider,
"api": ApiToolProvider,
"workflow": WorkflowToolProvider,
"mcp": MCPToolProvider,
}
def _lookup_credential_name(credential_id: str | None, provider_type: str | None) -> str:
if not credential_id:
return ""
model_cls = _PROVIDER_TYPE_TO_MODEL.get(provider_type or "")
if not model_cls:
return ""
with Session(db.engine) as session:
name = session.scalar(select(model_cls.name).where(model_cls.id == credential_id))
return str(name) if name else ""
def _lookup_llm_credential_info(
tenant_id: str | None, provider: str | None, model: str | None, model_type: str | None = "llm"
) -> tuple[str | None, str]:
"""
Lookup LLM credential ID and name for the given provider and model.
Returns (credential_id, credential_name).
Handles async timing issues gracefully - if credential is deleted between lookups,
returns the ID but empty name rather than failing.
"""
if not tenant_id or not provider:
return None, ""
try:
with Session(db.engine) as session:
# Try to find provider-level or model-level configuration
provider_record = session.scalar(
select(Provider).where(
Provider.tenant_id == tenant_id,
Provider.provider_name == provider,
Provider.provider_type == ProviderType.CUSTOM,
)
)
if not provider_record:
return None, ""
# Check if there's a model-specific config
credential_id = None
credential_name = ""
is_model_level = False
if model and provider_record.credential_id:
# Try model-level first
model_record = session.scalar(
select(ProviderModel).where(
ProviderModel.tenant_id == tenant_id,
ProviderModel.provider_name == provider,
ProviderModel.model_name == model,
ProviderModel.model_type == model_type,
)
)
if model_record and model_record.credential_id:
credential_id = model_record.credential_id
is_model_level = True
if not credential_id and provider_record.credential_id:
# Fall back to provider-level credential
credential_id = provider_record.credential_id
is_model_level = False
# Lookup credential_name if we have credential_id
if credential_id:
try:
if is_model_level:
# Query ProviderModelCredential
cred_name = session.scalar(
select(ProviderModelCredential.credential_name).where(
ProviderModelCredential.id == credential_id
)
)
else:
# Query ProviderCredential
cred_name = session.scalar(
select(ProviderCredential.credential_name).where(ProviderCredential.id == credential_id)
)
if cred_name:
credential_name = str(cred_name)
except Exception as e:
# Credential might have been deleted between lookups (async timing)
# Return ID but empty name rather than failing
logger.warning(
"Failed to lookup credential name for credential_id=%s (provider=%s, model=%s): %s",
credential_id,
provider,
model,
str(e),
)
return credential_id, credential_name
except Exception as e:
# Database query failed or other unexpected error
# Return empty rather than propagating error to telemetry emission
logger.warning(
"Failed to lookup LLM credential info for tenant_id=%s, provider=%s, model=%s: %s",
tenant_id,
provider,
model,
str(e),
)
return None, ""
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
def __getitem__(self, provider: str) -> dict[str, Any]:
match provider:
@@ -317,6 +457,10 @@ class OpsTraceManager:
if app_id is None:
return None
# Handle storage_id format (tenant-{uuid}) - not a real app_id
if isinstance(app_id, str) and app_id.startswith("tenant-"):
return None
app: App | None = db.session.query(App).where(App.id == app_id).first()
if app is None:
@@ -479,6 +623,56 @@ class TraceTask:
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
return cls._workflow_run_repo
@classmethod
def _get_user_id_from_metadata(cls, metadata: dict[str, Any]) -> str:
"""Extract user ID from metadata, prioritizing end_user over account.
Returns the actual user ID (end_user or account) who invoked the workflow,
regardless of invoke_from context.
"""
# Priority 1: End user (external users via API/WebApp)
if user_id := metadata.get("from_end_user_id"):
return f"end_user:{user_id}"
# Priority 2: Account user (internal users via console/debugger)
if user_id := metadata.get("from_account_id"):
return f"account:{user_id}"
# Priority 3: User (internal users via console/debugger)
if user_id := metadata.get("user_id"):
return f"user:{user_id}"
return "anonymous"
@classmethod
def _calculate_workflow_token_split(cls, workflow_run_id: str, tenant_id: str) -> tuple[int, int]:
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from models.workflow import WorkflowNodeExecutionModel
with Session(db.engine) as session:
node_executions = session.scalars(
select(WorkflowNodeExecutionModel).where(
WorkflowNodeExecutionModel.tenant_id == tenant_id,
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
)
).all()
total_prompt = 0
total_completion = 0
for node_exec in node_executions:
metadata = node_exec.execution_metadata_dict
prompt = metadata.get(WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS)
if prompt is not None:
total_prompt += prompt
completion = metadata.get(WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS)
if completion is not None:
total_completion += completion
return (total_prompt, total_completion)
def __init__(
self,
trace_type: Any,
@@ -499,6 +693,8 @@ class TraceTask:
self.app_id = None
self.trace_id = None
self.kwargs = kwargs
if user_id is not None and "user_id" not in self.kwargs:
self.kwargs["user_id"] = user_id
external_trace_id = kwargs.get("external_trace_id")
if external_trace_id:
self.trace_id = external_trace_id
@@ -512,7 +708,7 @@ class TraceTask:
TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id, **self.kwargs),
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
message_id=self.message_id, timer=self.timer, **self.kwargs
),
@@ -528,6 +724,9 @@ class TraceTask:
TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
),
TraceTaskName.PROMPT_GENERATION_TRACE: lambda: self.prompt_generation_trace(**self.kwargs),
TraceTaskName.NODE_EXECUTION_TRACE: lambda: self.node_execution_trace(**self.kwargs),
TraceTaskName.DRAFT_NODE_EXECUTION_TRACE: lambda: self.draft_node_execution_trace(**self.kwargs),
}
return preprocess_map.get(self.trace_type, lambda: None)()
@@ -563,6 +762,10 @@ class TraceTask:
total_tokens = workflow_run.total_tokens
prompt_tokens, completion_tokens = self._calculate_workflow_token_split(
workflow_run_id=workflow_run_id, tenant_id=tenant_id
)
file_list = workflow_run_inputs.get("sys.file") or []
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
@@ -583,7 +786,14 @@ class TraceTask:
)
message_id = session.scalar(message_data_stmt)
metadata = {
from core.telemetry.gateway import is_enterprise_telemetry_enabled
if is_enterprise_telemetry_enabled():
app_name, workspace_name = _lookup_app_and_workspace_names(workflow_run.app_id, tenant_id)
else:
app_name, workspace_name = "", ""
metadata: dict[str, Any] = {
"workflow_id": workflow_id,
"conversation_id": conversation_id,
"workflow_run_id": workflow_run_id,
@@ -596,8 +806,14 @@ class TraceTask:
"triggered_from": workflow_run.triggered_from,
"user_id": user_id,
"app_id": workflow_run.app_id,
"app_name": app_name,
"workspace_name": workspace_name,
}
parent_trace_context = self.kwargs.get("parent_trace_context")
if parent_trace_context:
metadata["parent_trace_context"] = parent_trace_context
workflow_trace_info = WorkflowTraceInfo(
trace_id=self.trace_id,
workflow_data=workflow_run.to_dict(),
@@ -612,6 +828,8 @@ class TraceTask:
workflow_run_version=workflow_run_version,
error=error,
total_tokens=total_tokens,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
file_list=file_list,
query=query,
metadata=metadata,
@@ -619,10 +837,11 @@ class TraceTask:
message_id=message_id,
start_time=workflow_run.created_at,
end_time=workflow_run.finished_at,
invoked_by=self._get_user_id_from_metadata(metadata),
)
return workflow_trace_info
def message_trace(self, message_id: str | None):
def message_trace(self, message_id: str | None, **kwargs):
if not message_id:
return {}
message_data = get_message_data(message_id)
@@ -645,6 +864,19 @@ class TraceTask:
streaming_metrics = self._extract_streaming_metrics(message_data)
tenant_id = ""
with Session(db.engine) as session:
tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id))
if tid:
tenant_id = str(tid)
from core.telemetry.gateway import is_enterprise_telemetry_enabled
if is_enterprise_telemetry_enabled():
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
else:
app_name, workspace_name = "", ""
metadata = {
"conversation_id": message_data.conversation_id,
"ls_provider": message_data.model_provider,
@@ -656,7 +888,14 @@ class TraceTask:
"workflow_run_id": message_data.workflow_run_id,
"from_source": message_data.from_source,
"message_id": message_id,
"tenant_id": tenant_id,
"app_id": message_data.app_id,
"user_id": message_data.from_end_user_id or message_data.from_account_id,
"app_name": app_name,
"workspace_name": workspace_name,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
message_tokens = message_data.message_tokens
@@ -673,7 +912,9 @@ class TraceTask:
outputs=message_data.answer,
file_list=file_list,
start_time=created_at,
end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
end_time=message_data.updated_at
if message_data.updated_at and message_data.updated_at > created_at
else created_at + timedelta(seconds=message_data.provider_response_latency),
metadata=metadata,
message_file_data=message_file_data,
conversation_mode=conversation_mode,
@@ -698,6 +939,8 @@ class TraceTask:
"preset_response": moderation_result.preset_response,
"query": moderation_result.query,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
# get workflow_app_log_id
workflow_app_log_id = None
@@ -739,6 +982,8 @@ class TraceTask:
"workflow_run_id": message_data.workflow_run_id,
"from_source": message_data.from_source,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
# get workflow_app_log_id
workflow_app_log_id = None
@@ -778,6 +1023,52 @@ class TraceTask:
if not message_data:
return {}
tenant_id = ""
with Session(db.engine) as session:
tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id))
if tid:
tenant_id = str(tid)
from core.telemetry.gateway import is_enterprise_telemetry_enabled
if is_enterprise_telemetry_enabled():
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
else:
app_name, workspace_name = "", ""
doc_list = [doc.model_dump() for doc in documents] if documents else []
dataset_ids: set[str] = set()
for doc in doc_list:
doc_meta = doc.get("metadata") or {}
did = doc_meta.get("dataset_id")
if did:
dataset_ids.add(did)
embedding_models: dict[str, dict[str, str]] = {}
if dataset_ids:
with Session(db.engine) as session:
rows = session.execute(
select(Dataset.id, Dataset.embedding_model, Dataset.embedding_model_provider).where(
Dataset.id.in_(list(dataset_ids))
)
).all()
for row in rows:
embedding_models[str(row[0])] = {
"embedding_model": row[1] or "",
"embedding_model_provider": row[2] or "",
}
# Extract rerank model info from retrieval_model kwargs
rerank_model_provider = ""
rerank_model_name = ""
if "retrieval_model" in kwargs:
retrieval_model = kwargs["retrieval_model"]
if isinstance(retrieval_model, dict):
reranking_model = retrieval_model.get("reranking_model")
if isinstance(reranking_model, dict):
rerank_model_provider = reranking_model.get("reranking_provider_name", "")
rerank_model_name = reranking_model.get("reranking_model_name", "")
metadata = {
"message_id": message_id,
"ls_provider": message_data.model_provider,
@@ -788,13 +1079,23 @@ class TraceTask:
"agent_based": message_data.agent_based,
"workflow_run_id": message_data.workflow_run_id,
"from_source": message_data.from_source,
"tenant_id": tenant_id,
"app_id": message_data.app_id,
"user_id": message_data.from_end_user_id or message_data.from_account_id,
"app_name": app_name,
"workspace_name": workspace_name,
"embedding_models": embedding_models,
"rerank_model_provider": rerank_model_provider,
"rerank_model_name": rerank_model_name,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
trace_id=self.trace_id,
message_id=message_id,
inputs=message_data.query or message_data.inputs,
documents=[doc.model_dump() for doc in documents] if documents else [],
documents=doc_list,
start_time=timer.get("start"),
end_time=timer.get("end"),
metadata=metadata,
@@ -837,6 +1138,10 @@ class TraceTask:
"error": error,
"tool_parameters": tool_parameters,
}
if message_data.workflow_run_id:
metadata["workflow_run_id"] = message_data.workflow_run_id
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
file_url = ""
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
@@ -891,6 +1196,8 @@ class TraceTask:
"conversation_id": conversation_id,
"tenant_id": tenant_id,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
generate_name_trace_info = GenerateNameTraceInfo(
trace_id=self.trace_id,
@@ -905,6 +1212,182 @@ class TraceTask:
return generate_name_trace_info
def prompt_generation_trace(self, **kwargs) -> PromptGenerationTraceInfo | dict:
tenant_id = kwargs.get("tenant_id", "")
user_id = kwargs.get("user_id", "")
app_id = kwargs.get("app_id")
operation_type = kwargs.get("operation_type", "")
instruction = kwargs.get("instruction", "")
generated_output = kwargs.get("generated_output", "")
prompt_tokens = kwargs.get("prompt_tokens", 0)
completion_tokens = kwargs.get("completion_tokens", 0)
total_tokens = kwargs.get("total_tokens", 0)
model_provider = kwargs.get("model_provider", "")
model_name = kwargs.get("model_name", "")
latency = kwargs.get("latency", 0.0)
timer = kwargs.get("timer")
start_time = timer.get("start") if timer else None
end_time = timer.get("end") if timer else None
total_price = kwargs.get("total_price")
currency = kwargs.get("currency")
error = kwargs.get("error")
app_name = None
workspace_name = None
if app_id:
app_name, workspace_name = _lookup_app_and_workspace_names(app_id, tenant_id)
metadata = {
"tenant_id": tenant_id,
"user_id": user_id,
"app_id": app_id or "",
"app_name": app_name,
"workspace_name": workspace_name,
"operation_type": operation_type,
"model_provider": model_provider,
"model_name": model_name,
}
if node_execution_id := kwargs.get("node_execution_id"):
metadata["node_execution_id"] = node_execution_id
return PromptGenerationTraceInfo(
trace_id=self.trace_id,
inputs=instruction,
outputs=generated_output,
start_time=start_time,
end_time=end_time,
metadata=metadata,
tenant_id=tenant_id,
user_id=user_id,
app_id=app_id,
operation_type=operation_type,
instruction=instruction,
prompt_tokens=prompt_tokens,
completion_tokens=completion_tokens,
total_tokens=total_tokens,
model_provider=model_provider,
model_name=model_name,
latency=latency,
total_price=total_price,
currency=currency,
error=error,
)
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict:
node_data: dict = kwargs.get("node_execution_data", {})
if not node_data:
return {}
from core.telemetry.gateway import is_enterprise_telemetry_enabled
if is_enterprise_telemetry_enabled():
app_name, workspace_name = _lookup_app_and_workspace_names(
node_data.get("app_id"), node_data.get("tenant_id")
)
else:
app_name, workspace_name = "", ""
# Try tool credential lookup first
credential_id = node_data.get("credential_id")
if is_enterprise_telemetry_enabled():
credential_name = _lookup_credential_name(credential_id, node_data.get("credential_provider_type"))
# If no credential_id found (e.g., LLM nodes), try LLM credential lookup
if not credential_id:
llm_cred_id, llm_cred_name = _lookup_llm_credential_info(
tenant_id=node_data.get("tenant_id"),
provider=node_data.get("model_provider"),
model=node_data.get("model_name"),
model_type="llm",
)
if llm_cred_id:
credential_id = llm_cred_id
credential_name = llm_cred_name
else:
credential_name = ""
metadata: dict[str, Any] = {
"tenant_id": node_data.get("tenant_id"),
"app_id": node_data.get("app_id"),
"app_name": app_name,
"workspace_name": workspace_name,
"user_id": node_data.get("user_id"),
"invoke_from": node_data.get("invoke_from"),
"credential_id": node_data.get("credential_id"),
"credential_name": credential_name,
"dataset_ids": node_data.get("dataset_ids"),
"dataset_names": node_data.get("dataset_names"),
"plugin_name": node_data.get("plugin_name"),
}
parent_trace_context = node_data.get("parent_trace_context")
if parent_trace_context:
metadata["parent_trace_context"] = parent_trace_context
message_id: str | None = None
conversation_id = node_data.get("conversation_id")
workflow_execution_id = node_data.get("workflow_execution_id")
if conversation_id and workflow_execution_id and not parent_trace_context:
with Session(db.engine) as session:
msg_id = session.scalar(
select(Message.id).where(
Message.conversation_id == conversation_id,
Message.workflow_run_id == workflow_execution_id,
)
)
if msg_id:
message_id = str(msg_id)
metadata["message_id"] = message_id
if conversation_id:
metadata["conversation_id"] = conversation_id
return WorkflowNodeTraceInfo(
trace_id=self.trace_id,
message_id=message_id,
start_time=node_data.get("created_at"),
end_time=node_data.get("finished_at"),
metadata=metadata,
workflow_id=node_data.get("workflow_id", ""),
workflow_run_id=node_data.get("workflow_execution_id", ""),
tenant_id=node_data.get("tenant_id", ""),
node_execution_id=node_data.get("node_execution_id", ""),
node_id=node_data.get("node_id", ""),
node_type=node_data.get("node_type", ""),
title=node_data.get("title", ""),
status=node_data.get("status", ""),
error=node_data.get("error"),
elapsed_time=node_data.get("elapsed_time", 0.0),
index=node_data.get("index", 0),
predecessor_node_id=node_data.get("predecessor_node_id"),
total_tokens=node_data.get("total_tokens", 0),
total_price=node_data.get("total_price", 0.0),
currency=node_data.get("currency"),
model_provider=node_data.get("model_provider"),
model_name=node_data.get("model_name"),
prompt_tokens=node_data.get("prompt_tokens"),
completion_tokens=node_data.get("completion_tokens"),
tool_name=node_data.get("tool_name"),
iteration_id=node_data.get("iteration_id"),
iteration_index=node_data.get("iteration_index"),
loop_id=node_data.get("loop_id"),
loop_index=node_data.get("loop_index"),
parallel_id=node_data.get("parallel_id"),
node_inputs=node_data.get("node_inputs"),
node_outputs=node_data.get("node_outputs"),
process_data=node_data.get("process_data"),
invoked_by=self._get_user_id_from_metadata(metadata),
)
def draft_node_execution_trace(self, **kwargs) -> DraftNodeExecutionTrace | dict:
node_trace = self.node_execution_trace(**kwargs)
if not node_trace or not isinstance(node_trace, WorkflowNodeTraceInfo):
return node_trace
return DraftNodeExecutionTrace(**node_trace.model_dump())
def _extract_streaming_metrics(self, message_data) -> dict:
if not message_data.message_metadata:
return {}
@@ -938,13 +1421,17 @@ class TraceQueueManager:
self.user_id = user_id
self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
self.flask_app = current_app._get_current_object() # type: ignore
from core.telemetry.gateway import is_enterprise_telemetry_enabled
self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled()
if trace_manager_timer is None:
self.start_timer()
def add_trace_task(self, trace_task: TraceTask):
global trace_manager_timer, trace_manager_queue
try:
if self.trace_instance:
if self._enterprise_telemetry_enabled or self.trace_instance:
trace_task.app_id = self.app_id
trace_manager_queue.put(trace_task)
except Exception:
@@ -980,20 +1467,27 @@ class TraceQueueManager:
def send_to_celery(self, tasks: list[TraceTask]):
with self.flask_app.app_context():
for task in tasks:
if task.app_id is None:
continue
storage_id = task.app_id
if storage_id is None:
tenant_id = task.kwargs.get("tenant_id")
if tenant_id:
storage_id = f"tenant-{tenant_id}"
else:
logger.warning("Skipping trace without app_id or tenant_id, trace_type: %s", task.trace_type)
continue
file_id = uuid4().hex
trace_info = task.execute()
task_data = TaskData(
app_id=task.app_id,
app_id=storage_id,
trace_info_type=type(trace_info).__name__,
trace_info=trace_info.model_dump() if trace_info else None,
)
file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
file_path = f"{OPS_FILE_PATH}{storage_id}/{file_id}.json"
storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
file_info = {
"file_id": file_id,
"app_id": task.app_id,
"app_id": storage_id,
}
process_trace_tasks.delay(file_info) # type: ignore

View File

@@ -27,8 +27,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.ops_trace_manager import TraceQueueManager
from core.ops.utils import measure_time
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
@@ -56,6 +55,8 @@ from core.rag.retrieval.template_prompts import (
METADATA_FILTER_USER_PROMPT_2,
METADATA_FILTER_USER_PROMPT_3,
)
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
from core.tools.signature import sign_upload_file
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
from extensions.ext_database import db
@@ -728,10 +729,21 @@ class DatasetRetrieval:
self.application_generate_entity.trace_manager if self.application_generate_entity else None
)
if trace_manager:
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
)
app_config = self.application_generate_entity.app_config if self.application_generate_entity else None
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.DATASET_RETRIEVAL_TRACE,
context=TelemetryContext(
tenant_id=app_config.tenant_id if app_config else None,
app_id=app_config.app_id if app_config else None,
),
payload={
"message_id": message_id,
"documents": documents,
"timer": timer,
},
),
trace_manager=trace_manager,
)
def _on_query(

View File

@@ -0,0 +1,43 @@
"""Telemetry facade.
Thin public API for emitting telemetry events. All routing logic
lives in ``core.telemetry.gateway`` which is shared by both CE and EE.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from core.ops.entities.trace_entity import TraceTaskName
from core.telemetry.events import TelemetryContext, TelemetryEvent
from core.telemetry.gateway import emit as gateway_emit
from core.telemetry.gateway import get_trace_task_to_case
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
def emit(event: TelemetryEvent, trace_manager: TraceQueueManager | None = None) -> None:
"""Emit a telemetry event.
Translates the ``TelemetryEvent`` (keyed by ``TraceTaskName``) into a
``TelemetryCase`` and delegates to ``core.telemetry.gateway.emit()``.
"""
case = get_trace_task_to_case().get(event.name)
if case is None:
return
context: dict[str, object] = {
"tenant_id": event.context.tenant_id,
"user_id": event.context.user_id,
"app_id": event.context.app_id,
}
gateway_emit(case, context, event.payload, trace_manager)
__all__ = [
"TelemetryContext",
"TelemetryEvent",
"TraceTaskName",
"emit",
]

View File

@@ -0,0 +1,21 @@
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from core.ops.entities.trace_entity import TraceTaskName
@dataclass(frozen=True)
class TelemetryContext:
tenant_id: str | None = None
user_id: str | None = None
app_id: str | None = None
@dataclass(frozen=True)
class TelemetryEvent:
name: TraceTaskName
context: TelemetryContext
payload: dict[str, Any]

View File

@@ -0,0 +1,229 @@
"""Telemetry gateway — single routing layer for all editions.
Maps ``TelemetryCase`` → ``CaseRoute`` and dispatches events to either
the CE/EE trace pipeline (``TraceQueueManager``) or the enterprise-only
metric/log Celery queue.
This module lives in ``core/`` so both CE and EE share one routing table
and one ``emit()`` entry point. No separate enterprise gateway module is
needed — enterprise-specific dispatch (Celery task, payload offloading)
is handled here behind lazy imports that no-op in CE.
"""
from __future__ import annotations
import json
import logging
import uuid
from typing import TYPE_CHECKING, Any
from core.ops.entities.trace_entity import TraceTaskName
from extensions.ext_storage import storage
if TYPE_CHECKING:
from core.ops.ops_trace_manager import TraceQueueManager
from enterprise.telemetry.contracts import TelemetryCase
logger = logging.getLogger(__name__)
PAYLOAD_SIZE_THRESHOLD_BYTES = 1 * 1024 * 1024
# ---------------------------------------------------------------------------
# Routing table — authoritative mapping for all editions
# ---------------------------------------------------------------------------
_case_to_trace_task: dict | None = None
_case_routing: dict | None = None
def _get_case_to_trace_task() -> dict:
global _case_to_trace_task
if _case_to_trace_task is None:
from enterprise.telemetry.contracts import TelemetryCase
_case_to_trace_task = {
TelemetryCase.WORKFLOW_RUN: TraceTaskName.WORKFLOW_TRACE,
TelemetryCase.MESSAGE_RUN: TraceTaskName.MESSAGE_TRACE,
TelemetryCase.NODE_EXECUTION: TraceTaskName.NODE_EXECUTION_TRACE,
TelemetryCase.DRAFT_NODE_EXECUTION: TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
TelemetryCase.PROMPT_GENERATION: TraceTaskName.PROMPT_GENERATION_TRACE,
TelemetryCase.TOOL_EXECUTION: TraceTaskName.TOOL_TRACE,
TelemetryCase.MODERATION_CHECK: TraceTaskName.MODERATION_TRACE,
TelemetryCase.SUGGESTED_QUESTION: TraceTaskName.SUGGESTED_QUESTION_TRACE,
TelemetryCase.DATASET_RETRIEVAL: TraceTaskName.DATASET_RETRIEVAL_TRACE,
TelemetryCase.GENERATE_NAME: TraceTaskName.GENERATE_NAME_TRACE,
}
return _case_to_trace_task
def get_trace_task_to_case() -> dict:
"""Return TraceTaskName → TelemetryCase (inverse of _get_case_to_trace_task)."""
return {v: k for k, v in _get_case_to_trace_task().items()}
def _get_case_routing() -> dict:
global _case_routing
if _case_routing is None:
from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase
_case_routing = {
# TRACE — CE-eligible (flow in both CE and EE)
TelemetryCase.WORKFLOW_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.MESSAGE_RUN: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.TOOL_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.MODERATION_CHECK: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.SUGGESTED_QUESTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.DATASET_RETRIEVAL: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
TelemetryCase.GENERATE_NAME: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True),
# TRACE — enterprise-only
TelemetryCase.NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
TelemetryCase.DRAFT_NODE_EXECUTION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
TelemetryCase.PROMPT_GENERATION: CaseRoute(signal_type=SignalType.TRACE, ce_eligible=False),
# METRIC_LOG — enterprise-only (signal-driven, not trace)
TelemetryCase.APP_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.APP_UPDATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.APP_DELETED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
TelemetryCase.FEEDBACK_CREATED: CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False),
}
return _case_routing
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def is_enterprise_telemetry_enabled() -> bool:
try:
from enterprise.telemetry.exporter import is_enterprise_telemetry_enabled
return is_enterprise_telemetry_enabled()
except Exception:
return False
def _handle_payload_sizing(
payload: dict[str, Any],
tenant_id: str,
event_id: str,
) -> tuple[dict[str, Any], str | None]:
"""Inline or offload payload based on size.
Returns ``(payload_for_envelope, storage_key | None)``. Payloads
exceeding ``PAYLOAD_SIZE_THRESHOLD_BYTES`` are written to object
storage and replaced with an empty dict in the envelope.
"""
try:
payload_json = json.dumps(payload)
payload_size = len(payload_json.encode("utf-8"))
except (TypeError, ValueError):
logger.warning("Failed to serialize payload for sizing: event_id=%s", event_id)
return payload, None
if payload_size <= PAYLOAD_SIZE_THRESHOLD_BYTES:
return payload, None
storage_key = f"telemetry/{tenant_id}/{event_id}.json"
try:
storage.save(storage_key, payload_json.encode("utf-8"))
logger.debug("Stored large payload to storage: key=%s, size=%d", storage_key, payload_size)
return {}, storage_key
except Exception:
logger.warning("Failed to store large payload, inlining instead: event_id=%s", event_id, exc_info=True)
return payload, None
# ---------------------------------------------------------------------------
# Public API
# ---------------------------------------------------------------------------
def emit(
case: TelemetryCase,
context: dict[str, Any],
payload: dict[str, Any],
trace_manager: TraceQueueManager | None = None,
) -> None:
"""Route a telemetry event to the correct pipeline.
TRACE events are enqueued into ``TraceQueueManager`` (works in both CE
and EE). Enterprise-only traces are silently dropped when EE is
disabled.
METRIC_LOG events are dispatched to the enterprise Celery queue;
silently dropped when enterprise telemetry is unavailable.
"""
route = _get_case_routing().get(case)
if route is None:
logger.warning("Unknown telemetry case: %s, dropping event", case)
return
if not route.ce_eligible and not is_enterprise_telemetry_enabled():
logger.debug("Dropping EE-only event: case=%s (EE disabled)", case)
return
if route.signal_type == "trace":
_emit_trace(case, context, payload, trace_manager)
else:
_emit_metric_log(case, context, payload)
def _emit_trace(
case: TelemetryCase,
context: dict[str, Any],
payload: dict[str, Any],
trace_manager: TraceQueueManager | None,
) -> None:
from core.ops.ops_trace_manager import TraceQueueManager as LocalTraceQueueManager
from core.ops.ops_trace_manager import TraceTask
trace_task_name = _get_case_to_trace_task().get(case)
if trace_task_name is None:
logger.warning("No TraceTaskName mapping for case: %s", case)
return
queue_manager = trace_manager or LocalTraceQueueManager(
app_id=context.get("app_id"),
user_id=context.get("user_id"),
)
queue_manager.add_trace_task(TraceTask(trace_task_name, **payload))
logger.debug("Enqueued trace task: case=%s, app_id=%s", case, context.get("app_id"))
def _emit_metric_log(
case: TelemetryCase,
context: dict[str, Any],
payload: dict[str, Any],
) -> None:
"""Build envelope and dispatch to enterprise Celery queue.
No-ops when the enterprise telemetry task is not importable (CE mode).
"""
try:
from tasks.enterprise_telemetry_task import process_enterprise_telemetry
except ImportError:
logger.debug("Enterprise metric/log dispatch unavailable, dropping: case=%s", case)
return
tenant_id = context.get("tenant_id", "")
event_id = str(uuid.uuid4())
payload_for_envelope, payload_ref = _handle_payload_sizing(payload, tenant_id, event_id)
from enterprise.telemetry.contracts import TelemetryEnvelope
envelope = TelemetryEnvelope(
case=case,
tenant_id=tenant_id,
event_id=event_id,
payload=payload_for_envelope,
metadata={"payload_ref": payload_ref} if payload_ref else None,
)
process_enterprise_telemetry.delay(envelope.model_dump_json())
logger.debug(
"Enqueued metric/log event: case=%s, tenant_id=%s, event_id=%s",
case,
tenant_id,
event_id,
)

View File

@@ -50,6 +50,7 @@ class WorkflowTool(Tool):
self.workflow_call_depth = workflow_call_depth
self.label = label
self._latest_usage = LLMUsage.empty_usage()
self.parent_trace_context: dict[str, str] | None = None
super().__init__(entity=entity, runtime=runtime)
@@ -90,11 +91,15 @@ class WorkflowTool(Tool):
self._latest_usage = LLMUsage.empty_usage()
args: dict[str, Any] = {"inputs": tool_parameters, "files": files}
if self.parent_trace_context:
args["_parent_trace_context"] = self.parent_trace_context
result = generator.generate(
app_model=app,
workflow=workflow,
user=user,
args={"inputs": tool_parameters, "files": files},
args=args,
invoke_from=self.runtime.invoke_from,
streaming=False,
call_depth=self.workflow_call_depth + 1,

View File

@@ -232,6 +232,8 @@ class WorkflowNodeExecutionMetadataKey(StrEnum):
"""
TOTAL_TOKENS = "total_tokens"
PROMPT_TOKENS = "prompt_tokens"
COMPLETION_TOKENS = "completion_tokens"
TOTAL_PRICE = "total_price"
CURRENCY = "currency"
TOOL_INFO = "tool_info"

View File

@@ -322,6 +322,8 @@ class LLMNode(Node[LLMNodeData]):
outputs=outputs,
metadata={
WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS: usage.total_tokens,
WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS: usage.prompt_tokens,
WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS: usage.completion_tokens,
WorkflowNodeExecutionMetadataKey.TOTAL_PRICE: usage.total_price,
WorkflowNodeExecutionMetadataKey.CURRENCY: usage.currency,
},

View File

@@ -60,7 +60,9 @@ class ToolNode(Node[ToolNodeData]):
tool_info = {
"provider_type": self.node_data.provider_type.value,
"provider_id": self.node_data.provider_id,
"tool_name": self.node_data.tool_name,
"plugin_unique_identifier": self.node_data.plugin_unique_identifier,
"credential_id": self.node_data.credential_id,
}
# get tool runtime
@@ -105,6 +107,20 @@ class ToolNode(Node[ToolNodeData]):
# get conversation id
conversation_id = self.graph_runtime_state.variable_pool.get(["sys", SystemVariableKey.CONVERSATION_ID])
from core.tools.workflow_as_tool.tool import WorkflowTool
if isinstance(tool_runtime, WorkflowTool):
workflow_run_id_var = self.graph_runtime_state.variable_pool.get(
["sys", SystemVariableKey.WORKFLOW_EXECUTION_ID]
)
tool_runtime.parent_trace_context = {
"trace_id": str(workflow_run_id_var.text) if workflow_run_id_var else "",
"parent_node_execution_id": self.execution_id,
"parent_workflow_run_id": str(workflow_run_id_var.text) if workflow_run_id_var else "",
"parent_app_id": self.app_id,
"parent_conversation_id": conversation_id.text if conversation_id else None,
}
try:
message_stream = ToolEngine.generic_invoke(
tool=tool_runtime,
@@ -431,6 +447,8 @@ class ToolNode(Node[ToolNodeData]):
}
if isinstance(usage.total_tokens, int) and usage.total_tokens > 0:
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS] = usage.total_tokens
metadata[WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS] = usage.prompt_tokens
metadata[WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS] = usage.completion_tokens
metadata[WorkflowNodeExecutionMetadataKey.TOTAL_PRICE] = usage.total_price
metadata[WorkflowNodeExecutionMetadataKey.CURRENCY] = usage.currency

View File

View File

@@ -0,0 +1,522 @@
# Dify Enterprise Telemetry Data Dictionary
Quick reference for all telemetry signals emitted by Dify Enterprise. For configuration and architecture details, see [README.md](./README.md).
## Resource Attributes
Attached to every signal (Span, Metric, Log).
| Attribute | Type | Example |
|-----------|------|---------|
| `service.name` | string | `dify` |
| `host.name` | string | `dify-api-7f8b` |
## Traces (Spans)
### `dify.workflow.run`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.trace_id` | string | Business trace ID (Workflow Run ID) |
| `dify.tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.workflow.id` | string | Workflow definition ID |
| `dify.workflow.run_id` | string | Unique ID for this run |
| `dify.workflow.status` | string | `succeeded`, `failed`, `stopped`, etc. |
| `dify.workflow.error` | string | Error message if failed |
| `dify.workflow.elapsed_time` | float | Total execution time (seconds) |
| `dify.invoke_from` | string | `api`, `webapp`, `debug` |
| `dify.conversation.id` | string | Conversation ID (optional) |
| `dify.message.id` | string | Message ID (optional) |
| `dify.invoked_by` | string | User ID who triggered the run |
| `gen_ai.usage.total_tokens` | int | Total tokens across all nodes (optional) |
| `gen_ai.user.id` | string | End-user identifier (optional) |
| `dify.parent.trace_id` | string | Parent workflow trace ID (optional) |
| `dify.parent.workflow.run_id` | string | Parent workflow run ID (optional) |
| `dify.parent.node.execution_id` | string | Parent node execution ID (optional) |
| `dify.parent.app.id` | string | Parent app ID (optional) |
### `dify.node.execution`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.trace_id` | string | Business trace ID |
| `dify.tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.workflow.id` | string | Workflow definition ID |
| `dify.workflow.run_id` | string | Workflow Run ID |
| `dify.message.id` | string | Message ID (optional) |
| `dify.conversation.id` | string | Conversation ID (optional) |
| `dify.node.execution_id` | string | Unique node execution ID |
| `dify.node.id` | string | Node ID in workflow graph |
| `dify.node.type` | string | Node type (see appendix) |
| `dify.node.title` | string | Display title |
| `dify.node.status` | string | `succeeded`, `failed` |
| `dify.node.error` | string | Error message if failed |
| `dify.node.elapsed_time` | float | Execution time (seconds) |
| `dify.node.index` | int | Execution order index |
| `dify.node.predecessor_node_id` | string | Triggering node ID |
| `dify.node.iteration_id` | string | Iteration ID (optional) |
| `dify.node.loop_id` | string | Loop ID (optional) |
| `dify.node.parallel_id` | string | Parallel branch ID (optional) |
| `dify.node.invoked_by` | string | User ID who triggered execution |
| `gen_ai.usage.input_tokens` | int | Prompt tokens (LLM nodes only) |
| `gen_ai.usage.output_tokens` | int | Completion tokens (LLM nodes only) |
| `gen_ai.usage.total_tokens` | int | Total tokens (LLM nodes only) |
| `gen_ai.request.model` | string | LLM model name (LLM nodes only) |
| `gen_ai.provider.name` | string | LLM provider name (LLM nodes only) |
| `gen_ai.user.id` | string | End-user identifier (optional) |
### `dify.node.execution.draft`
Same attributes as `dify.node.execution`. Emitted during Preview/Debug runs.
## Counters
All counters are cumulative and emitted at 100% accuracy.
### Token Counters
| Metric | Unit | Description |
|--------|------|-------------|
| `dify.tokens.total` | `{token}` | Total tokens consumed |
| `dify.tokens.input` | `{token}` | Input (prompt) tokens |
| `dify.tokens.output` | `{token}` | Output (completion) tokens |
**Labels:**
- `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type` (if node_execution)
⚠️ **Warning:** `dify.tokens.total` at workflow level includes all node tokens. Filter by `operation_type` to avoid double-counting.
#### Token Hierarchy & Query Patterns
Token metrics are emitted at multiple layers. Understanding the hierarchy prevents double-counting:
```
App-level total
├── workflow ← sum of all node_execution tokens (DO NOT add both)
│ └── node_execution ← per-node breakdown
├── message ← independent (non-workflow chat apps only)
├── rule_generate ← independent helper LLM call
├── code_generate ← independent helper LLM call
├── structured_output ← independent helper LLM call
└── instruction_modify← independent helper LLM call
```
**Key rule:** `workflow` tokens already include all `node_execution` tokens. Never sum both.
**Available labels on token metrics:** `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `node_type`.
App name is only available on span attributes (`dify.app.name`), not metric labels — use `app_id` for metric queries.
**Common queries** (PromQL):
```promql
# ── Totals ──────────────────────────────────────────────────
# App-level total (exclude node_execution to avoid double-counting)
sum by (app_id) (dify_tokens_total{operation_type!="node_execution"})
# Single app total
sum (dify_tokens_total{app_id="<app_id>", operation_type!="node_execution"})
# Per-tenant totals
sum by (tenant_id) (dify_tokens_total{operation_type!="node_execution"})
# ── Drill-down ──────────────────────────────────────────────
# Workflow-level tokens for an app
sum (dify_tokens_total{app_id="<app_id>", operation_type="workflow"})
# Node-level breakdown within an app
sum by (node_type) (dify_tokens_total{app_id="<app_id>", operation_type="node_execution"})
# Model breakdown for an app
sum by (model_provider, model_name) (dify_tokens_total{app_id="<app_id>"})
# Input vs output per model
sum by (model_name) (dify_tokens_input_total{app_id="<app_id>"})
sum by (model_name) (dify_tokens_output_total{app_id="<app_id>"})
# ── Rates ───────────────────────────────────────────────────
# Token consumption rate (per hour)
sum(rate(dify_tokens_total{operation_type!="node_execution"}[1h]))
# Per-app consumption rate
sum by (app_id) (rate(dify_tokens_total{operation_type!="node_execution"}[1h]))
```
**Finding `app_id` from app name** (trace query — Tempo / Jaeger):
```
{ resource.dify.app.name = "My Chatbot" } | select(resource.dify.app.id)
```
### Request Counters
| Metric | Unit | Description |
|--------|------|-------------|
| `dify.requests.total` | `{request}` | Total operations count |
**Labels by type:**
| `type` | Additional Labels |
|--------|-------------------|
| `workflow` | `tenant_id`, `app_id`, `status`, `invoke_from` |
| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` |
| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `status` |
| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name`, `status`, `invoke_from` |
| `tool` | `tenant_id`, `app_id`, `tool_name` |
| `moderation` | `tenant_id`, `app_id` |
| `suggested_question` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
| `dataset_retrieval` | `tenant_id`, `app_id` |
| `generate_name` | `tenant_id`, `app_id` |
| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name`, `status` |
### Error Counters
| Metric | Unit | Description |
|--------|------|-------------|
| `dify.errors.total` | `{error}` | Total failed operations |
**Labels by type:**
| `type` | Additional Labels |
|--------|-------------------|
| `workflow` | `tenant_id`, `app_id` |
| `node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` |
| `draft_node` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name` |
| `message` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
| `tool` | `tenant_id`, `app_id`, `tool_name` |
| `prompt_generation` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` |
### Other Counters
| Metric | Unit | Labels |
|--------|------|--------|
| `dify.feedback.total` | `{feedback}` | `tenant_id`, `app_id`, `rating` |
| `dify.dataset.retrievals.total` | `{retrieval}` | `tenant_id`, `app_id`, `dataset_id`, `embedding_model_provider`, `embedding_model`, `rerank_model_provider`, `rerank_model` |
| `dify.app.created.total` | `{app}` | `tenant_id`, `app_id`, `mode` |
| `dify.app.updated.total` | `{app}` | `tenant_id`, `app_id` |
| `dify.app.deleted.total` | `{app}` | `tenant_id`, `app_id` |
## Histograms
| Metric | Unit | Labels |
|--------|------|--------|
| `dify.workflow.duration` | `s` | `tenant_id`, `app_id`, `status` |
| `dify.node.duration` | `s` | `tenant_id`, `app_id`, `node_type`, `model_provider`, `model_name`, `plugin_name` |
| `dify.message.duration` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
| `dify.message.time_to_first_token` | `s` | `tenant_id`, `app_id`, `model_provider`, `model_name` |
| `dify.tool.duration` | `s` | `tenant_id`, `app_id`, `tool_name` |
| `dify.prompt_generation.duration` | `s` | `tenant_id`, `app_id`, `operation_type`, `model_provider`, `model_name` |
## Structured Logs
### Span Companion Logs
Logs that accompany spans. Signal type: `span_detail`
#### `dify.workflow.run` Companion Log
**Common attributes:** All span attributes (see Traces section) plus:
| Additional Attribute | Type | Always Present | Description |
|---------------------|------|----------------|-------------|
| `dify.app.name` | string | No | Application display name |
| `dify.workspace.name` | string | No | Workspace display name |
| `dify.workflow.version` | string | Yes | Workflow definition version |
| `dify.workflow.inputs` | string/JSON | Yes | Input parameters (content-gated) |
| `dify.workflow.outputs` | string/JSON | Yes | Output results (content-gated) |
| `dify.workflow.query` | string | No | User query text (content-gated) |
**Event attributes:**
- `dify.event.name`: `"dify.workflow.run"`
- `dify.event.signal`: `"span_detail"`
- `trace_id`, `span_id`, `tenant_id`, `user_id`
#### `dify.node.execution` and `dify.node.execution.draft` Companion Logs
**Common attributes:** All span attributes (see Traces section) plus:
| Additional Attribute | Type | Always Present | Description |
|---------------------|------|----------------|-------------|
| `dify.app.name` | string | No | Application display name |
| `dify.workspace.name` | string | No | Workspace display name |
| `dify.invoke_from` | string | No | Invocation source |
| `gen_ai.tool.name` | string | No | Tool name (tool nodes only) |
| `dify.node.total_price` | float | No | Cost (LLM nodes only) |
| `dify.node.currency` | string | No | Currency code (LLM nodes only) |
| `dify.node.iteration_index` | int | No | Iteration index (iteration nodes) |
| `dify.node.loop_index` | int | No | Loop index (loop nodes) |
| `dify.plugin.name` | string | No | Plugin name (tool/knowledge nodes) |
| `dify.credential.name` | string | No | Credential name (plugin nodes) |
| `dify.credential.id` | string | No | Credential ID (plugin nodes) |
| `dify.dataset.ids` | JSON array | No | Dataset IDs (knowledge nodes) |
| `dify.dataset.names` | JSON array | No | Dataset names (knowledge nodes) |
| `dify.node.inputs` | string/JSON | Yes | Node inputs (content-gated) |
| `dify.node.outputs` | string/JSON | Yes | Node outputs (content-gated) |
| `dify.node.process_data` | string/JSON | No | Processing data (content-gated) |
**Event attributes:**
- `dify.event.name`: `"dify.node.execution"` or `"dify.node.execution.draft"`
- `dify.event.signal`: `"span_detail"`
- `trace_id`, `span_id`, `tenant_id`, `user_id`
### Standalone Logs
Logs without structural spans. Signal type: `metric_only`
#### `dify.message.run`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.message.run"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID (32-char hex) |
| `span_id` | string | OTEL span ID (16-char hex) |
| `tenant_id` | string | Tenant identifier |
| `user_id` | string | User identifier (optional) |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.conversation.id` | string | Conversation ID (optional) |
| `dify.workflow.run_id` | string | Workflow run ID (optional) |
| `dify.invoke_from` | string | `service-api`, `web-app`, `debugger`, `explore` |
| `gen_ai.provider.name` | string | LLM provider |
| `gen_ai.request.model` | string | LLM model |
| `gen_ai.usage.input_tokens` | int | Input tokens |
| `gen_ai.usage.output_tokens` | int | Output tokens |
| `gen_ai.usage.total_tokens` | int | Total tokens |
| `dify.message.status` | string | `succeeded`, `failed` |
| `dify.message.error` | string | Error message (if failed) |
| `dify.message.duration` | float | Duration (seconds) |
| `dify.message.time_to_first_token` | float | TTFT (seconds) |
| `dify.message.inputs` | string/JSON | Inputs (content-gated) |
| `dify.message.outputs` | string/JSON | Outputs (content-gated) |
#### `dify.tool.execution`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.tool.execution"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.tool.name` | string | Tool name |
| `dify.tool.duration` | float | Duration (seconds) |
| `dify.tool.status` | string | `succeeded`, `failed` |
| `dify.tool.error` | string | Error message (if failed) |
| `dify.tool.inputs` | string/JSON | Inputs (content-gated) |
| `dify.tool.outputs` | string/JSON | Outputs (content-gated) |
| `dify.tool.parameters` | string/JSON | Parameters (content-gated) |
| `dify.tool.config` | string/JSON | Configuration (content-gated) |
#### `dify.moderation.check`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.moderation.check"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.moderation.type` | string | `input`, `output` |
| `dify.moderation.action` | string | `pass`, `block`, `flag` |
| `dify.moderation.flagged` | boolean | Whether flagged |
| `dify.moderation.categories` | JSON array | Flagged categories |
| `dify.moderation.query` | string | Content (content-gated) |
#### `dify.suggested_question.generation`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.suggested_question.generation"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.suggested_question.count` | int | Number of questions |
| `dify.suggested_question.duration` | float | Duration (seconds) |
| `dify.suggested_question.status` | string | `succeeded`, `failed` |
| `dify.suggested_question.error` | string | Error message (if failed) |
| `dify.suggested_question.questions` | JSON array | Questions (content-gated) |
#### `dify.dataset.retrieval`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.dataset.retrieval"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.dataset.id` | string | Dataset identifier |
| `dify.dataset.name` | string | Dataset name |
| `dify.dataset.embedding_providers` | JSON array | Embedding model providers (one per dataset) |
| `dify.dataset.embedding_models` | JSON array | Embedding models (one per dataset) |
| `dify.retrieval.rerank_provider` | string | Rerank model provider |
| `dify.retrieval.rerank_model` | string | Rerank model name |
| `dify.retrieval.query` | string | Search query (content-gated) |
| `dify.retrieval.document_count` | int | Documents retrieved |
| `dify.retrieval.duration` | float | Duration (seconds) |
| `dify.retrieval.status` | string | `succeeded`, `failed` |
| `dify.retrieval.error` | string | Error message (if failed) |
| `dify.dataset.documents` | JSON array | Documents (content-gated) |
#### `dify.generate_name.execution`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.generate_name.execution"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.conversation.id` | string | Conversation identifier |
| `dify.generate_name.duration` | float | Duration (seconds) |
| `dify.generate_name.status` | string | `succeeded`, `failed` |
| `dify.generate_name.error` | string | Error message (if failed) |
| `dify.generate_name.inputs` | string/JSON | Inputs (content-gated) |
| `dify.generate_name.outputs` | string | Generated name (content-gated) |
#### `dify.prompt_generation.execution`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.prompt_generation.execution"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.prompt_generation.operation_type` | string | Operation type (see appendix) |
| `gen_ai.provider.name` | string | LLM provider |
| `gen_ai.request.model` | string | LLM model |
| `gen_ai.usage.input_tokens` | int | Input tokens |
| `gen_ai.usage.output_tokens` | int | Output tokens |
| `gen_ai.usage.total_tokens` | int | Total tokens |
| `dify.prompt_generation.duration` | float | Duration (seconds) |
| `dify.prompt_generation.status` | string | `succeeded`, `failed` |
| `dify.prompt_generation.error` | string | Error message (if failed) |
| `dify.prompt_generation.instruction` | string | Instruction (content-gated) |
| `dify.prompt_generation.output` | string/JSON | Output (content-gated) |
#### `dify.app.created`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.app.created"` |
| `dify.event.signal` | string | `"metric_only"` |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.app.mode` | string | `chat`, `completion`, `agent-chat`, `workflow` |
| `dify.app.created_at` | string | Timestamp (ISO 8601) |
#### `dify.app.updated`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.app.updated"` |
| `dify.event.signal` | string | `"metric_only"` |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.app.updated_at` | string | Timestamp (ISO 8601) |
#### `dify.app.deleted`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.app.deleted"` |
| `dify.event.signal` | string | `"metric_only"` |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.app.deleted_at` | string | Timestamp (ISO 8601) |
#### `dify.feedback.created`
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.feedback.created"` |
| `dify.event.signal` | string | `"metric_only"` |
| `trace_id` | string | OTEL trace ID |
| `span_id` | string | OTEL span ID |
| `tenant_id` | string | Tenant identifier |
| `dify.app_id` | string | Application identifier |
| `dify.message.id` | string | Message identifier |
| `dify.feedback.rating` | string | `like`, `dislike`, `null` |
| `dify.feedback.content` | string | Feedback text (content-gated) |
| `dify.feedback.created_at` | string | Timestamp (ISO 8601) |
#### `dify.telemetry.rehydration_failed`
Diagnostic event for telemetry system health monitoring.
| Attribute | Type | Description |
|-----------|------|-------------|
| `dify.event.name` | string | `"dify.telemetry.rehydration_failed"` |
| `dify.event.signal` | string | `"metric_only"` |
| `tenant_id` | string | Tenant identifier |
| `dify.telemetry.error` | string | Error message |
| `dify.telemetry.payload_type` | string | Payload type (see appendix) |
| `dify.telemetry.correlation_id` | string | Correlation ID |
## Content-Gated Attributes
When `ENTERPRISE_INCLUDE_CONTENT=false`, these attributes are replaced with reference strings (`ref:{id_type}={uuid}`).
| Attribute | Signal |
|-----------|--------|
| `dify.workflow.inputs` | `dify.workflow.run` |
| `dify.workflow.outputs` | `dify.workflow.run` |
| `dify.workflow.query` | `dify.workflow.run` |
| `dify.node.inputs` | `dify.node.execution` |
| `dify.node.outputs` | `dify.node.execution` |
| `dify.node.process_data` | `dify.node.execution` |
| `dify.message.inputs` | `dify.message.run` |
| `dify.message.outputs` | `dify.message.run` |
| `dify.tool.inputs` | `dify.tool.execution` |
| `dify.tool.outputs` | `dify.tool.execution` |
| `dify.tool.parameters` | `dify.tool.execution` |
| `dify.tool.config` | `dify.tool.execution` |
| `dify.moderation.query` | `dify.moderation.check` |
| `dify.suggested_question.questions` | `dify.suggested_question.generation` |
| `dify.retrieval.query` | `dify.dataset.retrieval` |
| `dify.dataset.documents` | `dify.dataset.retrieval` |
| `dify.generate_name.inputs` | `dify.generate_name.execution` |
| `dify.generate_name.outputs` | `dify.generate_name.execution` |
| `dify.prompt_generation.instruction` | `dify.prompt_generation.execution` |
| `dify.prompt_generation.output` | `dify.prompt_generation.execution` |
| `dify.feedback.content` | `dify.feedback.created` |
## Appendix
### Operation Types
- `workflow`, `node_execution`, `message`, `rule_generate`, `code_generate`, `structured_output`, `instruction_modify`
### Node Types
- `start`, `end`, `answer`, `llm`, `knowledge-retrieval`, `knowledge-index`, `if-else`, `code`, `template-transform`, `question-classifier`, `http-request`, `tool`, `datasource`, `variable-aggregator`, `loop`, `iteration`, `parameter-extractor`, `assigner`, `document-extractor`, `list-operator`, `agent`, `trigger-webhook`, `trigger-schedule`, `trigger-plugin`, `human-input`
### Workflow Statuses
- `running`, `succeeded`, `failed`, `stopped`, `partial-succeeded`, `paused`
### Payload Types
- `workflow`, `node`, `message`, `tool`, `moderation`, `suggested_question`, `dataset_retrieval`, `generate_name`, `prompt_generation`, `app`, `feedback`
### Null Value Behavior
**Spans:** Attributes with `null` values are omitted.
**Logs:** Attributes with `null` values appear as `null` in JSON.
**Content-Gated:** Replaced with reference strings, not set to `null`.

View File

@@ -0,0 +1,116 @@
# Dify Enterprise Telemetry
This document provides an overview of the Dify Enterprise OpenTelemetry (OTEL) exporter and how to configure it for integration with observability stacks like Prometheus, Grafana, Jaeger, or Honeycomb.
## Overview
Dify Enterprise uses a "slim span + rich companion log" architecture to provide high-fidelity observability without overwhelming trace storage.
- **Traces (Spans)**: Capture the structure, identity, and timing of high-level operations (Workflows and Nodes).
- **Structured Logs**: Provide deep context (inputs, outputs, metadata) for every event, correlated to spans via `trace_id` and `span_id`.
- **Metrics**: Provide 100% accurate counters and histograms for usage, performance, and error tracking.
### Signal Architecture
```mermaid
graph TD
A[Workflow Run] -->|Span| B(dify.workflow.run)
A -->|Log| C(dify.workflow.run detail)
B ---|trace_id| C
D[Node Execution] -->|Span| E(dify.node.execution)
D -->|Log| F(dify.node.execution detail)
E ---|span_id| F
G[Message/Tool/etc] -->|Log| H(dify.* event)
G -->|Metric| I(dify.* counter/histogram)
```
## Configuration
The Enterprise OTEL exporter is configured via environment variables.
| Variable | Description | Default |
|----------|-------------|---------|
| `ENTERPRISE_ENABLED` | Master switch for all enterprise features. | `false` |
| `ENTERPRISE_TELEMETRY_ENABLED` | Master switch for enterprise telemetry. | `false` |
| `ENTERPRISE_OTLP_ENDPOINT` | OTLP collector endpoint (e.g., `http://otel-collector:4318`). | - |
| `ENTERPRISE_OTLP_HEADERS` | Custom headers for OTLP requests (e.g., `x-scope-orgid=tenant1`). | - |
| `ENTERPRISE_OTLP_PROTOCOL` | OTLP transport protocol (`http` or `grpc`). | `http` |
| `ENTERPRISE_OTLP_API_KEY` | Bearer token for authentication. | - |
| `ENTERPRISE_INCLUDE_CONTENT` | Whether to include sensitive content (inputs/outputs) in logs. | `true` |
| `ENTERPRISE_SERVICE_NAME` | Service name reported to OTEL. | `dify` |
| `ENTERPRISE_OTEL_SAMPLING_RATE` | Sampling rate for traces (0.0 to 1.0). Metrics are always 100%. | `1.0` |
## Correlation Model
Dify uses deterministic ID generation to ensure signals are correlated across different services and asynchronous tasks.
### ID Generation Rules
- `trace_id`: Derived from the correlation ID (workflow_run_id or node_execution_id for drafts) using `int(UUID(correlation_id))`
- `span_id`: Derived from the source ID using `SHA256(source_id)[:8]`
### Scenario A: Simple Workflow
A single workflow run with multiple nodes. All spans and logs share the same `trace_id` (derived from `workflow_run_id`).
```
trace_id = UUID(workflow_run_id)
├── [root span] dify.workflow.run (span_id = hash(workflow_run_id))
│ ├── [child] dify.node.execution - "Start" (span_id = hash(node_exec_id_1))
│ ├── [child] dify.node.execution - "LLM" (span_id = hash(node_exec_id_2))
│ └── [child] dify.node.execution - "End" (span_id = hash(node_exec_id_3))
```
### Scenario B: Nested Sub-Workflow
A workflow calling another workflow via a Tool or Sub-workflow node. The child workflow's spans are linked to the parent via `parent_span_id`. Both workflows share the same trace_id.
```
trace_id = UUID(outer_workflow_run_id) ← shared across both workflows
├── [root] dify.workflow.run (outer) (span_id = hash(outer_workflow_run_id))
│ ├── dify.node.execution - "Start Node"
│ ├── dify.node.execution - "Tool Node" (triggers sub-workflow)
│ │ └── [child] dify.workflow.run (inner) (span_id = hash(inner_workflow_run_id))
│ │ ├── dify.node.execution - "Inner Start"
│ │ └── dify.node.execution - "Inner End"
│ └── dify.node.execution - "End Node"
```
**Key attributes for nested workflows:**
- Inner workflow's `dify.parent.trace_id` = outer `workflow_run_id`
- Inner workflow's `dify.parent.node.execution_id` = tool node's `execution_id`
- Inner workflow's `dify.parent.workflow.run_id` = outer `workflow_run_id`
- Inner workflow's `dify.parent.app.id` = outer `app_id`
### Scenario C: Draft Node Execution
A single node run in isolation (debugger/preview mode). It creates its own trace where the node span is the root.
```
trace_id = UUID(node_execution_id) ← own trace, NOT part of any workflow
└── dify.node.execution.draft (span_id = hash(node_execution_id))
```
**Key difference:** Draft executions use `node_execution_id` as the correlation_id, so they are NOT children of any workflow trace.
## Content Gating
When `ENTERPRISE_INCLUDE_CONTENT` is set to `false`, sensitive content attributes (inputs, outputs, queries) are replaced with reference strings (e.g., `ref:workflow_run_id=...`) to prevent data leakage to the OTEL collector.
**Reference String Format:**
```
ref:{id_type}={uuid}
```
**Examples:**
```
ref:workflow_run_id=550e8400-e29b-41d4-a716-446655440000
ref:node_execution_id=660e8400-e29b-41d4-a716-446655440001
ref:message_id=770e8400-e29b-41d4-a716-446655440002
```
To retrieve actual content when gating is enabled, query the Dify database using the provided UUID.
## Reference
For a complete list of telemetry signals, attributes, and data structures, see [DATA_DICTIONARY.md](./DATA_DICTIONARY.md).

View File

View File

@@ -0,0 +1,73 @@
"""Telemetry gateway contracts and data structures.
This module defines the envelope format for telemetry events and the routing
configuration that determines how each event type is processed.
"""
from __future__ import annotations
from enum import StrEnum
from typing import Any
from pydantic import BaseModel, ConfigDict
class TelemetryCase(StrEnum):
"""Enumeration of all known telemetry event cases."""
WORKFLOW_RUN = "workflow_run"
NODE_EXECUTION = "node_execution"
DRAFT_NODE_EXECUTION = "draft_node_execution"
MESSAGE_RUN = "message_run"
TOOL_EXECUTION = "tool_execution"
MODERATION_CHECK = "moderation_check"
SUGGESTED_QUESTION = "suggested_question"
DATASET_RETRIEVAL = "dataset_retrieval"
GENERATE_NAME = "generate_name"
PROMPT_GENERATION = "prompt_generation"
APP_CREATED = "app_created"
APP_UPDATED = "app_updated"
APP_DELETED = "app_deleted"
FEEDBACK_CREATED = "feedback_created"
class SignalType(StrEnum):
"""Signal routing type for telemetry cases."""
TRACE = "trace"
METRIC_LOG = "metric_log"
class CaseRoute(BaseModel):
"""Routing configuration for a telemetry case.
Attributes:
signal_type: The type of signal (trace or metric_log).
ce_eligible: Whether this case is eligible for community edition tracing.
"""
signal_type: SignalType
ce_eligible: bool
class TelemetryEnvelope(BaseModel):
"""Envelope for telemetry events.
Attributes:
case: The telemetry case type.
tenant_id: The tenant identifier.
event_id: Unique event identifier for deduplication.
payload: The main event payload (inline for small payloads,
empty when offloaded to storage via ``payload_ref``).
metadata: Optional metadata dictionary. When the gateway
offloads a large payload to object storage, this contains
``{"payload_ref": "<storage_key>"}``.
"""
model_config = ConfigDict(extra="forbid", use_enum_values=False)
case: TelemetryCase
tenant_id: str
event_id: str
payload: dict[str, Any]
metadata: dict[str, Any] | None = None

View File

@@ -0,0 +1,77 @@
from __future__ import annotations
from collections.abc import Mapping
from typing import Any
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from models.workflow import WorkflowNodeExecutionModel
def enqueue_draft_node_execution_trace(
*,
execution: WorkflowNodeExecutionModel,
outputs: Mapping[str, Any] | None,
workflow_execution_id: str | None,
user_id: str,
) -> None:
node_data = _build_node_execution_data(
execution=execution,
outputs=outputs,
workflow_execution_id=workflow_execution_id,
)
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id=execution.tenant_id,
user_id=user_id,
app_id=execution.app_id,
),
payload={"node_execution_data": node_data},
)
)
def _build_node_execution_data(
*,
execution: WorkflowNodeExecutionModel,
outputs: Mapping[str, Any] | None,
workflow_execution_id: str | None,
) -> dict[str, Any]:
metadata = execution.execution_metadata_dict
node_outputs = outputs if outputs is not None else execution.outputs_dict
execution_id = workflow_execution_id or execution.workflow_run_id or execution.id
return {
"workflow_id": execution.workflow_id,
"workflow_execution_id": execution_id,
"tenant_id": execution.tenant_id,
"app_id": execution.app_id,
"node_execution_id": execution.id,
"node_id": execution.node_id,
"node_type": execution.node_type,
"title": execution.title,
"status": execution.status,
"error": execution.error,
"elapsed_time": execution.elapsed_time,
"index": execution.index,
"predecessor_node_id": execution.predecessor_node_id,
"created_at": execution.created_at,
"finished_at": execution.finished_at,
"total_tokens": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0),
"total_price": metadata.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0),
"currency": metadata.get(WorkflowNodeExecutionMetadataKey.CURRENCY),
"tool_name": (metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name")
if isinstance(metadata.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict)
else None,
"iteration_id": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID),
"iteration_index": metadata.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX),
"loop_id": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_ID),
"loop_index": metadata.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX),
"parallel_id": metadata.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID),
"node_inputs": execution.inputs_dict,
"node_outputs": node_outputs,
"process_data": execution.process_data_dict,
}

View File

@@ -0,0 +1,938 @@
"""Enterprise trace handler — duck-typed, NOT a BaseTraceInstance subclass.
Invoked directly in the Celery task, not through OpsTraceManager dispatch.
Only requires a matching ``trace(trace_info)`` method signature.
Signal strategy:
- **Traces (spans)**: workflow run, node execution, draft node execution only.
- **Metrics + structured logs**: all other event types.
Token metric labels (unified structure):
All token metrics (dify.tokens.input, dify.tokens.output, dify.tokens.total) use the
same label set for consistent filtering and aggregation:
- tenant_id: Tenant identifier
- app_id: Application identifier
- operation_type: Source of token usage (workflow | node_execution | message | rule_generate | etc.)
- model_provider: LLM provider name (empty string if not applicable)
- model_name: LLM model name (empty string if not applicable)
- node_type: Workflow node type (empty string if not node_execution)
This unified structure allows filtering by operation_type to separate:
- Workflow-level aggregates (operation_type=workflow)
- Individual node executions (operation_type=node_execution)
- Direct message calls (operation_type=message)
- Prompt generation operations (operation_type=rule_generate, code_generate, etc.)
Without this, tokens are double-counted when querying totals (workflow totals include
node totals, since workflow.total_tokens is the sum of all node tokens).
"""
from __future__ import annotations
import json
import logging
from typing import Any, cast
from opentelemetry.util.types import AttributeValue
from core.ops.entities.trace_entity import (
BaseTraceInfo,
DatasetRetrievalTraceInfo,
DraftNodeExecutionTrace,
GenerateNameTraceInfo,
MessageTraceInfo,
ModerationTraceInfo,
OperationType,
PromptGenerationTraceInfo,
SuggestedQuestionTraceInfo,
ToolTraceInfo,
WorkflowNodeTraceInfo,
WorkflowTraceInfo,
)
from enterprise.telemetry.entities import (
EnterpriseTelemetryCounter,
EnterpriseTelemetryEvent,
EnterpriseTelemetryHistogram,
EnterpriseTelemetrySpan,
TokenMetricLabels,
)
from enterprise.telemetry.telemetry_log import emit_metric_only_event, emit_telemetry_log
logger = logging.getLogger(__name__)
class EnterpriseOtelTrace:
"""Duck-typed enterprise trace handler.
``*_trace`` methods emit spans (workflow/node only) or structured logs
(all other events), plus metrics at 100 % accuracy.
"""
def __init__(self) -> None:
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if exporter is None:
raise RuntimeError("EnterpriseOtelTrace instantiated but exporter is not initialized")
self._exporter = exporter
def trace(self, trace_info: BaseTraceInfo) -> None:
if isinstance(trace_info, WorkflowTraceInfo):
self._workflow_trace(trace_info)
elif isinstance(trace_info, MessageTraceInfo):
self._message_trace(trace_info)
elif isinstance(trace_info, ToolTraceInfo):
self._tool_trace(trace_info)
elif isinstance(trace_info, DraftNodeExecutionTrace):
self._draft_node_execution_trace(trace_info)
elif isinstance(trace_info, WorkflowNodeTraceInfo):
self._node_execution_trace(trace_info)
elif isinstance(trace_info, ModerationTraceInfo):
self._moderation_trace(trace_info)
elif isinstance(trace_info, SuggestedQuestionTraceInfo):
self._suggested_question_trace(trace_info)
elif isinstance(trace_info, DatasetRetrievalTraceInfo):
self._dataset_retrieval_trace(trace_info)
elif isinstance(trace_info, GenerateNameTraceInfo):
self._generate_name_trace(trace_info)
elif isinstance(trace_info, PromptGenerationTraceInfo):
self._prompt_generation_trace(trace_info)
def _common_attrs(self, trace_info: BaseTraceInfo) -> dict[str, Any]:
metadata = self._metadata(trace_info)
tenant_id, app_id, user_id = self._context_ids(trace_info, metadata)
return {
"dify.trace_id": trace_info.resolved_trace_id,
"dify.tenant_id": tenant_id,
"dify.app_id": app_id,
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"gen_ai.user.id": user_id,
"dify.message.id": trace_info.message_id,
}
def _metadata(self, trace_info: BaseTraceInfo) -> dict[str, Any]:
return trace_info.metadata
def _context_ids(
self,
trace_info: BaseTraceInfo,
metadata: dict[str, Any],
) -> tuple[str | None, str | None, str | None]:
tenant_id = getattr(trace_info, "tenant_id", None) or metadata.get("tenant_id")
app_id = getattr(trace_info, "app_id", None) or metadata.get("app_id")
user_id = getattr(trace_info, "user_id", None) or metadata.get("user_id")
return tenant_id, app_id, user_id
def _labels(self, **values: AttributeValue) -> dict[str, AttributeValue]:
return dict(values)
def _safe_payload_value(self, value: Any) -> str | dict[str, Any] | list[object] | None:
if isinstance(value, str):
return value
if isinstance(value, dict):
return cast(dict[str, Any], value)
if isinstance(value, list):
items: list[object] = []
for item in cast(list[object], value):
items.append(item)
return items
return None
def _content_or_ref(self, value: Any, ref: str) -> Any:
if self._exporter.include_content:
return self._maybe_json(value)
return ref
def _maybe_json(self, value: Any) -> str | None:
if value is None:
return None
if isinstance(value, str):
return value
try:
return json.dumps(value, default=str)
except (TypeError, ValueError):
return str(value)
# ------------------------------------------------------------------
# SPAN-emitting handlers (workflow, node execution, draft node)
# ------------------------------------------------------------------
def _workflow_trace(self, info: WorkflowTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
# -- Span attrs: identity + structure + status + timing + gen_ai scalars --
span_attrs: dict[str, Any] = {
"dify.trace_id": info.resolved_trace_id,
"dify.tenant_id": tenant_id,
"dify.app_id": app_id,
"dify.workflow.id": info.workflow_id,
"dify.workflow.run_id": info.workflow_run_id,
"dify.workflow.status": info.workflow_run_status,
"dify.workflow.error": info.error,
"dify.workflow.elapsed_time": info.workflow_run_elapsed_time,
"dify.invoke_from": metadata.get("triggered_from"),
"dify.conversation.id": info.conversation_id,
"dify.message.id": info.message_id,
"dify.invoked_by": info.invoked_by,
"gen_ai.usage.total_tokens": info.total_tokens,
"gen_ai.user.id": user_id,
}
trace_correlation_override, parent_span_id_source = info.resolved_parent_context
parent_ctx = metadata.get("parent_trace_context")
if isinstance(parent_ctx, dict):
parent_ctx_dict = cast(dict[str, Any], parent_ctx)
span_attrs["dify.parent.trace_id"] = parent_ctx_dict.get("trace_id")
span_attrs["dify.parent.node.execution_id"] = parent_ctx_dict.get("parent_node_execution_id")
span_attrs["dify.parent.workflow.run_id"] = parent_ctx_dict.get("parent_workflow_run_id")
span_attrs["dify.parent.app.id"] = parent_ctx_dict.get("parent_app_id")
self._exporter.export_span(
EnterpriseTelemetrySpan.WORKFLOW_RUN,
span_attrs,
correlation_id=info.workflow_run_id,
span_id_source=info.workflow_run_id,
start_time=info.start_time,
end_time=info.end_time,
trace_correlation_override=trace_correlation_override,
parent_span_id_source=parent_span_id_source,
)
# -- Companion log: ALL attrs (span + detail) for full picture --
log_attrs: dict[str, Any] = {**span_attrs}
log_attrs.update(
{
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"gen_ai.user.id": user_id,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.workflow.version": info.workflow_run_version,
}
)
ref = f"ref:workflow_run_id={info.workflow_run_id}"
log_attrs["dify.workflow.inputs"] = self._content_or_ref(info.workflow_run_inputs, ref)
log_attrs["dify.workflow.outputs"] = self._content_or_ref(info.workflow_run_outputs, ref)
log_attrs["dify.workflow.query"] = self._content_or_ref(info.query, ref)
emit_telemetry_log(
event_name=EnterpriseTelemetryEvent.WORKFLOW_RUN,
attributes=log_attrs,
signal="span_detail",
trace_id_source=info.workflow_run_id,
span_id_source=info.workflow_run_id,
tenant_id=tenant_id,
user_id=user_id,
)
# -- Metrics --
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
token_labels = TokenMetricLabels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=OperationType.WORKFLOW,
model_provider="",
model_name="",
node_type="",
).to_dict()
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
if info.prompt_tokens is not None and info.prompt_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels)
if info.completion_tokens is not None and info.completion_tokens > 0:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
)
invoke_from = metadata.get("triggered_from", "")
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="workflow",
status=info.workflow_run_status,
invoke_from=invoke_from,
),
)
# Prefer wall-clock timestamps over the elapsed_time field: elapsed_time defaults
# to 0 in the DB and can be stale if the Celery write races with the trace task.
# start_time = workflow_run.created_at, end_time = workflow_run.finished_at.
if info.start_time and info.end_time:
workflow_duration = (info.end_time - info.start_time).total_seconds()
elif info.workflow_run_elapsed_time:
workflow_duration = float(info.workflow_run_elapsed_time)
else:
workflow_duration = 0.0
self._exporter.record_histogram(
EnterpriseTelemetryHistogram.WORKFLOW_DURATION,
workflow_duration,
self._labels(
**labels,
status=info.workflow_run_status,
),
)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="workflow",
),
)
def _node_execution_trace(self, info: WorkflowNodeTraceInfo) -> None:
self._emit_node_execution_trace(info, EnterpriseTelemetrySpan.NODE_EXECUTION, "node")
def _draft_node_execution_trace(self, info: DraftNodeExecutionTrace) -> None:
self._emit_node_execution_trace(
info,
EnterpriseTelemetrySpan.DRAFT_NODE_EXECUTION,
"draft_node",
correlation_id_override=info.node_execution_id,
trace_correlation_override_param=info.workflow_run_id,
)
def _emit_node_execution_trace(
self,
info: WorkflowNodeTraceInfo,
span_name: EnterpriseTelemetrySpan,
request_type: str,
correlation_id_override: str | None = None,
trace_correlation_override_param: str | None = None,
) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
# -- Span attrs: identity + structure + status + timing + gen_ai scalars --
span_attrs: dict[str, Any] = {
"dify.trace_id": info.resolved_trace_id,
"dify.tenant_id": tenant_id,
"dify.app_id": app_id,
"dify.workflow.id": info.workflow_id,
"dify.workflow.run_id": info.workflow_run_id,
"dify.message.id": info.message_id,
"dify.conversation.id": metadata.get("conversation_id"),
"dify.node.execution_id": info.node_execution_id,
"dify.node.id": info.node_id,
"dify.node.type": info.node_type,
"dify.node.title": info.title,
"dify.node.status": info.status,
"dify.node.error": info.error,
"dify.node.elapsed_time": info.elapsed_time,
"dify.node.index": info.index,
"dify.node.predecessor_node_id": info.predecessor_node_id,
"dify.node.iteration_id": info.iteration_id,
"dify.node.loop_id": info.loop_id,
"dify.node.parallel_id": info.parallel_id,
"dify.node.invoked_by": info.invoked_by,
"gen_ai.usage.input_tokens": info.prompt_tokens,
"gen_ai.usage.output_tokens": info.completion_tokens,
"gen_ai.usage.total_tokens": info.total_tokens,
"gen_ai.request.model": info.model_name,
"gen_ai.provider.name": info.model_provider,
"gen_ai.user.id": user_id,
}
resolved_override, _ = info.resolved_parent_context
trace_correlation_override = trace_correlation_override_param or resolved_override
effective_correlation_id = correlation_id_override or info.workflow_run_id
self._exporter.export_span(
span_name,
span_attrs,
correlation_id=effective_correlation_id,
span_id_source=info.node_execution_id,
start_time=info.start_time,
end_time=info.end_time,
trace_correlation_override=trace_correlation_override,
)
# -- Companion log: ALL attrs (span + detail) --
log_attrs: dict[str, Any] = {**span_attrs}
log_attrs.update(
{
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"dify.invoke_from": metadata.get("invoke_from"),
"gen_ai.user.id": user_id,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.node.total_price": info.total_price,
"dify.node.currency": info.currency,
"gen_ai.provider.name": info.model_provider,
"gen_ai.request.model": info.model_name,
"gen_ai.tool.name": info.tool_name,
"dify.node.iteration_index": info.iteration_index,
"dify.node.loop_index": info.loop_index,
"dify.plugin.name": metadata.get("plugin_name"),
"dify.credential.name": metadata.get("credential_name"),
"dify.credential.id": metadata.get("credential_id"),
"dify.dataset.ids": self._maybe_json(metadata.get("dataset_ids")),
"dify.dataset.names": self._maybe_json(metadata.get("dataset_names")),
}
)
ref = f"ref:node_execution_id={info.node_execution_id}"
log_attrs["dify.node.inputs"] = self._content_or_ref(info.node_inputs, ref)
log_attrs["dify.node.outputs"] = self._content_or_ref(info.node_outputs, ref)
log_attrs["dify.node.process_data"] = self._content_or_ref(info.process_data, ref)
emit_telemetry_log(
event_name=span_name.value,
attributes=log_attrs,
signal="span_detail",
trace_id_source=info.workflow_run_id,
span_id_source=info.node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
# -- Metrics --
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
node_type=info.node_type,
model_provider=info.model_provider or "",
)
if info.total_tokens:
token_labels = TokenMetricLabels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=OperationType.NODE_EXECUTION,
model_provider=info.model_provider or "",
model_name=info.model_name or "",
node_type=info.node_type,
).to_dict()
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
if info.prompt_tokens is not None and info.prompt_tokens > 0:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels
)
if info.completion_tokens is not None and info.completion_tokens > 0:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type=request_type,
status=info.status,
model_name=info.model_name or "",
),
)
duration_labels = dict(labels)
duration_labels["model_name"] = info.model_name or ""
plugin_name = metadata.get("plugin_name")
if plugin_name and info.node_type in {"tool", "knowledge-retrieval"}:
duration_labels["plugin_name"] = plugin_name
self._exporter.record_histogram(EnterpriseTelemetryHistogram.NODE_DURATION, info.elapsed_time, duration_labels)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type=request_type,
model_name=info.model_name or "",
),
)
# ------------------------------------------------------------------
# METRIC-ONLY handlers (structured log + counters/histograms)
# ------------------------------------------------------------------
def _message_trace(self, info: MessageTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs.update(
{
"dify.invoke_from": metadata.get("from_source"),
"dify.conversation.id": metadata.get("conversation_id"),
"dify.conversation.mode": info.conversation_mode,
"gen_ai.provider.name": metadata.get("ls_provider"),
"gen_ai.request.model": metadata.get("ls_model_name"),
"gen_ai.usage.input_tokens": info.message_tokens,
"gen_ai.usage.output_tokens": info.answer_tokens,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.message.status": metadata.get("status"),
"dify.message.error": info.error,
"dify.message.from_source": metadata.get("from_source"),
"dify.message.from_end_user_id": metadata.get("from_end_user_id"),
"dify.message.from_account_id": metadata.get("from_account_id"),
"dify.streaming": info.is_streaming_request,
"dify.message.time_to_first_token": info.gen_ai_server_time_to_first_token,
"dify.message.streaming_duration": info.llm_streaming_time_to_generate,
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
ref = f"ref:message_id={info.message_id}"
inputs = self._safe_payload_value(info.inputs)
outputs = self._safe_payload_value(info.outputs)
attrs["dify.message.inputs"] = self._content_or_ref(inputs, ref)
attrs["dify.message.outputs"] = self._content_or_ref(outputs, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.MESSAGE_RUN,
attributes=attrs,
trace_id_source=metadata.get("workflow_run_id") or str(info.message_id) if info.message_id else None,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
model_provider=metadata.get("ls_provider") or "",
model_name=metadata.get("ls_model_name") or "",
)
token_labels = TokenMetricLabels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=OperationType.MESSAGE,
model_provider=metadata.get("ls_provider") or "",
model_name=metadata.get("ls_model_name") or "",
node_type="",
).to_dict()
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
if info.message_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.message_tokens, token_labels)
if info.answer_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.answer_tokens, token_labels)
invoke_from = metadata.get("from_source", "")
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="message",
status=metadata.get("status", ""),
invoke_from=invoke_from,
),
)
if info.start_time and info.end_time:
duration = (info.end_time - info.start_time).total_seconds()
self._exporter.record_histogram(EnterpriseTelemetryHistogram.MESSAGE_DURATION, duration, labels)
if info.gen_ai_server_time_to_first_token is not None:
self._exporter.record_histogram(
EnterpriseTelemetryHistogram.MESSAGE_TTFT, info.gen_ai_server_time_to_first_token, labels
)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="message",
),
)
def _tool_trace(self, info: ToolTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs.update(
{
"gen_ai.tool.name": info.tool_name,
"dify.tool.time_cost": info.time_cost,
"dify.tool.error": info.error,
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
ref = f"ref:message_id={info.message_id}"
attrs["dify.tool.inputs"] = self._content_or_ref(info.tool_inputs, ref)
attrs["dify.tool.outputs"] = self._content_or_ref(info.tool_outputs, ref)
attrs["dify.tool.parameters"] = self._content_or_ref(info.tool_parameters, ref)
attrs["dify.tool.config"] = self._content_or_ref(info.tool_config, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.TOOL_EXECUTION,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
tool_name=info.tool_name,
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="tool",
),
)
self._exporter.record_histogram(EnterpriseTelemetryHistogram.TOOL_DURATION, float(info.time_cost), labels)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="tool",
),
)
def _moderation_trace(self, info: ModerationTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs.update(
{
"dify.moderation.flagged": info.flagged,
"dify.moderation.action": info.action,
"dify.moderation.preset_response": info.preset_response,
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
attrs["dify.moderation.query"] = self._content_or_ref(
info.query,
f"ref:message_id={info.message_id}",
)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.MODERATION_CHECK,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="moderation",
),
)
def _suggested_question_trace(self, info: SuggestedQuestionTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs.update(
{
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.suggested_question.status": info.status,
"dify.suggested_question.error": info.error,
"gen_ai.provider.name": info.model_provider,
"gen_ai.request.model": info.model_id,
"dify.suggested_question.count": len(info.suggested_question),
"dify.workflow.run_id": metadata.get("workflow_run_id"),
}
)
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
attrs["dify.suggested_question.questions"] = self._content_or_ref(
info.suggested_question,
f"ref:message_id={info.message_id}",
)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.SUGGESTED_QUESTION_GENERATION,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="suggested_question",
model_provider=info.model_provider or "",
model_name=info.model_id or "",
),
)
def _dataset_retrieval_trace(self, info: DatasetRetrievalTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs["dify.dataset.error"] = info.error
attrs["dify.workflow.run_id"] = metadata.get("workflow_run_id")
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
docs: list[dict[str, Any]] = []
documents_any: Any = info.documents
documents_list: list[Any] = cast(list[Any], documents_any) if isinstance(documents_any, list) else []
for entry in documents_list:
if isinstance(entry, dict):
entry_dict: dict[str, Any] = cast(dict[str, Any], entry)
docs.append(entry_dict)
dataset_ids: list[str] = []
dataset_names: list[str] = []
structured_docs: list[dict[str, Any]] = []
for doc in docs:
meta_raw = doc.get("metadata")
meta: dict[str, Any] = cast(dict[str, Any], meta_raw) if isinstance(meta_raw, dict) else {}
did = meta.get("dataset_id")
dname = meta.get("dataset_name")
if did and did not in dataset_ids:
dataset_ids.append(did)
if dname and dname not in dataset_names:
dataset_names.append(dname)
structured_docs.append(
{
"dataset_id": did,
"document_id": meta.get("document_id"),
"segment_id": meta.get("segment_id"),
"score": meta.get("score"),
}
)
attrs["dify.dataset.ids"] = self._maybe_json(dataset_ids)
attrs["dify.dataset.names"] = self._maybe_json(dataset_names)
attrs["dify.retrieval.document_count"] = len(docs)
embedding_models_raw: Any = metadata.get("embedding_models")
embedding_models: dict[str, Any] = (
cast(dict[str, Any], embedding_models_raw) if isinstance(embedding_models_raw, dict) else {}
)
if embedding_models:
providers: list[str] = []
models: list[str] = []
for ds_info in embedding_models.values():
if isinstance(ds_info, dict):
ds_info_dict: dict[str, Any] = cast(dict[str, Any], ds_info)
p = ds_info_dict.get("embedding_model_provider", "")
m = ds_info_dict.get("embedding_model", "")
if p and p not in providers:
providers.append(p)
if m and m not in models:
models.append(m)
attrs["dify.dataset.embedding_providers"] = self._maybe_json(providers)
attrs["dify.dataset.embedding_models"] = self._maybe_json(models)
# Add rerank model to logs
rerank_provider = metadata.get("rerank_model_provider", "")
rerank_model = metadata.get("rerank_model_name", "")
if rerank_provider or rerank_model:
attrs["dify.retrieval.rerank_provider"] = rerank_provider
attrs["dify.retrieval.rerank_model"] = rerank_model
ref = f"ref:message_id={info.message_id}"
retrieval_inputs = self._safe_payload_value(info.inputs)
attrs["dify.retrieval.query"] = self._content_or_ref(retrieval_inputs, ref)
attrs["dify.dataset.documents"] = self._content_or_ref(structured_docs, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.DATASET_RETRIEVAL,
attributes=attrs,
trace_id_source=metadata.get("workflow_run_id") or str(info.message_id) if info.message_id else None,
span_id_source=node_execution_id or (str(info.message_id) if info.message_id else None),
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="dataset_retrieval",
),
)
for did in dataset_ids:
# Get embedding model for this specific dataset
ds_embedding_info = embedding_models.get(did, {})
embedding_provider = ds_embedding_info.get("embedding_model_provider", "")
embedding_model = ds_embedding_info.get("embedding_model", "")
# Get rerank model (same for all datasets in this retrieval)
rerank_provider = metadata.get("rerank_model_provider", "")
rerank_model = metadata.get("rerank_model_name", "")
self._exporter.increment_counter(
EnterpriseTelemetryCounter.DATASET_RETRIEVALS,
1,
self._labels(
**labels,
dataset_id=did,
embedding_model_provider=embedding_provider,
embedding_model=embedding_model,
rerank_model_provider=rerank_provider,
rerank_model=rerank_model,
),
)
def _generate_name_trace(self, info: GenerateNameTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = self._common_attrs(info)
attrs["dify.conversation.id"] = info.conversation_id
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
ref = f"ref:conversation_id={info.conversation_id}"
inputs = self._safe_payload_value(info.inputs)
outputs = self._safe_payload_value(info.outputs)
attrs["dify.generate_name.inputs"] = self._content_or_ref(inputs, ref)
attrs["dify.generate_name.outputs"] = self._content_or_ref(outputs, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.GENERATE_NAME_EXECUTION,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
)
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="generate_name",
),
)
def _prompt_generation_trace(self, info: PromptGenerationTraceInfo) -> None:
metadata = self._metadata(info)
tenant_id, app_id, user_id = self._context_ids(info, metadata)
attrs = {
"dify.trace_id": info.resolved_trace_id,
"dify.tenant_id": tenant_id,
"dify.user.id": user_id,
"dify.app.id": app_id or "",
"dify.app.name": metadata.get("app_name"),
"dify.workspace.name": metadata.get("workspace_name"),
"dify.operation.type": info.operation_type,
"gen_ai.provider.name": info.model_provider,
"gen_ai.request.model": info.model_name,
"gen_ai.usage.input_tokens": info.prompt_tokens,
"gen_ai.usage.output_tokens": info.completion_tokens,
"gen_ai.usage.total_tokens": info.total_tokens,
"dify.prompt_generation.latency": info.latency,
"dify.prompt_generation.error": info.error,
}
node_execution_id = metadata.get("node_execution_id")
if node_execution_id:
attrs["dify.node.execution_id"] = node_execution_id
if info.total_price is not None:
attrs["dify.prompt_generation.total_price"] = info.total_price
attrs["dify.prompt_generation.currency"] = info.currency
ref = f"ref:trace_id={info.trace_id}"
outputs = self._safe_payload_value(info.outputs)
attrs["dify.prompt_generation.instruction"] = self._content_or_ref(info.instruction, ref)
attrs["dify.prompt_generation.output"] = self._content_or_ref(outputs, ref)
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.PROMPT_GENERATION_EXECUTION,
attributes=attrs,
trace_id_source=info.resolved_trace_id,
span_id_source=node_execution_id,
tenant_id=tenant_id,
user_id=user_id,
)
token_labels = TokenMetricLabels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=info.operation_type,
model_provider=info.model_provider,
model_name=info.model_name,
node_type="",
).to_dict()
labels = self._labels(
tenant_id=tenant_id or "",
app_id=app_id or "",
operation_type=info.operation_type,
model_provider=info.model_provider,
model_name=info.model_name,
)
self._exporter.increment_counter(EnterpriseTelemetryCounter.TOKENS, info.total_tokens, token_labels)
if info.prompt_tokens > 0:
self._exporter.increment_counter(EnterpriseTelemetryCounter.INPUT_TOKENS, info.prompt_tokens, token_labels)
if info.completion_tokens > 0:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.OUTPUT_TOKENS, info.completion_tokens, token_labels
)
status = "failed" if info.error else "success"
self._exporter.increment_counter(
EnterpriseTelemetryCounter.REQUESTS,
1,
self._labels(
**labels,
type="prompt_generation",
status=status,
),
)
self._exporter.record_histogram(
EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION,
info.latency,
labels,
)
if info.error:
self._exporter.increment_counter(
EnterpriseTelemetryCounter.ERRORS,
1,
self._labels(
**labels,
type="prompt_generation",
),
)

View File

@@ -0,0 +1,121 @@
from enum import StrEnum
from typing import cast
from opentelemetry.util.types import AttributeValue
from pydantic import BaseModel, ConfigDict
class EnterpriseTelemetrySpan(StrEnum):
WORKFLOW_RUN = "dify.workflow.run"
NODE_EXECUTION = "dify.node.execution"
DRAFT_NODE_EXECUTION = "dify.node.execution.draft"
class EnterpriseTelemetryEvent(StrEnum):
"""Event names for enterprise telemetry logs."""
APP_CREATED = "dify.app.created"
APP_UPDATED = "dify.app.updated"
APP_DELETED = "dify.app.deleted"
FEEDBACK_CREATED = "dify.feedback.created"
WORKFLOW_RUN = "dify.workflow.run"
MESSAGE_RUN = "dify.message.run"
TOOL_EXECUTION = "dify.tool.execution"
MODERATION_CHECK = "dify.moderation.check"
SUGGESTED_QUESTION_GENERATION = "dify.suggested_question.generation"
DATASET_RETRIEVAL = "dify.dataset.retrieval"
GENERATE_NAME_EXECUTION = "dify.generate_name.execution"
PROMPT_GENERATION_EXECUTION = "dify.prompt_generation.execution"
REHYDRATION_FAILED = "dify.telemetry.rehydration_failed"
class EnterpriseTelemetryCounter(StrEnum):
TOKENS = "tokens"
INPUT_TOKENS = "input_tokens"
OUTPUT_TOKENS = "output_tokens"
REQUESTS = "requests"
ERRORS = "errors"
FEEDBACK = "feedback"
DATASET_RETRIEVALS = "dataset_retrievals"
APP_CREATED = "app_created"
APP_UPDATED = "app_updated"
APP_DELETED = "app_deleted"
class EnterpriseTelemetryHistogram(StrEnum):
WORKFLOW_DURATION = "workflow_duration"
NODE_DURATION = "node_duration"
MESSAGE_DURATION = "message_duration"
MESSAGE_TTFT = "message_ttft"
TOOL_DURATION = "tool_duration"
PROMPT_GENERATION_DURATION = "prompt_generation_duration"
class TokenMetricLabels(BaseModel):
"""Unified label structure for all dify.token.* metrics.
All token counters (dify.tokens.input, dify.tokens.output, dify.tokens.total) MUST
use this exact label set to ensure consistent filtering and aggregation across
different operation types.
Attributes:
tenant_id: Tenant identifier.
app_id: Application identifier.
operation_type: Source of token usage (workflow | node_execution | message |
rule_generate | code_generate | structured_output | instruction_modify).
model_provider: LLM provider name. Empty string if not applicable (e.g., workflow-level).
model_name: LLM model name. Empty string if not applicable (e.g., workflow-level).
node_type: Workflow node type. Empty string unless operation_type=node_execution.
Usage:
labels = TokenMetricLabels(
tenant_id="tenant-123",
app_id="app-456",
operation_type=OperationType.WORKFLOW,
model_provider="",
model_name="",
node_type="",
)
exporter.increment_counter(
EnterpriseTelemetryCounter.INPUT_TOKENS,
100,
labels.to_dict()
)
Design rationale:
Without this unified structure, tokens get double-counted when querying totals
because workflow.total_tokens is already the sum of all node tokens. The
operation_type label allows filtering to separate workflow-level aggregates from
node-level detail, while keeping the same label cardinality for consistent queries.
"""
tenant_id: str
app_id: str
operation_type: str
model_provider: str
model_name: str
node_type: str
model_config = ConfigDict(extra="forbid", frozen=True)
def to_dict(self) -> dict[str, AttributeValue]:
return cast(
dict[str, AttributeValue],
{
"tenant_id": self.tenant_id,
"app_id": self.app_id,
"operation_type": self.operation_type,
"model_provider": self.model_provider,
"model_name": self.model_name,
"node_type": self.node_type,
},
)
__all__ = [
"EnterpriseTelemetryCounter",
"EnterpriseTelemetryEvent",
"EnterpriseTelemetryHistogram",
"EnterpriseTelemetrySpan",
"TokenMetricLabels",
]

View File

@@ -0,0 +1,99 @@
"""Blinker signal handlers for enterprise telemetry.
Registered at import time via ``@signal.connect`` decorators.
Import must happen during ``ext_enterprise_telemetry.init_app()`` to
ensure handlers fire. Each handler delegates to ``core.telemetry.gateway``
which handles routing, EE-gating, and dispatch.
All handlers are best-effort: exceptions are caught and logged so that
telemetry failures never break user-facing operations.
"""
from __future__ import annotations
import logging
from events.app_event import app_was_created, app_was_deleted, app_was_updated
from events.feedback_event import feedback_was_created
logger = logging.getLogger(__name__)
__all__ = [
"_handle_app_created",
"_handle_app_deleted",
"_handle_app_updated",
"_handle_feedback_created",
]
@app_was_created.connect
def _handle_app_created(sender: object, **kwargs: object) -> None:
try:
from core.telemetry.gateway import emit as gateway_emit
from enterprise.telemetry.contracts import TelemetryCase
gateway_emit(
case=TelemetryCase.APP_CREATED,
context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")},
payload={
"app_id": getattr(sender, "id", None),
"mode": getattr(sender, "mode", None),
},
)
except Exception:
logger.warning("Failed to emit app_created telemetry", exc_info=True)
@app_was_deleted.connect
def _handle_app_deleted(sender: object, **kwargs: object) -> None:
try:
from core.telemetry.gateway import emit as gateway_emit
from enterprise.telemetry.contracts import TelemetryCase
gateway_emit(
case=TelemetryCase.APP_DELETED,
context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")},
payload={"app_id": getattr(sender, "id", None)},
)
except Exception:
logger.warning("Failed to emit app_deleted telemetry", exc_info=True)
@app_was_updated.connect
def _handle_app_updated(sender: object, **kwargs: object) -> None:
try:
from core.telemetry.gateway import emit as gateway_emit
from enterprise.telemetry.contracts import TelemetryCase
gateway_emit(
case=TelemetryCase.APP_UPDATED,
context={"tenant_id": str(getattr(sender, "tenant_id", "") or "")},
payload={"app_id": getattr(sender, "id", None)},
)
except Exception:
logger.warning("Failed to emit app_updated telemetry", exc_info=True)
@feedback_was_created.connect
def _handle_feedback_created(sender: object, **kwargs: object) -> None:
try:
from core.telemetry.gateway import emit as gateway_emit
from enterprise.telemetry.contracts import TelemetryCase
tenant_id = str(kwargs.get("tenant_id", "") or "")
gateway_emit(
case=TelemetryCase.FEEDBACK_CREATED,
context={"tenant_id": tenant_id},
payload={
"message_id": getattr(sender, "message_id", None),
"app_id": getattr(sender, "app_id", None),
"conversation_id": getattr(sender, "conversation_id", None),
"from_end_user_id": getattr(sender, "from_end_user_id", None),
"from_account_id": getattr(sender, "from_account_id", None),
"rating": getattr(sender, "rating", None),
"from_source": getattr(sender, "from_source", None),
"content": getattr(sender, "content", None),
},
)
except Exception:
logger.warning("Failed to emit feedback_created telemetry", exc_info=True)

View File

@@ -0,0 +1,284 @@
"""Enterprise OTEL exporter — shared by EnterpriseOtelTrace, event handlers, and direct instrumentation.
Uses dedicated TracerProvider and MeterProvider instances (configurable sampling,
independent from ext_otel.py infrastructure).
Initialized once during Flask extension init (single-threaded via ext_enterprise_telemetry.py).
Accessed via ``ext_enterprise_telemetry.get_enterprise_exporter()`` from any thread/process.
"""
import logging
import socket
import uuid
from datetime import datetime
from typing import Any, cast
from opentelemetry import trace
from opentelemetry.context import Context
from opentelemetry.exporter.otlp.proto.grpc.metric_exporter import OTLPMetricExporter as GRPCMetricExporter
from opentelemetry.exporter.otlp.proto.grpc.trace_exporter import OTLPSpanExporter as GRPCSpanExporter
from opentelemetry.exporter.otlp.proto.http.metric_exporter import OTLPMetricExporter as HTTPMetricExporter
from opentelemetry.exporter.otlp.proto.http.trace_exporter import OTLPSpanExporter as HTTPSpanExporter
from opentelemetry.sdk.metrics import MeterProvider
from opentelemetry.sdk.metrics.export import PeriodicExportingMetricReader
from opentelemetry.sdk.resources import Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import BatchSpanProcessor
from opentelemetry.sdk.trace.sampling import ParentBasedTraceIdRatio
from opentelemetry.semconv.resource import ResourceAttributes
from opentelemetry.trace import SpanContext, TraceFlags
from opentelemetry.util.types import Attributes, AttributeValue
from configs import dify_config
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryHistogram
from enterprise.telemetry.id_generator import (
CorrelationIdGenerator,
compute_deterministic_span_id,
set_correlation_id,
set_span_id_source,
)
logger = logging.getLogger(__name__)
def is_enterprise_telemetry_enabled() -> bool:
return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED)
def _parse_otlp_headers(raw: str) -> dict[str, str]:
"""Parse ``key=value,key2=value2`` into a dict."""
if not raw:
return {}
headers: dict[str, str] = {}
for pair in raw.split(","):
if "=" not in pair:
continue
k, v = pair.split("=", 1)
headers[k.strip()] = v.strip()
return headers
def _datetime_to_ns(dt: datetime) -> int:
"""Convert a datetime to nanoseconds since epoch (OTEL convention)."""
return int(dt.timestamp() * 1_000_000_000)
class _ExporterFactory:
def __init__(self, protocol: str, endpoint: str, headers: dict[str, str], insecure: bool):
self._protocol = protocol
self._endpoint = endpoint
self._headers = headers
self._grpc_headers = tuple(headers.items()) if headers else None
self._http_headers = headers or None
self._insecure = insecure
def create_trace_exporter(self) -> HTTPSpanExporter | GRPCSpanExporter:
if self._protocol == "grpc":
return GRPCSpanExporter(
endpoint=self._endpoint or None,
headers=self._grpc_headers,
insecure=self._insecure,
)
trace_endpoint = f"{self._endpoint}/v1/traces" if self._endpoint else ""
return HTTPSpanExporter(endpoint=trace_endpoint or None, headers=self._http_headers)
def create_metric_exporter(self) -> HTTPMetricExporter | GRPCMetricExporter:
if self._protocol == "grpc":
return GRPCMetricExporter(
endpoint=self._endpoint or None,
headers=self._grpc_headers,
insecure=self._insecure,
)
metric_endpoint = f"{self._endpoint}/v1/metrics" if self._endpoint else ""
return HTTPMetricExporter(endpoint=metric_endpoint or None, headers=self._http_headers)
class EnterpriseExporter:
"""Shared OTEL exporter for all enterprise telemetry.
``export_span`` creates spans with optional real timestamps, deterministic
span/trace IDs, and cross-workflow parent linking.
``increment_counter`` / ``record_histogram`` emit OTEL metrics at 100% accuracy.
"""
def __init__(self, config: object) -> None:
endpoint: str = getattr(config, "ENTERPRISE_OTLP_ENDPOINT", "")
headers_raw: str = getattr(config, "ENTERPRISE_OTLP_HEADERS", "")
protocol: str = (getattr(config, "ENTERPRISE_OTLP_PROTOCOL", "http") or "http").lower()
service_name: str = getattr(config, "ENTERPRISE_SERVICE_NAME", "dify")
sampling_rate: float = getattr(config, "ENTERPRISE_OTEL_SAMPLING_RATE", 1.0)
self.include_content: bool = getattr(config, "ENTERPRISE_INCLUDE_CONTENT", True)
api_key: str = getattr(config, "ENTERPRISE_OTLP_API_KEY", "")
# Auto-detect TLS: https:// uses secure, everything else is insecure
insecure = not endpoint.startswith("https://")
resource = Resource(
attributes={
ResourceAttributes.SERVICE_NAME: service_name,
ResourceAttributes.HOST_NAME: socket.gethostname(),
}
)
sampler = ParentBasedTraceIdRatio(sampling_rate)
id_generator = CorrelationIdGenerator()
self._tracer_provider = TracerProvider(resource=resource, sampler=sampler, id_generator=id_generator)
headers = _parse_otlp_headers(headers_raw)
if api_key:
if "authorization" in headers:
logger.warning(
"ENTERPRISE_OTLP_API_KEY is set but ENTERPRISE_OTLP_HEADERS also contains "
"'authorization'; the API key will take precedence."
)
headers["authorization"] = f"Bearer {api_key}"
factory = _ExporterFactory(protocol, endpoint, headers, insecure=insecure)
trace_exporter = factory.create_trace_exporter()
self._tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
self._tracer = self._tracer_provider.get_tracer("dify.enterprise")
metric_exporter = factory.create_metric_exporter()
self._meter_provider = MeterProvider(
resource=resource,
metric_readers=[PeriodicExportingMetricReader(metric_exporter)],
)
meter = self._meter_provider.get_meter("dify.enterprise")
self._counters = {
EnterpriseTelemetryCounter.TOKENS: meter.create_counter("dify.tokens.total", unit="{token}"),
EnterpriseTelemetryCounter.INPUT_TOKENS: meter.create_counter("dify.tokens.input", unit="{token}"),
EnterpriseTelemetryCounter.OUTPUT_TOKENS: meter.create_counter("dify.tokens.output", unit="{token}"),
EnterpriseTelemetryCounter.REQUESTS: meter.create_counter("dify.requests.total", unit="{request}"),
EnterpriseTelemetryCounter.ERRORS: meter.create_counter("dify.errors.total", unit="{error}"),
EnterpriseTelemetryCounter.FEEDBACK: meter.create_counter("dify.feedback.total", unit="{feedback}"),
EnterpriseTelemetryCounter.DATASET_RETRIEVALS: meter.create_counter(
"dify.dataset.retrievals.total", unit="{retrieval}"
),
EnterpriseTelemetryCounter.APP_CREATED: meter.create_counter("dify.app.created.total", unit="{app}"),
EnterpriseTelemetryCounter.APP_UPDATED: meter.create_counter("dify.app.updated.total", unit="{app}"),
EnterpriseTelemetryCounter.APP_DELETED: meter.create_counter("dify.app.deleted.total", unit="{app}"),
}
self._histograms = {
EnterpriseTelemetryHistogram.WORKFLOW_DURATION: meter.create_histogram("dify.workflow.duration", unit="s"),
EnterpriseTelemetryHistogram.NODE_DURATION: meter.create_histogram("dify.node.duration", unit="s"),
EnterpriseTelemetryHistogram.MESSAGE_DURATION: meter.create_histogram("dify.message.duration", unit="s"),
EnterpriseTelemetryHistogram.MESSAGE_TTFT: meter.create_histogram(
"dify.message.time_to_first_token", unit="s"
),
EnterpriseTelemetryHistogram.TOOL_DURATION: meter.create_histogram("dify.tool.duration", unit="s"),
EnterpriseTelemetryHistogram.PROMPT_GENERATION_DURATION: meter.create_histogram(
"dify.prompt_generation.duration", unit="s"
),
}
def export_span(
self,
name: str,
attributes: dict[str, Any],
correlation_id: str | None = None,
span_id_source: str | None = None,
start_time: datetime | None = None,
end_time: datetime | None = None,
trace_correlation_override: str | None = None,
parent_span_id_source: str | None = None,
) -> None:
"""Export an OTEL span with optional deterministic IDs and real timestamps.
Args:
name: Span operation name.
attributes: Span attributes dict.
correlation_id: Source for trace_id derivation (groups spans in one trace).
span_id_source: Source for deterministic span_id (e.g. workflow_run_id or node_execution_id).
start_time: Real span start time. When None, uses current time.
end_time: Real span end time. When None, span ends immediately.
trace_correlation_override: Override trace_id source (for cross-workflow linking).
When set, trace_id is derived from this instead of ``correlation_id``.
parent_span_id_source: Override parent span_id source (for cross-workflow linking).
When set, parent span_id is derived from this value. When None and
``correlation_id`` is set, parent is the workflow root span.
"""
effective_trace_correlation = trace_correlation_override or correlation_id
set_correlation_id(effective_trace_correlation)
set_span_id_source(span_id_source)
try:
parent_context: Context | None = None
# A span is the "root" of its correlation group when span_id_source == correlation_id
# (i.e. a workflow root span). All other spans are children.
if parent_span_id_source:
# Cross-workflow linking: parent is an explicit span (e.g. tool node in outer workflow)
parent_span_id = compute_deterministic_span_id(parent_span_id_source)
try:
parent_trace_id = int(uuid.UUID(effective_trace_correlation)) if effective_trace_correlation else 0
except (ValueError, AttributeError):
logger.warning(
"Invalid trace correlation UUID for cross-workflow link: %s, span=%s",
effective_trace_correlation,
name,
)
parent_trace_id = 0
if parent_trace_id:
parent_span_context = SpanContext(
trace_id=parent_trace_id,
span_id=parent_span_id,
is_remote=True,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context))
elif correlation_id and correlation_id != span_id_source:
# Child span: parent is the correlation-group root (workflow root span)
parent_span_id = compute_deterministic_span_id(correlation_id)
try:
parent_trace_id = int(uuid.UUID(effective_trace_correlation or correlation_id))
except (ValueError, AttributeError):
logger.warning(
"Invalid trace correlation UUID for child span link: %s, span=%s",
effective_trace_correlation or correlation_id,
name,
)
parent_trace_id = 0
if parent_trace_id:
parent_span_context = SpanContext(
trace_id=parent_trace_id,
span_id=parent_span_id,
is_remote=True,
trace_flags=TraceFlags(TraceFlags.SAMPLED),
)
parent_context = trace.set_span_in_context(trace.NonRecordingSpan(parent_span_context))
span_start_time = _datetime_to_ns(start_time) if start_time is not None else None
span_end_on_exit = end_time is None
with self._tracer.start_as_current_span(
name,
context=parent_context,
start_time=span_start_time,
end_on_exit=span_end_on_exit,
) as span:
for key, value in attributes.items():
if value is not None:
span.set_attribute(key, value)
if end_time is not None:
span.end(end_time=_datetime_to_ns(end_time))
except Exception:
logger.exception("Failed to export span %s", name)
finally:
set_correlation_id(None)
set_span_id_source(None)
def increment_counter(
self, name: EnterpriseTelemetryCounter, value: int, labels: dict[str, AttributeValue]
) -> None:
counter = self._counters.get(name)
if counter:
counter.add(value, cast(Attributes, labels))
def record_histogram(
self, name: EnterpriseTelemetryHistogram, value: float, labels: dict[str, AttributeValue]
) -> None:
histogram = self._histograms.get(name)
if histogram:
histogram.record(value, cast(Attributes, labels))
def shutdown(self) -> None:
self._tracer_provider.shutdown()
self._meter_provider.shutdown()

View File

@@ -0,0 +1,76 @@
"""Custom OTEL ID Generator for correlation-based trace/span ID derivation.
Uses contextvars for thread-safe correlation_id -> trace_id mapping.
When a span_id_source is set, the span_id is derived deterministically
from that value, enabling any span to reference another as parent
without depending on span creation order.
"""
import random
import uuid
from contextvars import ContextVar
from typing import cast
from opentelemetry.sdk.trace.id_generator import IdGenerator
_correlation_id_context: ContextVar[str | None] = ContextVar("correlation_id", default=None)
_span_id_source_context: ContextVar[str | None] = ContextVar("span_id_source", default=None)
def set_correlation_id(correlation_id: str | None) -> None:
_correlation_id_context.set(correlation_id)
def get_correlation_id() -> str | None:
return _correlation_id_context.get()
def set_span_id_source(source_id: str | None) -> None:
"""Set the source for deterministic span_id generation.
When set, ``generate_span_id()`` derives the span_id from this value
(lower 64 bits of the UUID). Pass the ``workflow_run_id`` for workflow
root spans or ``node_execution_id`` for node spans.
"""
_span_id_source_context.set(source_id)
def compute_deterministic_span_id(source_id: str) -> int:
"""Derive a deterministic span_id from any UUID string.
Uses the lower 64 bits of the UUID, guaranteeing non-zero output
(OTEL requires span_id != 0).
"""
span_id = cast(int, uuid.UUID(source_id).int) & ((1 << 64) - 1)
return span_id if span_id != 0 else 1
class CorrelationIdGenerator(IdGenerator):
"""ID generator that derives trace_id and optionally span_id from context.
- trace_id: always derived from correlation_id (groups all spans in one trace)
- span_id: derived from span_id_source when set (enables deterministic
parent-child linking), otherwise random
"""
def generate_trace_id(self) -> int:
correlation_id = _correlation_id_context.get()
if correlation_id:
try:
return cast(int, uuid.UUID(correlation_id).int)
except (ValueError, AttributeError):
pass
return random.getrandbits(128)
def generate_span_id(self) -> int:
source = _span_id_source_context.get()
if source:
try:
return compute_deterministic_span_id(source)
except (ValueError, AttributeError):
pass
span_id = random.getrandbits(64)
while span_id == 0:
span_id = random.getrandbits(64)
return span_id

View File

@@ -0,0 +1,381 @@
"""Enterprise metric/log event handler.
This module processes metric and log telemetry events after they've been
dequeued from the enterprise_telemetry Celery queue. It handles case routing,
idempotency checking, and payload rehydration.
"""
from __future__ import annotations
import json
import logging
from typing import Any
from enterprise.telemetry.contracts import TelemetryCase, TelemetryEnvelope
from extensions.ext_redis import redis_client
from extensions.ext_storage import storage
logger = logging.getLogger(__name__)
class EnterpriseMetricHandler:
"""Handler for enterprise metric and log telemetry events.
Processes envelopes from the enterprise_telemetry queue, routing each
case to the appropriate handler method. Implements idempotency checking
and payload rehydration with fallback.
"""
def _increment_diagnostic_counter(self, counter_name: str, labels: dict[str, str] | None = None) -> None:
"""Increment a diagnostic counter for operational monitoring.
Args:
counter_name: Name of the counter (e.g., 'processed_total', 'deduped_total').
labels: Optional labels for the counter.
"""
try:
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
return
full_counter_name = f"enterprise_telemetry.handler.{counter_name}"
logger.debug(
"Diagnostic counter: %s, labels=%s",
full_counter_name,
labels or {},
)
except Exception:
logger.debug("Failed to increment diagnostic counter: %s", counter_name, exc_info=True)
def handle(self, envelope: TelemetryEnvelope) -> None:
"""Main entry point for processing telemetry envelopes.
Args:
envelope: The telemetry envelope to process.
"""
# Check for duplicate events
if self._is_duplicate(envelope):
logger.debug(
"Skipping duplicate event: tenant_id=%s, event_id=%s",
envelope.tenant_id,
envelope.event_id,
)
self._increment_diagnostic_counter("deduped_total")
return
# Route to appropriate handler based on case
case = envelope.case
if case == TelemetryCase.APP_CREATED:
self._on_app_created(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_created"})
elif case == TelemetryCase.APP_UPDATED:
self._on_app_updated(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_updated"})
elif case == TelemetryCase.APP_DELETED:
self._on_app_deleted(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "app_deleted"})
elif case == TelemetryCase.FEEDBACK_CREATED:
self._on_feedback_created(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "feedback_created"})
elif case == TelemetryCase.MESSAGE_RUN:
self._on_message_run(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "message_run"})
elif case == TelemetryCase.TOOL_EXECUTION:
self._on_tool_execution(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "tool_execution"})
elif case == TelemetryCase.MODERATION_CHECK:
self._on_moderation_check(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "moderation_check"})
elif case == TelemetryCase.SUGGESTED_QUESTION:
self._on_suggested_question(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "suggested_question"})
elif case == TelemetryCase.DATASET_RETRIEVAL:
self._on_dataset_retrieval(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "dataset_retrieval"})
elif case == TelemetryCase.GENERATE_NAME:
self._on_generate_name(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "generate_name"})
elif case == TelemetryCase.PROMPT_GENERATION:
self._on_prompt_generation(envelope)
self._increment_diagnostic_counter("processed_total", {"case": "prompt_generation"})
else:
logger.warning(
"Unknown telemetry case: %s (tenant_id=%s, event_id=%s)",
case,
envelope.tenant_id,
envelope.event_id,
)
def _is_duplicate(self, envelope: TelemetryEnvelope) -> bool:
"""Check if this event has already been processed.
Uses Redis with TTL for deduplication. Returns True if duplicate,
False if first time seeing this event.
Args:
envelope: The telemetry envelope to check.
Returns:
True if this event_id has been seen before, False otherwise.
"""
dedup_key = f"telemetry:dedup:{envelope.tenant_id}:{envelope.event_id}"
try:
# Atomic set-if-not-exists with 1h TTL
# Returns True if key was set (first time), None if already exists (duplicate)
was_set = redis_client.set(dedup_key, b"1", nx=True, ex=3600)
return was_set is None
except Exception:
# Fail open: if Redis is unavailable, process the event
# (prefer occasional duplicate over lost data)
logger.warning(
"Redis unavailable for deduplication check, processing event anyway: %s",
envelope.event_id,
exc_info=True,
)
return False
def _rehydrate(self, envelope: TelemetryEnvelope) -> dict[str, Any]:
"""Rehydrate payload from storage reference or inline data.
If the envelope payload is empty and metadata contains a
``payload_ref``, the full payload is loaded from object storage
(where the gateway wrote it as JSON). When both the inline
payload and storage resolution fail, a degraded-event marker
is emitted so the gap is observable.
Args:
envelope: The telemetry envelope containing payload data.
Returns:
The rehydrated payload dictionary, or ``{}`` on total failure.
"""
payload = envelope.payload
# Resolve from object storage when the gateway offloaded a large payload.
if not payload and envelope.metadata:
payload_ref = envelope.metadata.get("payload_ref")
if payload_ref:
try:
payload_bytes = storage.load(payload_ref)
payload = json.loads(payload_bytes.decode("utf-8"))
logger.debug("Loaded payload from storage: key=%s", payload_ref)
except Exception:
logger.warning(
"Failed to load payload from storage: key=%s, event_id=%s",
payload_ref,
envelope.event_id,
exc_info=True,
)
if not payload:
# Storage resolution failed or no data available — emit degraded event.
logger.error(
"Payload rehydration failed for event_id=%s, tenant_id=%s, case=%s",
envelope.event_id,
envelope.tenant_id,
envelope.case,
)
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.REHYDRATION_FAILED,
attributes={
"dify.tenant_id": envelope.tenant_id,
"dify.event_id": envelope.event_id,
"dify.case": envelope.case,
"rehydration_failed": True,
},
tenant_id=envelope.tenant_id,
)
self._increment_diagnostic_counter("rehydration_failed_total")
return {}
return payload
# Stub methods for each metric/log case
# These will be implemented in later tasks with actual emission logic
def _on_app_created(self, envelope: TelemetryEnvelope) -> None:
"""Handle app created event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for APP_CREATED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
attrs = {
"dify.app.id": payload.get("app_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.event.id": envelope.event_id,
"dify.app.mode": payload.get("mode"),
}
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.APP_CREATED,
attributes=attrs,
tenant_id=envelope.tenant_id,
)
exporter.increment_counter(
EnterpriseTelemetryCounter.APP_CREATED,
1,
{
"tenant_id": envelope.tenant_id,
"app_id": str(payload.get("app_id", "")),
"mode": str(payload.get("mode", "")),
},
)
def _on_app_updated(self, envelope: TelemetryEnvelope) -> None:
"""Handle app updated event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for APP_UPDATED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
attrs = {
"dify.app.id": payload.get("app_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.event.id": envelope.event_id,
}
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.APP_UPDATED,
attributes=attrs,
tenant_id=envelope.tenant_id,
)
exporter.increment_counter(
EnterpriseTelemetryCounter.APP_UPDATED,
1,
{
"tenant_id": envelope.tenant_id,
"app_id": str(payload.get("app_id", "")),
},
)
def _on_app_deleted(self, envelope: TelemetryEnvelope) -> None:
"""Handle app deleted event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for APP_DELETED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
attrs = {
"dify.app.id": payload.get("app_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.event.id": envelope.event_id,
}
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.APP_DELETED,
attributes=attrs,
tenant_id=envelope.tenant_id,
)
exporter.increment_counter(
EnterpriseTelemetryCounter.APP_DELETED,
1,
{
"tenant_id": envelope.tenant_id,
"app_id": str(payload.get("app_id", "")),
},
)
def _on_feedback_created(self, envelope: TelemetryEnvelope) -> None:
"""Handle feedback created event."""
from enterprise.telemetry.entities import EnterpriseTelemetryCounter, EnterpriseTelemetryEvent
from enterprise.telemetry.telemetry_log import emit_metric_only_event
from extensions.ext_enterprise_telemetry import get_enterprise_exporter
exporter = get_enterprise_exporter()
if not exporter:
logger.debug("No exporter available for FEEDBACK_CREATED: event_id=%s", envelope.event_id)
return
payload = self._rehydrate(envelope)
if not payload:
return
include_content = exporter.include_content
attrs: dict = {
"dify.message.id": payload.get("message_id"),
"dify.tenant_id": envelope.tenant_id,
"dify.event.id": envelope.event_id,
"dify.app_id": payload.get("app_id"),
"dify.conversation.id": payload.get("conversation_id"),
"gen_ai.user.id": payload.get("from_end_user_id") or payload.get("from_account_id"),
"dify.feedback.rating": payload.get("rating"),
"dify.feedback.from_source": payload.get("from_source"),
}
if include_content:
attrs["dify.feedback.content"] = payload.get("content")
user_id = payload.get("from_end_user_id") or payload.get("from_account_id")
emit_metric_only_event(
event_name=EnterpriseTelemetryEvent.FEEDBACK_CREATED,
attributes=attrs,
tenant_id=envelope.tenant_id,
user_id=str(user_id or ""),
)
exporter.increment_counter(
EnterpriseTelemetryCounter.FEEDBACK,
1,
{
"tenant_id": envelope.tenant_id,
"app_id": str(payload.get("app_id", "")),
"rating": str(payload.get("rating", "")),
},
)
def _on_message_run(self, envelope: TelemetryEnvelope) -> None:
"""Handle message run event (stub)."""
logger.debug("Processing MESSAGE_RUN: event_id=%s", envelope.event_id)
def _on_tool_execution(self, envelope: TelemetryEnvelope) -> None:
"""Handle tool execution event (stub)."""
logger.debug("Processing TOOL_EXECUTION: event_id=%s", envelope.event_id)
def _on_moderation_check(self, envelope: TelemetryEnvelope) -> None:
"""Handle moderation check event (stub)."""
logger.debug("Processing MODERATION_CHECK: event_id=%s", envelope.event_id)
def _on_suggested_question(self, envelope: TelemetryEnvelope) -> None:
"""Handle suggested question event (stub)."""
logger.debug("Processing SUGGESTED_QUESTION: event_id=%s", envelope.event_id)
def _on_dataset_retrieval(self, envelope: TelemetryEnvelope) -> None:
"""Handle dataset retrieval event (stub)."""
logger.debug("Processing DATASET_RETRIEVAL: event_id=%s", envelope.event_id)
def _on_generate_name(self, envelope: TelemetryEnvelope) -> None:
"""Handle generate name event (stub)."""
logger.debug("Processing GENERATE_NAME: event_id=%s", envelope.event_id)
def _on_prompt_generation(self, envelope: TelemetryEnvelope) -> None:
"""Handle prompt generation event (stub)."""
logger.debug("Processing PROMPT_GENERATION: event_id=%s", envelope.event_id)

View File

@@ -0,0 +1,122 @@
"""Structured-log emitter for enterprise telemetry events.
Emits structured JSON log lines correlated with OTEL traces via trace_id.
Picked up by ``StructuredJSONFormatter`` → stdout/Loki/Elastic.
"""
from __future__ import annotations
import logging
import uuid
from functools import lru_cache
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from enterprise.telemetry.entities import EnterpriseTelemetryEvent
logger = logging.getLogger("dify.telemetry")
@lru_cache(maxsize=4096)
def compute_trace_id_hex(uuid_str: str | None) -> str:
"""Convert a business UUID string to a 32-hex OTEL-compatible trace_id.
Returns empty string when *uuid_str* is ``None`` or invalid.
"""
if not uuid_str:
return ""
normalized = uuid_str.strip().lower()
if len(normalized) == 32 and all(ch in "0123456789abcdef" for ch in normalized):
return normalized
try:
return f"{uuid.UUID(normalized).int:032x}"
except (ValueError, AttributeError):
return ""
@lru_cache(maxsize=4096)
def compute_span_id_hex(uuid_str: str | None) -> str:
if not uuid_str:
return ""
normalized = uuid_str.strip().lower()
if len(normalized) == 16 and all(ch in "0123456789abcdef" for ch in normalized):
return normalized
try:
from enterprise.telemetry.id_generator import compute_deterministic_span_id
return f"{compute_deterministic_span_id(normalized):016x}"
except (ValueError, AttributeError):
return ""
def emit_telemetry_log(
*,
event_name: str | EnterpriseTelemetryEvent,
attributes: dict[str, Any],
signal: str = "metric_only",
trace_id_source: str | None = None,
span_id_source: str | None = None,
tenant_id: str | None = None,
user_id: str | None = None,
) -> None:
"""Emit a structured log line for a telemetry event.
Parameters
----------
event_name:
Canonical event name, e.g. ``"dify.workflow.run"``.
attributes:
All event-specific attributes (already built by the caller).
signal:
``"metric_only"`` for events with no span, ``"span_detail"``
for detail logs accompanying a slim span.
trace_id_source:
A UUID string (e.g. ``workflow_run_id``) used to derive a 32-hex
trace_id for cross-signal correlation.
tenant_id:
Tenant identifier (for the ``IdentityContextFilter``).
user_id:
User identifier (for the ``IdentityContextFilter``).
"""
if not logger.isEnabledFor(logging.INFO):
return
attrs = {
"dify.event.name": event_name,
"dify.event.signal": signal,
**attributes,
}
extra: dict[str, Any] = {"attributes": attrs}
trace_id_hex = compute_trace_id_hex(trace_id_source)
if trace_id_hex:
extra["trace_id"] = trace_id_hex
span_id_hex = compute_span_id_hex(span_id_source)
if span_id_hex:
extra["span_id"] = span_id_hex
if tenant_id:
extra["tenant_id"] = tenant_id
if user_id:
extra["user_id"] = user_id
logger.info("telemetry.%s", signal, extra=extra)
def emit_metric_only_event(
*,
event_name: str | EnterpriseTelemetryEvent,
attributes: dict[str, Any],
trace_id_source: str | None = None,
span_id_source: str | None = None,
tenant_id: str | None = None,
user_id: str | None = None,
) -> None:
emit_telemetry_log(
event_name=event_name,
attributes=attributes,
signal="metric_only",
trace_id_source=trace_id_source,
span_id_source=span_id_source,
tenant_id=tenant_id,
user_id=user_id,
)

View File

@@ -3,6 +3,12 @@ from blinker import signal
# sender: app
app_was_created = signal("app-was-created")
# sender: app
app_was_deleted = signal("app-was-deleted")
# sender: app
app_was_updated = signal("app-was-updated")
# sender: app, kwargs: app_model_config
app_model_config_was_updated = signal("app-model-config-was-updated")

View File

@@ -0,0 +1,4 @@
from blinker import signal
# sender: MessageFeedback, kwargs: tenant_id
feedback_was_created = signal("feedback-was-created")

View File

@@ -184,6 +184,8 @@ def init_app(app: DifyApp) -> Celery:
"task": "schedule.trigger_provider_refresh_task.trigger_provider_refresh",
"schedule": timedelta(minutes=dify_config.TRIGGER_PROVIDER_REFRESH_INTERVAL),
}
if dify_config.ENTERPRISE_TELEMETRY_ENABLED:
imports.append("tasks.enterprise_telemetry_task")
celery_app.conf.update(beat_schedule=beat_schedule, imports=imports)
return celery_app

View File

@@ -0,0 +1,50 @@
"""Flask extension for enterprise telemetry lifecycle management.
Initializes the EnterpriseExporter singleton during ``create_app()``
(single-threaded), registers blinker event handlers, and hooks atexit
for graceful shutdown.
Skipped entirely when ``ENTERPRISE_ENABLED`` and ``ENTERPRISE_TELEMETRY_ENABLED``
are false (``is_enabled()`` gate).
"""
from __future__ import annotations
import atexit
import logging
from typing import TYPE_CHECKING
from configs import dify_config
if TYPE_CHECKING:
from dify_app import DifyApp
from enterprise.telemetry.exporter import EnterpriseExporter
logger = logging.getLogger(__name__)
_exporter: EnterpriseExporter | None = None
def is_enabled() -> bool:
return bool(dify_config.ENTERPRISE_ENABLED and dify_config.ENTERPRISE_TELEMETRY_ENABLED)
def init_app(app: DifyApp) -> None:
global _exporter
if not is_enabled():
return
from enterprise.telemetry.exporter import EnterpriseExporter
_exporter = EnterpriseExporter(dify_config)
atexit.register(_exporter.shutdown)
# Import to trigger @signal.connect decorator registration
import enterprise.telemetry.event_handlers # noqa: F401 # type: ignore[reportUnusedImport]
logger.info("Enterprise telemetry initialized")
def get_enterprise_exporter() -> EnterpriseExporter | None:
return _exporter

View File

@@ -59,16 +59,24 @@ def init_app(app: DifyApp):
protocol = (dify_config.OTEL_EXPORTER_OTLP_PROTOCOL or "").lower()
if dify_config.OTEL_EXPORTER_TYPE == "otlp":
if protocol == "grpc":
# Auto-detect TLS: https:// uses secure, everything else is insecure
endpoint = dify_config.OTLP_BASE_ENDPOINT
insecure = not endpoint.startswith("https://")
exporter = GRPCSpanExporter(
endpoint=dify_config.OTLP_BASE_ENDPOINT,
endpoint=endpoint,
# Header field names must consist of lowercase letters, check RFC7540
headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),),
insecure=True,
headers=(
(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),) if dify_config.OTLP_API_KEY else None
),
insecure=insecure,
)
metric_exporter = GRPCMetricExporter(
endpoint=dify_config.OTLP_BASE_ENDPOINT,
headers=(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),),
insecure=True,
endpoint=endpoint,
headers=(
(("authorization", f"Bearer {dify_config.OTLP_API_KEY}"),) if dify_config.OTLP_API_KEY else None
),
insecure=insecure,
)
else:
headers = {"Authorization": f"Bearer {dify_config.OTLP_API_KEY}"} if dify_config.OTLP_API_KEY else None

View File

@@ -5,7 +5,7 @@ This module provides parsers that extract node-specific metadata and set
OpenTelemetry span attributes according to semantic conventions.
"""
from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps
from extensions.otel.parser.base import DefaultNodeOTelParser, NodeOTelParser, safe_json_dumps, should_include_content
from extensions.otel.parser.llm import LLMNodeOTelParser
from extensions.otel.parser.retrieval import RetrievalNodeOTelParser
from extensions.otel.parser.tool import ToolNodeOTelParser
@@ -17,4 +17,5 @@ __all__ = [
"RetrievalNodeOTelParser",
"ToolNodeOTelParser",
"safe_json_dumps",
"should_include_content",
]

View File

@@ -1,5 +1,10 @@
"""
Base parser interface and utilities for OpenTelemetry node parsers.
Content gating: ``should_include_content()`` controls whether content-bearing
span attributes (inputs, outputs, prompts, completions, documents) are written.
Gate is only active in EE (``ENTERPRISE_ENABLED=True``) when
``ENTERPRISE_INCLUDE_CONTENT=False``; CE behaviour is unchanged.
"""
import json
@@ -9,6 +14,7 @@ from opentelemetry.trace import Span
from opentelemetry.trace.status import Status, StatusCode
from pydantic import BaseModel
from configs import dify_config
from core.file.models import File
from core.variables import Segment
from core.workflow.enums import NodeType
@@ -17,6 +23,17 @@ from core.workflow.nodes.base.node import Node
from extensions.otel.semconv.gen_ai import ChainAttributes, GenAIAttributes
def should_include_content() -> bool:
"""Return True if content should be written to spans.
CE (ENTERPRISE_ENABLED=False): always True — no behaviour change.
EE: follows ENTERPRISE_INCLUDE_CONTENT (default True).
"""
if not dify_config.ENTERPRISE_ENABLED:
return True
return dify_config.ENTERPRISE_INCLUDE_CONTENT
def safe_json_dumps(obj: Any, ensure_ascii: bool = False) -> str:
"""
Safely serialize objects to JSON, handling non-serializable types.
@@ -105,10 +122,11 @@ class DefaultNodeOTelParser:
# Extract inputs and outputs from result_event
if result_event and result_event.node_run_result:
node_run_result = result_event.node_run_result
if node_run_result.inputs:
span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs))
if node_run_result.outputs:
span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs))
if should_include_content():
if node_run_result.inputs:
span.set_attribute(ChainAttributes.INPUT_VALUE, safe_json_dumps(node_run_result.inputs))
if node_run_result.outputs:
span.set_attribute(ChainAttributes.OUTPUT_VALUE, safe_json_dumps(node_run_result.outputs))
if error:
span.record_exception(error)

View File

@@ -10,7 +10,7 @@ from opentelemetry.trace import Span
from core.workflow.graph_events import GraphNodeEventBase
from core.workflow.nodes.base.node import Node
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps, should_include_content
from extensions.otel.semconv.gen_ai import LLMAttributes
logger = logging.getLogger(__name__)
@@ -132,24 +132,19 @@ class LLMNodeOTelParser:
span.set_attribute(LLMAttributes.USAGE_OUTPUT_TOKENS, completion_tokens)
span.set_attribute(LLMAttributes.USAGE_TOTAL_TOKENS, total_tokens)
# Prompts and completion
prompts = process_data.get("prompts", [])
if prompts:
prompts_json = safe_json_dumps(prompts)
span.set_attribute(LLMAttributes.PROMPT, prompts_json)
# Prompts and completion — gated by content policy
if should_include_content():
prompts = process_data.get("prompts", [])
if prompts:
prompts_json = safe_json_dumps(prompts)
span.set_attribute(LLMAttributes.PROMPT, prompts_json)
text_output = str(outputs.get("text", ""))
if text_output:
span.set_attribute(LLMAttributes.COMPLETION, text_output)
text_output = str(outputs.get("text", ""))
if text_output:
span.set_attribute(LLMAttributes.COMPLETION, text_output)
# Finish reason
finish_reason = outputs.get("finish_reason") or ""
if finish_reason:
span.set_attribute(LLMAttributes.RESPONSE_FINISH_REASON, finish_reason)
# Structured input/output messages
gen_ai_input_message = _format_input_messages(process_data)
gen_ai_output_message = _format_output_messages(outputs)
span.set_attribute(LLMAttributes.INPUT_MESSAGE, gen_ai_input_message)
span.set_attribute(LLMAttributes.OUTPUT_MESSAGE, gen_ai_output_message)
# Structured input/output messages
gen_ai_input_message = _format_input_messages(process_data)
gen_ai_output_message = _format_output_messages(outputs)
span.set_attribute(LLMAttributes.INPUT_MESSAGE, gen_ai_input_message)
span.set_attribute(LLMAttributes.OUTPUT_MESSAGE, gen_ai_output_message)

View File

@@ -11,7 +11,7 @@ from opentelemetry.trace import Span
from core.variables import Segment
from core.workflow.graph_events import GraphNodeEventBase
from core.workflow.nodes.base.node import Node
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps, should_include_content
from extensions.otel.semconv.gen_ai import RetrieverAttributes
logger = logging.getLogger(__name__)
@@ -83,23 +83,21 @@ class RetrievalNodeOTelParser:
inputs = node_run_result.inputs or {}
outputs = node_run_result.outputs or {}
# Extract query from inputs
query = str(inputs.get("query", "")) if inputs else ""
if query:
span.set_attribute(RetrieverAttributes.QUERY, query)
# Query and documents — gated by content policy
if should_include_content():
query = str(inputs.get("query", "")) if inputs else ""
if query:
span.set_attribute(RetrieverAttributes.QUERY, query)
# Extract and format retrieval documents from outputs
result_value = outputs.get("result") if outputs else None
retrieval_documents: list[Any] = []
if result_value:
value_to_check = result_value
if isinstance(result_value, Segment):
value_to_check = result_value.value
if isinstance(value_to_check, (list, Sequence)):
retrieval_documents = list(value_to_check)
if retrieval_documents:
semantic_retrieval_documents = _format_retrieval_documents(retrieval_documents)
semantic_retrieval_documents_json = safe_json_dumps(semantic_retrieval_documents)
span.set_attribute(RetrieverAttributes.DOCUMENT, semantic_retrieval_documents_json)
result_value = outputs.get("result") if outputs else None
retrieval_documents: list[Any] = []
if result_value:
value_to_check = result_value
if isinstance(result_value, Segment):
value_to_check = result_value.value
if isinstance(value_to_check, (list, Sequence)):
retrieval_documents = list(value_to_check)
if retrieval_documents:
semantic_retrieval_documents = _format_retrieval_documents(retrieval_documents)
semantic_retrieval_documents_json = safe_json_dumps(semantic_retrieval_documents)
span.set_attribute(RetrieverAttributes.DOCUMENT, semantic_retrieval_documents_json)

View File

@@ -8,7 +8,7 @@ from core.workflow.enums import WorkflowNodeExecutionMetadataKey
from core.workflow.graph_events import GraphNodeEventBase
from core.workflow.nodes.base.node import Node
from core.workflow.nodes.tool.entities import ToolNodeData
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps
from extensions.otel.parser.base import DefaultNodeOTelParser, safe_json_dumps, should_include_content
from extensions.otel.semconv.gen_ai import ToolAttributes
@@ -40,8 +40,14 @@ class ToolNodeOTelParser:
if tool_info:
span.set_attribute(ToolAttributes.TOOL_DESCRIPTION, safe_json_dumps(tool_info))
if result_event and result_event.node_run_result and result_event.node_run_result.inputs:
span.set_attribute(ToolAttributes.TOOL_CALL_ARGUMENTS, safe_json_dumps(result_event.node_run_result.inputs))
# Tool call arguments and result — gated by content policy
if should_include_content():
if result_event and result_event.node_run_result and result_event.node_run_result.inputs:
span.set_attribute(
ToolAttributes.TOOL_CALL_ARGUMENTS, safe_json_dumps(result_event.node_run_result.inputs)
)
if result_event and result_event.node_run_result and result_event.node_run_result.outputs:
span.set_attribute(ToolAttributes.TOOL_CALL_RESULT, safe_json_dumps(result_event.node_run_result.outputs))
if result_event and result_event.node_run_result and result_event.node_run_result.outputs:
span.set_attribute(
ToolAttributes.TOOL_CALL_RESULT, safe_json_dumps(result_event.node_run_result.outputs)
)

View File

@@ -21,3 +21,15 @@ class DifySpanAttributes:
INVOKE_FROM = "dify.invoke_from"
"""Invocation source, e.g. SERVICE_API, WEB_APP, DEBUGGER."""
INVOKED_BY = "dify.invoked_by"
"""Invoked by, e.g. end_user, account, user."""
USAGE_INPUT_TOKENS = "gen_ai.usage.input_tokens"
"""Number of input tokens (prompt tokens) used."""
USAGE_OUTPUT_TOKENS = "gen_ai.usage.output_tokens"
"""Number of output tokens (completion tokens) generated."""
USAGE_TOTAL_TOKENS = "gen_ai.usage.total_tokens"
"""Total number of tokens used."""

View File

@@ -0,0 +1,213 @@
"""
DB migration Redis lock with heartbeat renewal.
This is intentionally migration-specific. Background renewal is a trade-off that makes sense
for unbounded, blocking operations like DB migrations (DDL/DML) where the main thread cannot
periodically refresh the lock TTL.
Do NOT use this as a general-purpose lock primitive for normal application code. Prefer explicit
lock lifecycle management (e.g. redis-py Lock context manager + `extend()` / `reacquire()` from
the same thread) when execution flow is under control.
"""
from __future__ import annotations
import logging
import threading
from typing import Any
from redis.exceptions import LockNotOwnedError, RedisError
logger = logging.getLogger(__name__)
MIN_RENEW_INTERVAL_SECONDS = 0.1
DEFAULT_RENEW_INTERVAL_DIVISOR = 3
MIN_JOIN_TIMEOUT_SECONDS = 0.5
MAX_JOIN_TIMEOUT_SECONDS = 5.0
JOIN_TIMEOUT_MULTIPLIER = 2.0
class DbMigrationAutoRenewLock:
"""
Redis lock wrapper that automatically renews TTL while held (migration-only).
Notes:
- We force `thread_local=False` when creating the underlying redis-py lock, because the
lock token must be accessible from the heartbeat thread for `reacquire()` to work.
- `release_safely()` is best-effort: it never raises, so it won't mask the caller's
primary error/exit code.
"""
_redis_client: Any
_name: str
_ttl_seconds: float
_renew_interval_seconds: float
_log_context: str | None
_logger: logging.Logger
_lock: Any
_stop_event: threading.Event | None
_thread: threading.Thread | None
_acquired: bool
def __init__(
self,
redis_client: Any,
name: str,
ttl_seconds: float = 60,
renew_interval_seconds: float | None = None,
*,
logger: logging.Logger | None = None,
log_context: str | None = None,
) -> None:
self._redis_client = redis_client
self._name = name
self._ttl_seconds = float(ttl_seconds)
self._renew_interval_seconds = (
float(renew_interval_seconds)
if renew_interval_seconds is not None
else max(MIN_RENEW_INTERVAL_SECONDS, self._ttl_seconds / DEFAULT_RENEW_INTERVAL_DIVISOR)
)
self._logger = logger or logging.getLogger(__name__)
self._log_context = log_context
self._lock = None
self._stop_event = None
self._thread = None
self._acquired = False
@property
def name(self) -> str:
return self._name
def acquire(self, *args: Any, **kwargs: Any) -> bool:
"""
Acquire the lock and start heartbeat renewal on success.
Accepts the same args/kwargs as redis-py `Lock.acquire()`.
"""
# Prevent accidental double-acquire which could leave the previous heartbeat thread running.
if self._acquired:
raise RuntimeError("DB migration lock is already acquired; call release_safely() before acquiring again.")
# Reuse the lock object if we already created one.
if self._lock is None:
self._lock = self._redis_client.lock(
name=self._name,
timeout=self._ttl_seconds,
thread_local=False,
)
acquired = bool(self._lock.acquire(*args, **kwargs))
self._acquired = acquired
if acquired:
self._start_heartbeat()
return acquired
def owned(self) -> bool:
if self._lock is None:
return False
try:
return bool(self._lock.owned())
except Exception:
# Ownership checks are best-effort and must not break callers.
return False
def _start_heartbeat(self) -> None:
if self._lock is None:
return
if self._stop_event is not None:
return
self._stop_event = threading.Event()
self._thread = threading.Thread(
target=self._heartbeat_loop,
args=(self._lock, self._stop_event),
daemon=True,
name=f"DbMigrationAutoRenewLock({self._name})",
)
self._thread.start()
def _heartbeat_loop(self, lock: Any, stop_event: threading.Event) -> None:
while not stop_event.wait(self._renew_interval_seconds):
try:
lock.reacquire()
except LockNotOwnedError:
self._logger.warning(
"DB migration lock is no longer owned during heartbeat; stop renewing. log_context=%s",
self._log_context,
exc_info=True,
)
return
except RedisError:
self._logger.warning(
"Failed to renew DB migration lock due to Redis error; will retry. log_context=%s",
self._log_context,
exc_info=True,
)
except Exception:
self._logger.warning(
"Unexpected error while renewing DB migration lock; will retry. log_context=%s",
self._log_context,
exc_info=True,
)
def release_safely(self, *, status: str | None = None) -> None:
"""
Stop heartbeat and release lock. Never raises.
Args:
status: Optional caller-provided status (e.g. 'successful'/'failed') to add context to logs.
"""
lock = self._lock
if lock is None:
return
self._stop_heartbeat()
# Lock release errors should never mask the real error/exit code.
try:
lock.release()
except LockNotOwnedError:
self._logger.warning(
"DB migration lock not owned on release; ignoring. status=%s log_context=%s",
status,
self._log_context,
exc_info=True,
)
except RedisError:
self._logger.warning(
"Failed to release DB migration lock due to Redis error; ignoring. status=%s log_context=%s",
status,
self._log_context,
exc_info=True,
)
except Exception:
self._logger.warning(
"Unexpected error while releasing DB migration lock; ignoring. status=%s log_context=%s",
status,
self._log_context,
exc_info=True,
)
finally:
self._acquired = False
self._lock = None
def _stop_heartbeat(self) -> None:
if self._stop_event is None:
return
self._stop_event.set()
if self._thread is not None:
# Best-effort join: if Redis calls are blocked, the daemon thread may remain alive.
join_timeout_seconds = max(
MIN_JOIN_TIMEOUT_SECONDS,
min(MAX_JOIN_TIMEOUT_SECONDS, self._renew_interval_seconds * JOIN_TIMEOUT_MULTIPLIER),
)
self._thread.join(timeout=join_timeout_seconds)
if self._thread.is_alive():
self._logger.warning(
"DB migration lock heartbeat thread did not stop within %.2fs; ignoring. log_context=%s",
join_timeout_seconds,
self._log_context,
)
self._stop_event = None
self._thread = None

View File

@@ -8,7 +8,6 @@ Create Date: 2025-12-25 10:39:15.139304
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = '7df29de0f6be'
@@ -20,7 +19,7 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('tenant_credit_pools',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('pool_type', sa.String(length=40), server_default='trial', nullable=False),
sa.Column('quota_limit', sa.BigInteger(), nullable=False),

View File

@@ -8,7 +8,6 @@ Create Date: 2026-01-017 11:10:18.079355
from alembic import op
import models as models
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = 'f9f6d18a37f9'
@@ -20,7 +19,7 @@ depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('account_trial_app_records',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('account_id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('count', sa.Integer(), nullable=False),
@@ -33,17 +32,17 @@ def upgrade():
batch_op.create_index('account_trial_app_record_app_id_idx', ['app_id'], unique=False)
op.create_table('exporle_banners',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('content', sa.JSON(), nullable=False),
sa.Column('link', sa.String(length=255), nullable=False),
sa.Column('sort', sa.Integer(), nullable=False),
sa.Column('status', sa.String(length=255), server_default=sa.text("'enabled'::character varying"), nullable=False),
sa.Column('status', sa.String(length=255), server_default='enabled', nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),
sa.Column('language', sa.String(length=255), server_default=sa.text("'en-US'::character varying"), nullable=False),
sa.Column('language', sa.String(length=255), server_default='en-US', nullable=False),
sa.PrimaryKeyConstraint('id', name='exporler_banner_pkey')
)
op.create_table('trial_apps',
sa.Column('id', models.types.StringUUID(), server_default=sa.text('uuid_generate_v4()'), nullable=False),
sa.Column('id', models.types.StringUUID(), nullable=False),
sa.Column('app_id', models.types.StringUUID(), nullable=False),
sa.Column('tenant_id', models.types.StringUUID(), nullable=False),
sa.Column('created_at', sa.DateTime(), server_default=sa.text('CURRENT_TIMESTAMP'), nullable=False),

View File

@@ -620,7 +620,7 @@ class TrialApp(Base):
sa.UniqueConstraint("app_id", name="unique_trail_app_id"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
app_id = mapped_column(StringUUID, nullable=False)
tenant_id = mapped_column(StringUUID, nullable=False)
created_at = mapped_column(sa.DateTime, nullable=False, server_default=func.current_timestamp())
@@ -640,7 +640,7 @@ class AccountTrialAppRecord(Base):
sa.Index("account_trial_app_record_app_id_idx", "app_id"),
sa.UniqueConstraint("account_id", "app_id", name="unique_account_trial_app_record"),
)
id = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"))
id = mapped_column(StringUUID, default=lambda: str(uuid4()))
account_id = mapped_column(StringUUID, nullable=False)
app_id = mapped_column(StringUUID, nullable=False)
count = mapped_column(sa.Integer, nullable=False, default=0)
@@ -660,19 +660,15 @@ class AccountTrialAppRecord(Base):
class ExporleBanner(TypeBase):
__tablename__ = "exporle_banners"
__table_args__ = (sa.PrimaryKeyConstraint("id", name="exporler_banner_pkey"),)
id: Mapped[str] = mapped_column(StringUUID, server_default=sa.text("uuid_generate_v4()"), init=False)
id: Mapped[str] = mapped_column(StringUUID, default=lambda: str(uuid4()), init=False)
content: Mapped[dict[str, Any]] = mapped_column(sa.JSON, nullable=False)
link: Mapped[str] = mapped_column(String(255), nullable=False)
sort: Mapped[int] = mapped_column(sa.Integer, nullable=False)
status: Mapped[str] = mapped_column(
sa.String(255), nullable=False, server_default=sa.text("'enabled'::character varying"), default="enabled"
)
status: Mapped[str] = mapped_column(sa.String(255), nullable=False, server_default="enabled", default="enabled")
created_at: Mapped[datetime] = mapped_column(
sa.DateTime, nullable=False, server_default=func.current_timestamp(), init=False
)
language: Mapped[str] = mapped_column(
String(255), nullable=False, server_default=sa.text("'en-US'::character varying"), default="en-US"
)
language: Mapped[str] = mapped_column(String(255), nullable=False, server_default="en-US", default="en-US")
class OAuthProviderApp(TypeBase):
@@ -2166,7 +2162,7 @@ class TenantCreditPool(TypeBase):
sa.Index("tenant_credit_pool_pool_type_idx", "pool_type"),
)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, server_default=text("uuid_generate_v4()"), init=False)
id: Mapped[str] = mapped_column(StringUUID, primary_key=True, default=lambda: str(uuid4()), init=False)
tenant_id: Mapped[str] = mapped_column(StringUUID, nullable=False)
pool_type: Mapped[str] = mapped_column(String(40), nullable=False, default="trial", server_default="trial")
quota_limit: Mapped[int] = mapped_column(BigInteger, nullable=False, default=0)

View File

@@ -1,6 +1,6 @@
[project]
name = "dify-api"
version = "1.12.0"
version = "1.12.1"
requires-python = ">=3.11,<3.13"
dependencies = [
@@ -22,14 +22,14 @@ dependencies = [
"flask-sqlalchemy~=3.1.1",
"gevent~=25.9.1",
"gmpy2~=2.2.1",
"google-api-core==2.18.0",
"google-api-core>=2.19.1",
"google-api-python-client==2.90.0",
"google-auth==2.29.0",
"google-auth>=2.47.0",
"google-auth-httplib2==0.2.0",
"google-cloud-aiplatform==1.49.0",
"googleapis-common-protos==1.63.0",
"google-cloud-aiplatform>=1.123.0",
"googleapis-common-protos>=1.65.0",
"gunicorn~=23.0.0",
"httpx[socks]~=0.27.0",
"httpx[socks]~=0.28.0",
"jieba==0.42.1",
"json-repair>=0.55.1",
"jsonschema>=4.25.1",
@@ -41,26 +41,23 @@ dependencies = [
"openpyxl~=3.1.5",
"opik~=1.8.72",
"litellm==1.77.1", # Pinned to avoid madoka dependency issue
"opentelemetry-api==1.27.0",
"opentelemetry-distro==0.48b0",
"opentelemetry-exporter-otlp==1.27.0",
"opentelemetry-exporter-otlp-proto-common==1.27.0",
"opentelemetry-exporter-otlp-proto-grpc==1.27.0",
"opentelemetry-exporter-otlp-proto-http==1.27.0",
"opentelemetry-instrumentation==0.48b0",
"opentelemetry-instrumentation-celery==0.48b0",
"opentelemetry-instrumentation-flask==0.48b0",
"opentelemetry-instrumentation-httpx==0.48b0",
"opentelemetry-instrumentation-redis==0.48b0",
"opentelemetry-instrumentation-httpx==0.48b0",
"opentelemetry-instrumentation-sqlalchemy==0.48b0",
"opentelemetry-propagator-b3==1.27.0",
# opentelemetry-proto1.28.0 depends on protobuf (>=5.0,<6.0),
# which is conflict with googleapis-common-protos (1.63.0)
"opentelemetry-proto==1.27.0",
"opentelemetry-sdk==1.27.0",
"opentelemetry-semantic-conventions==0.48b0",
"opentelemetry-util-http==0.48b0",
"opentelemetry-api==1.28.0",
"opentelemetry-distro==0.49b0",
"opentelemetry-exporter-otlp==1.28.0",
"opentelemetry-exporter-otlp-proto-common==1.28.0",
"opentelemetry-exporter-otlp-proto-grpc==1.28.0",
"opentelemetry-exporter-otlp-proto-http==1.28.0",
"opentelemetry-instrumentation==0.49b0",
"opentelemetry-instrumentation-celery==0.49b0",
"opentelemetry-instrumentation-flask==0.49b0",
"opentelemetry-instrumentation-httpx==0.49b0",
"opentelemetry-instrumentation-redis==0.49b0",
"opentelemetry-instrumentation-sqlalchemy==0.49b0",
"opentelemetry-propagator-b3==1.28.0",
"opentelemetry-proto==1.28.0",
"opentelemetry-sdk==1.28.0",
"opentelemetry-semantic-conventions==0.49b0",
"opentelemetry-util-http==0.49b0",
"pandas[excel,output-formatting,performance]~=2.2.2",
"psycogreen~=1.0.2",
"psycopg2-binary~=2.9.6",
@@ -81,7 +78,7 @@ dependencies = [
"starlette==0.49.1",
"tiktoken~=0.9.0",
"transformers~=4.56.1",
"unstructured[docx,epub,md,ppt,pptx]~=0.16.1",
"unstructured[docx,epub,md,ppt,pptx]~=0.18.18",
"yarl~=1.18.3",
"webvtt-py~=0.5.1",
"sseclient-py~=1.8.0",

View File

@@ -74,6 +74,16 @@ from tasks.mail_reset_password_task import (
logger = logging.getLogger(__name__)
def _try_join_enterprise_default_workspace(account_id: str) -> None:
"""Best-effort join to enterprise default workspace."""
if not dify_config.ENTERPRISE_ENABLED:
return
from services.enterprise.enterprise_service import try_join_default_workspace
try_join_default_workspace(account_id)
class TokenPair(BaseModel):
access_token: str
refresh_token: str
@@ -287,7 +297,14 @@ class AccountService:
email=email, name=name, interface_language=interface_language, password=password
)
TenantService.create_owner_tenant_if_not_exist(account=account)
try:
TenantService.create_owner_tenant_if_not_exist(account=account)
except Exception:
# Enterprise-only side-effect should run independently from personal workspace creation.
_try_join_enterprise_default_workspace(str(account.id))
raise
_try_join_enterprise_default_workspace(str(account.id))
return account
@@ -327,6 +344,12 @@ class AccountService:
@staticmethod
def delete_account(account: Account):
"""Delete account. This method only adds a task to the queue for deletion."""
# Queue account deletion sync tasks for all workspaces BEFORE account deletion (enterprise only)
from services.enterprise.account_deletion_sync import sync_account_deletion
sync_account_deletion(account_id=account.id, source="account_deleted")
# Now proceed with async account deletion
delete_account_task.delay(account.id)
@staticmethod
@@ -1230,6 +1253,11 @@ class TenantService:
if dify_config.BILLING_ENABLED:
BillingService.clean_billing_info_cache(tenant.id)
# Queue account deletion sync task for enterprise backend to reassign resources (enterprise only)
from services.enterprise.account_deletion_sync import sync_workspace_member_removal
sync_workspace_member_removal(workspace_id=tenant.id, member_id=account.id, source="workspace_member_removed")
@staticmethod
def update_member_role(tenant: Tenant, member: Account, new_role: str, operator: Account):
"""Update member role"""
@@ -1350,12 +1378,18 @@ class RegisterService:
and create_workspace_required
and FeatureService.get_system_features().license.workspaces.is_available()
):
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
try:
tenant = TenantService.create_tenant(f"{account.name}'s Workspace")
TenantService.create_tenant_member(tenant, account, role="owner")
account.current_tenant = tenant
tenant_was_created.send(tenant)
except Exception:
_try_join_enterprise_default_workspace(str(account.id))
raise
db.session.commit()
_try_join_enterprise_default_workspace(str(account.id))
except WorkSpaceNotAllowedCreateError:
db.session.rollback()
logger.exception("Register failed")

View File

@@ -14,7 +14,7 @@ from core.model_runtime.entities.model_entities import ModelPropertyKey, ModelTy
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
from core.tools.tool_manager import ToolManager
from core.tools.utils.configuration import ToolParameterConfigurationManager
from events.app_event import app_was_created
from events.app_event import app_was_created, app_was_deleted
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from libs.login import current_user
@@ -340,6 +340,8 @@ class AppService:
db.session.delete(app)
db.session.commit()
app_was_deleted.send(app)
# clean up web app settings
if FeatureService.get_system_features().webapp_auth.enabled:
EnterpriseService.WebAppAuth.cleanup_webapp(app.id)

View File

@@ -0,0 +1,115 @@
import json
import logging
import uuid
from datetime import UTC, datetime
from redis import RedisError
from configs import dify_config
from extensions.ext_database import db
from extensions.ext_redis import redis_client
from models.account import TenantAccountJoin
logger = logging.getLogger(__name__)
ACCOUNT_DELETION_SYNC_QUEUE = "enterprise:member:sync:queue"
ACCOUNT_DELETION_SYNC_TASK_TYPE = "sync_member_deletion_from_workspace"
def _queue_task(workspace_id: str, member_id: str, *, source: str) -> bool:
"""
Queue an account deletion sync task to Redis.
Internal helper function. Do not call directly - use the public functions instead.
Args:
workspace_id: The workspace/tenant ID to sync
member_id: The member/account ID that was removed
source: Source of the sync request (for debugging/tracking)
Returns:
bool: True if task was queued successfully, False otherwise
"""
try:
task = {
"task_id": str(uuid.uuid4()),
"workspace_id": workspace_id,
"member_id": member_id,
"retry_count": 0,
"created_at": datetime.now(UTC).isoformat(),
"source": source,
"type": ACCOUNT_DELETION_SYNC_TASK_TYPE,
}
# Push to Redis list (queue) - LPUSH adds to the head, worker consumes from tail with RPOP
redis_client.lpush(ACCOUNT_DELETION_SYNC_QUEUE, json.dumps(task))
logger.info(
"Queued account deletion sync task for workspace %s, member %s, task_id: %s, source: %s",
workspace_id,
member_id,
task["task_id"],
source,
)
return True
except (RedisError, TypeError) as e:
logger.error(
"Failed to queue account deletion sync for workspace %s, member %s: %s",
workspace_id,
member_id,
str(e),
exc_info=True,
)
# Don't raise - we don't want to fail member deletion if queueing fails
return False
def sync_workspace_member_removal(workspace_id: str, member_id: str, *, source: str) -> bool:
"""
Sync a single workspace member removal (enterprise only).
Queues a task for the enterprise backend to reassign resources from the removed member.
Handles enterprise edition check internally. Safe to call in community edition (no-op).
Args:
workspace_id: The workspace/tenant ID
member_id: The member/account ID that was removed
source: Source of the sync request (e.g., "workspace_member_removed")
Returns:
bool: True if task was queued (or skipped in community), False if queueing failed
"""
if not dify_config.ENTERPRISE_ENABLED:
return True
return _queue_task(workspace_id=workspace_id, member_id=member_id, source=source)
def sync_account_deletion(account_id: str, *, source: str) -> bool:
"""
Sync full account deletion across all workspaces (enterprise only).
Fetches all workspace memberships for the account and queues a sync task for each.
Handles enterprise edition check internally. Safe to call in community edition (no-op).
Args:
account_id: The account ID being deleted
source: Source of the sync request (e.g., "account_deleted")
Returns:
bool: True if all tasks were queued (or skipped in community), False if any queueing failed
"""
if not dify_config.ENTERPRISE_ENABLED:
return True
# Fetch all workspaces the account belongs to
workspace_joins = db.session.query(TenantAccountJoin).filter_by(account_id=account_id).all()
# Queue sync task for each workspace
success = True
for join in workspace_joins:
if not _queue_task(workspace_id=join.tenant_id, member_id=account_id, source=source):
success = False
return success

View File

@@ -39,6 +39,9 @@ class BaseRequest:
endpoint: str,
json: Any | None = None,
params: Mapping[str, Any] | None = None,
*,
timeout: float | httpx.Timeout | None = None,
raise_for_status: bool = False,
) -> Any:
headers = {"Content-Type": "application/json", cls.secret_key_header: cls.secret_key}
url = f"{cls.base_url}{endpoint}"
@@ -53,7 +56,16 @@ class BaseRequest:
logger.debug("Failed to generate traceparent header", exc_info=True)
with httpx.Client(mounts=mounts) as client:
response = client.request(method, url, json=json, params=params, headers=headers)
# IMPORTANT:
# - In httpx, passing timeout=None disables timeouts (infinite) and overrides the library default.
# - To preserve httpx's default timeout behavior for existing call sites, only pass the kwarg when set.
request_kwargs: dict[str, Any] = {"json": json, "params": params, "headers": headers}
if timeout is not None:
request_kwargs["timeout"] = timeout
response = client.request(method, url, **request_kwargs)
if raise_for_status:
response.raise_for_status()
return response.json()

View File

@@ -1,9 +1,17 @@
import logging
import uuid
from datetime import datetime
from pydantic import BaseModel, Field
from pydantic import BaseModel, ConfigDict, Field, model_validator
from configs import dify_config
from services.enterprise.base import EnterpriseRequest
logger = logging.getLogger(__name__)
DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS = 1.0
ALLOWED_ACCESS_MODES = ["public", "private", "private_all", "sso_verified"]
class WebAppSettings(BaseModel):
access_mode: str = Field(
@@ -30,6 +38,55 @@ class WorkspacePermission(BaseModel):
)
class DefaultWorkspaceJoinResult(BaseModel):
"""
Result of ensuring an account is a member of the enterprise default workspace.
- joined=True is idempotent (already a member also returns True)
- joined=False means enterprise default workspace is not configured or invalid/archived
"""
workspace_id: str = Field(default="", alias="workspaceId")
joined: bool
message: str
model_config = ConfigDict(extra="forbid", populate_by_name=True)
@model_validator(mode="after")
def _check_workspace_id_when_joined(self) -> "DefaultWorkspaceJoinResult":
if self.joined and not self.workspace_id:
raise ValueError("workspace_id must be non-empty when joined is True")
return self
def try_join_default_workspace(account_id: str) -> None:
"""
Enterprise-only side-effect: ensure account is a member of the default workspace.
This is a best-effort integration. Failures must not block user registration.
"""
if not dify_config.ENTERPRISE_ENABLED:
return
try:
result = EnterpriseService.join_default_workspace(account_id=account_id)
if result.joined:
logger.info(
"Joined enterprise default workspace for account %s (workspace_id=%s)",
account_id,
result.workspace_id,
)
else:
logger.info(
"Skipped joining enterprise default workspace for account %s (message=%s)",
account_id,
result.message,
)
except Exception:
logger.warning("Failed to join enterprise default workspace for account %s", account_id, exc_info=True)
class EnterpriseService:
@classmethod
def get_info(cls):
@@ -39,6 +96,34 @@ class EnterpriseService:
def get_workspace_info(cls, tenant_id: str):
return EnterpriseRequest.send_request("GET", f"/workspace/{tenant_id}/info")
@classmethod
def join_default_workspace(cls, *, account_id: str) -> DefaultWorkspaceJoinResult:
"""
Call enterprise inner API to add an account to the default workspace.
NOTE: EnterpriseRequest.base_url is expected to already include the `/inner/api` prefix,
so the endpoint here is `/default-workspace/members`.
"""
# Ensure we are sending a UUID-shaped string (enterprise side validates too).
try:
uuid.UUID(account_id)
except ValueError as e:
raise ValueError(f"account_id must be a valid UUID: {account_id}") from e
data = EnterpriseRequest.send_request(
"POST",
"/default-workspace/members",
json={"account_id": account_id},
timeout=DEFAULT_WORKSPACE_JOIN_TIMEOUT_SECONDS,
raise_for_status=True,
)
if not isinstance(data, dict):
raise ValueError("Invalid response format from enterprise default workspace API")
if "joined" not in data or "message" not in data:
raise ValueError("Invalid response payload from enterprise default workspace API")
return DefaultWorkspaceJoinResult.model_validate(data)
@classmethod
def get_app_sso_settings_last_update_time(cls) -> datetime:
data = EnterpriseRequest.send_request("GET", "/sso/app/last-update-time")
@@ -123,8 +208,8 @@ class EnterpriseService:
def update_app_access_mode(cls, app_id: str, access_mode: str):
if not app_id:
raise ValueError("app_id must be provided.")
if access_mode not in ["public", "private", "private_all"]:
raise ValueError("access_mode must be either 'public', 'private', or 'private_all'")
if access_mode not in ALLOWED_ACCESS_MODES:
raise ValueError(f"access_mode must be one of: {', '.join(ALLOWED_ACCESS_MODES)}")
data = {"appId": app_id, "accessMode": access_mode}

View File

@@ -7,9 +7,10 @@ from core.llm_generator.llm_generator import LLMGenerator
from core.memory.token_buffer_memory import TokenBufferMemory
from core.model_manager import ModelManager
from core.model_runtime.entities.model_entities import ModelType
from core.ops.entities.trace_entity import TraceTaskName
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
from core.ops.utils import measure_time
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
from core.telemetry import emit as telemetry_emit
from events.feedback_event import feedback_was_created
from extensions.ext_database import db
from libs.infinite_scroll_pagination import InfiniteScrollPagination
from models import Account
@@ -179,6 +180,9 @@ class MessageService:
db.session.commit()
if feedback and rating:
feedback_was_created.send(feedback, tenant_id=app_model.tenant_id)
return feedback
@classmethod
@@ -294,10 +298,15 @@ class MessageService:
questions: list[str] = list(questions_sequence)
# get tracing instance
trace_manager = TraceQueueManager(app_id=app_model.id)
trace_manager.add_trace_task(
TraceTask(
TraceTaskName.SUGGESTED_QUESTION_TRACE, message_id=message_id, suggested_question=questions, timer=timer
telemetry_emit(
TelemetryEvent(
name=TraceTaskName.SUGGESTED_QUESTION_TRACE,
context=TelemetryContext(tenant_id=app_model.tenant_id, app_id=app_model.id),
payload={
"message_id": message_id,
"suggested_question": questions,
"timer": timer,
},
)
)

View File

@@ -1,3 +1,4 @@
import logging
from typing import Any
from core.ops.entities.config_entity import BaseTracingConfig
@@ -5,6 +6,8 @@ from core.ops.ops_trace_manager import OpsTraceManager, provider_config_map
from extensions.ext_database import db
from models.model import App, TraceAppConfig
logger = logging.getLogger(__name__)
class OpsService:
@classmethod
@@ -135,12 +138,13 @@ class OpsService:
return trace_config_data.to_dict()
@classmethod
def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict):
def create_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict, account_id: str):
"""
Create tracing app config
:param app_id: app id
:param tracing_provider: tracing provider
:param tracing_config: tracing config
:param account_id: account id of the user creating the config
:return:
"""
try:
@@ -207,15 +211,19 @@ class OpsService:
db.session.add(trace_config_data)
db.session.commit()
# Log the creation with modifier information
logger.info("Trace config created: app_id=%s, provider=%s, created_by=%s", app_id, tracing_provider, account_id)
return {"result": "success"}
@classmethod
def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict):
def update_tracing_app_config(cls, app_id: str, tracing_provider: str, tracing_config: dict, account_id: str):
"""
Update tracing app config
:param app_id: app id
:param tracing_provider: tracing provider
:param tracing_config: tracing config
:param account_id: account id of the user updating the config
:return:
"""
try:
@@ -251,14 +259,18 @@ class OpsService:
current_trace_config.tracing_config = tracing_config
db.session.commit()
# Log the update with modifier information
logger.info("Trace config updated: app_id=%s, provider=%s, updated_by=%s", app_id, tracing_provider, account_id)
return current_trace_config.to_dict()
@classmethod
def delete_tracing_app_config(cls, app_id: str, tracing_provider: str):
def delete_tracing_app_config(cls, app_id: str, tracing_provider: str, account_id: str):
"""
Delete tracing app config
:param app_id: app id
:param tracing_provider: tracing provider
:param account_id: account id of the user deleting the config
:return:
"""
trace_config = (
@@ -270,6 +282,9 @@ class OpsService:
if not trace_config:
return None
# Log the deletion with modifier information
logger.info("Trace config deleted: app_id=%s, provider=%s, deleted_by=%s", app_id, tracing_provider, account_id)
db.session.delete(trace_config)
db.session.commit()

View File

@@ -2,7 +2,10 @@ import json
import logging
from collections.abc import Mapping
from pathlib import Path
from typing import Any
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from models.account import Account
from sqlalchemy import exists, select
from sqlalchemy.orm import Session
@@ -406,20 +409,37 @@ class BuiltinToolManageService:
return {"result": "success"}
@staticmethod
def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str):
def set_default_provider(tenant_id: str, user_id: str, provider: str, id: str, account: "Account | None" = None):
"""
set default provider
"""
with Session(db.engine) as session:
# get provider
target_provider = session.query(BuiltinToolProvider).filter_by(id=id).first()
# get provider (verify tenant ownership to prevent IDOR)
target_provider = session.query(BuiltinToolProvider).filter_by(id=id, tenant_id=tenant_id).first()
if target_provider is None:
raise ValueError("provider not found")
# clear default provider
session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
).update({"is_default": False})
if dify_config.ENTERPRISE_ENABLED:
# Enterprise: verify admin permission for tenant-wide operation
from models.account import TenantAccountRole
if account is None:
# In enterprise mode, an account context is required to perform permission checks
raise ValueError("Account is required to set default credentials in enterprise mode")
if not TenantAccountRole.is_privileged_role(account.current_role):
raise ValueError("Only workspace admins/owners can set default credentials in enterprise mode")
# Enterprise: clear ALL defaults for this provider in the tenant
# (regardless of user_id, since enterprise credentials may have different user_id)
session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id, provider=provider, is_default=True
).update({"is_default": False})
else:
# Non-enterprise: only clear defaults for the current user
session.query(BuiltinToolProvider).filter_by(
tenant_id=tenant_id, user_id=user_id, provider=provider, is_default=True
).update({"is_default": False})
# set new default provider
target_provider.is_default = True

View File

@@ -647,6 +647,7 @@ class WorkflowService:
node_config = draft_workflow.get_node_config_by_id(node_id)
node_type = Workflow.get_node_type_from_node_config(node_config)
node_data = node_config.get("data", {})
workflow_execution_id: str | None = None
if node_type.is_start_node:
with Session(bind=db.engine) as session, session.begin():
draft_var_srv = WorkflowDraftVariableService(session)
@@ -672,10 +673,13 @@ class WorkflowService:
node_type=node_type,
conversation_id=conversation_id,
)
workflow_execution_id = variable_pool.system_variables.workflow_execution_id
else:
workflow_execution_id = str(uuid.uuid4())
system_variable = SystemVariable(workflow_execution_id=workflow_execution_id)
variable_pool = VariablePool(
system_variables=SystemVariable.default(),
system_variables=system_variable,
user_inputs=user_inputs,
environment_variables=draft_workflow.environment_variables,
conversation_variables=[],
@@ -729,6 +733,15 @@ class WorkflowService:
with Session(db.engine) as session:
outputs = workflow_node_execution.load_full_outputs(session, storage)
from enterprise.telemetry.draft_trace import enqueue_draft_node_execution_trace
enqueue_draft_node_execution_trace(
execution=workflow_node_execution,
outputs=outputs,
workflow_execution_id=workflow_execution_id,
user_id=account.id,
)
with Session(bind=db.engine) as session, session.begin():
draft_var_saver = DraftVariableSaver(
session=session,
@@ -784,19 +797,20 @@ class WorkflowService:
Returns:
WorkflowNodeExecution: The execution result
"""
created_at = naive_utc_now()
node, node_run_result, run_succeeded, error = self._execute_node_safely(invoke_node_fn)
finished_at = naive_utc_now()
# Create base node execution
node_execution = WorkflowNodeExecution(
id=str(uuid.uuid4()),
id=node.execution_id or str(uuid.uuid4()),
workflow_id="", # Single-step execution has no workflow ID
index=1,
node_id=node_id,
node_type=node.node_type,
title=node.title,
elapsed_time=time.perf_counter() - start_at,
created_at=naive_utc_now(),
finished_at=naive_utc_now(),
created_at=created_at,
finished_at=finished_at,
)
# Populate execution result data

View File

@@ -6,7 +6,6 @@ from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@@ -58,5 +57,3 @@ def add_annotation_to_index_task(
)
except Exception:
logger.exception("Build index for annotation failed")
finally:
db.session.close()

View File

@@ -5,7 +5,6 @@ import click
from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@@ -40,5 +39,3 @@ def delete_annotation_index_task(annotation_id: str, app_id: str, tenant_id: str
logger.info(click.style(f"App annotations index deleted : {app_id} latency: {end_at - start_at}", fg="green"))
except Exception:
logger.exception("Annotation deleted index failed")
finally:
db.session.close()

View File

@@ -6,7 +6,6 @@ from celery import shared_task
from core.rag.datasource.vdb.vector_factory import Vector
from core.rag.models.document import Document
from extensions.ext_database import db
from models.dataset import Dataset
from services.dataset_service import DatasetCollectionBindingService
@@ -59,5 +58,3 @@ def update_annotation_to_index_task(
)
except Exception:
logger.exception("Build index for annotation failed")
finally:
db.session.close()

View File

@@ -14,6 +14,9 @@ from models.model import UploadFile
logger = logging.getLogger(__name__)
# Batch size for database operations to keep transactions short
BATCH_SIZE = 1000
@shared_task(queue="dataset")
def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form: str | None, file_ids: list[str]):
@@ -31,63 +34,179 @@ def batch_clean_document_task(document_ids: list[str], dataset_id: str, doc_form
if not doc_form:
raise ValueError("doc_form is required")
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
).delete(synchronize_session=False)
storage_keys_to_delete: list[str] = []
index_node_ids: list[str] = []
segment_ids: list[str] = []
total_image_upload_file_ids: list[str] = []
try:
# ============ Step 1: Query segment and file data (short read-only transaction) ============
with session_factory.create_session() as session:
# Get segments info
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
segment_ids = [segment.id for segment in segments]
# Collect image file IDs from segment content
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
image_files = session.query(UploadFile).where(UploadFile.id.in_(image_upload_file_ids)).all()
for image_file in image_files:
try:
if image_file and image_file.key:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file.id,
)
stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(stmt)
session.delete(segment)
total_image_upload_file_ids.extend(image_upload_file_ids)
# Query storage keys for image files
if total_image_upload_file_ids:
image_files = session.scalars(
select(UploadFile).where(UploadFile.id.in_(total_image_upload_file_ids))
).all()
storage_keys_to_delete.extend([f.key for f in image_files if f and f.key])
# Query storage keys for document files
if file_ids:
files = session.scalars(select(UploadFile).where(UploadFile.id.in_(file_ids))).all()
for file in files:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file.id)
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
session.execute(stmt)
storage_keys_to_delete.extend([f.key for f in files if f and f.key])
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned documents when documents deleted latency: {end_at - start_at}",
fg="green",
# ============ Step 2: Clean vector index (external service, fresh session for dataset) ============
if index_node_ids:
try:
# Fetch dataset in a fresh session to avoid DetachedInstanceError
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
logger.warning("Dataset not found for vector index cleanup, dataset_id: %s", dataset_id)
else:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
except Exception:
logger.exception(
"Failed to clean vector index for dataset_id: %s, document_ids: %s, index_node_ids count: %d",
dataset_id,
document_ids,
len(index_node_ids),
)
)
# ============ Step 3: Delete metadata binding (separate short transaction) ============
try:
with session_factory.create_session() as session:
deleted_count = (
session.query(DatasetMetadataBinding)
.where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id.in_(document_ids),
)
.delete(synchronize_session=False)
)
session.commit()
logger.debug("Deleted %d metadata bindings for dataset_id: %s", deleted_count, dataset_id)
except Exception:
logger.exception("Cleaned documents when documents deleted failed")
logger.exception(
"Failed to delete metadata bindings for dataset_id: %s, document_ids: %s",
dataset_id,
document_ids,
)
# ============ Step 4: Batch delete UploadFile records (multiple short transactions) ============
if total_image_upload_file_ids:
failed_batches = 0
total_batches = (len(total_image_upload_file_ids) + BATCH_SIZE - 1) // BATCH_SIZE
for i in range(0, len(total_image_upload_file_ids), BATCH_SIZE):
batch = total_image_upload_file_ids[i : i + BATCH_SIZE]
try:
with session_factory.create_session() as session:
stmt = delete(UploadFile).where(UploadFile.id.in_(batch))
session.execute(stmt)
session.commit()
except Exception:
failed_batches += 1
logger.exception(
"Failed to delete image UploadFile batch %d-%d for dataset_id: %s",
i,
i + len(batch),
dataset_id,
)
if failed_batches > 0:
logger.warning(
"Image UploadFile deletion: %d/%d batches failed for dataset_id: %s",
failed_batches,
total_batches,
dataset_id,
)
# ============ Step 5: Batch delete DocumentSegment records (multiple short transactions) ============
if segment_ids:
failed_batches = 0
total_batches = (len(segment_ids) + BATCH_SIZE - 1) // BATCH_SIZE
for i in range(0, len(segment_ids), BATCH_SIZE):
batch = segment_ids[i : i + BATCH_SIZE]
try:
with session_factory.create_session() as session:
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(batch))
session.execute(segment_delete_stmt)
session.commit()
except Exception:
failed_batches += 1
logger.exception(
"Failed to delete DocumentSegment batch %d-%d for dataset_id: %s, document_ids: %s",
i,
i + len(batch),
dataset_id,
document_ids,
)
if failed_batches > 0:
logger.warning(
"DocumentSegment deletion: %d/%d batches failed, document_ids: %s",
failed_batches,
total_batches,
document_ids,
)
# ============ Step 6: Delete document-associated files (separate short transaction) ============
if file_ids:
try:
with session_factory.create_session() as session:
stmt = delete(UploadFile).where(UploadFile.id.in_(file_ids))
session.execute(stmt)
session.commit()
except Exception:
logger.exception(
"Failed to delete document UploadFile records for dataset_id: %s, file_ids: %s",
dataset_id,
file_ids,
)
# ============ Step 7: Delete storage files (I/O operations, no DB transaction) ============
storage_delete_failures = 0
for storage_key in storage_keys_to_delete:
try:
storage.delete(storage_key)
except Exception:
storage_delete_failures += 1
logger.exception("Failed to delete file from storage, key: %s", storage_key)
if storage_delete_failures > 0:
logger.warning(
"Storage file deletion completed with %d failures out of %d total files for dataset_id: %s",
storage_delete_failures,
len(storage_keys_to_delete),
dataset_id,
)
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned documents when documents deleted latency: {end_at - start_at:.2f}s, "
f"dataset_id: {dataset_id}, document_ids: {document_ids}, "
f"segments: {len(segment_ids)}, image_files: {len(total_image_upload_file_ids)}, "
f"storage_files: {len(storage_keys_to_delete)}",
fg="green",
)
)
except Exception:
logger.exception(
"Batch clean documents failed for dataset_id: %s, document_ids: %s",
dataset_id,
document_ids,
)

View File

@@ -48,6 +48,11 @@ def batch_create_segment_to_index_task(
indexing_cache_key = f"segment_batch_import_{job_id}"
# Initialize variables with default values
upload_file_key: str | None = None
dataset_config: dict | None = None
document_config: dict | None = None
with session_factory.create_session() as session:
try:
dataset = session.get(Dataset, dataset_id)
@@ -69,86 +74,115 @@ def batch_create_segment_to_index_task(
if not upload_file:
raise ValueError("UploadFile not found.")
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file.key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file.key, file_path)
dataset_config = {
"id": dataset.id,
"indexing_technique": dataset.indexing_technique,
"tenant_id": dataset.tenant_id,
"embedding_model_provider": dataset.embedding_model_provider,
"embedding_model": dataset.embedding_model,
}
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if dataset_document.doc_form == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
document_config = {
"id": dataset_document.id,
"doc_form": dataset_document.doc_form,
"word_count": dataset_document.word_count or 0,
}
document_segments = []
embedding_model = None
if dataset.indexing_technique == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset.tenant_id,
provider=dataset.embedding_model_provider,
model_type=ModelType.TEXT_EMBEDDING,
model=dataset.embedding_model,
)
upload_file_key = upload_file.key
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(
texts=[segment["content"] for segment in content]
)
except Exception:
logger.exception("Segments batch created index failed")
redis_client.setex(indexing_cache_key, 600, "error")
return
# Ensure required variables are set before proceeding
if upload_file_key is None or dataset_config is None or document_config is None:
logger.error("Required configuration not set due to session error")
redis_client.setex(indexing_cache_key, 600, "error")
return
with tempfile.TemporaryDirectory() as temp_dir:
suffix = Path(upload_file_key).suffix
file_path = f"{temp_dir}/{next(tempfile._get_candidate_names())}{suffix}" # type: ignore
storage.download(upload_file_key, file_path)
df = pd.read_csv(file_path)
content = []
for _, row in df.iterrows():
if document_config["doc_form"] == "qa_model":
data = {"content": row.iloc[0], "answer": row.iloc[1]}
else:
tokens_list = [0] * len(content)
data = {"content": row.iloc[0]}
content.append(data)
if len(content) == 0:
raise ValueError("The CSV file is empty.")
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
max_position = (
session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == dataset_document.id)
.scalar()
)
segment_document = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=user_id,
indexing_at=naive_utc_now(),
status="completed",
completed_at=naive_utc_now(),
)
if dataset_document.doc_form == "qa_model":
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count
session.add(segment_document)
document_segments.append(segment_document)
document_segments = []
embedding_model = None
if dataset_config["indexing_technique"] == "high_quality":
model_manager = ModelManager()
embedding_model = model_manager.get_model_instance(
tenant_id=dataset_config["tenant_id"],
provider=dataset_config["embedding_model_provider"],
model_type=ModelType.TEXT_EMBEDDING,
model=dataset_config["embedding_model"],
)
word_count_change = 0
if embedding_model:
tokens_list = embedding_model.get_text_embedding_num_tokens(texts=[segment["content"] for segment in content])
else:
tokens_list = [0] * len(content)
with session_factory.create_session() as session, session.begin():
for segment, tokens in zip(content, tokens_list):
content = segment["content"]
doc_id = str(uuid.uuid4())
segment_hash = helper.generate_text_hash(content)
max_position = (
session.query(func.max(DocumentSegment.position))
.where(DocumentSegment.document_id == document_config["id"])
.scalar()
)
segment_document = DocumentSegment(
tenant_id=tenant_id,
dataset_id=dataset_id,
document_id=document_id,
index_node_id=doc_id,
index_node_hash=segment_hash,
position=max_position + 1 if max_position else 1,
content=content,
word_count=len(content),
tokens=tokens,
created_by=user_id,
indexing_at=naive_utc_now(),
status="completed",
completed_at=naive_utc_now(),
)
if document_config["doc_form"] == "qa_model":
segment_document.answer = segment["answer"]
segment_document.word_count += len(segment["answer"])
word_count_change += segment_document.word_count
session.add(segment_document)
document_segments.append(segment_document)
with session_factory.create_session() as session, session.begin():
dataset_document = session.get(Document, document_id)
if dataset_document:
assert dataset_document.word_count is not None
dataset_document.word_count += word_count_change
session.add(dataset_document)
VectorService.create_segments_vector(None, document_segments, dataset, dataset_document.doc_form)
session.commit()
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
fg="green",
)
)
except Exception:
logger.exception("Segments batch created index failed")
redis_client.setex(indexing_cache_key, 600, "error")
with session_factory.create_session() as session:
dataset = session.get(Dataset, dataset_id)
if dataset:
VectorService.create_segments_vector(None, document_segments, dataset, document_config["doc_form"])
redis_client.setex(indexing_cache_key, 600, "completed")
end_at = time.perf_counter()
logger.info(
click.style(
f"Segment batch created job: {job_id} latency: {end_at - start_at}",
fg="green",
)
)

View File

@@ -28,6 +28,7 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
"""
logger.info(click.style(f"Start clean document when document deleted: {document_id}", fg="green"))
start_at = time.perf_counter()
total_attachment_files = []
with session_factory.create_session() as session:
try:
@@ -47,78 +48,91 @@ def clean_document_task(document_id: str, dataset_id: str, doc_form: str, file_i
SegmentAttachmentBinding.document_id == document_id,
)
).all()
# check segment is exist
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
total_attachment_files.extend([attachment_file.key for _, attachment_file in attachments_with_bindings])
index_node_ids = [segment.index_node_id for segment in segments]
segment_contents = [segment.content for segment in segments]
except Exception:
logger.exception("Cleaned document when document deleted failed")
return
# check segment is exist
if index_node_ids:
index_processor = IndexProcessorFactory(doc_form).init_index_processor()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
for segment in segments:
image_upload_file_ids = get_image_upload_file_ids(segment.content)
image_files = session.scalars(
select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
).all()
for image_file in image_files:
if image_file is None:
continue
try:
storage.delete(image_file.key)
except Exception:
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file.id,
)
total_image_files = []
with session_factory.create_session() as session, session.begin():
for segment_content in segment_contents:
image_upload_file_ids = get_image_upload_file_ids(segment_content)
image_files = session.scalars(select(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))).all()
total_image_files.extend([image_file.key for image_file in image_files])
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(image_file_delete_stmt)
image_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(image_upload_file_ids))
session.execute(image_file_delete_stmt)
session.delete(segment)
with session_factory.create_session() as session, session.begin():
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
session.execute(segment_delete_stmt)
session.commit()
if file_id:
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
session.delete(file)
# delete segment attachments
if attachments_with_bindings:
attachment_ids = [attachment_file.id for _, attachment_file in attachments_with_bindings]
binding_ids = [binding.id for binding, _ in attachments_with_bindings]
for binding, attachment_file in attachments_with_bindings:
try:
storage.delete(attachment_file.key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
binding.attachment_id,
)
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
session.execute(attachment_file_delete_stmt)
binding_delete_stmt = delete(SegmentAttachmentBinding).where(
SegmentAttachmentBinding.id.in_(binding_ids)
)
session.execute(binding_delete_stmt)
# delete dataset metadata binding
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
fg="green",
)
)
for image_file_key in total_image_files:
try:
storage.delete(image_file_key)
except Exception:
logger.exception("Cleaned document when document deleted failed")
logger.exception(
"Delete image_files failed when storage deleted, \
image_upload_file_is: %s",
image_file_key,
)
with session_factory.create_session() as session, session.begin():
if file_id:
file = session.query(UploadFile).where(UploadFile.id == file_id).first()
if file:
try:
storage.delete(file.key)
except Exception:
logger.exception("Delete file failed when document deleted, file_id: %s", file_id)
session.delete(file)
with session_factory.create_session() as session, session.begin():
# delete segment attachments
if attachment_ids:
attachment_file_delete_stmt = delete(UploadFile).where(UploadFile.id.in_(attachment_ids))
session.execute(attachment_file_delete_stmt)
if binding_ids:
binding_delete_stmt = delete(SegmentAttachmentBinding).where(SegmentAttachmentBinding.id.in_(binding_ids))
session.execute(binding_delete_stmt)
for attachment_file_key in total_attachment_files:
try:
storage.delete(attachment_file_key)
except Exception:
logger.exception(
"Delete attachment_file failed when storage deleted, \
attachment_file_id: %s",
attachment_file_key,
)
with session_factory.create_session() as session, session.begin():
# delete dataset metadata binding
session.query(DatasetMetadataBinding).where(
DatasetMetadataBinding.dataset_id == dataset_id,
DatasetMetadataBinding.document_id == document_id,
).delete()
end_at = time.perf_counter()
logger.info(
click.style(
f"Cleaned document when document deleted: {document_id} latency: {end_at - start_at}",
fg="green",
)
)

View File

@@ -23,40 +23,40 @@ def clean_notion_document_task(document_ids: list[str], dataset_id: str):
"""
logger.info(click.style(f"Start clean document when import form notion document deleted: {dataset_id}", fg="green"))
start_at = time.perf_counter()
total_index_node_ids = []
with session_factory.create_session() as session:
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Document has no dataset")
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if not dataset:
raise Exception("Document has no dataset")
index_type = dataset.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
session.execute(document_delete_stmt)
document_delete_stmt = delete(Document).where(Document.id.in_(document_ids))
session.execute(document_delete_stmt)
for document_id in document_ids:
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
for document_id in document_ids:
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
total_index_node_ids.extend([segment.index_node_id for segment in segments])
index_processor.clean(
dataset, index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
session.commit()
end_at = time.perf_counter()
logger.info(
click.style(
"Clean document when import form notion document deleted end :: {} latency: {}".format(
dataset_id, end_at - start_at
),
fg="green",
)
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
index_processor.clean(
dataset, total_index_node_ids, with_keywords=True, delete_child_chunks=True, delete_summaries=True
)
except Exception:
logger.exception("Cleaned document when import form notion document deleted failed")
with session_factory.create_session() as session, session.begin():
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id.in_(document_ids))
session.execute(segment_delete_stmt)
end_at = time.perf_counter()
logger.info(
click.style(
"Clean document when import form notion document deleted end :: {} latency: {}".format(
dataset_id, end_at - start_at
),
fg="green",
)
)

View File

@@ -3,6 +3,7 @@ import time
import click
from celery import shared_task
from sqlalchemy import delete
from core.db.session_factory import session_factory
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
@@ -67,8 +68,14 @@ def delete_segment_from_index_task(
if segment_attachment_bindings:
attachment_ids = [binding.attachment_id for binding in segment_attachment_bindings]
index_processor.clean(dataset=dataset, node_ids=attachment_ids, with_keywords=False)
for binding in segment_attachment_bindings:
session.delete(binding)
segment_attachment_bind_ids = [i.id for i in segment_attachment_bindings]
for i in range(0, len(segment_attachment_bind_ids), 1000):
segment_attachment_bind_delete_stmt = delete(SegmentAttachmentBinding).where(
SegmentAttachmentBinding.id.in_(segment_attachment_bind_ids[i : i + 1000])
)
session.execute(segment_attachment_bind_delete_stmt)
# delete upload file
session.query(UploadFile).where(UploadFile.id.in_(attachment_ids)).delete(synchronize_session=False)
session.commit()

View File

@@ -27,104 +27,129 @@ def document_indexing_sync_task(dataset_id: str, document_id: str):
"""
logger.info(click.style(f"Start sync document: {document_id}", fg="green"))
start_at = time.perf_counter()
tenant_id = None
with session_factory.create_session() as session:
with session_factory.create_session() as session, session.begin():
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
logger.info(click.style(f"Document not found: {document_id}", fg="red"))
return
if document.indexing_status == "parsing":
logger.info(click.style(f"Document {document_id} is already being processed, skipping", fg="yellow"))
return
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
data_source_info = document.data_source_info_dict
if document.data_source_type == "notion_import":
if (
not data_source_info
or "notion_page_id" not in data_source_info
or "notion_workspace_id" not in data_source_info
):
raise ValueError("no notion page found")
workspace_id = data_source_info["notion_workspace_id"]
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
credential_id = data_source_info.get("credential_id")
if document.data_source_type != "notion_import":
logger.info(click.style(f"Document {document_id} is not a notion_import, skipping", fg="yellow"))
return
# Get credentials from datasource provider
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=document.tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
if (
not data_source_info
or "notion_page_id" not in data_source_info
or "notion_workspace_id" not in data_source_info
):
raise ValueError("no notion page found")
if not credential:
logger.error(
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
document_id,
document.tenant_id,
credential_id,
)
workspace_id = data_source_info["notion_workspace_id"]
page_id = data_source_info["notion_page_id"]
page_type = data_source_info["type"]
page_edited_time = data_source_info["last_edited_time"]
credential_id = data_source_info.get("credential_id")
tenant_id = document.tenant_id
index_type = document.doc_form
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
# Get credentials from datasource provider
datasource_provider_service = DatasourceProviderService()
credential = datasource_provider_service.get_datasource_credentials(
tenant_id=tenant_id,
credential_id=credential_id,
provider="notion_datasource",
plugin_id="langgenius/notion_datasource",
)
if not credential:
logger.error(
"Datasource credential not found for document %s, tenant_id: %s, credential_id: %s",
document_id,
tenant_id,
credential_id,
)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if document:
document.indexing_status = "error"
document.error = "Datasource credential not found. Please reconnect your Notion workspace."
document.stopped_at = naive_utc_now()
session.commit()
return
return
loader = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
tenant_id=document.tenant_id,
)
loader = NotionExtractor(
notion_workspace_id=workspace_id,
notion_obj_id=page_id,
notion_page_type=page_type,
notion_access_token=credential.get("integration_secret"),
tenant_id=tenant_id,
)
last_edited_time = loader.get_notion_last_edited_time()
last_edited_time = loader.get_notion_last_edited_time()
if last_edited_time == page_edited_time:
logger.info(click.style(f"Document {document_id} content unchanged, skipping sync", fg="yellow"))
return
# check the page is updated
if last_edited_time != page_edited_time:
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
session.commit()
logger.info(click.style(f"Document {document_id} content changed, starting sync", fg="green"))
# delete all document segment and index
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
try:
index_processor = IndexProcessorFactory(index_type).init_index_processor()
with session_factory.create_session() as session:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if dataset:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
logger.info(click.style(f"Cleaned vector index for document {document_id}", fg="green"))
except Exception:
logger.exception("Failed to clean vector index for document %s", document_id)
segments = session.scalars(
select(DocumentSegment).where(DocumentSegment.document_id == document_id)
).all()
index_node_ids = [segment.index_node_id for segment in segments]
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if not document:
logger.warning(click.style(f"Document {document_id} not found during sync", fg="yellow"))
return
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
data_source_info = document.data_source_info_dict
data_source_info["last_edited_time"] = last_edited_time
document.data_source_info = data_source_info
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
end_at = time.perf_counter()
logger.info(
click.style(
"Cleaned document when document update data source or process rule: {} latency: {}".format(
document_id, end_at - start_at
),
fg="green",
)
)
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
session.execute(segment_delete_stmt)
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_sync_task failed, document_id: %s", document_id)
logger.info(click.style(f"Deleted segments for document {document_id}", fg="green"))
try:
indexing_runner = IndexingRunner()
with session_factory.create_session() as session:
document = session.query(Document).filter_by(id=document_id).first()
if document:
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"Sync completed for document {document_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception as e:
logger.exception("document_indexing_sync_task failed for document_id: %s", document_id)
with session_factory.create_session() as session, session.begin():
document = session.query(Document).filter_by(id=document_id).first()
if document:
document.indexing_status = "error"
document.error = str(e)
document.stopped_at = naive_utc_now()

View File

@@ -81,26 +81,35 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
session.commit()
return
for document_id in document_ids:
logger.info(click.style(f"Start process document: {document_id}", fg="green"))
document = (
session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
)
# Phase 1: Update status to parsing (short transaction)
with session_factory.create_session() as session, session.begin():
documents = (
session.query(Document).where(Document.id.in_(document_ids), Document.dataset_id == dataset_id).all()
)
for document in documents:
if document:
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
documents.append(document)
session.add(document)
session.commit()
# Transaction committed and closed
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
# Phase 2: Execute indexing (no transaction - IndexingRunner creates its own sessions)
has_error = False
try:
indexing_runner = IndexingRunner()
indexing_runner.run(documents)
end_at = time.perf_counter()
logger.info(click.style(f"Processed dataset: {dataset_id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
has_error = True
except Exception:
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
has_error = True
if not has_error:
with session_factory.create_session() as session:
# Trigger summary index generation for completed documents if enabled
# Only generate for high_quality indexing technique and when summary_index_setting is enabled
# Re-query dataset to get latest summary_index_setting (in case it was updated)
@@ -115,17 +124,18 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
# expire all session to get latest document's indexing status
session.expire_all()
# Check each document's indexing status and trigger summary generation if completed
for document_id in document_ids:
# Re-query document to get latest status (IndexingRunner may have updated it)
document = (
session.query(Document)
.where(Document.id == document_id, Document.dataset_id == dataset_id)
.first()
)
documents = (
session.query(Document)
.where(Document.id.in_(document_ids), Document.dataset_id == dataset_id)
.all()
)
for document in documents:
if document:
logger.info(
"Checking document %s for summary generation: status=%s, doc_form=%s, need_summary=%s",
document_id,
document.id,
document.indexing_status,
document.doc_form,
document.need_summary,
@@ -136,46 +146,36 @@ def _document_indexing(dataset_id: str, document_ids: Sequence[str]):
and document.need_summary is True
):
try:
generate_summary_index_task.delay(dataset.id, document_id, None)
generate_summary_index_task.delay(dataset.id, document.id, None)
logger.info(
"Queued summary index generation task for document %s in dataset %s "
"after indexing completed",
document_id,
document.id,
dataset.id,
)
except Exception:
logger.exception(
"Failed to queue summary index generation task for document %s",
document_id,
document.id,
)
# Don't fail the entire indexing process if summary task queuing fails
else:
logger.info(
"Skipping summary generation for document %s: "
"status=%s, doc_form=%s, need_summary=%s",
document_id,
document.id,
document.indexing_status,
document.doc_form,
document.need_summary,
)
else:
logger.warning("Document %s not found after indexing", document_id)
else:
logger.info(
"Summary index generation skipped for dataset %s: summary_index_setting.enable=%s",
dataset.id,
summary_index_setting.get("enable") if summary_index_setting else None,
)
logger.warning("Document %s not found after indexing", document.id)
else:
logger.info(
"Summary index generation skipped for dataset %s: indexing_technique=%s (not 'high_quality')",
dataset.id,
dataset.indexing_technique,
)
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("Document indexing task failed, dataset_id: %s", dataset_id)
def _document_indexing_with_tenant_queue(

View File

@@ -8,7 +8,6 @@ from sqlalchemy import delete, select
from core.db.session_factory import session_factory
from core.indexing_runner import DocumentIsPausedError, IndexingRunner
from core.rag.index_processor.index_processor_factory import IndexProcessorFactory
from extensions.ext_database import db
from libs.datetime_utils import naive_utc_now
from models.dataset import Dataset, Document, DocumentSegment
@@ -27,7 +26,7 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
logger.info(click.style(f"Start update document: {document_id}", fg="green"))
start_at = time.perf_counter()
with session_factory.create_session() as session:
with session_factory.create_session() as session, session.begin():
document = session.query(Document).where(Document.id == document_id, Document.dataset_id == dataset_id).first()
if not document:
@@ -36,27 +35,20 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
document.indexing_status = "parsing"
document.processing_started_at = naive_utc_now()
session.commit()
# delete all document segment and index
try:
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
raise Exception("Dataset not found")
dataset = session.query(Dataset).where(Dataset.id == dataset_id).first()
if not dataset:
return
index_type = document.doc_form
index_processor = IndexProcessorFactory(index_type).init_index_processor()
index_type = document.doc_form
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
index_node_ids = [segment.index_node_id for segment in segments]
segments = session.scalars(select(DocumentSegment).where(DocumentSegment.document_id == document_id)).all()
if segments:
index_node_ids = [segment.index_node_id for segment in segments]
# delete from vector index
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
segment_ids = [segment.id for segment in segments]
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.id.in_(segment_ids))
session.execute(segment_delete_stmt)
db.session.commit()
clean_success = False
try:
index_processor = IndexProcessorFactory(index_type).init_index_processor()
if index_node_ids:
index_processor.clean(dataset, index_node_ids, with_keywords=True, delete_child_chunks=True)
end_at = time.perf_counter()
logger.info(
click.style(
@@ -66,15 +58,21 @@ def document_indexing_update_task(dataset_id: str, document_id: str):
fg="green",
)
)
except Exception:
logger.exception("Cleaned document when document update data source or process rule failed")
clean_success = True
except Exception:
logger.exception("Failed to clean document index during update, document_id: %s", document_id)
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)
if clean_success:
with session_factory.create_session() as session, session.begin():
segment_delete_stmt = delete(DocumentSegment).where(DocumentSegment.document_id == document_id)
session.execute(segment_delete_stmt)
try:
indexing_runner = IndexingRunner()
indexing_runner.run([document])
end_at = time.perf_counter()
logger.info(click.style(f"update document: {document.id} latency: {end_at - start_at}", fg="green"))
except DocumentIsPausedError as ex:
logger.info(click.style(str(ex), fg="yellow"))
except Exception:
logger.exception("document_indexing_update_task failed, document_id: %s", document_id)

View File

@@ -0,0 +1,52 @@
"""Celery worker for enterprise metric/log telemetry events.
This module defines the Celery task that processes telemetry envelopes
from the enterprise_telemetry queue. It deserializes envelopes and
dispatches them to the EnterpriseMetricHandler.
"""
import json
import logging
from celery import shared_task
from enterprise.telemetry.contracts import TelemetryEnvelope
from enterprise.telemetry.metric_handler import EnterpriseMetricHandler
logger = logging.getLogger(__name__)
@shared_task(queue="enterprise_telemetry")
def process_enterprise_telemetry(envelope_json: str) -> None:
"""Process enterprise metric/log telemetry envelope.
This task is enqueued by the TelemetryGateway for metric/log-only
events. It deserializes the envelope and dispatches to the handler.
Best-effort processing: logs errors but never raises, to avoid
failing user requests due to telemetry issues.
Args:
envelope_json: JSON-serialized TelemetryEnvelope.
"""
try:
# Deserialize envelope
envelope_dict = json.loads(envelope_json)
envelope = TelemetryEnvelope.model_validate(envelope_dict)
# Process through handler
handler = EnterpriseMetricHandler()
handler.handle(envelope)
logger.debug(
"Successfully processed telemetry envelope: tenant_id=%s, event_id=%s, case=%s",
envelope.tenant_id,
envelope.event_id,
envelope.case,
)
except Exception:
# Best-effort: log and drop on error, never fail user request
logger.warning(
"Failed to process enterprise telemetry envelope, dropping event",
exc_info=True,
)

View File

@@ -39,12 +39,24 @@ def process_trace_tasks(file_info):
trace_info["documents"] = [Document.model_validate(doc) for doc in trace_info["documents"]]
try:
trace_type = trace_info_info_map.get(trace_info_type)
if trace_type:
trace_info = trace_type(**trace_info)
from extensions.ext_enterprise_telemetry import is_enabled as is_ee_telemetry_enabled
if is_ee_telemetry_enabled():
from enterprise.telemetry.enterprise_trace import EnterpriseOtelTrace
try:
EnterpriseOtelTrace().trace(trace_info)
except Exception:
logger.exception("Enterprise trace failed for app_id: %s", app_id)
if trace_instance:
with current_app.app_context():
trace_type = trace_info_info_map.get(trace_info_type)
if trace_type:
trace_info = trace_type(**trace_info)
trace_instance.trace(trace_info)
logger.info("Processing trace tasks success, app_id: %s", app_id)
except Exception as e:
logger.info("error:\n\n\n%s\n\n\n\n", e)
@@ -52,4 +64,12 @@ def process_trace_tasks(file_info):
redis_client.incr(failed_key)
logger.info("Processing trace tasks failed, app_id: %s", app_id)
finally:
storage.delete(file_path)
try:
storage.delete(file_path)
except Exception as e:
logger.warning(
"Failed to delete trace file %s for app_id %s: %s",
file_path,
app_id,
e,
)

View File

@@ -259,8 +259,8 @@ def _delete_app_workflow_app_logs(tenant_id: str, app_id: str):
def _delete_app_workflow_archive_logs(tenant_id: str, app_id: str):
def del_workflow_archive_log(workflow_archive_log_id: str):
db.session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
def del_workflow_archive_log(session, workflow_archive_log_id: str):
session.query(WorkflowArchiveLog).where(WorkflowArchiveLog.id == workflow_archive_log_id).delete(
synchronize_session=False
)
@@ -420,7 +420,7 @@ def delete_draft_variables_batch(app_id: str, batch_size: int = 1000) -> int:
total_files_deleted = 0
while True:
with session_factory.create_session() as session:
with session_factory.create_session() as session, session.begin():
# Get a batch of draft variable IDs along with their file_ids
query_sql = """
SELECT id, file_id FROM workflow_draft_variables

View File

@@ -6,9 +6,8 @@ improving performance by offloading storage operations to background workers.
"""
from celery import shared_task # type: ignore[import-untyped]
from sqlalchemy.orm import Session
from extensions.ext_database import db
from core.db.session_factory import session_factory
from services.workflow_draft_variable_service import DraftVarFileDeletion, WorkflowDraftVariableService
@@ -17,6 +16,6 @@ def save_workflow_execution_task(
self,
deletions: list[DraftVarFileDeletion],
):
with Session(bind=db.engine) as session, session.begin():
with session_factory.create_session() as session, session.begin():
srv = WorkflowDraftVariableService(session=session)
srv.delete_workflow_draft_variable_file(deletions=deletions)

View File

@@ -10,7 +10,10 @@ from models import Tenant
from models.enums import CreatorUserRole
from models.model import App, UploadFile
from models.workflow import WorkflowDraftVariable, WorkflowDraftVariableFile
from tasks.remove_app_and_related_data_task import _delete_draft_variables, delete_draft_variables_batch
from tasks.remove_app_and_related_data_task import (
_delete_draft_variables,
delete_draft_variables_batch,
)
@pytest.fixture
@@ -297,12 +300,18 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
def test_delete_draft_variables_with_offload_data(self, mock_storage, setup_offload_test_data):
data = setup_offload_test_data
app_id = data["app"].id
upload_file_ids = [uf.id for uf in data["upload_files"]]
variable_file_ids = [vf.id for vf in data["variable_files"]]
mock_storage.delete.return_value = None
with session_factory.create_session() as session:
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_before = session.query(WorkflowDraftVariableFile).count()
upload_files_before = session.query(UploadFile).count()
var_files_before = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
.count()
)
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert draft_vars_before == 3
assert var_files_before == 2
assert upload_files_before == 2
@@ -315,8 +324,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
assert draft_vars_after == 0
with session_factory.create_session() as session:
var_files_after = session.query(WorkflowDraftVariableFile).count()
upload_files_after = session.query(UploadFile).count()
var_files_after = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
.count()
)
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert var_files_after == 0
assert upload_files_after == 0
@@ -329,6 +342,8 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
def test_delete_draft_variables_storage_failure_continues_cleanup(self, mock_storage, setup_offload_test_data):
data = setup_offload_test_data
app_id = data["app"].id
upload_file_ids = [uf.id for uf in data["upload_files"]]
variable_file_ids = [vf.id for vf in data["variable_files"]]
mock_storage.delete.side_effect = [Exception("Storage error"), None]
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
@@ -339,8 +354,12 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
assert draft_vars_after == 0
with session_factory.create_session() as session:
var_files_after = session.query(WorkflowDraftVariableFile).count()
upload_files_after = session.query(UploadFile).count()
var_files_after = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_(variable_file_ids))
.count()
)
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert var_files_after == 0
assert upload_files_after == 0
@@ -395,3 +414,275 @@ class TestDeleteDraftVariablesWithOffloadIntegration:
if app2_obj:
session.delete(app2_obj)
session.commit()
class TestDeleteDraftVariablesSessionCommit:
"""Test suite to verify session commit behavior in delete_draft_variables_batch."""
@pytest.fixture
def setup_offload_test_data(self, app_and_tenant):
"""Create test data with offload files for session commit tests."""
from core.variables.types import SegmentType
from libs.datetime_utils import naive_utc_now
tenant, app = app_and_tenant
with session_factory.create_session() as session:
upload_file1 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file1.json",
name="file1.json",
size=1024,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
upload_file2 = UploadFile(
tenant_id=tenant.id,
storage_type="local",
key="test/file2.json",
name="file2.json",
size=2048,
extension="json",
mime_type="application/json",
created_by_role=CreatorUserRole.ACCOUNT,
created_by=str(uuid.uuid4()),
created_at=naive_utc_now(),
used=False,
)
session.add(upload_file1)
session.add(upload_file2)
session.flush()
var_file1 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file1.id,
size=1024,
length=10,
value_type=SegmentType.STRING,
)
var_file2 = WorkflowDraftVariableFile(
tenant_id=tenant.id,
app_id=app.id,
user_id=str(uuid.uuid4()),
upload_file_id=upload_file2.id,
size=2048,
length=20,
value_type=SegmentType.OBJECT,
)
session.add(var_file1)
session.add(var_file2)
session.flush()
draft_var1 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_1",
name="large_var_1",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file1.id,
)
draft_var2 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_2",
name="large_var_2",
value=StringSegment(value="truncated..."),
node_execution_id=str(uuid.uuid4()),
file_id=var_file2.id,
)
draft_var3 = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id="node_3",
name="regular_var",
value=StringSegment(value="regular_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(draft_var1)
session.add(draft_var2)
session.add(draft_var3)
session.commit()
data = {
"app": app,
"tenant": tenant,
"upload_files": [upload_file1, upload_file2],
"variable_files": [var_file1, var_file2],
"draft_variables": [draft_var1, draft_var2, draft_var3],
}
yield data
with session_factory.create_session() as session:
for table, ids in [
(WorkflowDraftVariable, [v.id for v in data["draft_variables"]]),
(WorkflowDraftVariableFile, [vf.id for vf in data["variable_files"]]),
(UploadFile, [uf.id for uf in data["upload_files"]]),
]:
cleanup_query = delete(table).where(table.id.in_(ids)).execution_options(synchronize_session=False)
session.execute(cleanup_query)
session.commit()
@pytest.fixture
def setup_commit_test_data(self, app_and_tenant):
"""Create test data for session commit tests."""
tenant, app = app_and_tenant
variable_ids: list[str] = []
with session_factory.create_session() as session:
variables = []
for i in range(10):
var = WorkflowDraftVariable.new_node_variable(
app_id=app.id,
node_id=f"node_{i}",
name=f"var_{i}",
value=StringSegment(value="test_value"),
node_execution_id=str(uuid.uuid4()),
)
session.add(var)
variables.append(var)
session.commit()
variable_ids = [v.id for v in variables]
yield {
"app": app,
"tenant": tenant,
"variable_ids": variable_ids,
}
with session_factory.create_session() as session:
cleanup_query = (
delete(WorkflowDraftVariable)
.where(WorkflowDraftVariable.id.in_(variable_ids))
.execution_options(synchronize_session=False)
)
session.execute(cleanup_query)
session.commit()
def test_session_commit_is_called_after_each_batch(self, setup_commit_test_data):
"""Test that session.begin() is used for automatic transaction management."""
data = setup_commit_test_data
app_id = data["app"].id
# Since session.begin() is used, the transaction is automatically committed
# when the with block exits successfully. We verify this by checking that
# data is actually persisted.
deleted_count = delete_draft_variables_batch(app_id, batch_size=3)
# Verify all data was deleted (proves transaction was committed)
with session_factory.create_session() as session:
remaining_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert deleted_count == 10
assert remaining_count == 0
def test_data_persisted_after_batch_deletion(self, setup_commit_test_data):
"""Test that data is actually persisted to database after batch deletion with commits."""
data = setup_commit_test_data
app_id = data["app"].id
variable_ids = data["variable_ids"]
# Verify initial state
with session_factory.create_session() as session:
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert initial_count == 10
# Perform deletion with small batch size to force multiple commits
deleted_count = delete_draft_variables_batch(app_id, batch_size=3)
assert deleted_count == 10
# Verify all data is deleted in a new session (proves commits worked)
with session_factory.create_session() as session:
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert final_count == 0
# Verify specific IDs are deleted
with session_factory.create_session() as session:
remaining_vars = (
session.query(WorkflowDraftVariable).where(WorkflowDraftVariable.id.in_(variable_ids)).count()
)
assert remaining_vars == 0
def test_session_commit_with_empty_dataset(self, setup_commit_test_data):
"""Test session behavior when deleting from an empty dataset."""
nonexistent_app_id = str(uuid.uuid4())
# Should not raise any errors and should return 0
deleted_count = delete_draft_variables_batch(nonexistent_app_id, batch_size=10)
assert deleted_count == 0
def test_session_commit_with_single_batch(self, setup_commit_test_data):
"""Test that commit happens correctly when all data fits in a single batch."""
data = setup_commit_test_data
app_id = data["app"].id
with session_factory.create_session() as session:
initial_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert initial_count == 10
# Delete all in a single batch
deleted_count = delete_draft_variables_batch(app_id, batch_size=100)
assert deleted_count == 10
# Verify data is persisted
with session_factory.create_session() as session:
final_count = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
assert final_count == 0
def test_invalid_batch_size_raises_error(self, setup_commit_test_data):
"""Test that invalid batch size raises ValueError."""
data = setup_commit_test_data
app_id = data["app"].id
with pytest.raises(ValueError, match="batch_size must be positive"):
delete_draft_variables_batch(app_id, batch_size=0)
with pytest.raises(ValueError, match="batch_size must be positive"):
delete_draft_variables_batch(app_id, batch_size=-1)
@patch("extensions.ext_storage.storage")
def test_session_commit_with_offload_data_cleanup(self, mock_storage, setup_offload_test_data):
"""Test that session commits correctly when cleaning up offload data."""
data = setup_offload_test_data
app_id = data["app"].id
upload_file_ids = [uf.id for uf in data["upload_files"]]
mock_storage.delete.return_value = None
# Verify initial state
with session_factory.create_session() as session:
draft_vars_before = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_before = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
.count()
)
upload_files_before = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert draft_vars_before == 3
assert var_files_before == 2
assert upload_files_before == 2
# Delete variables with offload data
deleted_count = delete_draft_variables_batch(app_id, batch_size=10)
assert deleted_count == 3
# Verify all data is persisted (deleted) in new session
with session_factory.create_session() as session:
draft_vars_after = session.query(WorkflowDraftVariable).filter_by(app_id=app_id).count()
var_files_after = (
session.query(WorkflowDraftVariableFile)
.where(WorkflowDraftVariableFile.id.in_([vf.id for vf in data["variable_files"]]))
.count()
)
upload_files_after = session.query(UploadFile).where(UploadFile.id.in_(upload_file_ids)).count()
assert draft_vars_after == 0
assert var_files_after == 0
assert upload_files_after == 0
# Verify storage cleanup was called
assert mock_storage.delete.call_count == 2

View File

@@ -0,0 +1,38 @@
"""
Integration tests for DbMigrationAutoRenewLock using real Redis via TestContainers.
"""
import time
import uuid
import pytest
from extensions.ext_redis import redis_client
from libs.db_migration_lock import DbMigrationAutoRenewLock
@pytest.mark.usefixtures("flask_app_with_containers")
def test_db_migration_lock_renews_ttl_and_releases():
lock_name = f"test:db_migration_auto_renew_lock:{uuid.uuid4().hex}"
# Keep base TTL very small, and renew frequently so the test is stable even on slower CI.
lock = DbMigrationAutoRenewLock(
redis_client=redis_client,
name=lock_name,
ttl_seconds=1.0,
renew_interval_seconds=0.2,
log_context="test_db_migration_lock",
)
acquired = lock.acquire(blocking=True, blocking_timeout=5)
assert acquired is True
# Wait beyond the base TTL; key should still exist due to renewal.
time.sleep(1.5)
ttl = redis_client.ttl(lock_name)
assert ttl > 0
lock.release_safely(status="successful")
# After release, the key should not exist.
assert redis_client.exists(lock_name) == 0

View File

@@ -605,26 +605,20 @@ class TestBatchCreateSegmentToIndexTask:
mock_storage.download.side_effect = mock_download
# Execute the task
# Execute the task - should raise ValueError for empty CSV
job_id = str(uuid.uuid4())
batch_create_segment_to_index_task(
job_id=job_id,
upload_file_id=upload_file.id,
dataset_id=dataset.id,
document_id=document.id,
tenant_id=tenant.id,
user_id=account.id,
)
with pytest.raises(ValueError, match="The CSV file is empty"):
batch_create_segment_to_index_task(
job_id=job_id,
upload_file_id=upload_file.id,
dataset_id=dataset.id,
document_id=document.id,
tenant_id=tenant.id,
user_id=account.id,
)
# Verify error handling
# Check Redis cache was set to error status
from extensions.ext_redis import redis_client
cache_key = f"segment_batch_import_{job_id}"
cache_value = redis_client.get(cache_key)
assert cache_value == b"error"
# Verify no segments were created
# Since exception was raised, no segments should be created
from extensions.ext_database import db
segments = db.session.query(DocumentSegment).all()

View File

@@ -153,8 +153,7 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task
clean_notion_document_task(document_ids, dataset.id)
# Verify documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id.in_(document_ids)).count() == 0
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(document_ids))
@@ -162,9 +161,9 @@ class TestCleanNotionDocumentTask:
== 0
)
# Verify index processor was called for each document
# Verify index processor was called
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
assert mock_processor.clean.call_count == len(document_ids)
mock_processor.clean.assert_called_once()
# This test successfully verifies:
# 1. Document records are properly deleted from the database
@@ -186,12 +185,12 @@ class TestCleanNotionDocumentTask:
non_existent_dataset_id = str(uuid.uuid4())
document_ids = [str(uuid.uuid4()), str(uuid.uuid4())]
# Execute cleanup task with non-existent dataset
clean_notion_document_task(document_ids, non_existent_dataset_id)
# Execute cleanup task with non-existent dataset - expect exception
with pytest.raises(Exception, match="Document has no dataset"):
clean_notion_document_task(document_ids, non_existent_dataset_id)
# Verify that the index processor was not called
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_processor.clean.assert_not_called()
# Verify that the index processor factory was not used
mock_index_processor_factory.return_value.init_index_processor.assert_not_called()
def test_clean_notion_document_task_empty_document_list(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
@@ -229,9 +228,13 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task with empty document list
clean_notion_document_task([], dataset.id)
# Verify that the index processor was not called
# Verify that the index processor was called once with empty node list
mock_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_processor.clean.assert_not_called()
assert mock_processor.clean.call_count == 1
args, kwargs = mock_processor.clean.call_args
# args: (dataset, total_index_node_ids)
assert isinstance(args[0], Dataset)
assert args[1] == []
def test_clean_notion_document_task_with_different_index_types(
self, db_session_with_containers, mock_index_processor_factory, mock_external_service_dependencies
@@ -315,8 +318,7 @@ class TestCleanNotionDocumentTask:
# Note: This test successfully verifies cleanup with different document types.
# The task properly handles various index types and document configurations.
# Verify documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id == document.id)
@@ -404,8 +406,7 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task
clean_notion_document_task([document.id], dataset.id)
# Verify documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
== 0
@@ -508,8 +509,7 @@ class TestCleanNotionDocumentTask:
clean_notion_document_task(documents_to_clean, dataset.id)
# Verify only specified documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id.in_(documents_to_clean)).count() == 0
# Verify only specified documents' segments are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id.in_(documents_to_clean))
@@ -697,11 +697,12 @@ class TestCleanNotionDocumentTask:
db_session_with_containers.commit()
# Mock index processor to raise an exception
mock_index_processor = mock_index_processor_factory.init_index_processor.return_value
mock_index_processor = mock_index_processor_factory.return_value.init_index_processor.return_value
mock_index_processor.clean.side_effect = Exception("Index processor error")
# Execute cleanup task - it should handle the exception gracefully
clean_notion_document_task([document.id], dataset.id)
# Execute cleanup task - current implementation propagates the exception
with pytest.raises(Exception, match="Index processor error"):
clean_notion_document_task([document.id], dataset.id)
# Note: This test demonstrates the task's error handling capability.
# Even with external service errors, the database operations complete successfully.
@@ -803,8 +804,7 @@ class TestCleanNotionDocumentTask:
all_document_ids = [doc.id for doc in documents]
clean_notion_document_task(all_document_ids, dataset.id)
# Verify all documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0
# Verify all segments are deleted
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
== 0
@@ -914,8 +914,7 @@ class TestCleanNotionDocumentTask:
clean_notion_document_task([target_document.id], target_dataset.id)
# Verify only documents from target dataset are deleted
assert db_session_with_containers.query(Document).filter(Document.id == target_document.id).count() == 0
# Verify only documents' segments from target dataset are deleted
assert (
db_session_with_containers.query(DocumentSegment)
.filter(DocumentSegment.document_id == target_document.id)
@@ -1030,8 +1029,7 @@ class TestCleanNotionDocumentTask:
all_document_ids = [doc.id for doc in documents]
clean_notion_document_task(all_document_ids, dataset.id)
# Verify all documents and segments are deleted regardless of status
assert db_session_with_containers.query(Document).filter(Document.dataset_id == dataset.id).count() == 0
# Verify all segments are deleted regardless of status
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.dataset_id == dataset.id).count()
== 0
@@ -1142,8 +1140,7 @@ class TestCleanNotionDocumentTask:
# Execute cleanup task
clean_notion_document_task([document.id], dataset.id)
# Verify documents and segments are deleted
assert db_session_with_containers.query(Document).filter(Document.id == document.id).count() == 0
# Verify segments are deleted
assert (
db_session_with_containers.query(DocumentSegment).filter(DocumentSegment.document_id == document.id).count()
== 0

View File

@@ -0,0 +1,182 @@
from unittest.mock import MagicMock, patch
import pytest
from faker import Faker
from models import Account, Tenant, TenantAccountJoin, TenantAccountRole
from models.dataset import Dataset, Document, DocumentSegment
from tasks.document_indexing_update_task import document_indexing_update_task
class TestDocumentIndexingUpdateTask:
@pytest.fixture
def mock_external_dependencies(self):
"""Patch external collaborators used by the update task.
- IndexProcessorFactory.init_index_processor().clean(...)
- IndexingRunner.run([...])
"""
with (
patch("tasks.document_indexing_update_task.IndexProcessorFactory") as mock_factory,
patch("tasks.document_indexing_update_task.IndexingRunner") as mock_runner,
):
processor_instance = MagicMock()
mock_factory.return_value.init_index_processor.return_value = processor_instance
runner_instance = MagicMock()
mock_runner.return_value = runner_instance
yield {
"factory": mock_factory,
"processor": processor_instance,
"runner": mock_runner,
"runner_instance": runner_instance,
}
def _create_dataset_document_with_segments(self, db_session_with_containers, *, segment_count: int = 2):
fake = Faker()
# Account and tenant
account = Account(
email=fake.email(),
name=fake.name(),
interface_language="en-US",
status="active",
)
db_session_with_containers.add(account)
db_session_with_containers.commit()
tenant = Tenant(name=fake.company(), status="normal")
db_session_with_containers.add(tenant)
db_session_with_containers.commit()
join = TenantAccountJoin(
tenant_id=tenant.id,
account_id=account.id,
role=TenantAccountRole.OWNER,
current=True,
)
db_session_with_containers.add(join)
db_session_with_containers.commit()
# Dataset and document
dataset = Dataset(
tenant_id=tenant.id,
name=fake.company(),
description=fake.text(max_nb_chars=64),
data_source_type="upload_file",
indexing_technique="high_quality",
created_by=account.id,
)
db_session_with_containers.add(dataset)
db_session_with_containers.commit()
document = Document(
tenant_id=tenant.id,
dataset_id=dataset.id,
position=0,
data_source_type="upload_file",
batch="test_batch",
name=fake.file_name(),
created_from="upload_file",
created_by=account.id,
indexing_status="waiting",
enabled=True,
doc_form="text_model",
)
db_session_with_containers.add(document)
db_session_with_containers.commit()
# Segments
node_ids = []
for i in range(segment_count):
node_id = f"node-{i + 1}"
seg = DocumentSegment(
tenant_id=tenant.id,
dataset_id=dataset.id,
document_id=document.id,
position=i,
content=fake.text(max_nb_chars=32),
answer=None,
word_count=10,
tokens=5,
index_node_id=node_id,
status="completed",
created_by=account.id,
)
db_session_with_containers.add(seg)
node_ids.append(node_id)
db_session_with_containers.commit()
# Refresh to ensure ORM state
db_session_with_containers.refresh(dataset)
db_session_with_containers.refresh(document)
return dataset, document, node_ids
def test_cleans_segments_and_reindexes(self, db_session_with_containers, mock_external_dependencies):
dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers)
# Act
document_indexing_update_task(dataset.id, document.id)
# Ensure we see committed changes from another session
db_session_with_containers.expire_all()
# Assert document status updated before reindex
updated = db_session_with_containers.query(Document).where(Document.id == document.id).first()
assert updated.indexing_status == "parsing"
assert updated.processing_started_at is not None
# Segments should be deleted
remaining = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
)
assert remaining == 0
# Assert index processor clean was called with expected args
clean_call = mock_external_dependencies["processor"].clean.call_args
assert clean_call is not None
args, kwargs = clean_call
# args[0] is a Dataset instance (from another session) — validate by id
assert getattr(args[0], "id", None) == dataset.id
# args[1] should contain our node_ids
assert set(args[1]) == set(node_ids)
assert kwargs.get("with_keywords") is True
assert kwargs.get("delete_child_chunks") is True
# Assert indexing runner invoked with the updated document
run_call = mock_external_dependencies["runner_instance"].run.call_args
assert run_call is not None
run_docs = run_call[0][0]
assert len(run_docs) == 1
first = run_docs[0]
assert getattr(first, "id", None) == document.id
def test_clean_error_is_logged_and_indexing_continues(self, db_session_with_containers, mock_external_dependencies):
dataset, document, node_ids = self._create_dataset_document_with_segments(db_session_with_containers)
# Force clean to raise; task should continue to indexing
mock_external_dependencies["processor"].clean.side_effect = Exception("boom")
document_indexing_update_task(dataset.id, document.id)
# Ensure we see committed changes from another session
db_session_with_containers.expire_all()
# Indexing should still be triggered
mock_external_dependencies["runner_instance"].run.assert_called_once()
# Segments should remain (since clean failed before DB delete)
remaining = (
db_session_with_containers.query(DocumentSegment).where(DocumentSegment.document_id == document.id).count()
)
assert remaining > 0
def test_document_not_found_noop(self, db_session_with_containers, mock_external_dependencies):
fake = Faker()
# Act with non-existent document id
document_indexing_update_task(dataset_id=fake.uuid4(), document_id=fake.uuid4())
# Neither processor nor runner should be called
mock_external_dependencies["processor"].clean.assert_not_called()
mock_external_dependencies["runner_instance"].run.assert_not_called()

View File

@@ -0,0 +1,146 @@
import sys
import threading
import types
from unittest.mock import MagicMock
import commands
from libs.db_migration_lock import LockNotOwnedError, RedisError
HEARTBEAT_WAIT_TIMEOUT_SECONDS = 5.0
def _install_fake_flask_migrate(monkeypatch, upgrade_impl) -> None:
module = types.ModuleType("flask_migrate")
module.upgrade = upgrade_impl
monkeypatch.setitem(sys.modules, "flask_migrate", module)
def _invoke_upgrade_db() -> int:
try:
commands.upgrade_db.callback()
except SystemExit as e:
return int(e.code or 0)
return 0
def test_upgrade_db_skips_when_lock_not_acquired(monkeypatch, capsys):
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 1234)
lock = MagicMock()
lock.acquire.return_value = False
commands.redis_client.lock.return_value = lock
exit_code = _invoke_upgrade_db()
captured = capsys.readouterr()
assert exit_code == 0
assert "Database migration skipped" in captured.out
commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=1234, thread_local=False)
lock.acquire.assert_called_once_with(blocking=False)
lock.release.assert_not_called()
def test_upgrade_db_failure_not_masked_by_lock_release(monkeypatch, capsys):
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 321)
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = LockNotOwnedError("simulated")
commands.redis_client.lock.return_value = lock
def _upgrade():
raise RuntimeError("boom")
_install_fake_flask_migrate(monkeypatch, _upgrade)
exit_code = _invoke_upgrade_db()
captured = capsys.readouterr()
assert exit_code == 1
assert "Database migration failed: boom" in captured.out
commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=321, thread_local=False)
lock.acquire.assert_called_once_with(blocking=False)
lock.release.assert_called_once()
def test_upgrade_db_success_ignores_lock_not_owned_on_release(monkeypatch, capsys):
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 999)
lock = MagicMock()
lock.acquire.return_value = True
lock.release.side_effect = LockNotOwnedError("simulated")
commands.redis_client.lock.return_value = lock
_install_fake_flask_migrate(monkeypatch, lambda: None)
exit_code = _invoke_upgrade_db()
captured = capsys.readouterr()
assert exit_code == 0
assert "Database migration successful!" in captured.out
commands.redis_client.lock.assert_called_once_with(name="db_upgrade_lock", timeout=999, thread_local=False)
lock.acquire.assert_called_once_with(blocking=False)
lock.release.assert_called_once()
def test_upgrade_db_renews_lock_during_migration(monkeypatch, capsys):
"""
Ensure the lock is renewed while migrations are running, so the base TTL can stay short.
"""
# Use a small TTL so the heartbeat interval triggers quickly.
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3)
lock = MagicMock()
lock.acquire.return_value = True
commands.redis_client.lock.return_value = lock
renewed = threading.Event()
def _reacquire():
renewed.set()
return True
lock.reacquire.side_effect = _reacquire
def _upgrade():
assert renewed.wait(HEARTBEAT_WAIT_TIMEOUT_SECONDS)
_install_fake_flask_migrate(monkeypatch, _upgrade)
exit_code = _invoke_upgrade_db()
_ = capsys.readouterr()
assert exit_code == 0
assert lock.reacquire.call_count >= 1
def test_upgrade_db_ignores_reacquire_errors(monkeypatch, capsys):
# Use a small TTL so heartbeat runs during the upgrade call.
monkeypatch.setattr(commands, "DB_UPGRADE_LOCK_TTL_SECONDS", 0.3)
lock = MagicMock()
lock.acquire.return_value = True
commands.redis_client.lock.return_value = lock
attempted = threading.Event()
def _reacquire():
attempted.set()
raise RedisError("simulated")
lock.reacquire.side_effect = _reacquire
def _upgrade():
assert attempted.wait(HEARTBEAT_WAIT_TIMEOUT_SECONDS)
_install_fake_flask_migrate(monkeypatch, _upgrade)
exit_code = _invoke_upgrade_db()
_ = capsys.readouterr()
assert exit_code == 0
assert lock.reacquire.call_count >= 1

View File

@@ -25,15 +25,19 @@ class TestMessageCycleManagerOptimization:
task_state = Mock()
return MessageCycleManager(application_generate_entity=mock_application_generate_entity, task_state=task_state)
def test_get_message_event_type_with_message_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE_FILE when message has files."""
def test_get_message_event_type_with_assistant_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE_FILE when message has assistant-generated files.
This ensures that AI-generated images (belongs_to='assistant') trigger the MESSAGE_FILE event,
allowing the frontend to properly display generated image files with url field.
"""
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
# Setup mock session and message file
mock_session = Mock()
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
mock_message_file = Mock()
# Current implementation uses session.scalar(select(...))
mock_message_file.belongs_to = "assistant"
mock_session.scalar.return_value = mock_message_file
# Execute
@@ -44,6 +48,31 @@ class TestMessageCycleManagerOptimization:
assert result == StreamEvent.MESSAGE_FILE
mock_session.scalar.assert_called_once()
def test_get_message_event_type_with_user_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE when message only has user-uploaded files.
This is a regression test for the issue where user-uploaded images (belongs_to='user')
caused the LLM text response to be incorrectly tagged with MESSAGE_FILE event,
resulting in broken images in the chat UI. The query filters for belongs_to='assistant',
so when only user files exist, the database query returns None, resulting in MESSAGE event type.
"""
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
# Setup mock session and message file
mock_session = Mock()
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
# When querying for assistant files with only user files present, return None
# (simulates database query with belongs_to='assistant' filter returning no results)
mock_session.scalar.return_value = None
# Execute
with current_app.app_context():
result = message_cycle_manager.get_message_event_type("test-message-id")
# Assert
assert result == StreamEvent.MESSAGE
mock_session.scalar.assert_called_once()
def test_get_message_event_type_without_message_file(self, message_cycle_manager):
"""Test get_message_event_type returns MESSAGE when message has no files."""
with patch("core.app.task_pipeline.message_cycle_manager.session_factory") as mock_session_factory:
@@ -69,7 +98,7 @@ class TestMessageCycleManagerOptimization:
mock_session_factory.create_session.return_value.__enter__.return_value = mock_session
mock_message_file = Mock()
# Current implementation uses session.scalar(select(...))
mock_message_file.belongs_to = "assistant"
mock_session.scalar.return_value = mock_message_file
# Execute: compute event type once, then pass to message_to_stream_response

View File

@@ -0,0 +1,200 @@
"""Unit tests for TraceQueueManager telemetry guard.
This test suite verifies that TraceQueueManager correctly drops trace tasks
when telemetry is disabled, proving Bug 1 from code review is a false positive.
The guard logic moved from persistence.py to TraceQueueManager.add_trace_task()
at line 1282 of ops_trace_manager.py:
if self._enterprise_telemetry_enabled or self.trace_instance:
trace_task.app_id = self.app_id
trace_manager_queue.put(trace_task)
Tasks are only enqueued if EITHER:
- Enterprise telemetry is enabled (_enterprise_telemetry_enabled=True), OR
- A third-party trace instance (Langfuse, etc.) is configured
When BOTH are false, tasks are silently dropped (correct behavior).
"""
import queue
import sys
import types
from unittest.mock import MagicMock, patch
import pytest
@pytest.fixture
def trace_queue_manager_and_task(monkeypatch):
"""Fixture to provide TraceQueueManager and TraceTask with delayed imports."""
module_name = "core.ops.ops_trace_manager"
if module_name not in sys.modules:
ops_stub = types.ModuleType(module_name)
class StubTraceTask:
def __init__(self, trace_type):
self.trace_type = trace_type
self.app_id = None
class StubTraceQueueManager:
def __init__(self, app_id=None):
self.app_id = app_id
from core.telemetry.gateway import is_enterprise_telemetry_enabled
self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled()
self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id)
def add_trace_task(self, trace_task):
if self._enterprise_telemetry_enabled or self.trace_instance:
trace_task.app_id = self.app_id
from core.ops.ops_trace_manager import trace_manager_queue
trace_manager_queue.put(trace_task)
class StubOpsTraceManager:
@staticmethod
def get_ops_trace_instance(app_id):
return None
ops_stub.TraceQueueManager = StubTraceQueueManager
ops_stub.TraceTask = StubTraceTask
ops_stub.OpsTraceManager = StubOpsTraceManager
ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue)
monkeypatch.setitem(sys.modules, module_name, ops_stub)
from core.ops.entities.trace_entity import TraceTaskName
ops_module = __import__(module_name, fromlist=["TraceQueueManager", "TraceTask"])
TraceQueueManager = ops_module.TraceQueueManager
TraceTask = ops_module.TraceTask
return TraceQueueManager, TraceTask, TraceTaskName
class TestTraceQueueManagerTelemetryGuard:
"""Test TraceQueueManager's telemetry guard in add_trace_task()."""
def test_task_not_enqueued_when_telemetry_disabled_and_no_trace_instance(self, trace_queue_manager_and_task):
"""Verify task is NOT enqueued when telemetry disabled and no trace instance.
This is the core guard: when _enterprise_telemetry_enabled=False AND
trace_instance=None, the task should be silently dropped.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False),
patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="test-app-id")
manager.add_trace_task(trace_task)
mock_queue.put.assert_not_called()
def test_task_enqueued_when_telemetry_enabled(self, trace_queue_manager_and_task):
"""Verify task IS enqueued when enterprise telemetry is enabled.
When _enterprise_telemetry_enabled=True, the task should be enqueued
regardless of trace_instance state.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True),
patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="test-app-id")
manager.add_trace_task(trace_task)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.app_id == "test-app-id"
def test_task_enqueued_when_trace_instance_configured(self, trace_queue_manager_and_task):
"""Verify task IS enqueued when third-party trace instance is configured.
When trace_instance is not None (e.g., Langfuse configured), the task
should be enqueued even if enterprise telemetry is disabled.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
mock_trace_instance = MagicMock()
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False),
patch(
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance
),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="test-app-id")
manager.add_trace_task(trace_task)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.app_id == "test-app-id"
def test_task_enqueued_when_both_telemetry_and_trace_instance_enabled(self, trace_queue_manager_and_task):
"""Verify task IS enqueued when both telemetry and trace instance are enabled.
When both _enterprise_telemetry_enabled=True AND trace_instance is set,
the task should definitely be enqueued.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
mock_trace_instance = MagicMock()
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True),
patch(
"core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=mock_trace_instance
),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="test-app-id")
manager.add_trace_task(trace_task)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.app_id == "test-app-id"
def test_app_id_set_before_enqueue(self, trace_queue_manager_and_task):
"""Verify app_id is set on the task before enqueuing.
The guard logic sets trace_task.app_id = self.app_id before calling
trace_manager_queue.put(trace_task). This test verifies that behavior.
"""
TraceQueueManager, TraceTask, TraceTaskName = trace_queue_manager_and_task
mock_queue = MagicMock(spec=queue.Queue)
trace_task = TraceTask(trace_type=TraceTaskName.WORKFLOW_TRACE)
with (
patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True),
patch("core.ops.ops_trace_manager.OpsTraceManager.get_ops_trace_instance", return_value=None),
patch("core.ops.ops_trace_manager.trace_manager_queue", mock_queue),
):
manager = TraceQueueManager(app_id="expected-app-id")
manager.add_trace_task(trace_task)
called_task = mock_queue.put.call_args[0][0]
assert called_task.app_id == "expected-app-id"

View File

@@ -0,0 +1,181 @@
"""Unit tests for core.telemetry.emit() routing and enterprise-only filtering."""
from __future__ import annotations
import queue
import sys
import types
from unittest.mock import MagicMock, patch
import pytest
from core.ops.entities.trace_entity import TraceTaskName
from core.telemetry.events import TelemetryContext, TelemetryEvent
@pytest.fixture
def telemetry_test_setup(monkeypatch):
module_name = "core.ops.ops_trace_manager"
ops_stub = types.ModuleType(module_name)
class StubTraceTask:
def __init__(self, trace_type, **kwargs):
self.trace_type = trace_type
self.app_id = None
self.kwargs = kwargs
class StubTraceQueueManager:
def __init__(self, app_id=None, user_id=None):
self.app_id = app_id
self.user_id = user_id
self.trace_instance = StubOpsTraceManager.get_ops_trace_instance(app_id)
def add_trace_task(self, trace_task):
trace_task.app_id = self.app_id
from core.ops.ops_trace_manager import trace_manager_queue
trace_manager_queue.put(trace_task)
class StubOpsTraceManager:
@staticmethod
def get_ops_trace_instance(app_id):
return None
ops_stub.TraceQueueManager = StubTraceQueueManager
ops_stub.TraceTask = StubTraceTask
ops_stub.OpsTraceManager = StubOpsTraceManager
ops_stub.trace_manager_queue = MagicMock(spec=queue.Queue)
monkeypatch.setitem(sys.modules, module_name, ops_stub)
from core.telemetry import emit
return emit, ops_stub.trace_manager_queue
class TestTelemetryEmit:
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_emit_enterprise_trace_creates_trace_task(self, _mock_ee, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
event = TelemetryEvent(
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={"key": "value"},
)
emit_fn(event)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
def test_emit_community_trace_enqueued(self, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
event = TelemetryEvent(
name=TraceTaskName.WORKFLOW_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={},
)
emit_fn(event)
mock_queue.put.assert_called_once()
def test_emit_enterprise_only_trace_dropped_when_ee_disabled(self, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
event = TelemetryEvent(
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={},
)
emit_fn(event)
mock_queue.put.assert_not_called()
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_emit_all_enterprise_only_traces_allowed_when_ee_enabled(self, _mock_ee, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
enterprise_only_traces = [
TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
TraceTaskName.NODE_EXECUTION_TRACE,
TraceTaskName.PROMPT_GENERATION_TRACE,
]
for trace_name in enterprise_only_traces:
mock_queue.reset_mock()
event = TelemetryEvent(
name=trace_name,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={},
)
emit_fn(event)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.trace_type == trace_name
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_emit_passes_name_directly_to_trace_task(self, _mock_ee, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
event = TelemetryEvent(
name=TraceTaskName.DRAFT_NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={"extra": "data"},
)
emit_fn(event)
mock_queue.put.assert_called_once()
called_task = mock_queue.put.call_args[0][0]
assert called_task.trace_type == TraceTaskName.DRAFT_NODE_EXECUTION_TRACE
assert isinstance(called_task.trace_type, TraceTaskName)
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_emit_with_provided_trace_manager(self, _mock_ee, telemetry_test_setup):
emit_fn, mock_queue = telemetry_test_setup
mock_trace_manager = MagicMock()
mock_trace_manager.add_trace_task = MagicMock()
event = TelemetryEvent(
name=TraceTaskName.NODE_EXECUTION_TRACE,
context=TelemetryContext(
tenant_id="test-tenant",
user_id="test-user",
app_id="test-app",
),
payload={},
)
emit_fn(event, trace_manager=mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
called_task = mock_trace_manager.add_trace_task.call_args[0][0]
assert called_task.trace_type == TraceTaskName.NODE_EXECUTION_TRACE

View File

@@ -0,0 +1,225 @@
from __future__ import annotations
import sys
from unittest.mock import MagicMock, patch
import pytest
from core.telemetry.gateway import emit, is_enterprise_telemetry_enabled
from enterprise.telemetry.contracts import TelemetryCase
class TestTelemetryCoreExports:
def test_is_enterprise_telemetry_enabled_exported(self) -> None:
from core.telemetry.gateway import is_enterprise_telemetry_enabled as exported_func
assert callable(exported_func)
@pytest.fixture
def mock_ops_trace_manager():
mock_module = MagicMock()
mock_trace_task_class = MagicMock()
mock_trace_task_class.return_value = MagicMock()
mock_module.TraceTask = mock_trace_task_class
mock_module.TraceQueueManager = MagicMock()
mock_trace_entity = MagicMock()
mock_trace_task_name = MagicMock()
mock_trace_task_name.return_value = "workflow"
mock_trace_entity.TraceTaskName = mock_trace_task_name
with (
patch.dict(sys.modules, {"core.ops.ops_trace_manager": mock_module}),
patch.dict(sys.modules, {"core.ops.entities.trace_entity": mock_trace_entity}),
):
yield mock_module, mock_trace_entity
class TestGatewayIntegrationTraceRouting:
@pytest.fixture
def mock_trace_manager(self) -> MagicMock:
return MagicMock()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_ce_eligible_trace_routed_to_trace_manager(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True):
context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"}
payload = {"workflow_run_id": "run-abc"}
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_ce_eligible_trace_routed_when_ee_disabled(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"workflow_run_id": "run-abc"}
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_enterprise_only_trace_dropped_when_ee_disabled(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_id": "node-abc"}
emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_not_called()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_enterprise_only_trace_routed_when_ee_enabled(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_id": "node-abc"}
emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
class TestGatewayIntegrationMetricRouting:
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_metric_case_routes_to_celery_task(
self,
_mock_ee_enabled: MagicMock,
) -> None:
from enterprise.telemetry.contracts import TelemetryEnvelope
with patch("tasks.enterprise_telemetry_task.process_enterprise_telemetry.delay") as mock_delay:
context = {"tenant_id": "tenant-123"}
payload = {"app_id": "app-abc", "name": "My App"}
emit(TelemetryCase.APP_CREATED, context, payload)
mock_delay.assert_called_once()
envelope_json = mock_delay.call_args[0][0]
envelope = TelemetryEnvelope.model_validate_json(envelope_json)
assert envelope.case == TelemetryCase.APP_CREATED
assert envelope.tenant_id == "tenant-123"
assert envelope.payload["app_id"] == "app-abc"
@pytest.mark.usefixtures("mock_ops_trace_manager")
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_tool_execution_trace_routed(
self,
_mock_ee_enabled: MagicMock,
) -> None:
mock_trace_manager = MagicMock()
context = {"tenant_id": "tenant-123", "app_id": "app-123"}
payload = {"tool_name": "test_tool", "tool_inputs": {}, "tool_outputs": "result"}
emit(TelemetryCase.TOOL_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@pytest.mark.usefixtures("mock_ops_trace_manager")
@patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=True)
def test_moderation_check_trace_routed(
self,
_mock_ee_enabled: MagicMock,
) -> None:
mock_trace_manager = MagicMock()
context = {"tenant_id": "tenant-123", "app_id": "app-123"}
payload = {"message_id": "msg-123", "moderation_result": {"flagged": False}}
emit(TelemetryCase.MODERATION_CHECK, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
class TestGatewayIntegrationCEEligibility:
@pytest.fixture
def mock_trace_manager(self) -> MagicMock:
return MagicMock()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_workflow_run_is_ce_eligible(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"workflow_run_id": "run-abc"}
emit(TelemetryCase.WORKFLOW_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_message_run_is_ce_eligible(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"message_id": "msg-abc", "conversation_id": "conv-123"}
emit(TelemetryCase.MESSAGE_RUN, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_called_once()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_node_execution_not_ce_eligible(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_id": "node-abc"}
emit(TelemetryCase.NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_not_called()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_draft_node_execution_not_ce_eligible(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456"}
payload = {"node_execution_data": {}}
emit(TelemetryCase.DRAFT_NODE_EXECUTION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_not_called()
@pytest.mark.usefixtures("mock_ops_trace_manager")
def test_prompt_generation_not_ce_eligible(
self,
mock_trace_manager: MagicMock,
) -> None:
with patch("core.telemetry.gateway.is_enterprise_telemetry_enabled", return_value=False):
context = {"app_id": "app-123", "user_id": "user-456", "tenant_id": "tenant-789"}
payload = {"operation_type": "generate", "instruction": "test"}
emit(TelemetryCase.PROMPT_GENERATION, context, payload, mock_trace_manager)
mock_trace_manager.add_trace_task.assert_not_called()
class TestIsEnterpriseTelemetryEnabled:
def test_returns_false_when_exporter_import_fails(self) -> None:
with patch.dict(sys.modules, {"enterprise.telemetry.exporter": None}):
result = is_enterprise_telemetry_enabled()
assert result is False
def test_function_is_callable(self) -> None:
assert callable(is_enterprise_telemetry_enabled)

View File

@@ -0,0 +1,230 @@
"""Unit tests for telemetry gateway contracts."""
from __future__ import annotations
import pytest
from pydantic import ValidationError
from core.telemetry.gateway import CASE_ROUTING
from enterprise.telemetry.contracts import CaseRoute, SignalType, TelemetryCase, TelemetryEnvelope
class TestTelemetryCase:
"""Tests for TelemetryCase enum."""
def test_all_cases_defined(self) -> None:
"""Verify all 14 telemetry cases are defined."""
expected_cases = {
"WORKFLOW_RUN",
"NODE_EXECUTION",
"DRAFT_NODE_EXECUTION",
"MESSAGE_RUN",
"TOOL_EXECUTION",
"MODERATION_CHECK",
"SUGGESTED_QUESTION",
"DATASET_RETRIEVAL",
"GENERATE_NAME",
"PROMPT_GENERATION",
"APP_CREATED",
"APP_UPDATED",
"APP_DELETED",
"FEEDBACK_CREATED",
}
actual_cases = {case.name for case in TelemetryCase}
assert actual_cases == expected_cases
def test_case_values(self) -> None:
"""Verify case enum values are correct."""
assert TelemetryCase.WORKFLOW_RUN.value == "workflow_run"
assert TelemetryCase.NODE_EXECUTION.value == "node_execution"
assert TelemetryCase.DRAFT_NODE_EXECUTION.value == "draft_node_execution"
assert TelemetryCase.MESSAGE_RUN.value == "message_run"
assert TelemetryCase.TOOL_EXECUTION.value == "tool_execution"
assert TelemetryCase.MODERATION_CHECK.value == "moderation_check"
assert TelemetryCase.SUGGESTED_QUESTION.value == "suggested_question"
assert TelemetryCase.DATASET_RETRIEVAL.value == "dataset_retrieval"
assert TelemetryCase.GENERATE_NAME.value == "generate_name"
assert TelemetryCase.PROMPT_GENERATION.value == "prompt_generation"
assert TelemetryCase.APP_CREATED.value == "app_created"
assert TelemetryCase.APP_UPDATED.value == "app_updated"
assert TelemetryCase.APP_DELETED.value == "app_deleted"
assert TelemetryCase.FEEDBACK_CREATED.value == "feedback_created"
class TestCaseRoute:
"""Tests for CaseRoute model."""
def test_valid_trace_route(self) -> None:
"""Verify valid trace route creation."""
route = CaseRoute(signal_type=SignalType.TRACE, ce_eligible=True)
assert route.signal_type == SignalType.TRACE
assert route.ce_eligible is True
def test_valid_metric_log_route(self) -> None:
"""Verify valid metric_log route creation."""
route = CaseRoute(signal_type=SignalType.METRIC_LOG, ce_eligible=False)
assert route.signal_type == SignalType.METRIC_LOG
assert route.ce_eligible is False
def test_invalid_signal_type(self) -> None:
"""Verify invalid signal_type is rejected."""
with pytest.raises(ValidationError):
CaseRoute(signal_type="invalid", ce_eligible=True)
class TestTelemetryEnvelope:
"""Tests for TelemetryEnvelope model."""
def test_valid_envelope_minimal(self) -> None:
"""Verify valid minimal envelope creation."""
envelope = TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
tenant_id="tenant-123",
event_id="event-456",
payload={"key": "value"},
)
assert envelope.case == TelemetryCase.WORKFLOW_RUN
assert envelope.tenant_id == "tenant-123"
assert envelope.event_id == "event-456"
assert envelope.payload == {"key": "value"}
assert envelope.metadata is None
def test_valid_envelope_full(self) -> None:
"""Verify valid envelope with all fields."""
metadata = {"payload_ref": "telemetry/tenant-789/event-012.json"}
envelope = TelemetryEnvelope(
case=TelemetryCase.MESSAGE_RUN,
tenant_id="tenant-789",
event_id="event-012",
payload={"message": "hello"},
metadata=metadata,
)
assert envelope.case == TelemetryCase.MESSAGE_RUN
assert envelope.tenant_id == "tenant-789"
assert envelope.event_id == "event-012"
assert envelope.payload == {"message": "hello"}
assert envelope.metadata == metadata
def test_missing_required_case(self) -> None:
"""Verify missing case field is rejected."""
with pytest.raises(ValidationError):
TelemetryEnvelope(
tenant_id="tenant-123",
event_id="event-456",
payload={"key": "value"},
)
def test_missing_required_tenant_id(self) -> None:
"""Verify missing tenant_id field is rejected."""
with pytest.raises(ValidationError):
TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
event_id="event-456",
payload={"key": "value"},
)
def test_missing_required_event_id(self) -> None:
"""Verify missing event_id field is rejected."""
with pytest.raises(ValidationError):
TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
tenant_id="tenant-123",
payload={"key": "value"},
)
def test_missing_required_payload(self) -> None:
"""Verify missing payload field is rejected."""
with pytest.raises(ValidationError):
TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
tenant_id="tenant-123",
event_id="event-456",
)
def test_metadata_none(self) -> None:
"""Verify metadata can be None."""
envelope = TelemetryEnvelope(
case=TelemetryCase.WORKFLOW_RUN,
tenant_id="tenant-123",
event_id="event-456",
payload={"key": "value"},
metadata=None,
)
assert envelope.metadata is None
class TestCaseRouting:
"""Tests for CASE_ROUTING table."""
def test_all_cases_routed(self) -> None:
"""Verify all 14 cases have routing entries."""
assert len(CASE_ROUTING) == 14
for case in TelemetryCase:
assert case in CASE_ROUTING
def test_trace_ce_eligible_cases(self) -> None:
"""Verify trace cases with CE eligibility."""
ce_eligible_trace_cases = {
TelemetryCase.WORKFLOW_RUN,
TelemetryCase.MESSAGE_RUN,
}
for case in ce_eligible_trace_cases:
route = CASE_ROUTING[case]
assert route.signal_type == SignalType.TRACE
assert route.ce_eligible is True
def test_trace_enterprise_only_cases(self) -> None:
"""Verify trace cases that are enterprise-only."""
enterprise_only_trace_cases = {
TelemetryCase.NODE_EXECUTION,
TelemetryCase.DRAFT_NODE_EXECUTION,
TelemetryCase.PROMPT_GENERATION,
}
for case in enterprise_only_trace_cases:
route = CASE_ROUTING[case]
assert route.signal_type == SignalType.TRACE
assert route.ce_eligible is False
def test_metric_log_cases(self) -> None:
"""Verify metric/log-only cases."""
metric_log_cases = {
TelemetryCase.APP_CREATED,
TelemetryCase.APP_UPDATED,
TelemetryCase.APP_DELETED,
TelemetryCase.FEEDBACK_CREATED,
}
for case in metric_log_cases:
route = CASE_ROUTING[case]
assert route.signal_type == SignalType.METRIC_LOG
assert route.ce_eligible is False
def test_routing_table_completeness(self) -> None:
"""Verify routing table covers all cases with correct types."""
trace_cases = {
TelemetryCase.WORKFLOW_RUN,
TelemetryCase.MESSAGE_RUN,
TelemetryCase.NODE_EXECUTION,
TelemetryCase.DRAFT_NODE_EXECUTION,
TelemetryCase.PROMPT_GENERATION,
TelemetryCase.TOOL_EXECUTION,
TelemetryCase.MODERATION_CHECK,
TelemetryCase.SUGGESTED_QUESTION,
TelemetryCase.DATASET_RETRIEVAL,
TelemetryCase.GENERATE_NAME,
}
metric_log_cases = {
TelemetryCase.APP_CREATED,
TelemetryCase.APP_UPDATED,
TelemetryCase.APP_DELETED,
TelemetryCase.FEEDBACK_CREATED,
}
all_cases = trace_cases | metric_log_cases
assert len(all_cases) == 14
assert all_cases == set(TelemetryCase)
for case in trace_cases:
assert CASE_ROUTING[case].signal_type == SignalType.TRACE
for case in metric_log_cases:
assert CASE_ROUTING[case].signal_type == SignalType.METRIC_LOG

Some files were not shown because too many files have changed in this diff Show More