mirror of
https://github.com/langgenius/dify.git
synced 2026-03-15 04:07:01 +00:00
Compare commits
73 Commits
1.13.0
...
fix/enterp
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
d22f0dfdac | ||
|
|
ef7918764b | ||
|
|
cf616ddc58 | ||
|
|
da06f738f7 | ||
|
|
c0b05db9d7 | ||
|
|
248504acef | ||
|
|
d105a2f568 | ||
|
|
64fcd9859f | ||
|
|
0f938d453c | ||
|
|
6625828246 | ||
|
|
7a1f0e3258 | ||
|
|
9a682f1009 | ||
|
|
877de7fb22 | ||
|
|
d81684d8d1 | ||
|
|
c900460ab3 | ||
|
|
5afb24f461 | ||
|
|
808002fbbd | ||
|
|
757fabda1e | ||
|
|
858ccd8746 | ||
|
|
ea35ee0a3e | ||
|
|
0e9dc86f3b | ||
|
|
0ed39d81e9 | ||
|
|
2b739b9544 | ||
|
|
22e82297c5 | ||
|
|
7ef139cadd | ||
|
|
bf5a327156 | ||
|
|
d94af41f07 | ||
|
|
8d8552cbb9 | ||
|
|
58524fd7fd | ||
|
|
2d7bffcc11 | ||
|
|
5025e29220 | ||
|
|
3cdc9c119e | ||
|
|
18ba367b11 | ||
|
|
d0bd74fccb | ||
|
|
5ccbc00eb9 | ||
|
|
94603b5408 | ||
|
|
8d4bd5636b | ||
|
|
ee0c4a8852 | ||
|
|
6032c598b0 | ||
|
|
afdd5b6c86 | ||
|
|
9acdfbde2f | ||
|
|
1977e68b2d | ||
|
|
e9a7e8f77f | ||
|
|
9e2b28c950 | ||
|
|
affd07ae94 | ||
|
|
111c76b71f | ||
|
|
793d22754e | ||
|
|
b62965034e | ||
|
|
016d72a8c6 | ||
|
|
08b8eff933 | ||
|
|
579cdea820 | ||
|
|
125f7e3ab4 | ||
|
|
400ed2fd72 | ||
|
|
840a8f3fc2 | ||
|
|
b4a5296fd1 | ||
|
|
d7c3ae50dc | ||
|
|
b921711e9e | ||
|
|
fb38ad84e1 | ||
|
|
91c854b5be | ||
|
|
d35b231941 | ||
|
|
849b4b8c40 | ||
|
|
990e8feee8 | ||
|
|
53641019b1 | ||
|
|
d1f10ff301 | ||
|
|
c8027e168b | ||
|
|
aae3f76999 | ||
|
|
2860c72b03 | ||
|
|
fcb53383df | ||
|
|
540e1db83c | ||
|
|
2f75e38c08 | ||
|
|
cd03e0a9ef | ||
|
|
df2421d187 | ||
|
|
0ba321d840 |
1
.codex/skills/component-refactoring
Symbolic link
1
.codex/skills/component-refactoring
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/component-refactoring
|
||||
1
.codex/skills/frontend-code-review
Symbolic link
1
.codex/skills/frontend-code-review
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/frontend-code-review
|
||||
1
.codex/skills/frontend-testing
Symbolic link
1
.codex/skills/frontend-testing
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/frontend-testing
|
||||
1
.codex/skills/orpc-contract-first
Symbolic link
1
.codex/skills/orpc-contract-first
Symbolic link
@@ -0,0 +1 @@
|
||||
../../.agents/skills/orpc-contract-first
|
||||
7
.github/CODEOWNERS
vendored
7
.github/CODEOWNERS
vendored
@@ -24,10 +24,6 @@
|
||||
/api/services/tools/mcp_tools_manage_service.py @Nov1c444
|
||||
/api/controllers/mcp/ @Nov1c444
|
||||
/api/controllers/console/app/mcp_server.py @Nov1c444
|
||||
|
||||
# Backend - Tests
|
||||
/api/tests/ @laipz8200 @QuantumGhost
|
||||
|
||||
/api/tests/**/*mcp* @Nov1c444
|
||||
|
||||
# Backend - Workflow - Engine (Core graph execution engine)
|
||||
@@ -238,9 +234,6 @@
|
||||
# Frontend - Base Components
|
||||
/web/app/components/base/ @iamjoel @zxhlyh
|
||||
|
||||
# Frontend - Base Components Tests
|
||||
/web/app/components/base/**/*.spec.tsx @hyoban @CodingOnStar
|
||||
|
||||
# Frontend - Utils and Hooks
|
||||
/web/utils/classnames.ts @iamjoel @zxhlyh
|
||||
/web/utils/time.ts @iamjoel @zxhlyh
|
||||
|
||||
23
.github/workflows/autofix.yml
vendored
23
.github/workflows/autofix.yml
vendored
@@ -79,6 +79,29 @@ jobs:
|
||||
find . -name "*.py" -type f -exec sed -i.bak -E 's/"([^"]+)" \| None/Optional["\1"]/g; s/'"'"'([^'"'"']+)'"'"' \| None/Optional['"'"'\1'"'"']/g' {} \;
|
||||
find . -name "*.py.bak" -type f -delete
|
||||
|
||||
- name: Install pnpm
|
||||
uses: pnpm/action-setup@v4
|
||||
with:
|
||||
package_json_file: web/package.json
|
||||
run_install: false
|
||||
|
||||
- name: Setup Node.js
|
||||
uses: actions/setup-node@v6
|
||||
with:
|
||||
node-version: 24
|
||||
cache: pnpm
|
||||
cache-dependency-path: ./web/pnpm-lock.yaml
|
||||
|
||||
- name: Install web dependencies
|
||||
run: |
|
||||
cd web
|
||||
pnpm install --frozen-lockfile
|
||||
|
||||
- name: ESLint autofix
|
||||
run: |
|
||||
cd web
|
||||
pnpm lint:fix || true
|
||||
|
||||
# mdformat breaks YAML front matter in markdown files. Add --exclude for directories containing YAML front matter.
|
||||
- name: mdformat
|
||||
run: |
|
||||
|
||||
1
.github/workflows/build-push.yml
vendored
1
.github/workflows/build-push.yml
vendored
@@ -8,7 +8,6 @@ on:
|
||||
- "build/**"
|
||||
- "release/e-*"
|
||||
- "hotfix/**"
|
||||
- "feat/hitl-backend"
|
||||
tags:
|
||||
- "*"
|
||||
|
||||
|
||||
8
.github/workflows/deploy-hitl.yml
vendored
8
.github/workflows/deploy-hitl.yml
vendored
@@ -4,7 +4,8 @@ on:
|
||||
workflow_run:
|
||||
workflows: ["Build and Push API & Web"]
|
||||
branches:
|
||||
- "build/feat/hitl"
|
||||
- "feat/hitl-frontend"
|
||||
- "feat/hitl-backend"
|
||||
types:
|
||||
- completed
|
||||
|
||||
@@ -13,7 +14,10 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
if: |
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.event.workflow_run.head_branch == 'build/feat/hitl'
|
||||
(
|
||||
github.event.workflow_run.head_branch == 'feat/hitl-frontend' ||
|
||||
github.event.workflow_run.head_branch == 'feat/hitl-backend'
|
||||
)
|
||||
steps:
|
||||
- name: Deploy to server
|
||||
uses: appleboy/ssh-action@v1
|
||||
|
||||
2
.github/workflows/web-tests.yml
vendored
2
.github/workflows/web-tests.yml
vendored
@@ -39,7 +39,7 @@ jobs:
|
||||
run: pnpm install --frozen-lockfile
|
||||
|
||||
- name: Run tests
|
||||
run: pnpm test:ci
|
||||
run: pnpm test:coverage
|
||||
|
||||
- name: Coverage Summary
|
||||
if: always()
|
||||
|
||||
2
.vscode/launch.json.template
vendored
2
.vscode/launch.json.template
vendored
@@ -37,7 +37,7 @@
|
||||
"-c",
|
||||
"1",
|
||||
"-Q",
|
||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,workflow_based_app_execution",
|
||||
"dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention",
|
||||
"--loglevel",
|
||||
"INFO"
|
||||
],
|
||||
|
||||
@@ -715,31 +715,5 @@ ANNOTATION_IMPORT_MAX_CONCURRENT=5
|
||||
# Sandbox expired records clean configuration
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_GRACEFUL_PERIOD=21
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_SIZE=1000
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL=200
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS=30
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_TASK_LOCK_TTL=90000
|
||||
|
||||
|
||||
# Redis URL used for PubSub between API and
|
||||
# celery worker
|
||||
# defaults to url constructed from `REDIS_*`
|
||||
# configurations
|
||||
PUBSUB_REDIS_URL=
|
||||
# Pub/sub channel type for streaming events.
|
||||
# valid options are:
|
||||
#
|
||||
# - pubsub: for normal Pub/Sub
|
||||
# - sharded: for sharded Pub/Sub
|
||||
#
|
||||
# It's highly recommended to use sharded Pub/Sub AND redis cluster
|
||||
# for large deployments.
|
||||
PUBSUB_REDIS_CHANNEL_TYPE=pubsub
|
||||
# Whether to use Redis cluster mode while running
|
||||
# PubSub.
|
||||
# It's highly recommended to enable this for large deployments.
|
||||
PUBSUB_REDIS_USE_CLUSTERS=false
|
||||
|
||||
# Whether to Enable human input timeout check task
|
||||
ENABLE_HUMAN_INPUT_TIMEOUT_TASK=true
|
||||
# Human input timeout check interval in minutes
|
||||
HUMAN_INPUT_TIMEOUT_TASK_INTERVAL=1
|
||||
|
||||
@@ -36,8 +36,6 @@ ignore_imports =
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph
|
||||
core.workflow.nodes.loop.loop_node -> core.workflow.graph_engine.command_channels
|
||||
# TODO(QuantumGhost): fix the import violation later
|
||||
core.workflow.entities.pause_reason -> core.workflow.nodes.human_input.entities
|
||||
|
||||
[importlinter:contract:workflow-infrastructure-dependencies]
|
||||
name = Workflow Infrastructure Dependencies
|
||||
@@ -52,14 +50,14 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
core.workflow.graph_engine.command_channels.redis_channel -> extensions.ext_redis
|
||||
core.workflow.graph_engine.manager -> extensions.ext_redis
|
||||
# TODO(QuantumGhost): use DI to avoid depending on global DB.
|
||||
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||
|
||||
[importlinter:contract:workflow-external-imports]
|
||||
name = Workflow External Imports
|
||||
@@ -124,6 +122,11 @@ ignore_imports =
|
||||
core.workflow.nodes.http_request.node -> core.tools.tool_file_manager
|
||||
core.workflow.nodes.iteration.iteration_node -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.index_processor_factory
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.datasource.retrieval_service
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.dataset_retrieval
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> models.dataset
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> services.feature_service
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_runtime.model_providers.__base.large_language_model
|
||||
core.workflow.nodes.llm.llm_utils -> configs
|
||||
core.workflow.nodes.llm.llm_utils -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.file.models
|
||||
@@ -133,6 +136,7 @@ ignore_imports =
|
||||
core.workflow.nodes.llm.llm_utils -> models.provider
|
||||
core.workflow.nodes.llm.llm_utils -> services.credit_pool_service
|
||||
core.workflow.nodes.llm.node -> core.tools.signature
|
||||
core.workflow.nodes.template_transform.template_transform_node -> configs
|
||||
core.workflow.nodes.tool.tool_node -> core.callback_handler.workflow_tool_callback_handler
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.tool_engine
|
||||
core.workflow.nodes.tool.tool_node -> core.tools.tool_manager
|
||||
@@ -141,9 +145,9 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> core.agent.entities
|
||||
core.workflow.nodes.agent.agent_node -> core.agent.plugin_entities
|
||||
core.workflow.nodes.base.node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.human_input.human_input_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.app_config.entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.llm.node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.app.entities.app_invoke_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.prompt.advanced_prompt_transform
|
||||
@@ -159,6 +163,9 @@ ignore_imports =
|
||||
core.workflow.workflow_entry -> core.app.workflow.node_factory
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.datasource_manager
|
||||
core.workflow.nodes.datasource.datasource_node -> core.datasource.utils.message_transformer
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.agent_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.entities.model_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.model_manager
|
||||
core.workflow.nodes.llm.llm_utils -> core.entities.provider_entities
|
||||
core.workflow.nodes.parameter_extractor.parameter_extractor_node -> core.model_manager
|
||||
core.workflow.nodes.question_classifier.question_classifier_node -> core.model_manager
|
||||
@@ -207,6 +214,7 @@ ignore_imports =
|
||||
core.workflow.nodes.llm.node -> core.llm_generator.output_parser.structured_output
|
||||
core.workflow.nodes.llm.node -> core.model_manager
|
||||
core.workflow.nodes.agent.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.prompt.simple_prompt_transform
|
||||
core.workflow.nodes.llm.entities -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.llm_utils -> core.prompt.entities.advanced_prompt_entities
|
||||
core.workflow.nodes.llm.node -> core.prompt.entities.advanced_prompt_entities
|
||||
@@ -222,6 +230,7 @@ ignore_imports =
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> services.summary_index_service
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> tasks.generate_summary_index_task
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> core.rag.index_processor.processor.paragraph_index_processor
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> core.rag.retrieval.retrieval_methods
|
||||
core.workflow.nodes.llm.node -> models.dataset
|
||||
core.workflow.nodes.agent.agent_node -> core.tools.utils.message_transformer
|
||||
core.workflow.nodes.llm.file_saver -> core.tools.signature
|
||||
@@ -239,7 +248,6 @@ ignore_imports =
|
||||
core.workflow.nodes.document_extractor.node -> core.variables.segments
|
||||
core.workflow.nodes.http_request.executor -> core.variables.segments
|
||||
core.workflow.nodes.http_request.node -> core.variables.segments
|
||||
core.workflow.nodes.human_input.entities -> core.variables.consts
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables.segments
|
||||
core.workflow.nodes.iteration.iteration_node -> core.variables.variables
|
||||
@@ -280,12 +288,12 @@ ignore_imports =
|
||||
core.workflow.nodes.agent.agent_node -> extensions.ext_database
|
||||
core.workflow.nodes.datasource.datasource_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_index.knowledge_index_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_database
|
||||
core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node -> extensions.ext_redis
|
||||
core.workflow.nodes.llm.file_saver -> extensions.ext_database
|
||||
core.workflow.nodes.llm.llm_utils -> extensions.ext_database
|
||||
core.workflow.nodes.llm.node -> extensions.ext_database
|
||||
core.workflow.nodes.tool.tool_node -> extensions.ext_database
|
||||
core.workflow.nodes.human_input.human_input_node -> extensions.ext_database
|
||||
core.workflow.nodes.human_input.human_input_node -> core.repositories.human_input_repository
|
||||
core.workflow.workflow_entry -> extensions.otel.runtime
|
||||
core.workflow.nodes.agent.agent_node -> models
|
||||
core.workflow.nodes.base.node -> models.enums
|
||||
|
||||
@@ -106,10 +106,10 @@ ignore = [
|
||||
"N803", # invalid-argument-name
|
||||
]
|
||||
"tests/*" = [
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests,
|
||||
"S110", # allow ignoring exceptions in tests code (currently)
|
||||
|
||||
"F811", # redefined-while-unused
|
||||
"T201", # allow print in tests,
|
||||
"S110", # allow ignoring exceptions in tests code (currently)
|
||||
"PT019", # @patch-injected params look like unused fixtures
|
||||
]
|
||||
"controllers/console/explore/trial.py" = ["TID251"]
|
||||
"controllers/console/human_input_form.py" = ["TID251"]
|
||||
|
||||
2
api/.vscode/launch.json.example
vendored
2
api/.vscode/launch.json.example
vendored
@@ -54,7 +54,7 @@
|
||||
"--loglevel",
|
||||
"DEBUG",
|
||||
"-Q",
|
||||
"dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,workflow_based_app_execution,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
"dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor"
|
||||
]
|
||||
}
|
||||
]
|
||||
|
||||
@@ -122,7 +122,8 @@ These commands assume you start from the repository root.
|
||||
|
||||
```bash
|
||||
cd api
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q api_token,dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention
|
||||
# Note: enterprise_telemetry queue is only used in Enterprise Edition
|
||||
uv run celery -A app.celery worker -P threads -c 2 --loglevel INFO -Q dataset,priority_dataset,priority_pipeline,pipeline,mail,ops_trace,app_deletion,plugin,workflow_storage,conversation,workflow,schedule_poller,schedule_executor,triggered_workflow_dispatcher,trigger_refresh_executor,retention,enterprise_telemetry
|
||||
```
|
||||
|
||||
1. Optional: start Celery Beat (scheduled tasks, in a new terminal).
|
||||
|
||||
@@ -1,16 +1,45 @@
|
||||
import logging
|
||||
import time
|
||||
|
||||
from flask import request
|
||||
from opentelemetry.trace import get_current_span
|
||||
from opentelemetry.trace.span import INVALID_SPAN_ID, INVALID_TRACE_ID
|
||||
|
||||
from configs import dify_config
|
||||
from contexts.wrapper import RecyclableContextVar
|
||||
from controllers.console.error import UnauthorizedAndForceLogout
|
||||
from core.logging.context import init_request_context
|
||||
from dify_app import DifyApp
|
||||
from services.enterprise.enterprise_service import EnterpriseService
|
||||
from services.feature_service import LicenseStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Console bootstrap APIs exempt from license check.
|
||||
# Defined at module level to avoid per-request tuple construction.
|
||||
# - system-features: license status for expiry UI (GlobalPublicStoreProvider)
|
||||
# - setup: install/setup status check (AppInitializer)
|
||||
# - init: init password validation for fresh install (InitPasswordPopup)
|
||||
# - login: auto-login after setup completion (InstallForm)
|
||||
# - features: billing/plan features (ProviderContextProvider)
|
||||
# - account/profile: login check + user profile (AppContextProvider, useIsLogin)
|
||||
# - workspaces/current: workspace + model providers (AppContextProvider)
|
||||
# - version: version check (AppContextProvider)
|
||||
# - activate/check: invitation link validation (signin page)
|
||||
# Without these exemptions, the signin page triggers location.reload()
|
||||
# on unauthorized_and_force_logout, causing an infinite loop.
|
||||
_CONSOLE_EXEMPT_PREFIXES = (
|
||||
"/console/api/system-features",
|
||||
"/console/api/setup",
|
||||
"/console/api/init",
|
||||
"/console/api/login",
|
||||
"/console/api/features",
|
||||
"/console/api/account/profile",
|
||||
"/console/api/workspaces/current",
|
||||
"/console/api/version",
|
||||
"/console/api/activate/check",
|
||||
)
|
||||
|
||||
|
||||
# ----------------------------
|
||||
# Application Factory Function
|
||||
@@ -31,6 +60,39 @@ def create_flask_app_with_configs() -> DifyApp:
|
||||
init_request_context()
|
||||
RecyclableContextVar.increment_thread_recycles()
|
||||
|
||||
# Enterprise license validation for API endpoints (both console and webapp)
|
||||
# When license expires, block all API access except bootstrap endpoints needed
|
||||
# for the frontend to load the license expiration page without infinite reloads.
|
||||
if dify_config.ENTERPRISE_ENABLED:
|
||||
is_console_api = request.path.startswith("/console/api/")
|
||||
is_webapp_api = request.path.startswith("/api/") and not is_console_api
|
||||
|
||||
if is_console_api or is_webapp_api:
|
||||
if is_console_api:
|
||||
is_exempt = any(request.path.startswith(p) for p in _CONSOLE_EXEMPT_PREFIXES)
|
||||
else: # webapp API
|
||||
is_exempt = request.path.startswith("/api/system-features")
|
||||
|
||||
if not is_exempt:
|
||||
try:
|
||||
# Check license status with caching (10 min TTL)
|
||||
license_status = EnterpriseService.get_cached_license_status()
|
||||
if license_status in (LicenseStatus.INACTIVE, LicenseStatus.EXPIRED, LicenseStatus.LOST):
|
||||
# Cookie clearing is handled by register_external_error_handlers
|
||||
# in libs/external_api.py which detects the error code and calls
|
||||
# build_force_logout_cookie_headers(). Frontend then checks
|
||||
# code === 'unauthorized_and_force_logout' and calls location.reload().
|
||||
raise UnauthorizedAndForceLogout(
|
||||
f"Enterprise license is {license_status}. Please contact your administrator."
|
||||
)
|
||||
except UnauthorizedAndForceLogout:
|
||||
raise
|
||||
except Exception:
|
||||
# If license check fails, log but don't block the request.
|
||||
# This prevents service disruption if enterprise API is temporarily
|
||||
# unavailable.
|
||||
logger.exception("Failed to check enterprise license status")
|
||||
|
||||
# add after request hook for injecting trace headers from OpenTelemetry span context
|
||||
# Only adds headers when OTEL is enabled and has valid context
|
||||
@dify_app.after_request
|
||||
@@ -81,6 +143,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_commands,
|
||||
ext_compress,
|
||||
ext_database,
|
||||
ext_enterprise_telemetry,
|
||||
ext_fastopenapi,
|
||||
ext_forward_refs,
|
||||
ext_hosting_provider,
|
||||
@@ -131,6 +194,7 @@ def initialize_extensions(app: DifyApp):
|
||||
ext_commands,
|
||||
ext_fastopenapi,
|
||||
ext_otel,
|
||||
ext_enterprise_telemetry,
|
||||
ext_request_logging,
|
||||
ext_session_factory,
|
||||
]
|
||||
|
||||
@@ -30,6 +30,7 @@ from extensions.ext_redis import redis_client
|
||||
from extensions.ext_storage import storage
|
||||
from extensions.storage.opendal_storage import OpenDALStorage
|
||||
from extensions.storage.storage_type import StorageType
|
||||
from libs.db_migration_lock import DbMigrationAutoRenewLock
|
||||
from libs.helper import email as email_validate
|
||||
from libs.password import hash_password, password_pattern, valid_password
|
||||
from libs.rsa import generate_key_pair
|
||||
@@ -54,6 +55,8 @@ from tasks.remove_app_and_related_data_task import delete_draft_variables_batch
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DB_UPGRADE_LOCK_TTL_SECONDS = 60
|
||||
|
||||
|
||||
@click.command("reset-password", help="Reset the account password.")
|
||||
@click.option("--email", prompt=True, help="Account email to reset password for")
|
||||
@@ -727,8 +730,15 @@ def create_tenant(email: str, language: str | None = None, name: str | None = No
|
||||
@click.command("upgrade-db", help="Upgrade the database")
|
||||
def upgrade_db():
|
||||
click.echo("Preparing database migration...")
|
||||
lock = redis_client.lock(name="db_upgrade_lock", timeout=60)
|
||||
lock = DbMigrationAutoRenewLock(
|
||||
redis_client=redis_client,
|
||||
name="db_upgrade_lock",
|
||||
ttl_seconds=DB_UPGRADE_LOCK_TTL_SECONDS,
|
||||
logger=logger,
|
||||
log_context="db_migration",
|
||||
)
|
||||
if lock.acquire(blocking=False):
|
||||
migration_succeeded = False
|
||||
try:
|
||||
click.echo(click.style("Starting database migration.", fg="green"))
|
||||
|
||||
@@ -737,6 +747,7 @@ def upgrade_db():
|
||||
|
||||
flask_migrate.upgrade()
|
||||
|
||||
migration_succeeded = True
|
||||
click.echo(click.style("Database migration successful!", fg="green"))
|
||||
|
||||
except Exception as e:
|
||||
@@ -744,7 +755,8 @@ def upgrade_db():
|
||||
click.echo(click.style(f"Database migration failed: {e}", fg="red"))
|
||||
raise SystemExit(1)
|
||||
finally:
|
||||
lock.release()
|
||||
status = "successful" if migration_succeeded else "failed"
|
||||
lock.release_safely(status=status)
|
||||
else:
|
||||
click.echo("Database migration skipped")
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, Settings
|
||||
from libs.file_utils import search_file_upwards
|
||||
|
||||
from .deploy import DeploymentConfig
|
||||
from .enterprise import EnterpriseFeatureConfig
|
||||
from .enterprise import EnterpriseFeatureConfig, EnterpriseTelemetryConfig
|
||||
from .extra import ExtraServiceConfig
|
||||
from .feature import FeatureConfig
|
||||
from .middleware import MiddlewareConfig
|
||||
@@ -73,6 +73,8 @@ class DifyConfig(
|
||||
# Enterprise feature configs
|
||||
# **Before using, please contact business@dify.ai by email to inquire about licensing matters.**
|
||||
EnterpriseFeatureConfig,
|
||||
# Enterprise telemetry configs
|
||||
EnterpriseTelemetryConfig,
|
||||
):
|
||||
model_config = SettingsConfigDict(
|
||||
# read from dotenv format config file
|
||||
|
||||
@@ -18,3 +18,49 @@ class EnterpriseFeatureConfig(BaseSettings):
|
||||
description="Allow customization of the enterprise logo.",
|
||||
default=False,
|
||||
)
|
||||
|
||||
|
||||
class EnterpriseTelemetryConfig(BaseSettings):
|
||||
"""
|
||||
Configuration for enterprise telemetry.
|
||||
"""
|
||||
|
||||
ENTERPRISE_TELEMETRY_ENABLED: bool = Field(
|
||||
description="Enable enterprise telemetry collection (also requires ENTERPRISE_ENABLED=true).",
|
||||
default=False,
|
||||
)
|
||||
|
||||
ENTERPRISE_OTLP_ENDPOINT: str = Field(
|
||||
description="Enterprise OTEL collector endpoint.",
|
||||
default="",
|
||||
)
|
||||
|
||||
ENTERPRISE_OTLP_HEADERS: str = Field(
|
||||
description="Auth headers for OTLP export (key=value,key2=value2).",
|
||||
default="",
|
||||
)
|
||||
|
||||
ENTERPRISE_OTLP_PROTOCOL: str = Field(
|
||||
description="OTLP protocol: 'http' or 'grpc' (default: http).",
|
||||
default="http",
|
||||
)
|
||||
|
||||
ENTERPRISE_OTLP_API_KEY: str = Field(
|
||||
description="Bearer token for enterprise OTLP export authentication.",
|
||||
default="",
|
||||
)
|
||||
|
||||
ENTERPRISE_INCLUDE_CONTENT: bool = Field(
|
||||
description="Include input/output content in traces (privacy toggle).",
|
||||
default=True,
|
||||
)
|
||||
|
||||
ENTERPRISE_SERVICE_NAME: str = Field(
|
||||
description="Service name for OTEL resource.",
|
||||
default="dify",
|
||||
)
|
||||
|
||||
ENTERPRISE_OTEL_SAMPLING_RATE: float = Field(
|
||||
description="Sampling rate for enterprise traces (0.0 to 1.0, default 1.0 = 100%).",
|
||||
default=1.0,
|
||||
)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from datetime import timedelta
|
||||
from enum import StrEnum
|
||||
from typing import Literal
|
||||
|
||||
@@ -49,16 +48,6 @@ class SecurityConfig(BaseSettings):
|
||||
default=5,
|
||||
)
|
||||
|
||||
WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS: PositiveInt = Field(
|
||||
description="Maximum number of web form submissions allowed per IP within the rate limit window",
|
||||
default=30,
|
||||
)
|
||||
|
||||
WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS: PositiveInt = Field(
|
||||
description="Time window in seconds for web form submission rate limiting",
|
||||
default=60,
|
||||
)
|
||||
|
||||
LOGIN_DISABLED: bool = Field(
|
||||
description="Whether to disable login checks",
|
||||
default=False,
|
||||
@@ -93,12 +82,6 @@ class AppExecutionConfig(BaseSettings):
|
||||
default=0,
|
||||
)
|
||||
|
||||
HUMAN_INPUT_GLOBAL_TIMEOUT_SECONDS: PositiveInt = Field(
|
||||
description="Maximum seconds a workflow run can stay paused waiting for human input before global timeout.",
|
||||
default=int(timedelta(days=7).total_seconds()),
|
||||
ge=1,
|
||||
)
|
||||
|
||||
|
||||
class CodeExecutionSandboxConfig(BaseSettings):
|
||||
"""
|
||||
@@ -1151,14 +1134,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
description="Enable queue monitor task",
|
||||
default=False,
|
||||
)
|
||||
ENABLE_HUMAN_INPUT_TIMEOUT_TASK: bool = Field(
|
||||
description="Enable human input timeout check task",
|
||||
default=True,
|
||||
)
|
||||
HUMAN_INPUT_TIMEOUT_TASK_INTERVAL: PositiveInt = Field(
|
||||
description="Human input timeout check interval in minutes",
|
||||
default=1,
|
||||
)
|
||||
ENABLE_CHECK_UPGRADABLE_PLUGIN_TASK: bool = Field(
|
||||
description="Enable check upgradable plugin task",
|
||||
default=True,
|
||||
@@ -1180,16 +1155,6 @@ class CeleryScheduleTasksConfig(BaseSettings):
|
||||
default=0,
|
||||
)
|
||||
|
||||
# API token last_used_at batch update
|
||||
ENABLE_API_TOKEN_LAST_USED_UPDATE_TASK: bool = Field(
|
||||
description="Enable periodic batch update of API token last_used_at timestamps",
|
||||
default=True,
|
||||
)
|
||||
API_TOKEN_LAST_USED_UPDATE_INTERVAL: int = Field(
|
||||
description="Interval in minutes for batch updating API token last_used_at (default 30)",
|
||||
default=30,
|
||||
)
|
||||
|
||||
# Trigger provider refresh (simple version)
|
||||
ENABLE_TRIGGER_PROVIDER_REFRESH_TASK: bool = Field(
|
||||
description="Enable trigger provider refresh poller",
|
||||
@@ -1344,10 +1309,6 @@ class SandboxExpiredRecordsCleanConfig(BaseSettings):
|
||||
description="Maximum number of records to process in each batch",
|
||||
default=1000,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_CLEAN_BATCH_MAX_INTERVAL: PositiveInt = Field(
|
||||
description="Maximum interval in milliseconds between batches",
|
||||
default=200,
|
||||
)
|
||||
SANDBOX_EXPIRED_RECORDS_RETENTION_DAYS: PositiveInt = Field(
|
||||
description="Retention days for sandbox expired workflow_run records and message records",
|
||||
default=30,
|
||||
|
||||
@@ -6,7 +6,6 @@ from pydantic import Field, NonNegativeFloat, NonNegativeInt, PositiveFloat, Pos
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
from .cache.redis_config import RedisConfig
|
||||
from .cache.redis_pubsub_config import RedisPubSubConfig
|
||||
from .storage.aliyun_oss_storage_config import AliyunOSSStorageConfig
|
||||
from .storage.amazon_s3_storage_config import S3StorageConfig
|
||||
from .storage.azure_blob_storage_config import AzureBlobStorageConfig
|
||||
@@ -259,20 +258,11 @@ class CeleryConfig(DatabaseConfig):
|
||||
description="Password of the Redis Sentinel master.",
|
||||
default=None,
|
||||
)
|
||||
|
||||
CELERY_SENTINEL_SOCKET_TIMEOUT: PositiveFloat | None = Field(
|
||||
description="Timeout for Redis Sentinel socket operations in seconds.",
|
||||
default=0.1,
|
||||
)
|
||||
|
||||
CELERY_TASK_ANNOTATIONS: dict[str, Any] | None = Field(
|
||||
description=(
|
||||
"Annotations for Celery tasks as a JSON mapping of task name -> options "
|
||||
"(for example, rate limits or other task-specific settings)."
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
|
||||
@computed_field
|
||||
def CELERY_RESULT_BACKEND(self) -> str | None:
|
||||
if self.CELERY_BACKEND in ("database", "rabbitmq"):
|
||||
@@ -327,7 +317,6 @@ class MiddlewareConfig(
|
||||
CeleryConfig, # Note: CeleryConfig already inherits from DatabaseConfig
|
||||
KeywordStoreConfig,
|
||||
RedisConfig,
|
||||
RedisPubSubConfig,
|
||||
# configs of storage and storage providers
|
||||
StorageConfig,
|
||||
AliyunOSSStorageConfig,
|
||||
|
||||
@@ -1,96 +0,0 @@
|
||||
from typing import Literal, Protocol
|
||||
from urllib.parse import quote_plus, urlunparse
|
||||
|
||||
from pydantic import Field
|
||||
from pydantic_settings import BaseSettings
|
||||
|
||||
|
||||
class RedisConfigDefaults(Protocol):
|
||||
REDIS_HOST: str
|
||||
REDIS_PORT: int
|
||||
REDIS_USERNAME: str | None
|
||||
REDIS_PASSWORD: str | None
|
||||
REDIS_DB: int
|
||||
REDIS_USE_SSL: bool
|
||||
REDIS_USE_SENTINEL: bool | None
|
||||
REDIS_USE_CLUSTERS: bool
|
||||
|
||||
|
||||
class RedisConfigDefaultsMixin:
|
||||
def _redis_defaults(self: RedisConfigDefaults) -> RedisConfigDefaults:
|
||||
return self
|
||||
|
||||
|
||||
class RedisPubSubConfig(BaseSettings, RedisConfigDefaultsMixin):
|
||||
"""
|
||||
Configuration settings for Redis pub/sub streaming.
|
||||
"""
|
||||
|
||||
PUBSUB_REDIS_URL: str | None = Field(
|
||||
alias="PUBSUB_REDIS_URL",
|
||||
description=(
|
||||
"Redis connection URL for pub/sub streaming events between API "
|
||||
"and celery worker, defaults to url constructed from "
|
||||
"`REDIS_*` configurations"
|
||||
),
|
||||
default=None,
|
||||
)
|
||||
|
||||
PUBSUB_REDIS_USE_CLUSTERS: bool = Field(
|
||||
description=(
|
||||
"Enable Redis Cluster mode for pub/sub streaming. It's highly "
|
||||
"recommended to enable this for large deployments."
|
||||
),
|
||||
default=False,
|
||||
)
|
||||
|
||||
PUBSUB_REDIS_CHANNEL_TYPE: Literal["pubsub", "sharded"] = Field(
|
||||
description=(
|
||||
"Pub/sub channel type for streaming events. "
|
||||
"Valid options are:\n"
|
||||
"\n"
|
||||
" - pubsub: for normal Pub/Sub\n"
|
||||
" - sharded: for sharded Pub/Sub\n"
|
||||
"\n"
|
||||
"It's highly recommended to use sharded Pub/Sub AND redis cluster "
|
||||
"for large deployments."
|
||||
),
|
||||
default="pubsub",
|
||||
)
|
||||
|
||||
def _build_default_pubsub_url(self) -> str:
|
||||
defaults = self._redis_defaults()
|
||||
if not defaults.REDIS_HOST or not defaults.REDIS_PORT:
|
||||
raise ValueError("PUBSUB_REDIS_URL must be set when default Redis URL cannot be constructed")
|
||||
|
||||
scheme = "rediss" if defaults.REDIS_USE_SSL else "redis"
|
||||
username = defaults.REDIS_USERNAME or None
|
||||
password = defaults.REDIS_PASSWORD or None
|
||||
|
||||
userinfo = ""
|
||||
if username:
|
||||
userinfo = quote_plus(username)
|
||||
if password:
|
||||
password_part = quote_plus(password)
|
||||
userinfo = f"{userinfo}:{password_part}" if userinfo else f":{password_part}"
|
||||
if userinfo:
|
||||
userinfo = f"{userinfo}@"
|
||||
|
||||
host = defaults.REDIS_HOST
|
||||
port = defaults.REDIS_PORT
|
||||
db = defaults.REDIS_DB
|
||||
|
||||
netloc = f"{userinfo}{host}:{port}"
|
||||
return urlunparse((scheme, netloc, f"/{db}", "", "", ""))
|
||||
|
||||
@property
|
||||
def normalized_pubsub_redis_url(self) -> str:
|
||||
pubsub_redis_url = self.PUBSUB_REDIS_URL
|
||||
if pubsub_redis_url:
|
||||
cleaned = pubsub_redis_url.strip()
|
||||
pubsub_redis_url = cleaned or None
|
||||
|
||||
if pubsub_redis_url:
|
||||
return pubsub_redis_url
|
||||
|
||||
return self._build_default_pubsub_url()
|
||||
@@ -21,7 +21,6 @@ language_timezone_mapping = {
|
||||
"th-TH": "Asia/Bangkok",
|
||||
"id-ID": "Asia/Jakarta",
|
||||
"ar-TN": "Africa/Tunis",
|
||||
"nl-NL": "Europe/Amsterdam",
|
||||
}
|
||||
|
||||
languages = list(language_timezone_mapping.keys())
|
||||
|
||||
@@ -5,6 +5,8 @@ from enum import StrEnum
|
||||
from flask_restx import Namespace
|
||||
from pydantic import BaseModel, TypeAdapter
|
||||
|
||||
from controllers.console import console_ns
|
||||
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
@@ -22,9 +24,6 @@ def register_schema_models(namespace: Namespace, *models: type[BaseModel]) -> No
|
||||
|
||||
|
||||
def get_or_create_model(model_name: str, field_def):
|
||||
# Import lazily to avoid circular imports between console controllers and schema helpers.
|
||||
from controllers.console import console_ns
|
||||
|
||||
existing = console_ns.models.get(model_name)
|
||||
if existing is None:
|
||||
existing = console_ns.model(model_name, field_def)
|
||||
|
||||
@@ -37,7 +37,6 @@ from . import (
|
||||
apikey,
|
||||
extension,
|
||||
feature,
|
||||
human_input_form,
|
||||
init_validate,
|
||||
ping,
|
||||
setup,
|
||||
@@ -172,7 +171,6 @@ __all__ = [
|
||||
"forgot_password",
|
||||
"generator",
|
||||
"hit_testing",
|
||||
"human_input_form",
|
||||
"init_validate",
|
||||
"installed_app",
|
||||
"load_balancing_config",
|
||||
|
||||
@@ -10,7 +10,6 @@ from libs.helper import TimestampField
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models.dataset import Dataset
|
||||
from models.model import ApiToken, App
|
||||
from services.api_token_service import ApiTokenCache
|
||||
|
||||
from . import console_ns
|
||||
from .wraps import account_initialization_required, edit_permission_required, setup_required
|
||||
@@ -132,11 +131,6 @@ class BaseApiKeyResource(Resource):
|
||||
if key is None:
|
||||
flask_restx.abort(HTTPStatus.NOT_FOUND, message="API key not found")
|
||||
|
||||
# Invalidate cache before deleting from database
|
||||
# Type assertion: key is guaranteed to be non-None here because abort() raises
|
||||
assert key is not None # nosec - for type checker only
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@@ -660,6 +660,19 @@ class AppCopyApi(Resource):
|
||||
)
|
||||
session.commit()
|
||||
|
||||
# Inherit web app permission from original app
|
||||
if result.app_id and FeatureService.get_system_features().webapp_auth.enabled:
|
||||
try:
|
||||
# Get the original app's access mode
|
||||
original_settings = EnterpriseService.WebAppAuth.get_app_access_mode_by_id(app_model.id)
|
||||
access_mode = original_settings.access_mode
|
||||
except Exception:
|
||||
# If original app has no settings (old app), default to public to match fallback behavior
|
||||
access_mode = "public"
|
||||
|
||||
# Apply the same access mode to the copied app
|
||||
EnterpriseService.WebAppAuth.update_app_access_mode(result.app_id, access_mode)
|
||||
|
||||
stmt = select(App).where(App.id == result.app_id)
|
||||
app = session.scalar(stmt)
|
||||
|
||||
|
||||
@@ -89,7 +89,6 @@ status_count_model = console_ns.model(
|
||||
"success": fields.Integer,
|
||||
"failed": fields.Integer,
|
||||
"partial_success": fields.Integer,
|
||||
"paused": fields.Integer,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -599,12 +598,7 @@ def _get_conversation(app_model, conversation_id):
|
||||
db.session.execute(
|
||||
sa.update(Conversation)
|
||||
.where(Conversation.id == conversation_id, Conversation.read_at.is_(None))
|
||||
# Keep updated_at unchanged when only marking a conversation as read.
|
||||
.values(
|
||||
read_at=naive_utc_now(),
|
||||
read_account_id=current_user.id,
|
||||
updated_at=Conversation.updated_at,
|
||||
)
|
||||
.values(read_at=naive_utc_now(), read_account_id=current_user.id)
|
||||
)
|
||||
db.session.commit()
|
||||
db.session.refresh(conversation)
|
||||
|
||||
@@ -35,6 +35,7 @@ class InstructionGeneratePayload(BaseModel):
|
||||
instruction: str = Field(..., description="Instruction for generation")
|
||||
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
|
||||
ideal_output: str = Field(default="", description="Expected ideal output")
|
||||
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")
|
||||
|
||||
|
||||
class InstructionTemplatePayload(BaseModel):
|
||||
@@ -66,10 +67,17 @@ class RuleGenerateApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = RuleGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
account, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
rules = LLMGenerator.generate_rule_config(tenant_id=current_tenant_id, args=args)
|
||||
rules = LLMGenerator.generate_rule_config(
|
||||
tenant_id=current_tenant_id,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=args.no_variable,
|
||||
user_id=account.id,
|
||||
app_id=args.app_id,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
except QuotaExceededError:
|
||||
@@ -95,12 +103,16 @@ class RuleCodeGenerateApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = RuleCodeGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
account, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
code_result = LLMGenerator.generate_code(
|
||||
tenant_id=current_tenant_id,
|
||||
args=args,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.code_language,
|
||||
user_id=account.id,
|
||||
app_id=args.app_id,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -127,12 +139,15 @@ class RuleStructuredOutputGenerateApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = RuleStructuredOutputPayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
account, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
try:
|
||||
structured_output = LLMGenerator.generate_structured_output(
|
||||
tenant_id=current_tenant_id,
|
||||
args=args,
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
user_id=account.id,
|
||||
app_id=args.app_id,
|
||||
)
|
||||
except ProviderTokenNotInitError as ex:
|
||||
raise ProviderNotInitializeError(ex.description)
|
||||
@@ -159,14 +174,14 @@ class InstructionGenerateApi(Resource):
|
||||
@account_initialization_required
|
||||
def post(self):
|
||||
args = InstructionGeneratePayload.model_validate(console_ns.payload)
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
account, current_tenant_id = current_account_with_tenant()
|
||||
app_id = args.app_id or args.flow_id
|
||||
providers: list[type[CodeNodeProvider]] = [Python3CodeProvider, JavascriptCodeProvider]
|
||||
code_provider: type[CodeNodeProvider] | None = next(
|
||||
(p for p in providers if p.is_accept_language(args.language)), None
|
||||
)
|
||||
code_template = code_provider.get_default_code() if code_provider else ""
|
||||
try:
|
||||
# Generate from nothing for a workflow node
|
||||
if (args.current in (code_template, "")) and args.node_id != "":
|
||||
app = db.session.query(App).where(App.id == args.flow_id).first()
|
||||
if not app:
|
||||
@@ -183,33 +198,33 @@ class InstructionGenerateApi(Resource):
|
||||
case "llm":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_tenant_id,
|
||||
args=RuleGeneratePayload(
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
),
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
user_id=account.id,
|
||||
app_id=app_id,
|
||||
)
|
||||
case "agent":
|
||||
return LLMGenerator.generate_rule_config(
|
||||
current_tenant_id,
|
||||
args=RuleGeneratePayload(
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
),
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
no_variable=True,
|
||||
user_id=account.id,
|
||||
app_id=app_id,
|
||||
)
|
||||
case "code":
|
||||
return LLMGenerator.generate_code(
|
||||
tenant_id=current_tenant_id,
|
||||
args=RuleCodeGeneratePayload(
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.language,
|
||||
),
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
code_language=args.language,
|
||||
user_id=account.id,
|
||||
app_id=app_id,
|
||||
)
|
||||
case _:
|
||||
return {"error": f"invalid node type: {node_type}"}
|
||||
if args.node_id == "" and args.current != "": # For legacy app without a workflow
|
||||
if args.node_id == "" and args.current != "":
|
||||
return LLMGenerator.instruction_modify_legacy(
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args.flow_id,
|
||||
@@ -217,8 +232,10 @@ class InstructionGenerateApi(Resource):
|
||||
instruction=args.instruction,
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
user_id=account.id,
|
||||
app_id=app_id,
|
||||
)
|
||||
if args.node_id != "" and args.current != "": # For workflow node
|
||||
if args.node_id != "" and args.current != "":
|
||||
return LLMGenerator.instruction_modify_workflow(
|
||||
tenant_id=current_tenant_id,
|
||||
flow_id=args.flow_id,
|
||||
@@ -228,6 +245,8 @@ class InstructionGenerateApi(Resource):
|
||||
model_config=args.model_config_data,
|
||||
ideal_output=args.ideal_output,
|
||||
workflow_service=WorkflowService(),
|
||||
user_id=account.id,
|
||||
app_id=app_id,
|
||||
)
|
||||
return {"error": "incompatible parameters"}, 400
|
||||
except ProviderTokenNotInitError as ex:
|
||||
|
||||
@@ -33,7 +33,7 @@ from libs.login import current_account_with_tenant, login_required
|
||||
from models.model import AppMode, Conversation, Message, MessageAnnotation, MessageFeedback
|
||||
from services.errors.conversation import ConversationNotExistsError
|
||||
from services.errors.message import MessageNotExistsError, SuggestedQuestionsAfterAnswerDisabledError
|
||||
from services.message_service import MessageService, attach_message_extra_contents
|
||||
from services.message_service import MessageService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -207,7 +207,6 @@ message_detail_model = console_ns.model(
|
||||
"created_at": TimestampField,
|
||||
"agent_thoughts": fields.List(fields.Nested(agent_thought_model)),
|
||||
"message_files": fields.List(fields.Nested(message_file_model)),
|
||||
"extra_contents": fields.List(fields.Raw),
|
||||
"metadata": fields.Raw(attribute="message_metadata_dict"),
|
||||
"status": fields.String,
|
||||
"error": fields.String,
|
||||
@@ -300,7 +299,6 @@ class ChatMessageListApi(Resource):
|
||||
has_more = False
|
||||
|
||||
history_messages = list(reversed(history_messages))
|
||||
attach_message_extra_contents(history_messages)
|
||||
|
||||
return InfiniteScrollPagination(data=history_messages, limit=args.limit, has_more=has_more)
|
||||
|
||||
@@ -483,5 +481,4 @@ class MessageApi(Resource):
|
||||
if not message:
|
||||
raise NotFound("Message Not Exists.")
|
||||
|
||||
attach_message_extra_contents([message])
|
||||
return message
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
from typing import Any
|
||||
|
||||
from flask import request
|
||||
from flask_login import current_user
|
||||
from flask_restx import Resource, fields
|
||||
from pydantic import BaseModel, Field
|
||||
from werkzeug.exceptions import BadRequest
|
||||
@@ -77,7 +78,10 @@ class TraceAppConfigApi(Resource):
|
||||
|
||||
try:
|
||||
result = OpsService.create_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
app_id=app_id,
|
||||
tracing_provider=args.tracing_provider,
|
||||
tracing_config=args.tracing_config,
|
||||
account_id=current_user.id,
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigIsExist()
|
||||
@@ -102,7 +106,10 @@ class TraceAppConfigApi(Resource):
|
||||
|
||||
try:
|
||||
result = OpsService.update_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args.tracing_provider, tracing_config=args.tracing_config
|
||||
app_id=app_id,
|
||||
tracing_provider=args.tracing_provider,
|
||||
tracing_config=args.tracing_config,
|
||||
account_id=current_user.id,
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
@@ -124,7 +131,9 @@ class TraceAppConfigApi(Resource):
|
||||
args = TraceProviderQuery.model_validate(request.args.to_dict(flat=True)) # type: ignore
|
||||
|
||||
try:
|
||||
result = OpsService.delete_tracing_app_config(app_id=app_id, tracing_provider=args.tracing_provider)
|
||||
result = OpsService.delete_tracing_app_config(
|
||||
app_id=app_id, tracing_provider=args.tracing_provider, account_id=current_user.id
|
||||
)
|
||||
if not result:
|
||||
raise TracingConfigNotExist()
|
||||
return {"result": "success"}, 204
|
||||
|
||||
@@ -507,179 +507,6 @@ class WorkflowDraftRunLoopNodeApi(Resource):
|
||||
raise InternalServerError()
|
||||
|
||||
|
||||
class HumanInputFormPreviewPayload(BaseModel):
|
||||
inputs: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Values used to fill missing upstream variables referenced in form_content",
|
||||
)
|
||||
|
||||
|
||||
class HumanInputFormSubmitPayload(BaseModel):
|
||||
form_inputs: dict[str, Any] = Field(..., description="Values the user provides for the form's own fields")
|
||||
inputs: dict[str, Any] = Field(
|
||||
...,
|
||||
description="Values used to fill missing upstream variables referenced in form_content",
|
||||
)
|
||||
action: str = Field(..., description="Selected action ID")
|
||||
|
||||
|
||||
class HumanInputDeliveryTestPayload(BaseModel):
|
||||
delivery_method_id: str = Field(..., description="Delivery method ID")
|
||||
inputs: dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Values used to fill missing upstream variables referenced in form_content",
|
||||
)
|
||||
|
||||
|
||||
reg(HumanInputFormPreviewPayload)
|
||||
reg(HumanInputFormSubmitPayload)
|
||||
reg(HumanInputDeliveryTestPayload)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/human-input/nodes/<string:node_id>/form/preview")
|
||||
class AdvancedChatDraftHumanInputFormPreviewApi(Resource):
|
||||
@console_ns.doc("get_advanced_chat_draft_human_input_form")
|
||||
@console_ns.doc(description="Get human input form preview for advanced chat workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Preview human input form content and placeholders
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
|
||||
inputs = args.inputs
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
preview = workflow_service.get_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
node_id=node_id,
|
||||
inputs=inputs,
|
||||
)
|
||||
return jsonable_encoder(preview)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/advanced-chat/workflows/draft/human-input/nodes/<string:node_id>/form/run")
|
||||
class AdvancedChatDraftHumanInputFormRunApi(Resource):
|
||||
@console_ns.doc("submit_advanced_chat_draft_human_input_form")
|
||||
@console_ns.doc(description="Submit human input form preview for advanced chat workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Submit human input form preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
|
||||
workflow_service = WorkflowService()
|
||||
result = workflow_service.submit_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
node_id=node_id,
|
||||
form_inputs=args.form_inputs,
|
||||
inputs=args.inputs,
|
||||
action=args.action,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/form/preview")
|
||||
class WorkflowDraftHumanInputFormPreviewApi(Resource):
|
||||
@console_ns.doc("get_workflow_draft_human_input_form")
|
||||
@console_ns.doc(description="Get human input form preview for workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[HumanInputFormPreviewPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Preview human input form content and placeholders
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
args = HumanInputFormPreviewPayload.model_validate(console_ns.payload or {})
|
||||
inputs = args.inputs
|
||||
|
||||
workflow_service = WorkflowService()
|
||||
preview = workflow_service.get_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
node_id=node_id,
|
||||
inputs=inputs,
|
||||
)
|
||||
return jsonable_encoder(preview)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/form/run")
|
||||
class WorkflowDraftHumanInputFormRunApi(Resource):
|
||||
@console_ns.doc("submit_workflow_draft_human_input_form")
|
||||
@console_ns.doc(description="Submit human input form preview for workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[HumanInputFormSubmitPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Submit human input form preview
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
args = HumanInputFormSubmitPayload.model_validate(console_ns.payload or {})
|
||||
result = workflow_service.submit_human_input_form_preview(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
node_id=node_id,
|
||||
form_inputs=args.form_inputs,
|
||||
inputs=args.inputs,
|
||||
action=args.action,
|
||||
)
|
||||
return jsonable_encoder(result)
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/human-input/nodes/<string:node_id>/delivery-test")
|
||||
class WorkflowDraftHumanInputDeliveryTestApi(Resource):
|
||||
@console_ns.doc("test_workflow_draft_human_input_delivery")
|
||||
@console_ns.doc(description="Test human input delivery for workflow")
|
||||
@console_ns.doc(params={"app_id": "Application ID", "node_id": "Node ID"})
|
||||
@console_ns.expect(console_ns.models[HumanInputDeliveryTestPayload.__name__])
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
@get_app_model(mode=[AppMode.WORKFLOW, AppMode.ADVANCED_CHAT])
|
||||
@edit_permission_required
|
||||
def post(self, app_model: App, node_id: str):
|
||||
"""
|
||||
Test human input delivery
|
||||
"""
|
||||
current_user, _ = current_account_with_tenant()
|
||||
workflow_service = WorkflowService()
|
||||
args = HumanInputDeliveryTestPayload.model_validate(console_ns.payload or {})
|
||||
workflow_service.test_human_input_delivery(
|
||||
app_model=app_model,
|
||||
account=current_user,
|
||||
node_id=node_id,
|
||||
delivery_method_id=args.delivery_method_id,
|
||||
inputs=args.inputs,
|
||||
)
|
||||
return jsonable_encoder({})
|
||||
|
||||
|
||||
@console_ns.route("/apps/<uuid:app_id>/workflows/draft/run")
|
||||
class DraftWorkflowRunApi(Resource):
|
||||
@console_ns.doc("run_draft_workflow")
|
||||
|
||||
@@ -5,15 +5,10 @@ from flask import request
|
||||
from flask_restx import Resource, fields, marshal_with
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.app.wraps import get_app_model
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import NotFoundError
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from extensions.ext_database import db
|
||||
from fields.end_user_fields import simple_end_user_fields
|
||||
from fields.member_fields import simple_account_fields
|
||||
@@ -32,21 +27,9 @@ from libs.custom_inputs import time_duration
|
||||
from libs.helper import uuid_value
|
||||
from libs.login import current_user, login_required
|
||||
from models import Account, App, AppMode, EndUser, WorkflowArchiveLog, WorkflowRunTriggeredFrom
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.retention.workflow_run.constants import ARCHIVE_BUNDLE_NAME
|
||||
from services.workflow_run_service import WorkflowRunService
|
||||
|
||||
|
||||
def _build_backstage_input_url(form_token: str | None) -> str | None:
|
||||
if not form_token:
|
||||
return None
|
||||
base_url = dify_config.APP_WEB_URL
|
||||
if not base_url:
|
||||
return None
|
||||
return f"{base_url.rstrip('/')}/form/{form_token}"
|
||||
|
||||
|
||||
# Workflow run status choices for filtering
|
||||
WORKFLOW_RUN_STATUS_CHOICES = ["running", "succeeded", "failed", "stopped", "partial-succeeded"]
|
||||
EXPORT_SIGNED_URL_EXPIRE_SECONDS = 3600
|
||||
@@ -457,68 +440,3 @@ class WorkflowRunNodeExecutionListApi(Resource):
|
||||
)
|
||||
|
||||
return {"data": node_executions}
|
||||
|
||||
|
||||
@console_ns.route("/workflow/<string:workflow_run_id>/pause-details")
|
||||
class ConsoleWorkflowPauseDetailsApi(Resource):
|
||||
"""Console API for getting workflow pause details."""
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, workflow_run_id: str):
|
||||
"""
|
||||
Get workflow pause details.
|
||||
|
||||
GET /console/api/workflow/<workflow_run_id>/pause-details
|
||||
|
||||
Returns information about why and where the workflow is paused.
|
||||
"""
|
||||
|
||||
# Query WorkflowRun to determine if workflow is suspended
|
||||
session_maker = sessionmaker(bind=db.engine)
|
||||
workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker=session_maker)
|
||||
|
||||
workflow_run = db.session.get(WorkflowRun, workflow_run_id)
|
||||
if not workflow_run:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
if workflow_run.tenant_id != current_user.current_tenant_id:
|
||||
raise NotFoundError("Workflow run not found")
|
||||
|
||||
# Check if workflow is suspended
|
||||
is_paused = workflow_run.status == WorkflowExecutionStatus.PAUSED
|
||||
if not is_paused:
|
||||
return {
|
||||
"paused_at": None,
|
||||
"paused_nodes": [],
|
||||
}, 200
|
||||
|
||||
pause_entity = workflow_run_repo.get_workflow_pause(workflow_run_id)
|
||||
pause_reasons = pause_entity.get_pause_reasons() if pause_entity else []
|
||||
|
||||
# Build response
|
||||
paused_at = pause_entity.paused_at if pause_entity else None
|
||||
paused_nodes = []
|
||||
response = {
|
||||
"paused_at": paused_at.isoformat() + "Z" if paused_at else None,
|
||||
"paused_nodes": paused_nodes,
|
||||
}
|
||||
|
||||
for reason in pause_reasons:
|
||||
if isinstance(reason, HumanInputRequired):
|
||||
paused_nodes.append(
|
||||
{
|
||||
"node_id": reason.node_id,
|
||||
"node_title": reason.node_title,
|
||||
"pause_type": {
|
||||
"type": "human_input",
|
||||
"form_id": reason.form_id,
|
||||
"backstage_input_url": _build_backstage_input_url(reason.form_token),
|
||||
},
|
||||
}
|
||||
)
|
||||
else:
|
||||
raise AssertionError("unimplemented.")
|
||||
|
||||
return response, 200
|
||||
|
||||
@@ -55,7 +55,6 @@ from libs.login import current_account_with_tenant, login_required
|
||||
from models import ApiToken, Dataset, Document, DocumentSegment, UploadFile
|
||||
from models.dataset import DatasetPermissionEnum
|
||||
from models.provider_ids import ModelProviderID
|
||||
from services.api_token_service import ApiTokenCache
|
||||
from services.dataset_service import DatasetPermissionService, DatasetService, DocumentService
|
||||
|
||||
# Register models for flask_restx to avoid dict type issues in Swagger
|
||||
@@ -821,11 +820,6 @@ class DatasetApiDeleteApi(Resource):
|
||||
if key is None:
|
||||
console_ns.abort(404, message="API key not found")
|
||||
|
||||
# Invalidate cache before deleting from database
|
||||
# Type assertion: key is guaranteed to be non-None here because abort() raises
|
||||
assert key is not None # nosec - for type checker only
|
||||
ApiTokenCache.delete(key.token, key.type)
|
||||
|
||||
db.session.query(ApiToken).where(ApiToken.id == api_key_id).delete()
|
||||
db.session.commit()
|
||||
|
||||
|
||||
@@ -1,16 +1,15 @@
|
||||
import logging
|
||||
from typing import Any, Literal, cast
|
||||
from typing import Any, cast
|
||||
|
||||
from flask import request
|
||||
from flask_restx import Resource, fields, marshal, marshal_with
|
||||
from pydantic import BaseModel
|
||||
from flask_restx import Resource, fields, marshal, marshal_with, reqparse
|
||||
from werkzeug.exceptions import Forbidden, InternalServerError, NotFound
|
||||
|
||||
import services
|
||||
from controllers.common.fields import Parameters as ParametersResponse
|
||||
from controllers.common.fields import Site as SiteResponse
|
||||
from controllers.common.schema import get_or_create_model
|
||||
from controllers.console import api, console_ns
|
||||
from controllers.console import api
|
||||
from controllers.console.app.error import (
|
||||
AppUnavailableError,
|
||||
AudioTooLargeError,
|
||||
@@ -118,56 +117,7 @@ workflow_fields_copy["rag_pipeline_variables"] = fields.List(fields.Nested(pipel
|
||||
workflow_model = get_or_create_model("TrialWorkflow", workflow_fields_copy)
|
||||
|
||||
|
||||
# Pydantic models for request validation
|
||||
DEFAULT_REF_TEMPLATE_SWAGGER_2_0 = "#/definitions/{model}"
|
||||
|
||||
|
||||
class WorkflowRunRequest(BaseModel):
|
||||
inputs: dict
|
||||
files: list | None = None
|
||||
|
||||
|
||||
class ChatRequest(BaseModel):
|
||||
inputs: dict
|
||||
query: str
|
||||
files: list | None = None
|
||||
conversation_id: str | None = None
|
||||
parent_message_id: str | None = None
|
||||
retriever_from: str = "explore_app"
|
||||
|
||||
|
||||
class TextToSpeechRequest(BaseModel):
|
||||
message_id: str | None = None
|
||||
voice: str | None = None
|
||||
text: str | None = None
|
||||
streaming: bool | None = None
|
||||
|
||||
|
||||
class CompletionRequest(BaseModel):
|
||||
inputs: dict
|
||||
query: str = ""
|
||||
files: list | None = None
|
||||
response_mode: Literal["blocking", "streaming"] | None = None
|
||||
retriever_from: str = "explore_app"
|
||||
|
||||
|
||||
# Register schemas for Swagger documentation
|
||||
console_ns.schema_model(
|
||||
WorkflowRunRequest.__name__, WorkflowRunRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
ChatRequest.__name__, ChatRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
TextToSpeechRequest.__name__, TextToSpeechRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
console_ns.schema_model(
|
||||
CompletionRequest.__name__, CompletionRequest.model_json_schema(ref_template=DEFAULT_REF_TEMPLATE_SWAGGER_2_0)
|
||||
)
|
||||
|
||||
|
||||
class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[WorkflowRunRequest.__name__])
|
||||
def post(self, trial_app):
|
||||
"""
|
||||
Run workflow
|
||||
@@ -179,8 +129,10 @@ class TrialAppWorkflowRunApi(TrialAppResource):
|
||||
if app_mode != AppMode.WORKFLOW:
|
||||
raise NotWorkflowAppError()
|
||||
|
||||
request_data = WorkflowRunRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, nullable=False, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
args = parser.parse_args()
|
||||
assert current_user is not None
|
||||
try:
|
||||
app_id = app_model.id
|
||||
@@ -231,7 +183,6 @@ class TrialAppWorkflowTaskStopApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialChatApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[ChatRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
@@ -239,14 +190,14 @@ class TrialChatApi(TrialAppResource):
|
||||
if app_mode not in {AppMode.CHAT, AppMode.AGENT_CHAT, AppMode.ADVANCED_CHAT}:
|
||||
raise NotChatAppError()
|
||||
|
||||
request_data = ChatRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
|
||||
# Validate UUID values if provided
|
||||
if args.get("conversation_id"):
|
||||
args["conversation_id"] = uuid_value(args["conversation_id"])
|
||||
if args.get("parent_message_id"):
|
||||
args["parent_message_id"] = uuid_value(args["parent_message_id"])
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, required=True, location="json")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("conversation_id", type=uuid_value, location="json")
|
||||
parser.add_argument("parent_message_id", type=uuid_value, required=False, location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
@@ -369,16 +320,20 @@ class TrialChatAudioApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialChatTextApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[TextToSpeechRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
try:
|
||||
request_data = TextToSpeechRequest.model_validate(console_ns.payload)
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("message_id", type=str, required=False, location="json")
|
||||
parser.add_argument("voice", type=str, location="json")
|
||||
parser.add_argument("text", type=str, location="json")
|
||||
parser.add_argument("streaming", type=bool, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
message_id = request_data.message_id
|
||||
text = request_data.text
|
||||
voice = request_data.voice
|
||||
message_id = args.get("message_id", None)
|
||||
text = args.get("text", None)
|
||||
voice = args.get("voice", None)
|
||||
if not isinstance(current_user, Account):
|
||||
raise ValueError("current_user must be an Account instance")
|
||||
|
||||
@@ -416,15 +371,19 @@ class TrialChatTextApi(TrialAppResource):
|
||||
|
||||
|
||||
class TrialCompletionApi(TrialAppResource):
|
||||
@console_ns.expect(console_ns.models[CompletionRequest.__name__])
|
||||
@trial_feature_enable
|
||||
def post(self, trial_app):
|
||||
app_model = trial_app
|
||||
if app_model.mode != "completion":
|
||||
raise NotCompletionAppError()
|
||||
|
||||
request_data = CompletionRequest.model_validate(console_ns.payload)
|
||||
args = request_data.model_dump()
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("query", type=str, location="json", default="")
|
||||
parser.add_argument("files", type=list, required=False, location="json")
|
||||
parser.add_argument("response_mode", type=str, choices=["blocking", "streaming"], location="json")
|
||||
parser.add_argument("retriever_from", type=str, required=False, default="explore_app", location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
streaming = args["response_mode"] == "streaming"
|
||||
args["auto_generate_name"] = False
|
||||
|
||||
@@ -1,217 +0,0 @@
|
||||
"""
|
||||
Console/Studio Human Input Form APIs.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, jsonify, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from controllers.console import console_ns
|
||||
from controllers.console.wraps import account_initialization_required, setup_required
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from extensions.ext_database import db
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from models import App
|
||||
from models.enums import CreatorUserRole
|
||||
from models.human_input import RecipientType
|
||||
from models.model import AppMode
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.human_input_service import Form, HumanInputService
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form) -> Response:
|
||||
payload = form.get_definition().model_dump()
|
||||
payload["expiration_time"] = int(form.expiration_time.timestamp())
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
@console_ns.route("/form/human_input/<string:form_token>")
|
||||
class ConsoleHumanInputFormApi(Resource):
|
||||
"""Console API for getting human input form definition."""
|
||||
|
||||
@staticmethod
|
||||
def _ensure_console_access(form: Form):
|
||||
_, current_tenant_id = current_account_with_tenant()
|
||||
|
||||
if form.tenant_id != current_tenant_id:
|
||||
raise NotFoundError("App not found")
|
||||
|
||||
@setup_required
|
||||
@login_required
|
||||
@account_initialization_required
|
||||
def get(self, form_token: str):
|
||||
"""
|
||||
Get human input form definition by form token.
|
||||
|
||||
GET /console/api/form/human_input/<form_token>
|
||||
"""
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_definition_by_token_for_console(form_token)
|
||||
if form is None:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
|
||||
return _jsonify_form_definition(form)
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def post(self, form_token: str):
|
||||
"""
|
||||
Submit human input form by form token.
|
||||
|
||||
POST /console/api/form/human_input/<form_token>
|
||||
|
||||
Request body:
|
||||
{
|
||||
"inputs": {
|
||||
"content": "User input content"
|
||||
},
|
||||
"action": "Approve"
|
||||
}
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("action", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
current_user, _ = current_account_with_tenant()
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
|
||||
self._ensure_console_access(form)
|
||||
|
||||
recipient_type = form.recipient_type
|
||||
if recipient_type not in {RecipientType.CONSOLE, RecipientType.BACKSTAGE}:
|
||||
raise NotFoundError(f"form not found, token={form_token}")
|
||||
# The type checker is not smart enought to validate the following invariant.
|
||||
# So we need to assert it manually.
|
||||
assert recipient_type is not None, "recipient_type cannot be None here."
|
||||
|
||||
service.submit_form_by_token(
|
||||
recipient_type=recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=args["action"],
|
||||
form_data=args["inputs"],
|
||||
submission_user_id=current_user.id,
|
||||
)
|
||||
|
||||
return jsonify({})
|
||||
|
||||
|
||||
@console_ns.route("/workflow/<string:workflow_run_id>/events")
|
||||
class ConsoleWorkflowEventsApi(Resource):
|
||||
"""Console API for getting workflow execution events after resume."""
|
||||
|
||||
@account_initialization_required
|
||||
@login_required
|
||||
def get(self, workflow_run_id: str):
|
||||
"""
|
||||
Get workflow execution events stream after resume.
|
||||
|
||||
GET /console/api/workflow/<workflow_run_id>/events
|
||||
|
||||
Returns Server-Sent Events stream.
|
||||
"""
|
||||
|
||||
user, tenant_id = current_account_with_tenant()
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
tenant_id=tenant_id,
|
||||
run_id=workflow_run_id,
|
||||
)
|
||||
if workflow_run is None:
|
||||
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.created_by_role != CreatorUserRole.ACCOUNT:
|
||||
raise NotFoundError(f"WorkflowRun not created by account, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.created_by != user.id:
|
||||
raise NotFoundError(f"WorkflowRun not created by the current account, id={workflow_run_id}")
|
||||
|
||||
with Session(expire_on_commit=False, bind=db.engine) as session:
|
||||
app = _retrieve_app_for_workflow_run(session, workflow_run)
|
||||
|
||||
if workflow_run.finished_at is not None:
|
||||
# TODO(QuantumGhost): should we modify the handling for finished workflow run here?
|
||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||
task_id=workflow_run.id,
|
||||
workflow_run=workflow_run,
|
||||
creator_user=user,
|
||||
)
|
||||
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
|
||||
def _generate_finished_events() -> Generator[str, None, None]:
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
event_generator = _generate_finished_events
|
||||
|
||||
else:
|
||||
msg_generator = MessageGenerator()
|
||||
if app.mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
elif app.mode == AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
else:
|
||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
|
||||
def _generate_stream_events():
|
||||
if include_state_snapshot:
|
||||
return generator.convert_to_event_stream(
|
||||
build_workflow_event_stream(
|
||||
app_mode=AppMode(app.mode),
|
||||
workflow_run=workflow_run,
|
||||
tenant_id=workflow_run.tenant_id,
|
||||
app_id=workflow_run.app_id,
|
||||
session_maker=session_maker,
|
||||
)
|
||||
)
|
||||
return generator.convert_to_event_stream(
|
||||
msg_generator.retrieve_events(AppMode(app.mode), workflow_run.id),
|
||||
)
|
||||
|
||||
event_generator = _generate_stream_events
|
||||
|
||||
return Response(
|
||||
event_generator(),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
def _retrieve_app_for_workflow_run(session: Session, workflow_run: WorkflowRun):
|
||||
query = select(App).where(
|
||||
App.id == workflow_run.app_id,
|
||||
App.tenant_id == workflow_run.tenant_id,
|
||||
)
|
||||
app = session.scalars(query).first()
|
||||
if app is None:
|
||||
raise AssertionError(
|
||||
f"App not found for WorkflowRun, workflow_run_id={workflow_run.id}, "
|
||||
f"app_id={workflow_run.app_id}, tenant_id={workflow_run.tenant_id}"
|
||||
)
|
||||
|
||||
return app
|
||||
@@ -1,7 +1,6 @@
|
||||
import urllib.parse
|
||||
|
||||
import httpx
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
import services
|
||||
@@ -11,12 +10,12 @@ from controllers.common.errors import (
|
||||
RemoteFileUploadError,
|
||||
UnsupportedFileTypeError,
|
||||
)
|
||||
from controllers.console import console_ns
|
||||
from controllers.fastopenapi import console_router
|
||||
from core.file import helpers as file_helpers
|
||||
from core.helper import ssrf_proxy
|
||||
from extensions.ext_database import db
|
||||
from fields.file_fields import FileWithSignedUrl, RemoteFileInfo
|
||||
from libs.login import current_account_with_tenant, login_required
|
||||
from libs.login import current_account_with_tenant
|
||||
from services.file_service import FileService
|
||||
|
||||
|
||||
@@ -24,73 +23,69 @@ class RemoteFileUploadPayload(BaseModel):
|
||||
url: str = Field(..., description="URL to fetch")
|
||||
|
||||
|
||||
@console_ns.route("/remote-files/<path:url>")
|
||||
class GetRemoteFileInfo(Resource):
|
||||
@login_required
|
||||
def get(self, url: str):
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
@console_router.get(
|
||||
"/remote-files/<path:url>",
|
||||
response_model=RemoteFileInfo,
|
||||
tags=["console"],
|
||||
)
|
||||
def get_remote_file_info(url: str) -> RemoteFileInfo:
|
||||
decoded_url = urllib.parse.unquote(url)
|
||||
resp = ssrf_proxy.head(decoded_url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
return RemoteFileInfo(
|
||||
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
file_length=int(resp.headers.get("Content-Length", 0)),
|
||||
)
|
||||
|
||||
|
||||
@console_router.post(
|
||||
"/remote-files/upload",
|
||||
response_model=FileWithSignedUrl,
|
||||
tags=["console"],
|
||||
status_code=201,
|
||||
)
|
||||
def upload_remote_file(payload: RemoteFileUploadPayload) -> FileWithSignedUrl:
|
||||
url = payload.url
|
||||
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(decoded_url, timeout=3)
|
||||
resp.raise_for_status()
|
||||
return RemoteFileInfo(
|
||||
file_type=resp.headers.get("Content-Type", "application/octet-stream"),
|
||||
file_length=int(resp.headers.get("Content-Length", 0)),
|
||||
).model_dump(mode="json")
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as e:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
|
||||
@console_ns.route("/remote-files/upload")
|
||||
class RemoteFileUpload(Resource):
|
||||
@login_required
|
||||
def post(self):
|
||||
payload = RemoteFileUploadPayload.model_validate(console_ns.payload)
|
||||
url = payload.url
|
||||
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||
raise FileTooLargeError
|
||||
|
||||
# Try to fetch remote file metadata/content first
|
||||
try:
|
||||
resp = ssrf_proxy.head(url=url)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
resp = ssrf_proxy.get(url=url, timeout=3, follow_redirects=True)
|
||||
if resp.status_code != httpx.codes.OK:
|
||||
# Normalize into a user-friendly error message expected by tests
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {resp.text}")
|
||||
except httpx.RequestError as e:
|
||||
raise RemoteFileUploadError(f"Failed to fetch file from {url}: {str(e)}")
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
file_info = helpers.guess_file_info_from_response(resp)
|
||||
|
||||
# Enforce file size limit with 400 (Bad Request) per tests' expectation
|
||||
if not FileService.is_file_size_within_limit(extension=file_info.extension, file_size=file_info.size):
|
||||
raise FileTooLargeError()
|
||||
|
||||
# Load content if needed
|
||||
content = resp.content if resp.request.method == "GET" else ssrf_proxy.get(url).content
|
||||
|
||||
try:
|
||||
user, _ = current_account_with_tenant()
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
user=user,
|
||||
source_url=url,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
# Success: return created resource with 201 status
|
||||
return (
|
||||
FileWithSignedUrl(
|
||||
id=upload_file.id,
|
||||
name=upload_file.name,
|
||||
size=upload_file.size,
|
||||
extension=upload_file.extension,
|
||||
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||
mime_type=upload_file.mime_type,
|
||||
created_by=upload_file.created_by,
|
||||
created_at=int(upload_file.created_at.timestamp()),
|
||||
).model_dump(mode="json"),
|
||||
201,
|
||||
try:
|
||||
user, _ = current_account_with_tenant()
|
||||
upload_file = FileService(db.engine).upload_file(
|
||||
filename=file_info.filename,
|
||||
content=content,
|
||||
mimetype=file_info.mimetype,
|
||||
user=user,
|
||||
source_url=url,
|
||||
)
|
||||
except services.errors.file.FileTooLargeError as file_too_large_error:
|
||||
raise FileTooLargeError(file_too_large_error.description)
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return FileWithSignedUrl(
|
||||
id=upload_file.id,
|
||||
name=upload_file.name,
|
||||
size=upload_file.size,
|
||||
extension=upload_file.extension,
|
||||
url=file_helpers.get_signed_file_url(upload_file_id=upload_file.id),
|
||||
mime_type=upload_file.mime_type,
|
||||
created_by=upload_file.created_by,
|
||||
created_at=int(upload_file.created_at.timestamp()),
|
||||
)
|
||||
|
||||
@@ -42,15 +42,7 @@ class SetupResponse(BaseModel):
|
||||
tags=["console"],
|
||||
)
|
||||
def get_setup_status_api() -> SetupStatusResponse:
|
||||
"""Get system setup status.
|
||||
|
||||
NOTE: This endpoint is unauthenticated by design.
|
||||
|
||||
During first-time bootstrap there is no admin account yet, so frontend initialization must be
|
||||
able to query setup progress before any login flow exists.
|
||||
|
||||
Only bootstrap-safe status information should be returned by this endpoint.
|
||||
"""
|
||||
"""Get system setup status."""
|
||||
if dify_config.EDITION == "SELF_HOSTED":
|
||||
setup_status = get_setup_status()
|
||||
if setup_status and not isinstance(setup_status, bool):
|
||||
@@ -69,12 +61,7 @@ def get_setup_status_api() -> SetupStatusResponse:
|
||||
)
|
||||
@only_edition_self_hosted
|
||||
def setup_system(payload: SetupRequestPayload) -> SetupResponse:
|
||||
"""Initialize system setup with admin account.
|
||||
|
||||
NOTE: This endpoint is unauthenticated by design for first-time bootstrap.
|
||||
Access is restricted by deployment mode (`SELF_HOSTED`), one-time setup guards,
|
||||
and init-password validation rather than user session authentication.
|
||||
"""
|
||||
"""Initialize system setup with admin account."""
|
||||
if get_setup_status():
|
||||
raise AlreadySetupError()
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ class TagUpdateDeleteApi(Resource):
|
||||
|
||||
TagService.delete_tag(tag_id)
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@console_ns.route("/tag-bindings/create")
|
||||
|
||||
@@ -878,7 +878,11 @@ class ToolBuiltinProviderSetDefaultApi(Resource):
|
||||
current_user, current_tenant_id = current_account_with_tenant()
|
||||
payload = BuiltinProviderDefaultCredentialPayload.model_validate(console_ns.payload or {})
|
||||
return BuiltinToolManageService.set_default_provider(
|
||||
tenant_id=current_tenant_id, user_id=current_user.id, provider=provider, id=payload.id
|
||||
tenant_id=current_tenant_id,
|
||||
user_id=current_user.id,
|
||||
provider=provider,
|
||||
id=payload.id,
|
||||
account=current_user,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -34,8 +34,6 @@ from .dataset import (
|
||||
metadata,
|
||||
segment,
|
||||
)
|
||||
from .dataset.rag_pipeline import rag_pipeline_workflow
|
||||
from .end_user import end_user
|
||||
from .workspace import models
|
||||
|
||||
__all__ = [
|
||||
@@ -46,7 +44,6 @@ __all__ = [
|
||||
"conversation",
|
||||
"dataset",
|
||||
"document",
|
||||
"end_user",
|
||||
"file",
|
||||
"file_preview",
|
||||
"hit_testing",
|
||||
@@ -54,7 +51,6 @@ __all__ = [
|
||||
"message",
|
||||
"metadata",
|
||||
"models",
|
||||
"rag_pipeline_workflow",
|
||||
"segment",
|
||||
"site",
|
||||
"workflow",
|
||||
|
||||
@@ -33,9 +33,8 @@ from core.workflow.graph_engine.manager import GraphEngineManager
|
||||
from extensions.ext_database import db
|
||||
from fields.workflow_app_log_fields import build_workflow_app_log_pagination_model
|
||||
from libs import helper
|
||||
from libs.helper import OptionalTimestampField, TimestampField
|
||||
from libs.helper import TimestampField
|
||||
from models.model import App, AppMode, EndUser
|
||||
from models.workflow import WorkflowRun
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.app_generate_service import AppGenerateService
|
||||
from services.errors.app import IsDraftWorkflowError, WorkflowIdFormatError, WorkflowNotFoundError
|
||||
@@ -64,32 +63,17 @@ class WorkflowLogQuery(BaseModel):
|
||||
|
||||
register_schema_models(service_api_ns, WorkflowRunPayload, WorkflowLogQuery)
|
||||
|
||||
|
||||
class WorkflowRunStatusField(fields.Raw):
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
return obj.status.value
|
||||
|
||||
|
||||
class WorkflowRunOutputsField(fields.Raw):
|
||||
def output(self, key, obj: WorkflowRun, **kwargs):
|
||||
if obj.status == WorkflowExecutionStatus.PAUSED:
|
||||
return {}
|
||||
|
||||
outputs = obj.outputs_dict
|
||||
return outputs or {}
|
||||
|
||||
|
||||
workflow_run_fields = {
|
||||
"id": fields.String,
|
||||
"workflow_id": fields.String,
|
||||
"status": WorkflowRunStatusField,
|
||||
"status": fields.String,
|
||||
"inputs": fields.Raw,
|
||||
"outputs": WorkflowRunOutputsField,
|
||||
"outputs": fields.Raw,
|
||||
"error": fields.String,
|
||||
"total_steps": fields.Integer,
|
||||
"total_tokens": fields.Integer,
|
||||
"created_at": TimestampField,
|
||||
"finished_at": OptionalTimestampField,
|
||||
"finished_at": TimestampField,
|
||||
"elapsed_time": fields.Float,
|
||||
}
|
||||
|
||||
|
||||
@@ -396,7 +396,7 @@ class DatasetApi(DatasetApiResource):
|
||||
try:
|
||||
if DatasetService.delete_dataset(dataset_id_str, current_user):
|
||||
DatasetPermissionService.clear_partial_member_list(dataset_id_str)
|
||||
return "", 204
|
||||
return 204
|
||||
else:
|
||||
raise NotFound("Dataset not found.")
|
||||
except services.errors.dataset.DatasetInUseError:
|
||||
@@ -557,7 +557,7 @@ class DatasetTagsApi(DatasetApiResource):
|
||||
payload = TagDeletePayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag(payload.tag_id)
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags/binding")
|
||||
@@ -581,7 +581,7 @@ class DatasetTagBindingApi(DatasetApiResource):
|
||||
payload = TagBindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.save_tag_binding({"tag_ids": payload.tag_ids, "target_id": payload.target_id, "type": "knowledge"})
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/tags/unbinding")
|
||||
@@ -605,7 +605,7 @@ class DatasetTagUnbindingApi(DatasetApiResource):
|
||||
payload = TagUnbindingPayload.model_validate(service_api_ns.payload or {})
|
||||
TagService.delete_tag_binding({"tag_id": payload.tag_id, "target_id": payload.target_id, "type": "knowledge"})
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/tags")
|
||||
|
||||
@@ -746,4 +746,4 @@ class DocumentApi(DatasetApiResource):
|
||||
except services.errors.document.DocumentIndexingError:
|
||||
raise DocumentIndexingError("Cannot delete document during indexing.")
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
@@ -128,7 +128,7 @@ class DatasetMetadataServiceApi(DatasetApiResource):
|
||||
DatasetService.check_dataset_permission(dataset, current_user)
|
||||
|
||||
MetadataService.delete_metadata(dataset_id_str, metadata_id_str)
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/metadata/built-in")
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
import string
|
||||
import uuid
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
@@ -10,7 +12,6 @@ from controllers.common.errors import FilenameNotExistsError, NoFileUploadedErro
|
||||
from controllers.common.schema import register_schema_model
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.dataset.error import PipelineRunError
|
||||
from controllers.service_api.dataset.rag_pipeline.serializers import serialize_upload_file
|
||||
from controllers.service_api.wraps import DatasetApiResource
|
||||
from core.app.apps.pipeline.pipeline_generator import PipelineGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
@@ -40,7 +41,7 @@ register_schema_model(service_api_ns, DatasourceNodeRunPayload)
|
||||
register_schema_model(service_api_ns, PipelineRunApiEntity)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/datasource-plugins")
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource-plugins")
|
||||
class DatasourcePluginsApi(DatasetApiResource):
|
||||
"""Resource for datasource plugins."""
|
||||
|
||||
@@ -75,7 +76,7 @@ class DatasourcePluginsApi(DatasetApiResource):
|
||||
return datasource_plugins, 200
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/datasource/nodes/<string:node_id>/run")
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/datasource/nodes/{string:node_id}/run")
|
||||
class DatasourceNodeRunApi(DatasetApiResource):
|
||||
"""Resource for datasource node run."""
|
||||
|
||||
@@ -130,7 +131,7 @@ class DatasourceNodeRunApi(DatasetApiResource):
|
||||
)
|
||||
|
||||
|
||||
@service_api_ns.route("/datasets/<uuid:dataset_id>/pipeline/run")
|
||||
@service_api_ns.route(f"/datasets/{uuid:dataset_id}/pipeline/run")
|
||||
class PipelineRunApi(DatasetApiResource):
|
||||
"""Resource for datasource node run."""
|
||||
|
||||
@@ -231,4 +232,12 @@ class KnowledgebasePipelineFileUploadApi(DatasetApiResource):
|
||||
except services.errors.file.UnsupportedFileTypeError:
|
||||
raise UnsupportedFileTypeError()
|
||||
|
||||
return serialize_upload_file(upload_file), 201
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at,
|
||||
}, 201
|
||||
|
||||
@@ -1,22 +0,0 @@
|
||||
"""
|
||||
Serialization helpers for Service API knowledge pipeline endpoints.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from models.model import UploadFile
|
||||
|
||||
|
||||
def serialize_upload_file(upload_file: UploadFile) -> dict[str, Any]:
|
||||
return {
|
||||
"id": upload_file.id,
|
||||
"name": upload_file.name,
|
||||
"size": upload_file.size,
|
||||
"extension": upload_file.extension,
|
||||
"mime_type": upload_file.mime_type,
|
||||
"created_by": upload_file.created_by,
|
||||
"created_at": upload_file.created_at.isoformat() if upload_file.created_at else None,
|
||||
}
|
||||
@@ -233,7 +233,7 @@ class DatasetSegmentApi(DatasetApiResource):
|
||||
if not segment:
|
||||
raise NotFound("Segment not found.")
|
||||
SegmentService.delete_segment(segment, document, dataset)
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[SegmentUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_segment")
|
||||
@@ -499,7 +499,7 @@ class DatasetChildChunkApi(DatasetApiResource):
|
||||
except ChildChunkDeleteIndexServiceError as e:
|
||||
raise ChildChunkDeleteIndexError(str(e))
|
||||
|
||||
return "", 204
|
||||
return 204
|
||||
|
||||
@service_api_ns.expect(service_api_ns.models[ChildChunkUpdatePayload.__name__])
|
||||
@service_api_ns.doc("update_child_chunk")
|
||||
|
||||
@@ -1,3 +0,0 @@
|
||||
from . import end_user
|
||||
|
||||
__all__ = ["end_user"]
|
||||
@@ -1,41 +0,0 @@
|
||||
from uuid import UUID
|
||||
|
||||
from flask_restx import Resource
|
||||
|
||||
from controllers.service_api import service_api_ns
|
||||
from controllers.service_api.end_user.error import EndUserNotFoundError
|
||||
from controllers.service_api.wraps import validate_app_token
|
||||
from fields.end_user_fields import EndUserDetail
|
||||
from models.model import App
|
||||
from services.end_user_service import EndUserService
|
||||
|
||||
|
||||
@service_api_ns.route("/end-users/<uuid:end_user_id>")
|
||||
class EndUserApi(Resource):
|
||||
"""Resource for retrieving end user details by ID."""
|
||||
|
||||
@service_api_ns.doc("get_end_user")
|
||||
@service_api_ns.doc(description="Get an end user by ID")
|
||||
@service_api_ns.doc(
|
||||
params={"end_user_id": "End user ID"},
|
||||
responses={
|
||||
200: "End user retrieved successfully",
|
||||
401: "Unauthorized - invalid API token",
|
||||
404: "End user not found",
|
||||
},
|
||||
)
|
||||
@validate_app_token
|
||||
def get(self, app_model: App, end_user_id: UUID):
|
||||
"""Get end user detail.
|
||||
|
||||
This endpoint is scoped to the current app token's tenant/app to prevent
|
||||
cross-tenant/app access when an end-user ID is known.
|
||||
"""
|
||||
|
||||
end_user = EndUserService.get_end_user_by_id(
|
||||
tenant_id=app_model.tenant_id, app_id=app_model.id, end_user_id=str(end_user_id)
|
||||
)
|
||||
if end_user is None:
|
||||
raise EndUserNotFoundError()
|
||||
|
||||
return EndUserDetail.model_validate(end_user).model_dump(mode="json")
|
||||
@@ -1,7 +0,0 @@
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class EndUserNotFoundError(BaseHTTPException):
|
||||
error_code = "end_user_not_found"
|
||||
description = "End user not found."
|
||||
code = 404
|
||||
@@ -1,24 +1,27 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Callable
|
||||
from datetime import timedelta
|
||||
from enum import StrEnum, auto
|
||||
from functools import wraps
|
||||
from typing import Concatenate, ParamSpec, TypeVar, cast
|
||||
from typing import Concatenate, ParamSpec, TypeVar
|
||||
|
||||
from flask import current_app, request
|
||||
from flask_login import user_logged_in
|
||||
from flask_restx import Resource
|
||||
from pydantic import BaseModel
|
||||
from sqlalchemy import select, update
|
||||
from sqlalchemy.orm import Session
|
||||
from werkzeug.exceptions import Forbidden, NotFound, Unauthorized
|
||||
|
||||
from enums.cloud_plan import CloudPlan
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.login import current_user
|
||||
from models import Account, Tenant, TenantAccountJoin, TenantStatus
|
||||
from models.dataset import Dataset, RateLimitLog
|
||||
from models.model import ApiToken, App
|
||||
from services.api_token_service import ApiTokenCache, fetch_token_with_single_flight, record_token_usage
|
||||
from services.end_user_service import EndUserService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
@@ -217,8 +220,6 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
def decorator(view: Callable[Concatenate[T, P], R]):
|
||||
@wraps(view)
|
||||
def decorated(*args: P.args, **kwargs: P.kwargs):
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
|
||||
# get url path dataset_id from positional args or kwargs
|
||||
# Flask passes URL path parameters as positional arguments
|
||||
dataset_id = None
|
||||
@@ -255,18 +256,12 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
# Validate dataset if dataset_id is provided
|
||||
if dataset_id:
|
||||
dataset_id = str(dataset_id)
|
||||
dataset = (
|
||||
db.session.query(Dataset)
|
||||
.where(
|
||||
Dataset.id == dataset_id,
|
||||
Dataset.tenant_id == api_token.tenant_id,
|
||||
)
|
||||
.first()
|
||||
)
|
||||
dataset = db.session.query(Dataset).where(Dataset.id == dataset_id).first()
|
||||
if not dataset:
|
||||
raise NotFound("Dataset not found.")
|
||||
if not dataset.enable_api:
|
||||
raise Forbidden("Dataset api access is not enabled.")
|
||||
api_token = validate_and_get_api_token("dataset")
|
||||
tenant_account_join = (
|
||||
db.session.query(Tenant, TenantAccountJoin)
|
||||
.where(Tenant.id == api_token.tenant_id)
|
||||
@@ -301,14 +296,7 @@ def validate_dataset_token(view: Callable[Concatenate[T, P], R] | None = None):
|
||||
|
||||
def validate_and_get_api_token(scope: str | None = None):
|
||||
"""
|
||||
Validate and get API token with Redis caching.
|
||||
|
||||
This function uses a two-tier approach:
|
||||
1. First checks Redis cache for the token
|
||||
2. If not cached, queries database and caches the result
|
||||
|
||||
The last_used_at field is updated asynchronously via Celery task
|
||||
to avoid blocking the request.
|
||||
Validate and get API token.
|
||||
"""
|
||||
auth_header = request.headers.get("Authorization")
|
||||
if auth_header is None or " " not in auth_header:
|
||||
@@ -320,18 +308,29 @@ def validate_and_get_api_token(scope: str | None = None):
|
||||
if auth_scheme != "bearer":
|
||||
raise Unauthorized("Authorization scheme must be 'Bearer'")
|
||||
|
||||
# Try to get token from cache first
|
||||
# Returns a CachedApiToken (plain Python object), not a SQLAlchemy model
|
||||
cached_token = ApiTokenCache.get(auth_token, scope)
|
||||
if cached_token is not None:
|
||||
logger.debug("Token validation served from cache for scope: %s", scope)
|
||||
# Record usage in Redis for later batch update (no Celery task per request)
|
||||
record_token_usage(auth_token, scope)
|
||||
return cast(ApiToken, cached_token)
|
||||
current_time = naive_utc_now()
|
||||
cutoff_time = current_time - timedelta(minutes=1)
|
||||
with Session(db.engine, expire_on_commit=False) as session:
|
||||
update_stmt = (
|
||||
update(ApiToken)
|
||||
.where(
|
||||
ApiToken.token == auth_token,
|
||||
(ApiToken.last_used_at.is_(None) | (ApiToken.last_used_at < cutoff_time)),
|
||||
ApiToken.type == scope,
|
||||
)
|
||||
.values(last_used_at=current_time)
|
||||
)
|
||||
stmt = select(ApiToken).where(ApiToken.token == auth_token, ApiToken.type == scope)
|
||||
result = session.execute(update_stmt)
|
||||
api_token = session.scalar(stmt)
|
||||
|
||||
# Cache miss - use Redis lock for single-flight mode
|
||||
# This ensures only one request queries DB for the same token concurrently
|
||||
return fetch_token_with_single_flight(auth_token, scope)
|
||||
if hasattr(result, "rowcount") and result.rowcount > 0:
|
||||
session.commit()
|
||||
|
||||
if not api_token:
|
||||
raise Unauthorized("Access token is invalid")
|
||||
|
||||
return api_token
|
||||
|
||||
|
||||
class DatasetApiResource(Resource):
|
||||
|
||||
@@ -23,7 +23,6 @@ from . import (
|
||||
feature,
|
||||
files,
|
||||
forgot_password,
|
||||
human_input_form,
|
||||
login,
|
||||
message,
|
||||
passport,
|
||||
@@ -31,7 +30,6 @@ from . import (
|
||||
saved_message,
|
||||
site,
|
||||
workflow,
|
||||
workflow_events,
|
||||
)
|
||||
|
||||
api.add_namespace(web_ns)
|
||||
@@ -46,7 +44,6 @@ __all__ = [
|
||||
"feature",
|
||||
"files",
|
||||
"forgot_password",
|
||||
"human_input_form",
|
||||
"login",
|
||||
"message",
|
||||
"passport",
|
||||
@@ -55,5 +52,4 @@ __all__ = [
|
||||
"site",
|
||||
"web_ns",
|
||||
"workflow",
|
||||
"workflow_events",
|
||||
]
|
||||
|
||||
@@ -117,12 +117,6 @@ class InvokeRateLimitError(BaseHTTPException):
|
||||
code = 429
|
||||
|
||||
|
||||
class WebFormRateLimitExceededError(BaseHTTPException):
|
||||
error_code = "web_form_rate_limit_exceeded"
|
||||
description = "Too many form requests. Please try again later."
|
||||
code = 429
|
||||
|
||||
|
||||
class NotFoundError(BaseHTTPException):
|
||||
error_code = "not_found"
|
||||
code = 404
|
||||
|
||||
@@ -1,161 +0,0 @@
|
||||
"""
|
||||
Web App Human Input Form APIs.
|
||||
"""
|
||||
|
||||
import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
|
||||
from flask import Response, request
|
||||
from flask_restx import Resource, reqparse
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
from controllers.web import web_ns
|
||||
from controllers.web.error import NotFoundError, WebFormRateLimitExceededError
|
||||
from controllers.web.site import serialize_app_site_payload
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import RateLimiter, extract_remote_ip
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, Site
|
||||
from services.human_input_service import Form, FormNotFoundError, HumanInputService
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
_FORM_SUBMIT_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_submit_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
|
||||
)
|
||||
_FORM_ACCESS_RATE_LIMITER = RateLimiter(
|
||||
prefix="web_form_access_rate_limit",
|
||||
max_attempts=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_MAX_ATTEMPTS,
|
||||
time_window=dify_config.WEB_FORM_SUBMIT_RATE_LIMIT_WINDOW_SECONDS,
|
||||
)
|
||||
|
||||
|
||||
def _stringify_default_values(values: dict[str, object]) -> dict[str, str]:
|
||||
result: dict[str, str] = {}
|
||||
for key, value in values.items():
|
||||
if value is None:
|
||||
result[key] = ""
|
||||
elif isinstance(value, (dict, list)):
|
||||
result[key] = json.dumps(value, ensure_ascii=False)
|
||||
else:
|
||||
result[key] = str(value)
|
||||
return result
|
||||
|
||||
|
||||
def _to_timestamp(value: datetime) -> int:
|
||||
return int(value.timestamp())
|
||||
|
||||
|
||||
def _jsonify_form_definition(form: Form, site_payload: dict | None = None) -> Response:
|
||||
"""Return the form payload (optionally with site) as a JSON response."""
|
||||
definition_payload = form.get_definition().model_dump()
|
||||
payload = {
|
||||
"form_content": definition_payload["rendered_content"],
|
||||
"inputs": definition_payload["inputs"],
|
||||
"resolved_default_values": _stringify_default_values(definition_payload["default_values"]),
|
||||
"user_actions": definition_payload["user_actions"],
|
||||
"expiration_time": _to_timestamp(form.expiration_time),
|
||||
}
|
||||
if site_payload is not None:
|
||||
payload["site"] = site_payload
|
||||
return Response(json.dumps(payload, ensure_ascii=False), mimetype="application/json")
|
||||
|
||||
|
||||
@web_ns.route("/form/human_input/<string:form_token>")
|
||||
class HumanInputFormApi(Resource):
|
||||
"""API for getting and submitting human input forms via the web app."""
|
||||
|
||||
# NOTE(QuantumGhost): this endpoint is unauthenticated on purpose for now.
|
||||
|
||||
# def get(self, _app_model: App, _end_user: EndUser, form_token: str):
|
||||
def get(self, form_token: str):
|
||||
"""
|
||||
Get human input form definition by token.
|
||||
|
||||
GET /api/form/human_input/<form_token>
|
||||
"""
|
||||
ip_address = extract_remote_ip(request)
|
||||
if _FORM_ACCESS_RATE_LIMITER.is_rate_limited(ip_address):
|
||||
raise WebFormRateLimitExceededError()
|
||||
_FORM_ACCESS_RATE_LIMITER.increment_rate_limit(ip_address)
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
# TODO(QuantumGhost): forbid submision for form tokens
|
||||
# that are only for console.
|
||||
form = service.get_form_by_token(form_token)
|
||||
|
||||
if form is None:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
service.ensure_form_active(form)
|
||||
app_model, site = _get_app_site_from_form(form)
|
||||
|
||||
return _jsonify_form_definition(form, site_payload=serialize_app_site_payload(app_model, site, None))
|
||||
|
||||
# def post(self, _app_model: App, _end_user: EndUser, form_token: str):
|
||||
def post(self, form_token: str):
|
||||
"""
|
||||
Submit human input form by token.
|
||||
|
||||
POST /api/form/human_input/<form_token>
|
||||
|
||||
Request body:
|
||||
{
|
||||
"inputs": {
|
||||
"content": "User input content"
|
||||
},
|
||||
"action": "Approve"
|
||||
}
|
||||
"""
|
||||
parser = reqparse.RequestParser()
|
||||
parser.add_argument("inputs", type=dict, required=True, location="json")
|
||||
parser.add_argument("action", type=str, required=True, location="json")
|
||||
args = parser.parse_args()
|
||||
|
||||
ip_address = extract_remote_ip(request)
|
||||
if _FORM_SUBMIT_RATE_LIMITER.is_rate_limited(ip_address):
|
||||
raise WebFormRateLimitExceededError()
|
||||
_FORM_SUBMIT_RATE_LIMITER.increment_rate_limit(ip_address)
|
||||
|
||||
service = HumanInputService(db.engine)
|
||||
form = service.get_form_by_token(form_token)
|
||||
if form is None:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
if (recipient_type := form.recipient_type) is None:
|
||||
logger.warning("Recipient type is None for form, form_id=%", form.id)
|
||||
raise AssertionError("Recipient type is None")
|
||||
|
||||
try:
|
||||
service.submit_form_by_token(
|
||||
recipient_type=recipient_type,
|
||||
form_token=form_token,
|
||||
selected_action_id=args["action"],
|
||||
form_data=args["inputs"],
|
||||
submission_end_user_id=None,
|
||||
# submission_end_user_id=_end_user.id,
|
||||
)
|
||||
except FormNotFoundError:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
return {}, 200
|
||||
|
||||
|
||||
def _get_app_site_from_form(form: Form) -> tuple[App, Site]:
|
||||
"""Resolve App/Site for the form's app and validate tenant status."""
|
||||
app_model = db.session.query(App).where(App.id == form.app_id).first()
|
||||
if app_model is None or app_model.tenant_id != form.tenant_id:
|
||||
raise NotFoundError("Form not found")
|
||||
|
||||
site = db.session.query(Site).where(Site.app_id == app_model.id).first()
|
||||
if site is None:
|
||||
raise Forbidden()
|
||||
|
||||
if app_model.tenant and app_model.tenant.status == TenantStatus.ARCHIVE:
|
||||
raise Forbidden()
|
||||
|
||||
return app_model, site
|
||||
@@ -1,6 +1,4 @@
|
||||
from typing import cast
|
||||
|
||||
from flask_restx import fields, marshal, marshal_with
|
||||
from flask_restx import fields, marshal_with
|
||||
from werkzeug.exceptions import Forbidden
|
||||
|
||||
from configs import dify_config
|
||||
@@ -9,7 +7,7 @@ from controllers.web.wraps import WebApiResource
|
||||
from extensions.ext_database import db
|
||||
from libs.helper import AppIconUrlField
|
||||
from models.account import TenantStatus
|
||||
from models.model import App, Site
|
||||
from models.model import Site
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
|
||||
@@ -110,14 +108,3 @@ class AppSiteInfo:
|
||||
"remove_webapp_brand": remove_webapp_brand,
|
||||
"replace_webapp_logo": replace_webapp_logo,
|
||||
}
|
||||
|
||||
|
||||
def serialize_site(site: Site) -> dict:
|
||||
"""Serialize Site model using the same schema as AppSiteApi."""
|
||||
return cast(dict, marshal(site, AppSiteApi.site_fields))
|
||||
|
||||
|
||||
def serialize_app_site_payload(app_model: App, site: Site, end_user_id: str | None) -> dict:
|
||||
can_replace_logo = FeatureService.get_features(app_model.tenant_id).can_replace_logo
|
||||
app_site_info = AppSiteInfo(app_model.tenant, app_model, site, end_user_id, can_replace_logo)
|
||||
return cast(dict, marshal(app_site_info, AppSiteApi.app_fields))
|
||||
|
||||
@@ -1,112 +0,0 @@
|
||||
"""
|
||||
Web App Workflow Resume APIs.
|
||||
"""
|
||||
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from flask import Response, request
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from controllers.web import api
|
||||
from controllers.web.error import InvalidArgumentError, NotFoundError
|
||||
from controllers.web.wraps import WebApiResource
|
||||
from core.app.apps.advanced_chat.app_generator import AdvancedChatAppGenerator
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.common.workflow_response_converter import WorkflowResponseConverter
|
||||
from core.app.apps.message_generator import MessageGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from extensions.ext_database import db
|
||||
from models.enums import CreatorUserRole
|
||||
from models.model import App, AppMode, EndUser
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
from services.workflow_event_snapshot_service import build_workflow_event_stream
|
||||
|
||||
|
||||
class WorkflowEventsApi(WebApiResource):
|
||||
"""API for getting workflow execution events after resume."""
|
||||
|
||||
def get(self, app_model: App, end_user: EndUser, task_id: str):
|
||||
"""
|
||||
Get workflow execution events stream after resume.
|
||||
|
||||
GET /api/workflow/<task_id>/events
|
||||
|
||||
Returns Server-Sent Events stream.
|
||||
"""
|
||||
workflow_run_id = task_id
|
||||
session_maker = sessionmaker(db.engine)
|
||||
repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
workflow_run = repo.get_workflow_run_by_id_and_tenant_id(
|
||||
tenant_id=app_model.tenant_id,
|
||||
run_id=workflow_run_id,
|
||||
)
|
||||
|
||||
if workflow_run is None:
|
||||
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.app_id != app_model.id:
|
||||
raise NotFoundError(f"WorkflowRun not found, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.created_by_role != CreatorUserRole.END_USER:
|
||||
raise NotFoundError(f"WorkflowRun not created by end user, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.created_by != end_user.id:
|
||||
raise NotFoundError(f"WorkflowRun not created by the current end user, id={workflow_run_id}")
|
||||
|
||||
if workflow_run.finished_at is not None:
|
||||
response = WorkflowResponseConverter.workflow_run_result_to_finish_response(
|
||||
task_id=workflow_run.id,
|
||||
workflow_run=workflow_run,
|
||||
creator_user=end_user,
|
||||
)
|
||||
|
||||
payload = response.model_dump(mode="json")
|
||||
payload["event"] = response.event.value
|
||||
|
||||
def _generate_finished_events() -> Generator[str, None, None]:
|
||||
yield f"data: {json.dumps(payload)}\n\n"
|
||||
|
||||
event_generator = _generate_finished_events
|
||||
else:
|
||||
app_mode = AppMode.value_of(app_model.mode)
|
||||
msg_generator = MessageGenerator()
|
||||
generator: BaseAppGenerator
|
||||
if app_mode == AppMode.ADVANCED_CHAT:
|
||||
generator = AdvancedChatAppGenerator()
|
||||
elif app_mode == AppMode.WORKFLOW:
|
||||
generator = WorkflowAppGenerator()
|
||||
else:
|
||||
raise InvalidArgumentError(f"cannot subscribe to workflow run, workflow_run_id={workflow_run.id}")
|
||||
|
||||
include_state_snapshot = request.args.get("include_state_snapshot", "false").lower() == "true"
|
||||
|
||||
def _generate_stream_events():
|
||||
if include_state_snapshot:
|
||||
return generator.convert_to_event_stream(
|
||||
build_workflow_event_stream(
|
||||
app_mode=app_mode,
|
||||
workflow_run=workflow_run,
|
||||
tenant_id=app_model.tenant_id,
|
||||
app_id=app_model.id,
|
||||
session_maker=session_maker,
|
||||
)
|
||||
)
|
||||
return generator.convert_to_event_stream(
|
||||
msg_generator.retrieve_events(app_mode, workflow_run.id),
|
||||
)
|
||||
|
||||
event_generator = _generate_stream_events
|
||||
|
||||
return Response(
|
||||
event_generator(),
|
||||
mimetype="text/event-stream",
|
||||
headers={
|
||||
"Cache-Control": "no-cache",
|
||||
"Connection": "keep-alive",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
# Register the APIs
|
||||
api.add_resource(WorkflowEventsApi, "/workflow/<string:task_id>/events")
|
||||
@@ -79,7 +79,7 @@ class BaseAgentRunner(AppRunner):
|
||||
self.model_instance = model_instance
|
||||
|
||||
# init callback
|
||||
self.agent_callback = DifyAgentCallbackHandler()
|
||||
self.agent_callback = DifyAgentCallbackHandler(tenant_id=tenant_id)
|
||||
# init dataset tools
|
||||
hit_callback = DatasetIndexToolCallbackHandler(
|
||||
queue_manager=queue_manager,
|
||||
|
||||
@@ -4,8 +4,8 @@ import contextvars
|
||||
import logging
|
||||
import threading
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Literal, TypeVar, Union, overload
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import TYPE_CHECKING, Any, Literal, Union, overload
|
||||
|
||||
from flask import Flask, current_app
|
||||
from pydantic import ValidationError
|
||||
@@ -29,25 +29,21 @@ from core.app.apps.message_based_app_generator import MessageBasedAppGenerator
|
||||
from core.app.apps.message_based_app_queue_manager import MessageBasedAppQueueManager
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom
|
||||
from core.app.entities.task_entities import ChatbotAppBlockingResponse, ChatbotAppStreamResponse
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.get_thread_messages_length import get_thread_messages_length
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.repositories.draft_variable_repository import (
|
||||
DraftVariableSaverFactory,
|
||||
)
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models import Account, App, Conversation, EndUser, Message, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.base import Base
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from services.conversation_service import ConversationService
|
||||
from services.workflow_draft_variable_service import (
|
||||
@@ -69,9 +65,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
streaming: Literal[False],
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
@@ -80,11 +74,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
args: Mapping,
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
streaming: Literal[True],
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Generator[Mapping | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@@ -93,11 +85,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
args: Mapping,
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
streaming: bool,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]: ...
|
||||
|
||||
def generate(
|
||||
@@ -105,11 +95,9 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
args: Mapping[str, Any],
|
||||
args: Mapping,
|
||||
invoke_from: InvokeFrom,
|
||||
workflow_run_id: str,
|
||||
streaming: bool = True,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping, None, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
@@ -173,6 +161,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
# always enable retriever resource in debugger mode
|
||||
app_config.additional_features.show_retrieve_source = True # type: ignore
|
||||
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
# init application generate entity
|
||||
application_generate_entity = AdvancedChatAppGenerateEntity(
|
||||
task_id=str(uuid.uuid4()),
|
||||
@@ -190,7 +179,7 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
invoke_from=invoke_from,
|
||||
extras=extras,
|
||||
trace_manager=trace_manager,
|
||||
workflow_run_id=str(workflow_run_id),
|
||||
workflow_run_id=workflow_run_id,
|
||||
)
|
||||
contexts.plugin_tool_providers.set({})
|
||||
contexts.plugin_tool_providers_lock.set(threading.Lock())
|
||||
@@ -227,38 +216,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
conversation=conversation,
|
||||
stream=streaming,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
|
||||
def resume(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
conversation: Conversation,
|
||||
message: Message,
|
||||
application_generate_entity: AdvancedChatAppGenerateEntity,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
):
|
||||
"""
|
||||
Resume a paused advanced chat execution.
|
||||
"""
|
||||
return self._generate(
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
application_generate_entity=application_generate_entity,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
conversation=conversation,
|
||||
message=message,
|
||||
stream=application_generate_entity.stream,
|
||||
pause_state_config=pause_state_config,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
def single_iteration_generate(
|
||||
@@ -439,12 +396,8 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
conversation: Conversation | None = None,
|
||||
message: Message | None = None,
|
||||
stream: bool = True,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
) -> Mapping[str, Any] | Generator[str | Mapping[str, Any], Any, None]:
|
||||
"""
|
||||
Generate App response.
|
||||
@@ -458,12 +411,12 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
:param conversation: conversation
|
||||
:param stream: is stream
|
||||
"""
|
||||
is_first_conversation = conversation is None
|
||||
is_first_conversation = False
|
||||
if not conversation:
|
||||
is_first_conversation = True
|
||||
|
||||
if conversation is not None and message is not None:
|
||||
pass
|
||||
else:
|
||||
conversation, message = self._init_generate_records(application_generate_entity, conversation)
|
||||
# init generate records
|
||||
(conversation, message) = self._init_generate_records(application_generate_entity, conversation)
|
||||
|
||||
if is_first_conversation:
|
||||
# update conversation features
|
||||
@@ -486,16 +439,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
message_id=message.id,
|
||||
)
|
||||
|
||||
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
|
||||
if pause_state_config is not None:
|
||||
graph_layers.append(
|
||||
PauseStatePersistenceLayer(
|
||||
session_factory=pause_state_config.session_factory,
|
||||
generate_entity=application_generate_entity,
|
||||
state_owner_user_id=pause_state_config.state_owner_user_id,
|
||||
)
|
||||
)
|
||||
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
@@ -511,25 +454,14 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
"variable_loader": variable_loader,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"graph_engine_layers": tuple(graph_layers),
|
||||
"graph_runtime_state": graph_runtime_state,
|
||||
},
|
||||
)
|
||||
|
||||
worker_thread.start()
|
||||
|
||||
# release database connection, because the following new thread operations may take a long time
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
workflow = _refresh_model(session, workflow)
|
||||
message = _refresh_model(session, message)
|
||||
# workflow_ = session.get(Workflow, workflow.id)
|
||||
# assert workflow_ is not None
|
||||
# workflow = workflow_
|
||||
# message_ = session.get(Message, message.id)
|
||||
# assert message_ is not None
|
||||
# message = message_
|
||||
# db.session.refresh(workflow)
|
||||
# db.session.refresh(message)
|
||||
db.session.refresh(workflow)
|
||||
db.session.refresh(message)
|
||||
# db.session.refresh(user)
|
||||
db.session.close()
|
||||
|
||||
@@ -558,8 +490,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
variable_loader: VariableLoader,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
):
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@@ -617,8 +547,6 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
app=app,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
try:
|
||||
@@ -686,13 +614,3 @@ class AdvancedChatAppGenerator(MessageBasedAppGenerator):
|
||||
else:
|
||||
logger.exception("Failed to process generate task pipeline, conversation_id: %s", conversation.id)
|
||||
raise e
|
||||
|
||||
|
||||
_T = TypeVar("_T", bound=Base)
|
||||
|
||||
|
||||
def _refresh_model(session, model: _T) -> _T:
|
||||
with Session(bind=db.engine, expire_on_commit=False) as session:
|
||||
detach_model = session.get(type(model), model.id)
|
||||
assert detach_model is not None
|
||||
return detach_model
|
||||
|
||||
@@ -66,7 +66,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
@@ -83,7 +82,6 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
self._app = app
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
self._resume_graph_runtime_state = graph_runtime_state
|
||||
|
||||
@trace_span(WorkflowAppRunnerHandler)
|
||||
def run(self):
|
||||
@@ -112,21 +110,7 @@ class AdvancedChatAppRunner(WorkflowBasedAppRunner):
|
||||
invoke_from = InvokeFrom.DEBUGGER
|
||||
user_from = self._resolve_user_from(invoke_from)
|
||||
|
||||
resume_state = self._resume_graph_runtime_state
|
||||
|
||||
if resume_state is not None:
|
||||
graph_runtime_state = resume_state
|
||||
variable_pool = graph_runtime_state.variable_pool
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
invoke_from=invoke_from,
|
||||
user_from=user_from,
|
||||
)
|
||||
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
# Handle single iteration or single loop run
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=self._workflow,
|
||||
|
||||
@@ -24,8 +24,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueAnnotationReplyEvent,
|
||||
QueueErrorEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
QueueHumanInputFormTimeoutEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
@@ -44,7 +42,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
WorkflowQueueMessage,
|
||||
@@ -66,8 +63,8 @@ from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.model_runtime.entities.llm_entities import LLMUsage
|
||||
from core.model_runtime.utils.encoders import jsonable_encoder
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.repositories.human_input_repository import HumanInputFormRepositoryImpl
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.nodes import NodeType
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
@@ -76,8 +73,7 @@ from core.workflow.system_variable import SystemVariable
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, Conversation, EndUser, Message, MessageFile
|
||||
from models.enums import CreatorUserRole, MessageStatus
|
||||
from models.execution_extra_content import HumanInputContent
|
||||
from models.enums import CreatorUserRole
|
||||
from models.workflow import Workflow
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -134,7 +130,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
)
|
||||
|
||||
self._task_state = WorkflowTaskState()
|
||||
self._seed_task_state_from_message(message)
|
||||
self._message_cycle_manager = MessageCycleManager(
|
||||
application_generate_entity=application_generate_entity, task_state=self._task_state
|
||||
)
|
||||
@@ -142,7 +137,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
self._application_generate_entity = application_generate_entity
|
||||
self._workflow_id = workflow.id
|
||||
self._workflow_features_dict = workflow.features_dict
|
||||
self._workflow_tenant_id = workflow.tenant_id
|
||||
self._conversation_id = conversation.id
|
||||
self._conversation_mode = conversation.mode
|
||||
self._message_id = message.id
|
||||
@@ -152,13 +146,8 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
self._workflow_run_id: str = ""
|
||||
self._draft_var_saver_factory = draft_var_saver_factory
|
||||
self._graph_runtime_state: GraphRuntimeState | None = None
|
||||
self._message_saved_on_pause = False
|
||||
self._seed_graph_runtime_state_from_queue_manager()
|
||||
|
||||
def _seed_task_state_from_message(self, message: Message) -> None:
|
||||
if message.status == MessageStatus.PAUSED and message.answer:
|
||||
self._task_state.answer = message.answer
|
||||
|
||||
def process(self) -> Union[ChatbotAppBlockingResponse, Generator[ChatbotAppStreamResponse, None, None]]:
|
||||
"""
|
||||
Process generate task pipeline.
|
||||
@@ -321,7 +310,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=run_id,
|
||||
workflow_id=self._workflow_id,
|
||||
reason=event.reason,
|
||||
)
|
||||
|
||||
yield workflow_start_resp
|
||||
@@ -539,35 +527,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
)
|
||||
|
||||
yield workflow_finish_resp
|
||||
|
||||
def _handle_workflow_paused_event(
|
||||
self,
|
||||
event: QueueWorkflowPausedEvent,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow paused events."""
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
graph_runtime_state=validated_state,
|
||||
)
|
||||
for reason in event.reasons:
|
||||
if isinstance(reason, HumanInputRequired):
|
||||
self._persist_human_input_extra_content(form_id=reason.form_id, node_id=reason.node_id)
|
||||
yield from responses
|
||||
resolved_state: GraphRuntimeState | None = None
|
||||
try:
|
||||
resolved_state = self._ensure_graph_runtime_initialized()
|
||||
except ValueError:
|
||||
resolved_state = None
|
||||
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
message = self._get_message(session=session)
|
||||
if message is not None:
|
||||
message.status = MessageStatus.PAUSED
|
||||
self._message_saved_on_pause = True
|
||||
self._base_task_pipeline.queue_manager.publish(QueueAdvancedChatMessageEndEvent(), PublishFrom.TASK_PIPELINE)
|
||||
|
||||
def _handle_workflow_failed_event(
|
||||
@@ -607,7 +566,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle stop events."""
|
||||
_ = trace_manager
|
||||
resolved_state = None
|
||||
if self._workflow_run_id:
|
||||
resolved_state = self._resolve_graph_runtime_state(graph_runtime_state)
|
||||
@@ -622,8 +580,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
)
|
||||
|
||||
with self._database_session() as session:
|
||||
# Save message
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
|
||||
|
||||
yield workflow_finish_resp
|
||||
elif event.stopped_by in (
|
||||
@@ -632,8 +589,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
):
|
||||
# When hitting input-moderation or annotation-reply, the workflow will not start
|
||||
with self._database_session() as session:
|
||||
# Save message
|
||||
self._save_message(session=session)
|
||||
self._save_message(session=session, trace_manager=trace_manager)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
@@ -642,6 +598,7 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
event: QueueAdvancedChatMessageEndEvent,
|
||||
*,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle advanced chat message end events."""
|
||||
@@ -657,10 +614,9 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
reason=QueueMessageReplaceEvent.MessageReplaceReason.OUTPUT_MODERATION,
|
||||
)
|
||||
|
||||
# Save message unless it has already been persisted on pause.
|
||||
if not self._message_saved_on_pause:
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state)
|
||||
# Save message
|
||||
with self._database_session() as session:
|
||||
self._save_message(session=session, graph_runtime_state=resolved_state, trace_manager=trace_manager)
|
||||
|
||||
yield self._message_end_to_stream_response()
|
||||
|
||||
@@ -686,65 +642,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
"""Handle message replace events."""
|
||||
yield self._message_cycle_manager.message_replace_to_stream_response(answer=event.text, reason=event.reason)
|
||||
|
||||
def _handle_human_input_form_filled_event(
|
||||
self, event: QueueHumanInputFormFilledEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle human input form filled events."""
|
||||
self._persist_human_input_extra_content(node_id=event.node_id)
|
||||
yield self._workflow_response_converter.human_input_form_filled_to_stream_response(
|
||||
event=event, task_id=self._application_generate_entity.task_id
|
||||
)
|
||||
|
||||
def _handle_human_input_form_timeout_event(
|
||||
self, event: QueueHumanInputFormTimeoutEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle human input form timeout events."""
|
||||
yield self._workflow_response_converter.human_input_form_timeout_to_stream_response(
|
||||
event=event, task_id=self._application_generate_entity.task_id
|
||||
)
|
||||
|
||||
def _persist_human_input_extra_content(self, *, node_id: str | None = None, form_id: str | None = None) -> None:
|
||||
if not self._workflow_run_id or not self._message_id:
|
||||
return
|
||||
|
||||
if form_id is None:
|
||||
if node_id is None:
|
||||
return
|
||||
form_id = self._load_human_input_form_id(node_id=node_id)
|
||||
if form_id is None:
|
||||
logger.warning(
|
||||
"HumanInput form not found for workflow run %s node %s",
|
||||
self._workflow_run_id,
|
||||
node_id,
|
||||
)
|
||||
return
|
||||
|
||||
with self._database_session() as session:
|
||||
exists_stmt = select(HumanInputContent).where(
|
||||
HumanInputContent.workflow_run_id == self._workflow_run_id,
|
||||
HumanInputContent.message_id == self._message_id,
|
||||
HumanInputContent.form_id == form_id,
|
||||
)
|
||||
if session.scalar(exists_stmt) is not None:
|
||||
return
|
||||
|
||||
content = HumanInputContent(
|
||||
workflow_run_id=self._workflow_run_id,
|
||||
message_id=self._message_id,
|
||||
form_id=form_id,
|
||||
)
|
||||
session.add(content)
|
||||
|
||||
def _load_human_input_form_id(self, *, node_id: str) -> str | None:
|
||||
form_repository = HumanInputFormRepositoryImpl(
|
||||
session_factory=db.engine,
|
||||
tenant_id=self._workflow_tenant_id,
|
||||
)
|
||||
form = form_repository.get_form(self._workflow_run_id, node_id)
|
||||
if form is None:
|
||||
return None
|
||||
return form.id
|
||||
|
||||
def _handle_agent_log_event(self, event: QueueAgentLogEvent, **kwargs) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle agent log events."""
|
||||
yield self._workflow_response_converter.handle_agent_log(
|
||||
@@ -762,7 +659,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
|
||||
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
|
||||
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
|
||||
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
|
||||
QueueWorkflowFailedEvent: self._handle_workflow_failed_event,
|
||||
# Node events
|
||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||
@@ -784,8 +680,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
QueueMessageReplaceEvent: self._handle_message_replace_event,
|
||||
QueueAdvancedChatMessageEndEvent: self._handle_advanced_chat_message_end_event,
|
||||
QueueAgentLogEvent: self._handle_agent_log_event,
|
||||
QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event,
|
||||
QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event,
|
||||
}
|
||||
|
||||
def _dispatch_event(
|
||||
@@ -853,9 +747,6 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
case QueueWorkflowFailedEvent():
|
||||
yield from self._handle_workflow_failed_event(event, trace_manager=trace_manager)
|
||||
break
|
||||
case QueueWorkflowPausedEvent():
|
||||
yield from self._handle_workflow_paused_event(event)
|
||||
break
|
||||
|
||||
case QueueStopEvent():
|
||||
yield from self._handle_stop_event(event, graph_runtime_state=None, trace_manager=trace_manager)
|
||||
@@ -879,13 +770,14 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
if self._conversation_name_generate_thread:
|
||||
logger.debug("Conversation name generation running as daemon thread")
|
||||
|
||||
def _save_message(self, *, session: Session, graph_runtime_state: GraphRuntimeState | None = None):
|
||||
def _save_message(
|
||||
self,
|
||||
*,
|
||||
session: Session,
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
trace_manager: TraceQueueManager | None = None,
|
||||
):
|
||||
message = self._get_message(session=session)
|
||||
if message is None:
|
||||
return
|
||||
|
||||
if message.status == MessageStatus.PAUSED:
|
||||
message.status = MessageStatus.NORMAL
|
||||
|
||||
# If there are assistant files, remove markdown image links from answer
|
||||
answer_text = self._task_state.answer
|
||||
@@ -940,6 +832,22 @@ class AdvancedChatAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
]
|
||||
session.add_all(message_files)
|
||||
|
||||
if trace_manager:
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.MESSAGE_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
),
|
||||
payload={
|
||||
"conversation_id": str(message.conversation_id),
|
||||
"message_id": str(message.id),
|
||||
},
|
||||
),
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
def _seed_graph_runtime_state_from_queue_manager(self) -> None:
|
||||
"""Bootstrap the cached runtime state from the queue manager when present."""
|
||||
candidate = self._base_task_pipeline.queue_manager.graph_runtime_state
|
||||
|
||||
@@ -5,14 +5,9 @@ from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from typing import Any, NewType, Union
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.queue_entities import (
|
||||
QueueAgentLogEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
QueueHumanInputFormTimeoutEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
@@ -24,13 +19,9 @@ from core.app.entities.queue_entities import (
|
||||
QueueNodeRetryEvent,
|
||||
QueueNodeStartedEvent,
|
||||
QueueNodeSucceededEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
AgentLogStreamResponse,
|
||||
HumanInputFormFilledResponse,
|
||||
HumanInputFormTimeoutResponse,
|
||||
HumanInputRequiredResponse,
|
||||
IterationNodeCompletedStreamResponse,
|
||||
IterationNodeNextStreamResponse,
|
||||
IterationNodeStartStreamResponse,
|
||||
@@ -40,9 +31,7 @@ from core.app.entities.task_entities import (
|
||||
NodeFinishStreamResponse,
|
||||
NodeRetryStreamResponse,
|
||||
NodeStartStreamResponse,
|
||||
StreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.file import FILE_MODEL_IDENTITY, File
|
||||
@@ -51,8 +40,6 @@ from core.tools.entities.tool_entities import ToolProviderType
|
||||
from core.tools.tool_manager import ToolManager
|
||||
from core.trigger.trigger_manager import TriggerManager
|
||||
from core.variables.segments import ArrayFileSegment, FileSegment, Segment
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import (
|
||||
NodeType,
|
||||
SystemVariableKey,
|
||||
@@ -64,11 +51,8 @@ from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.system_variable import SystemVariable
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from core.workflow.workflow_type_encoder import WorkflowRuntimeTypeConverter
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account, EndUser
|
||||
from models.human_input import HumanInputForm
|
||||
from models.workflow import WorkflowRun
|
||||
from services.variable_truncator import BaseTruncator, DummyVariableTruncator, VariableTruncator
|
||||
|
||||
NodeExecutionId = NewType("NodeExecutionId", str)
|
||||
@@ -207,7 +191,6 @@ class WorkflowResponseConverter:
|
||||
task_id: str,
|
||||
workflow_run_id: str,
|
||||
workflow_id: str,
|
||||
reason: WorkflowStartReason,
|
||||
) -> WorkflowStartStreamResponse:
|
||||
run_id = self._ensure_workflow_run_id(workflow_run_id)
|
||||
started_at = naive_utc_now()
|
||||
@@ -221,7 +204,6 @@ class WorkflowResponseConverter:
|
||||
workflow_id=workflow_id,
|
||||
inputs=self._workflow_inputs,
|
||||
created_at=int(started_at.timestamp()),
|
||||
reason=reason,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -282,160 +264,6 @@ class WorkflowResponseConverter:
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_pause_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
event: QueueWorkflowPausedEvent,
|
||||
task_id: str,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
) -> list[StreamResponse]:
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
started_at = self._workflow_started_at
|
||||
if started_at is None:
|
||||
raise ValueError(
|
||||
"workflow_pause_to_stream_response called before workflow_start_to_stream_response",
|
||||
)
|
||||
paused_at = naive_utc_now()
|
||||
elapsed_time = (paused_at - started_at).total_seconds()
|
||||
encoded_outputs = self._encode_outputs(event.outputs) or {}
|
||||
if self._application_generate_entity.invoke_from == InvokeFrom.SERVICE_API:
|
||||
encoded_outputs = {}
|
||||
pause_reasons = [reason.model_dump(mode="json") for reason in event.reasons]
|
||||
human_input_form_ids = [reason.form_id for reason in event.reasons if isinstance(reason, HumanInputRequired)]
|
||||
expiration_times_by_form_id: dict[str, datetime] = {}
|
||||
if human_input_form_ids:
|
||||
stmt = select(HumanInputForm.id, HumanInputForm.expiration_time).where(
|
||||
HumanInputForm.id.in_(human_input_form_ids)
|
||||
)
|
||||
with Session(bind=db.engine) as session:
|
||||
for form_id, expiration_time in session.execute(stmt):
|
||||
expiration_times_by_form_id[str(form_id)] = expiration_time
|
||||
|
||||
responses: list[StreamResponse] = []
|
||||
|
||||
for reason in event.reasons:
|
||||
if isinstance(reason, HumanInputRequired):
|
||||
expiration_time = expiration_times_by_form_id.get(reason.form_id)
|
||||
if expiration_time is None:
|
||||
raise ValueError(f"HumanInputForm not found for pause reason, form_id={reason.form_id}")
|
||||
responses.append(
|
||||
HumanInputRequiredResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=HumanInputRequiredResponse.Data(
|
||||
form_id=reason.form_id,
|
||||
node_id=reason.node_id,
|
||||
node_title=reason.node_title,
|
||||
form_content=reason.form_content,
|
||||
inputs=reason.inputs,
|
||||
actions=reason.actions,
|
||||
display_in_ui=reason.display_in_ui,
|
||||
form_token=reason.form_token,
|
||||
resolved_default_values=reason.resolved_default_values,
|
||||
expiration_time=int(expiration_time.timestamp()),
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
responses.append(
|
||||
WorkflowPauseStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=WorkflowPauseStreamResponse.Data(
|
||||
workflow_run_id=run_id,
|
||||
paused_nodes=list(event.paused_nodes),
|
||||
outputs=encoded_outputs,
|
||||
reasons=pause_reasons,
|
||||
status=WorkflowExecutionStatus.PAUSED,
|
||||
created_at=int(started_at.timestamp()),
|
||||
elapsed_time=elapsed_time,
|
||||
total_tokens=graph_runtime_state.total_tokens,
|
||||
total_steps=graph_runtime_state.node_run_steps,
|
||||
),
|
||||
)
|
||||
)
|
||||
|
||||
return responses
|
||||
|
||||
def human_input_form_filled_to_stream_response(
|
||||
self, *, event: QueueHumanInputFormFilledEvent, task_id: str
|
||||
) -> HumanInputFormFilledResponse:
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
return HumanInputFormFilledResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=HumanInputFormFilledResponse.Data(
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
rendered_content=event.rendered_content,
|
||||
action_id=event.action_id,
|
||||
action_text=event.action_text,
|
||||
),
|
||||
)
|
||||
|
||||
def human_input_form_timeout_to_stream_response(
|
||||
self, *, event: QueueHumanInputFormTimeoutEvent, task_id: str
|
||||
) -> HumanInputFormTimeoutResponse:
|
||||
run_id = self._ensure_workflow_run_id()
|
||||
return HumanInputFormTimeoutResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=HumanInputFormTimeoutResponse.Data(
|
||||
node_id=event.node_id,
|
||||
node_title=event.node_title,
|
||||
expiration_time=int(event.expiration_time.timestamp()),
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def workflow_run_result_to_finish_response(
|
||||
cls,
|
||||
*,
|
||||
task_id: str,
|
||||
workflow_run: WorkflowRun,
|
||||
creator_user: Account | EndUser,
|
||||
) -> WorkflowFinishStreamResponse:
|
||||
run_id = workflow_run.id
|
||||
elapsed_time = workflow_run.elapsed_time
|
||||
|
||||
encoded_outputs = workflow_run.outputs_dict
|
||||
finished_at = workflow_run.finished_at
|
||||
assert finished_at is not None
|
||||
|
||||
created_by: Mapping[str, object]
|
||||
user = creator_user
|
||||
if isinstance(user, Account):
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"name": user.name,
|
||||
"email": user.email,
|
||||
}
|
||||
else:
|
||||
created_by = {
|
||||
"id": user.id,
|
||||
"user": user.session_id,
|
||||
}
|
||||
|
||||
return WorkflowFinishStreamResponse(
|
||||
task_id=task_id,
|
||||
workflow_run_id=run_id,
|
||||
data=WorkflowFinishStreamResponse.Data(
|
||||
id=run_id,
|
||||
workflow_id=workflow_run.workflow_id,
|
||||
status=workflow_run.status,
|
||||
outputs=encoded_outputs,
|
||||
error=workflow_run.error,
|
||||
elapsed_time=elapsed_time,
|
||||
total_tokens=workflow_run.total_tokens,
|
||||
total_steps=workflow_run.total_steps,
|
||||
created_by=created_by,
|
||||
created_at=int(workflow_run.created_at.timestamp()),
|
||||
finished_at=int(finished_at.timestamp()),
|
||||
files=cls.fetch_files_from_node_outputs(encoded_outputs),
|
||||
exceptions_count=workflow_run.exceptions_count,
|
||||
),
|
||||
)
|
||||
|
||||
def workflow_node_start_to_stream_response(
|
||||
self,
|
||||
*,
|
||||
@@ -764,8 +592,7 @@ class WorkflowResponseConverter:
|
||||
),
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def fetch_files_from_node_outputs(cls, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
|
||||
def fetch_files_from_node_outputs(self, outputs_dict: Mapping[str, Any] | None) -> Sequence[Mapping[str, Any]]:
|
||||
"""
|
||||
Fetch files from node outputs
|
||||
:param outputs_dict: node outputs dict
|
||||
@@ -774,7 +601,7 @@ class WorkflowResponseConverter:
|
||||
if not outputs_dict:
|
||||
return []
|
||||
|
||||
files = [cls._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
|
||||
files = [self._fetch_files_from_variable_value(output_value) for output_value in outputs_dict.values()]
|
||||
# Remove None
|
||||
files = [file for file in files if file]
|
||||
# Flatten list
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
from collections.abc import Generator
|
||||
from typing import Union, cast
|
||||
|
||||
from sqlalchemy import select
|
||||
@@ -10,14 +10,12 @@ from core.app.app_config.entities import EasyUIBasedAppConfig, EasyUIBasedAppMod
|
||||
from core.app.apps.base_app_generator import BaseAppGenerator
|
||||
from core.app.apps.base_app_queue_manager import AppQueueManager
|
||||
from core.app.apps.exc import GenerateTaskStoppedError
|
||||
from core.app.apps.streaming_utils import stream_topic_events
|
||||
from core.app.entities.app_invoke_entities import (
|
||||
AdvancedChatAppGenerateEntity,
|
||||
AgentChatAppGenerateEntity,
|
||||
AppGenerateEntity,
|
||||
ChatAppGenerateEntity,
|
||||
CompletionAppGenerateEntity,
|
||||
ConversationAppGenerateEntity,
|
||||
InvokeFrom,
|
||||
)
|
||||
from core.app.entities.task_entities import (
|
||||
@@ -29,8 +27,6 @@ from core.app.entities.task_entities import (
|
||||
from core.app.task_pipeline.easy_ui_based_generate_task_pipeline import EasyUIBasedGenerateTaskPipeline
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import get_pubsub_broadcast_channel
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models import Account
|
||||
from models.enums import CreatorUserRole
|
||||
@@ -160,7 +156,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
query = application_generate_entity.query or "New conversation"
|
||||
conversation_name = (query[:20] + "…") if len(query) > 20 else query
|
||||
|
||||
created_new_conversation = conversation is None
|
||||
try:
|
||||
if not conversation:
|
||||
conversation = Conversation(
|
||||
@@ -237,10 +232,6 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
db.session.add_all(message_files)
|
||||
|
||||
db.session.commit()
|
||||
|
||||
if isinstance(application_generate_entity, ConversationAppGenerateEntity):
|
||||
application_generate_entity.conversation_id = conversation.id
|
||||
application_generate_entity.is_new_conversation = created_new_conversation
|
||||
return conversation, message
|
||||
except Exception:
|
||||
db.session.rollback()
|
||||
@@ -293,29 +284,3 @@ class MessageBasedAppGenerator(BaseAppGenerator):
|
||||
raise MessageNotExistsError("Message not exists")
|
||||
|
||||
return message
|
||||
|
||||
@staticmethod
|
||||
def _make_channel_key(app_mode: AppMode, workflow_run_id: str):
|
||||
return f"channel:{app_mode}:{workflow_run_id}"
|
||||
|
||||
@classmethod
|
||||
def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic:
|
||||
key = cls._make_channel_key(app_mode, workflow_run_id)
|
||||
channel = get_pubsub_broadcast_channel()
|
||||
topic = channel.topic(key)
|
||||
return topic
|
||||
|
||||
@classmethod
|
||||
def retrieve_events(
|
||||
cls,
|
||||
app_mode: AppMode,
|
||||
workflow_run_id: str,
|
||||
idle_timeout=300,
|
||||
on_subscribe: Callable[[], None] | None = None,
|
||||
) -> Generator[Mapping | str, None, None]:
|
||||
topic = cls.get_response_topic(app_mode, workflow_run_id)
|
||||
return stream_topic_events(
|
||||
topic=topic,
|
||||
idle_timeout=idle_timeout,
|
||||
on_subscribe=on_subscribe,
|
||||
)
|
||||
|
||||
@@ -1,36 +0,0 @@
|
||||
from collections.abc import Callable, Generator, Mapping
|
||||
|
||||
from core.app.apps.streaming_utils import stream_topic_events
|
||||
from extensions.ext_redis import get_pubsub_broadcast_channel
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
from models.model import AppMode
|
||||
|
||||
|
||||
class MessageGenerator:
|
||||
@staticmethod
|
||||
def _make_channel_key(app_mode: AppMode, workflow_run_id: str):
|
||||
return f"channel:{app_mode}:{str(workflow_run_id)}"
|
||||
|
||||
@classmethod
|
||||
def get_response_topic(cls, app_mode: AppMode, workflow_run_id: str) -> Topic:
|
||||
key = cls._make_channel_key(app_mode, workflow_run_id)
|
||||
channel = get_pubsub_broadcast_channel()
|
||||
topic = channel.topic(key)
|
||||
return topic
|
||||
|
||||
@classmethod
|
||||
def retrieve_events(
|
||||
cls,
|
||||
app_mode: AppMode,
|
||||
workflow_run_id: str,
|
||||
idle_timeout=300,
|
||||
ping_interval: float = 10.0,
|
||||
on_subscribe: Callable[[], None] | None = None,
|
||||
) -> Generator[Mapping | str, None, None]:
|
||||
topic = cls.get_response_topic(app_mode, workflow_run_id)
|
||||
return stream_topic_events(
|
||||
topic=topic,
|
||||
idle_timeout=idle_timeout,
|
||||
ping_interval=ping_interval,
|
||||
on_subscribe=on_subscribe,
|
||||
)
|
||||
@@ -1,70 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import time
|
||||
from collections.abc import Callable, Generator, Iterable, Mapping
|
||||
from typing import Any
|
||||
|
||||
from core.app.entities.task_entities import StreamEvent
|
||||
from libs.broadcast_channel.channel import Topic
|
||||
from libs.broadcast_channel.exc import SubscriptionClosedError
|
||||
|
||||
|
||||
def stream_topic_events(
|
||||
*,
|
||||
topic: Topic,
|
||||
idle_timeout: float,
|
||||
ping_interval: float | None = None,
|
||||
on_subscribe: Callable[[], None] | None = None,
|
||||
terminal_events: Iterable[str | StreamEvent] | None = None,
|
||||
) -> Generator[Mapping[str, Any] | str, None, None]:
|
||||
# send a PING event immediately to prevent the connection staying in pending state for a long time.
|
||||
#
|
||||
# This simplify the debugging process as the DevTools in Chrome does not
|
||||
# provide complete curl command for pending connections.
|
||||
yield StreamEvent.PING.value
|
||||
|
||||
terminal_values = _normalize_terminal_events(terminal_events)
|
||||
last_msg_time = time.time()
|
||||
last_ping_time = last_msg_time
|
||||
with topic.subscribe() as sub:
|
||||
# on_subscribe fires only after the Redis subscription is active.
|
||||
# This is used to gate task start and reduce pub/sub race for the first event.
|
||||
if on_subscribe is not None:
|
||||
on_subscribe()
|
||||
while True:
|
||||
try:
|
||||
msg = sub.receive(timeout=1)
|
||||
except SubscriptionClosedError:
|
||||
return
|
||||
if msg is None:
|
||||
current_time = time.time()
|
||||
if current_time - last_msg_time > idle_timeout:
|
||||
return
|
||||
if ping_interval is not None and current_time - last_ping_time >= ping_interval:
|
||||
yield StreamEvent.PING.value
|
||||
last_ping_time = current_time
|
||||
continue
|
||||
|
||||
last_msg_time = time.time()
|
||||
last_ping_time = last_msg_time
|
||||
event = json.loads(msg)
|
||||
yield event
|
||||
if not isinstance(event, dict):
|
||||
continue
|
||||
|
||||
event_type = event.get("event")
|
||||
if event_type in terminal_values:
|
||||
return
|
||||
|
||||
|
||||
def _normalize_terminal_events(terminal_events: Iterable[str | StreamEvent] | None) -> set[str]:
|
||||
if not terminal_events:
|
||||
return {StreamEvent.WORKFLOW_FINISHED.value, StreamEvent.WORKFLOW_PAUSED.value}
|
||||
values: set[str] = set()
|
||||
for item in terminal_events:
|
||||
if isinstance(item, StreamEvent):
|
||||
values.add(item.value)
|
||||
else:
|
||||
values.add(str(item))
|
||||
return values
|
||||
@@ -25,7 +25,6 @@ from core.app.apps.workflow.generate_response_converter import WorkflowAppGenera
|
||||
from core.app.apps.workflow.generate_task_pipeline import WorkflowAppGenerateTaskPipeline
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, WorkflowAppGenerateEntity
|
||||
from core.app.entities.task_entities import WorkflowAppBlockingResponse, WorkflowAppStreamResponse
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig, PauseStatePersistenceLayer
|
||||
from core.db.session_factory import session_factory
|
||||
from core.helper.trace_id_helper import extract_external_trace_id_from_args
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError
|
||||
@@ -35,15 +34,12 @@ from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.repositories.workflow_execution_repository import WorkflowExecutionRepository
|
||||
from core.workflow.repositories.workflow_node_execution_repository import WorkflowNodeExecutionRepository
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader
|
||||
from extensions.ext_database import db
|
||||
from factories import file_factory
|
||||
from libs.flask_utils import preserve_flask_contexts
|
||||
from models.account import Account
|
||||
from models import Account, App, EndUser, Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import WorkflowRunTriggeredFrom
|
||||
from models.model import App, EndUser
|
||||
from models.workflow import Workflow, WorkflowNodeExecutionTriggeredFrom
|
||||
from services.workflow_draft_variable_service import DraftVarLoader, WorkflowDraftVariableService
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -70,11 +66,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[True],
|
||||
call_depth: int,
|
||||
workflow_run_id: str | uuid.UUID | None = None,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Generator[Mapping[str, Any] | str, None, None]: ...
|
||||
|
||||
@overload
|
||||
@@ -88,11 +82,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: Literal[False],
|
||||
call_depth: int,
|
||||
workflow_run_id: str | uuid.UUID | None = None,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Mapping[str, Any]: ...
|
||||
|
||||
@overload
|
||||
@@ -106,11 +98,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool,
|
||||
call_depth: int,
|
||||
workflow_run_id: str | uuid.UUID | None = None,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]: ...
|
||||
|
||||
def generate(
|
||||
@@ -123,11 +113,9 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
invoke_from: InvokeFrom,
|
||||
streaming: bool = True,
|
||||
call_depth: int = 0,
|
||||
workflow_run_id: str | uuid.UUID | None = None,
|
||||
triggered_from: WorkflowRunTriggeredFrom | None = None,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[Mapping[str, Any] | str, None, None]]:
|
||||
files: Sequence[Mapping[str, Any]] = args.get("files") or []
|
||||
|
||||
@@ -159,10 +147,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
|
||||
inputs: Mapping[str, Any] = args["inputs"]
|
||||
|
||||
extras = {
|
||||
extras: dict[str, Any] = {
|
||||
**extract_external_trace_id_from_args(args),
|
||||
}
|
||||
workflow_run_id = str(workflow_run_id or uuid.uuid4())
|
||||
parent_trace_context = args.get("_parent_trace_context")
|
||||
if parent_trace_context:
|
||||
extras["parent_trace_context"] = parent_trace_context
|
||||
workflow_run_id = str(uuid.uuid4())
|
||||
# FIXME (Yeuoly): we need to remove the SKIP_PREPARE_USER_INPUTS_KEY from the args
|
||||
# trigger shouldn't prepare user inputs
|
||||
if self._should_prepare_user_inputs(args):
|
||||
@@ -228,40 +219,13 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
streaming=streaming,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
|
||||
def resume(
|
||||
self,
|
||||
*,
|
||||
app_model: App,
|
||||
workflow: Workflow,
|
||||
user: Union[Account, EndUser],
|
||||
application_generate_entity: WorkflowAppGenerateEntity,
|
||||
graph_runtime_state: GraphRuntimeState,
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
def resume(self, *, workflow_run_id: str) -> None:
|
||||
"""
|
||||
Resume a paused workflow execution using the persisted runtime state.
|
||||
@TBD
|
||||
"""
|
||||
return self._generate(
|
||||
app_model=app_model,
|
||||
workflow=workflow,
|
||||
user=user,
|
||||
application_generate_entity=application_generate_entity,
|
||||
invoke_from=application_generate_entity.invoke_from,
|
||||
workflow_execution_repository=workflow_execution_repository,
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=application_generate_entity.stream,
|
||||
variable_loader=variable_loader,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
pause_state_config=pause_state_config,
|
||||
)
|
||||
pass
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
@@ -277,8 +241,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
variable_loader: VariableLoader = DUMMY_VARIABLE_LOADER,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
pause_state_config: PauseStateLayerConfig | None = None,
|
||||
) -> Union[Mapping[str, Any], Generator[str | Mapping[str, Any], None, None]]:
|
||||
"""
|
||||
Generate App response.
|
||||
@@ -292,8 +254,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
:param workflow_node_execution_repository: repository for workflow node execution
|
||||
:param streaming: is stream
|
||||
"""
|
||||
graph_layers: list[GraphEngineLayer] = list(graph_engine_layers)
|
||||
|
||||
# init queue manager
|
||||
queue_manager = WorkflowAppQueueManager(
|
||||
task_id=application_generate_entity.task_id,
|
||||
@@ -302,15 +262,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
app_mode=app_model.mode,
|
||||
)
|
||||
|
||||
if pause_state_config is not None:
|
||||
graph_layers.append(
|
||||
PauseStatePersistenceLayer(
|
||||
session_factory=pause_state_config.session_factory,
|
||||
generate_entity=application_generate_entity,
|
||||
state_owner_user_id=pause_state_config.state_owner_user_id,
|
||||
)
|
||||
)
|
||||
|
||||
# new thread with request context and contextvars
|
||||
context = contextvars.copy_context()
|
||||
|
||||
@@ -328,8 +279,7 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
"root_node_id": root_node_id,
|
||||
"workflow_execution_repository": workflow_execution_repository,
|
||||
"workflow_node_execution_repository": workflow_node_execution_repository,
|
||||
"graph_engine_layers": tuple(graph_layers),
|
||||
"graph_runtime_state": graph_runtime_state,
|
||||
"graph_engine_layers": graph_engine_layers,
|
||||
},
|
||||
)
|
||||
|
||||
@@ -431,7 +381,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
pause_state_config=None,
|
||||
)
|
||||
|
||||
def single_loop_generate(
|
||||
@@ -513,7 +462,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
streaming=streaming,
|
||||
variable_loader=var_loader,
|
||||
pause_state_config=None,
|
||||
)
|
||||
|
||||
def _generate_worker(
|
||||
@@ -527,7 +475,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
root_node_id: str | None = None,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
) -> None:
|
||||
"""
|
||||
Generate worker in a new thread.
|
||||
@@ -573,7 +520,6 @@ class WorkflowAppGenerator(BaseAppGenerator):
|
||||
workflow_node_execution_repository=workflow_node_execution_repository,
|
||||
root_node_id=root_node_id,
|
||||
graph_engine_layers=graph_engine_layers,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
)
|
||||
|
||||
try:
|
||||
|
||||
@@ -42,7 +42,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
workflow_execution_repository: WorkflowExecutionRepository,
|
||||
workflow_node_execution_repository: WorkflowNodeExecutionRepository,
|
||||
graph_engine_layers: Sequence[GraphEngineLayer] = (),
|
||||
graph_runtime_state: GraphRuntimeState | None = None,
|
||||
):
|
||||
super().__init__(
|
||||
queue_manager=queue_manager,
|
||||
@@ -56,7 +55,6 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
self._root_node_id = root_node_id
|
||||
self._workflow_execution_repository = workflow_execution_repository
|
||||
self._workflow_node_execution_repository = workflow_node_execution_repository
|
||||
self._resume_graph_runtime_state = graph_runtime_state
|
||||
|
||||
@trace_span(WorkflowAppRunnerHandler)
|
||||
def run(self):
|
||||
@@ -65,28 +63,23 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
"""
|
||||
app_config = self.application_generate_entity.app_config
|
||||
app_config = cast(WorkflowAppConfig, app_config)
|
||||
|
||||
system_inputs = SystemVariable(
|
||||
files=self.application_generate_entity.files,
|
||||
user_id=self._sys_user_id,
|
||||
app_id=app_config.app_id,
|
||||
timestamp=int(naive_utc_now().timestamp()),
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
)
|
||||
|
||||
invoke_from = self.application_generate_entity.invoke_from
|
||||
# if only single iteration or single loop run is requested
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
invoke_from = InvokeFrom.DEBUGGER
|
||||
user_from = self._resolve_user_from(invoke_from)
|
||||
|
||||
resume_state = self._resume_graph_runtime_state
|
||||
|
||||
if resume_state is not None:
|
||||
graph_runtime_state = resume_state
|
||||
variable_pool = graph_runtime_state.variable_pool
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
workflow_id=self._workflow.id,
|
||||
tenant_id=self._workflow.tenant_id,
|
||||
user_id=self.application_generate_entity.user_id,
|
||||
user_from=user_from,
|
||||
invoke_from=invoke_from,
|
||||
root_node_id=self._root_node_id,
|
||||
)
|
||||
elif self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
if self.application_generate_entity.single_iteration_run or self.application_generate_entity.single_loop_run:
|
||||
graph, variable_pool, graph_runtime_state = self._prepare_single_node_execution(
|
||||
workflow=self._workflow,
|
||||
single_iteration_run=self.application_generate_entity.single_iteration_run,
|
||||
@@ -96,14 +89,7 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
inputs = self.application_generate_entity.inputs
|
||||
|
||||
# Create a variable pool.
|
||||
system_inputs = SystemVariable(
|
||||
files=self.application_generate_entity.files,
|
||||
user_id=self._sys_user_id,
|
||||
app_id=app_config.app_id,
|
||||
timestamp=int(naive_utc_now().timestamp()),
|
||||
workflow_id=app_config.workflow_id,
|
||||
workflow_execution_id=self.application_generate_entity.workflow_execution_id,
|
||||
)
|
||||
|
||||
variable_pool = VariablePool(
|
||||
system_variables=system_inputs,
|
||||
user_inputs=inputs,
|
||||
@@ -112,6 +98,8 @@ class WorkflowAppRunner(WorkflowBasedAppRunner):
|
||||
)
|
||||
|
||||
graph_runtime_state = GraphRuntimeState(variable_pool=variable_pool, start_at=time.perf_counter())
|
||||
|
||||
# init graph
|
||||
graph = self._init_graph(
|
||||
graph_config=self._workflow.graph_dict,
|
||||
graph_runtime_state=graph_runtime_state,
|
||||
|
||||
@@ -1,7 +0,0 @@
|
||||
from libs.exception import BaseHTTPException
|
||||
|
||||
|
||||
class WorkflowPausedInBlockingModeError(BaseHTTPException):
|
||||
error_code = "workflow_paused_in_blocking_mode"
|
||||
description = "Workflow execution paused for human input; blocking response mode is not supported."
|
||||
code = 400
|
||||
@@ -16,8 +16,6 @@ from core.app.entities.queue_entities import (
|
||||
MessageQueueMessage,
|
||||
QueueAgentLogEvent,
|
||||
QueueErrorEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
QueueHumanInputFormTimeoutEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
@@ -34,7 +32,6 @@ from core.app.entities.queue_entities import (
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
WorkflowQueueMessage,
|
||||
@@ -49,13 +46,11 @@ from core.app.entities.task_entities import (
|
||||
WorkflowAppBlockingResponse,
|
||||
WorkflowAppStreamResponse,
|
||||
WorkflowFinishStreamResponse,
|
||||
WorkflowPauseStreamResponse,
|
||||
WorkflowStartStreamResponse,
|
||||
)
|
||||
from core.app.task_pipeline.based_generate_task_pipeline import BasedGenerateTaskPipeline
|
||||
from core.base.tts import AppGeneratorTTSPublisher, AudioTrunk
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import WorkflowExecutionStatus
|
||||
from core.workflow.repositories.draft_variable_repository import DraftVariableSaverFactory
|
||||
from core.workflow.runtime import GraphRuntimeState
|
||||
@@ -137,25 +132,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
for stream_response in generator:
|
||||
if isinstance(stream_response, ErrorStreamResponse):
|
||||
raise stream_response.err
|
||||
elif isinstance(stream_response, WorkflowPauseStreamResponse):
|
||||
response = WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=stream_response.data.workflow_run_id,
|
||||
data=WorkflowAppBlockingResponse.Data(
|
||||
id=stream_response.data.workflow_run_id,
|
||||
workflow_id=self._workflow.id,
|
||||
status=stream_response.data.status,
|
||||
outputs=stream_response.data.outputs or {},
|
||||
error=None,
|
||||
elapsed_time=stream_response.data.elapsed_time,
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=stream_response.data.created_at,
|
||||
finished_at=None,
|
||||
),
|
||||
)
|
||||
|
||||
return response
|
||||
elif isinstance(stream_response, WorkflowFinishStreamResponse):
|
||||
response = WorkflowAppBlockingResponse(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
@@ -170,7 +146,7 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
total_tokens=stream_response.data.total_tokens,
|
||||
total_steps=stream_response.data.total_steps,
|
||||
created_at=int(stream_response.data.created_at),
|
||||
finished_at=int(stream_response.data.finished_at) if stream_response.data.finished_at else None,
|
||||
finished_at=int(stream_response.data.finished_at),
|
||||
),
|
||||
)
|
||||
|
||||
@@ -283,15 +259,13 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
run_id = self._extract_workflow_run_id(runtime_state)
|
||||
self._workflow_execution_id = run_id
|
||||
|
||||
if event.reason == WorkflowStartReason.INITIAL:
|
||||
with self._database_session() as session:
|
||||
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
|
||||
with self._database_session() as session:
|
||||
self._save_workflow_app_log(session=session, workflow_run_id=self._workflow_execution_id)
|
||||
|
||||
start_resp = self._workflow_response_converter.workflow_start_to_stream_response(
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
workflow_run_id=run_id,
|
||||
workflow_id=self._workflow.id,
|
||||
reason=event.reason,
|
||||
)
|
||||
yield start_resp
|
||||
|
||||
@@ -466,21 +440,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
)
|
||||
yield workflow_finish_resp
|
||||
|
||||
def _handle_workflow_paused_event(
|
||||
self,
|
||||
event: QueueWorkflowPausedEvent,
|
||||
**kwargs,
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle workflow paused events."""
|
||||
self._ensure_workflow_initialized()
|
||||
validated_state = self._ensure_graph_runtime_initialized()
|
||||
responses = self._workflow_response_converter.workflow_pause_to_stream_response(
|
||||
event=event,
|
||||
task_id=self._application_generate_entity.task_id,
|
||||
graph_runtime_state=validated_state,
|
||||
)
|
||||
yield from responses
|
||||
|
||||
def _handle_workflow_failed_and_stop_events(
|
||||
self,
|
||||
event: Union[QueueWorkflowFailedEvent, QueueStopEvent],
|
||||
@@ -536,22 +495,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
task_id=self._application_generate_entity.task_id, event=event
|
||||
)
|
||||
|
||||
def _handle_human_input_form_filled_event(
|
||||
self, event: QueueHumanInputFormFilledEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle human input form filled events."""
|
||||
yield self._workflow_response_converter.human_input_form_filled_to_stream_response(
|
||||
event=event, task_id=self._application_generate_entity.task_id
|
||||
)
|
||||
|
||||
def _handle_human_input_form_timeout_event(
|
||||
self, event: QueueHumanInputFormTimeoutEvent, **kwargs
|
||||
) -> Generator[StreamResponse, None, None]:
|
||||
"""Handle human input form timeout events."""
|
||||
yield self._workflow_response_converter.human_input_form_timeout_to_stream_response(
|
||||
event=event, task_id=self._application_generate_entity.task_id
|
||||
)
|
||||
|
||||
def _get_event_handlers(self) -> dict[type, Callable]:
|
||||
"""Get mapping of event types to their handlers using fluent pattern."""
|
||||
return {
|
||||
@@ -563,7 +506,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
QueueWorkflowStartedEvent: self._handle_workflow_started_event,
|
||||
QueueWorkflowSucceededEvent: self._handle_workflow_succeeded_event,
|
||||
QueueWorkflowPartialSuccessEvent: self._handle_workflow_partial_success_event,
|
||||
QueueWorkflowPausedEvent: self._handle_workflow_paused_event,
|
||||
# Node events
|
||||
QueueNodeRetryEvent: self._handle_node_retry_event,
|
||||
QueueNodeStartedEvent: self._handle_node_started_event,
|
||||
@@ -578,8 +520,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
QueueLoopCompletedEvent: self._handle_loop_completed_event,
|
||||
# Agent events
|
||||
QueueAgentLogEvent: self._handle_agent_log_event,
|
||||
QueueHumanInputFormFilledEvent: self._handle_human_input_form_filled_event,
|
||||
QueueHumanInputFormTimeoutEvent: self._handle_human_input_form_timeout_event,
|
||||
}
|
||||
|
||||
def _dispatch_event(
|
||||
@@ -662,9 +602,6 @@ class WorkflowAppGenerateTaskPipeline(GraphRuntimeStateSupport):
|
||||
case QueueWorkflowFailedEvent():
|
||||
yield from self._handle_workflow_failed_and_stop_events(event)
|
||||
break
|
||||
case QueueWorkflowPausedEvent():
|
||||
yield from self._handle_workflow_paused_event(event)
|
||||
break
|
||||
|
||||
case QueueStopEvent():
|
||||
yield from self._handle_workflow_failed_and_stop_events(event)
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, cast
|
||||
@@ -8,8 +7,6 @@ from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.entities.queue_entities import (
|
||||
AppQueueEvent,
|
||||
QueueAgentLogEvent,
|
||||
QueueHumanInputFormFilledEvent,
|
||||
QueueHumanInputFormTimeoutEvent,
|
||||
QueueIterationCompletedEvent,
|
||||
QueueIterationNextEvent,
|
||||
QueueIterationStartEvent,
|
||||
@@ -25,27 +22,22 @@ from core.app.entities.queue_entities import (
|
||||
QueueTextChunkEvent,
|
||||
QueueWorkflowFailedEvent,
|
||||
QueueWorkflowPartialSuccessEvent,
|
||||
QueueWorkflowPausedEvent,
|
||||
QueueWorkflowStartedEvent,
|
||||
QueueWorkflowSucceededEvent,
|
||||
)
|
||||
from core.app.workflow.node_factory import DifyNodeFactory
|
||||
from core.workflow.entities import GraphInitParams
|
||||
from core.workflow.entities.pause_reason import HumanInputRequired
|
||||
from core.workflow.graph import Graph
|
||||
from core.workflow.graph_engine.layers.base import GraphEngineLayer
|
||||
from core.workflow.graph_events import (
|
||||
GraphEngineEvent,
|
||||
GraphRunFailedEvent,
|
||||
GraphRunPartialSucceededEvent,
|
||||
GraphRunPausedEvent,
|
||||
GraphRunStartedEvent,
|
||||
GraphRunSucceededEvent,
|
||||
NodeRunAgentLogEvent,
|
||||
NodeRunExceptionEvent,
|
||||
NodeRunFailedEvent,
|
||||
NodeRunHumanInputFormFilledEvent,
|
||||
NodeRunHumanInputFormTimeoutEvent,
|
||||
NodeRunIterationFailedEvent,
|
||||
NodeRunIterationNextEvent,
|
||||
NodeRunIterationStartedEvent,
|
||||
@@ -69,9 +61,6 @@ from core.workflow.variable_loader import DUMMY_VARIABLE_LOADER, VariableLoader,
|
||||
from core.workflow.workflow_entry import WorkflowEntry
|
||||
from models.enums import UserFrom
|
||||
from models.workflow import Workflow
|
||||
from tasks.mail_human_input_delivery_task import dispatch_human_input_email_task
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class WorkflowBasedAppRunner:
|
||||
@@ -338,7 +327,7 @@ class WorkflowBasedAppRunner:
|
||||
:param event: event
|
||||
"""
|
||||
if isinstance(event, GraphRunStartedEvent):
|
||||
self._publish_event(QueueWorkflowStartedEvent(reason=event.reason))
|
||||
self._publish_event(QueueWorkflowStartedEvent())
|
||||
elif isinstance(event, GraphRunSucceededEvent):
|
||||
self._publish_event(QueueWorkflowSucceededEvent(outputs=event.outputs))
|
||||
elif isinstance(event, GraphRunPartialSucceededEvent):
|
||||
@@ -349,38 +338,6 @@ class WorkflowBasedAppRunner:
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.error, exceptions_count=event.exceptions_count))
|
||||
elif isinstance(event, GraphRunAbortedEvent):
|
||||
self._publish_event(QueueWorkflowFailedEvent(error=event.reason or "Unknown error", exceptions_count=0))
|
||||
elif isinstance(event, GraphRunPausedEvent):
|
||||
runtime_state = workflow_entry.graph_engine.graph_runtime_state
|
||||
paused_nodes = runtime_state.get_paused_nodes()
|
||||
self._enqueue_human_input_notifications(event.reasons)
|
||||
self._publish_event(
|
||||
QueueWorkflowPausedEvent(
|
||||
reasons=event.reasons,
|
||||
outputs=event.outputs,
|
||||
paused_nodes=paused_nodes,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunHumanInputFormFilledEvent):
|
||||
self._publish_event(
|
||||
QueueHumanInputFormFilledEvent(
|
||||
node_execution_id=event.id,
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
rendered_content=event.rendered_content,
|
||||
action_id=event.action_id,
|
||||
action_text=event.action_text,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunHumanInputFormTimeoutEvent):
|
||||
self._publish_event(
|
||||
QueueHumanInputFormTimeoutEvent(
|
||||
node_id=event.node_id,
|
||||
node_type=event.node_type,
|
||||
node_title=event.node_title,
|
||||
expiration_time=event.expiration_time,
|
||||
)
|
||||
)
|
||||
elif isinstance(event, NodeRunRetryEvent):
|
||||
node_run_result = event.node_run_result
|
||||
inputs = node_run_result.inputs
|
||||
@@ -587,19 +544,5 @@ class WorkflowBasedAppRunner:
|
||||
)
|
||||
)
|
||||
|
||||
def _enqueue_human_input_notifications(self, reasons: Sequence[object]) -> None:
|
||||
for reason in reasons:
|
||||
if not isinstance(reason, HumanInputRequired):
|
||||
continue
|
||||
if not reason.form_id:
|
||||
continue
|
||||
try:
|
||||
dispatch_human_input_email_task.apply_async(
|
||||
kwargs={"form_id": reason.form_id, "node_title": reason.node_title},
|
||||
queue="mail",
|
||||
)
|
||||
except Exception: # pragma: no cover - defensive logging
|
||||
logger.exception("Failed to enqueue human input email task for form %s", reason.form_id)
|
||||
|
||||
def _publish_event(self, event: AppQueueEvent):
|
||||
self._queue_manager.publish(event, PublishFrom.APPLICATION_MANAGER)
|
||||
|
||||
@@ -132,7 +132,7 @@ class AppGenerateEntity(BaseModel):
|
||||
extras: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# tracing instance
|
||||
trace_manager: Optional["TraceQueueManager"] = Field(default=None, exclude=True, repr=False)
|
||||
trace_manager: Optional["TraceQueueManager"] = None
|
||||
|
||||
|
||||
class EasyUIBasedAppGenerateEntity(AppGenerateEntity):
|
||||
@@ -156,7 +156,6 @@ class ConversationAppGenerateEntity(AppGenerateEntity):
|
||||
"""
|
||||
|
||||
conversation_id: str | None = None
|
||||
is_new_conversation: bool = False
|
||||
parent_message_id: str | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
|
||||
@@ -8,8 +8,6 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMResultChunk
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.pause_reason import PauseReason
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from core.workflow.nodes import NodeType
|
||||
|
||||
@@ -48,9 +46,6 @@ class QueueEvent(StrEnum):
|
||||
PING = "ping"
|
||||
STOP = "stop"
|
||||
RETRY = "retry"
|
||||
PAUSE = "pause"
|
||||
HUMAN_INPUT_FORM_FILLED = "human_input_form_filled"
|
||||
HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout"
|
||||
|
||||
|
||||
class AppQueueEvent(BaseModel):
|
||||
@@ -266,8 +261,6 @@ class QueueWorkflowStartedEvent(AppQueueEvent):
|
||||
"""QueueWorkflowStartedEvent entity."""
|
||||
|
||||
event: QueueEvent = QueueEvent.WORKFLOW_STARTED
|
||||
# Always present; mirrors GraphRunStartedEvent.reason for downstream consumers.
|
||||
reason: WorkflowStartReason = WorkflowStartReason.INITIAL
|
||||
|
||||
|
||||
class QueueWorkflowSucceededEvent(AppQueueEvent):
|
||||
@@ -491,35 +484,6 @@ class QueueStopEvent(AppQueueEvent):
|
||||
return reason_mapping.get(self.stopped_by, "Stopped by unknown reason.")
|
||||
|
||||
|
||||
class QueueHumanInputFormFilledEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueHumanInputFormFilledEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_FILLED
|
||||
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
rendered_content: str
|
||||
action_id: str
|
||||
action_text: str
|
||||
|
||||
|
||||
class QueueHumanInputFormTimeoutEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueHumanInputFormTimeoutEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.HUMAN_INPUT_FORM_TIMEOUT
|
||||
|
||||
node_id: str
|
||||
node_type: NodeType
|
||||
node_title: str
|
||||
expiration_time: datetime
|
||||
|
||||
|
||||
class QueueMessage(BaseModel):
|
||||
"""
|
||||
QueueMessage abstract entity
|
||||
@@ -545,14 +509,3 @@ class WorkflowQueueMessage(QueueMessage):
|
||||
"""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class QueueWorkflowPausedEvent(AppQueueEvent):
|
||||
"""
|
||||
QueueWorkflowPausedEvent entity
|
||||
"""
|
||||
|
||||
event: QueueEvent = QueueEvent.PAUSE
|
||||
reasons: Sequence[PauseReason] = Field(default_factory=list)
|
||||
outputs: Mapping[str, object] = Field(default_factory=dict)
|
||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||
|
||||
@@ -7,9 +7,7 @@ from pydantic import BaseModel, ConfigDict, Field
|
||||
from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.rag.entities.citation_metadata import RetrievalSourceMetadata
|
||||
from core.workflow.entities import AgentNodeStrategyInit
|
||||
from core.workflow.entities.workflow_start_reason import WorkflowStartReason
|
||||
from core.workflow.enums import WorkflowExecutionStatus, WorkflowNodeExecutionMetadataKey, WorkflowNodeExecutionStatus
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
|
||||
|
||||
class AnnotationReplyAccount(BaseModel):
|
||||
@@ -71,7 +69,6 @@ class StreamEvent(StrEnum):
|
||||
AGENT_THOUGHT = "agent_thought"
|
||||
AGENT_MESSAGE = "agent_message"
|
||||
WORKFLOW_STARTED = "workflow_started"
|
||||
WORKFLOW_PAUSED = "workflow_paused"
|
||||
WORKFLOW_FINISHED = "workflow_finished"
|
||||
NODE_STARTED = "node_started"
|
||||
NODE_FINISHED = "node_finished"
|
||||
@@ -85,9 +82,6 @@ class StreamEvent(StrEnum):
|
||||
TEXT_CHUNK = "text_chunk"
|
||||
TEXT_REPLACE = "text_replace"
|
||||
AGENT_LOG = "agent_log"
|
||||
HUMAN_INPUT_REQUIRED = "human_input_required"
|
||||
HUMAN_INPUT_FORM_FILLED = "human_input_form_filled"
|
||||
HUMAN_INPUT_FORM_TIMEOUT = "human_input_form_timeout"
|
||||
|
||||
|
||||
class StreamResponse(BaseModel):
|
||||
@@ -211,8 +205,6 @@ class WorkflowStartStreamResponse(StreamResponse):
|
||||
workflow_id: str
|
||||
inputs: Mapping[str, Any]
|
||||
created_at: int
|
||||
# Always present; mirrors QueueWorkflowStartedEvent.reason for SSE clients.
|
||||
reason: WorkflowStartReason = WorkflowStartReason.INITIAL
|
||||
|
||||
event: StreamEvent = StreamEvent.WORKFLOW_STARTED
|
||||
workflow_run_id: str
|
||||
@@ -239,7 +231,7 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
total_steps: int
|
||||
created_by: Mapping[str, object] = Field(default_factory=dict)
|
||||
created_at: int
|
||||
finished_at: int | None
|
||||
finished_at: int
|
||||
exceptions_count: int | None = 0
|
||||
files: Sequence[Mapping[str, Any]] | None = []
|
||||
|
||||
@@ -248,85 +240,6 @@ class WorkflowFinishStreamResponse(StreamResponse):
|
||||
data: Data
|
||||
|
||||
|
||||
class WorkflowPauseStreamResponse(StreamResponse):
|
||||
"""
|
||||
WorkflowPauseStreamResponse entity
|
||||
"""
|
||||
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
workflow_run_id: str
|
||||
paused_nodes: Sequence[str] = Field(default_factory=list)
|
||||
outputs: Mapping[str, Any] = Field(default_factory=dict)
|
||||
reasons: Sequence[Mapping[str, Any]] = Field(default_factory=list)
|
||||
status: WorkflowExecutionStatus
|
||||
created_at: int
|
||||
elapsed_time: float
|
||||
total_tokens: int
|
||||
total_steps: int
|
||||
|
||||
event: StreamEvent = StreamEvent.WORKFLOW_PAUSED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class HumanInputRequiredResponse(StreamResponse):
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
form_id: str
|
||||
node_id: str
|
||||
node_title: str
|
||||
form_content: str
|
||||
inputs: Sequence[FormInput] = Field(default_factory=list)
|
||||
actions: Sequence[UserAction] = Field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
form_token: str | None = None
|
||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
expiration_time: int = Field(..., description="Unix timestamp in seconds")
|
||||
|
||||
event: StreamEvent = StreamEvent.HUMAN_INPUT_REQUIRED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class HumanInputFormFilledResponse(StreamResponse):
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
node_title: str
|
||||
rendered_content: str
|
||||
action_id: str
|
||||
action_text: str
|
||||
|
||||
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_FILLED
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class HumanInputFormTimeoutResponse(StreamResponse):
|
||||
class Data(BaseModel):
|
||||
"""
|
||||
Data entity
|
||||
"""
|
||||
|
||||
node_id: str
|
||||
node_title: str
|
||||
expiration_time: int
|
||||
|
||||
event: StreamEvent = StreamEvent.HUMAN_INPUT_FORM_TIMEOUT
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
|
||||
class NodeStartStreamResponse(StreamResponse):
|
||||
"""
|
||||
NodeStartStreamResponse entity
|
||||
@@ -813,7 +726,7 @@ class WorkflowAppBlockingResponse(AppBlockingResponse):
|
||||
total_tokens: int
|
||||
total_steps: int
|
||||
created_at: int
|
||||
finished_at: int | None
|
||||
finished_at: int
|
||||
|
||||
workflow_run_id: str
|
||||
data: Data
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import contextlib
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
@@ -104,14 +103,6 @@ class RateLimit:
|
||||
)
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def rate_limit_context(rate_limit: RateLimit, request_id: str | None):
|
||||
request_id = rate_limit.enter(request_id)
|
||||
yield
|
||||
if request_id is not None:
|
||||
rate_limit.exit(request_id)
|
||||
|
||||
|
||||
class RateLimitGenerator:
|
||||
def __init__(self, rate_limit: RateLimit, generator: Generator[str, None, None], request_id: str):
|
||||
self.rate_limit = rate_limit
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Annotated, Literal, Self, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
@@ -53,14 +52,6 @@ class WorkflowResumptionContext(BaseModel):
|
||||
return self.generate_entity.entity
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class PauseStateLayerConfig:
|
||||
"""Configuration container for instantiating pause persistence layers."""
|
||||
|
||||
session_factory: Engine | sessionmaker[Session]
|
||||
state_owner_user_id: str
|
||||
|
||||
|
||||
class PauseStatePersistenceLayer(GraphEngineLayer):
|
||||
def __init__(
|
||||
self,
|
||||
|
||||
@@ -54,10 +54,11 @@ from core.model_runtime.entities.message_entities import (
|
||||
TextPromptMessageContent,
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.prompt.utils.prompt_message_util import PromptMessageUtil
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from core.tools.signature import sign_tool_file
|
||||
from events.message_event import message_was_created
|
||||
from extensions.ext_database import db
|
||||
@@ -412,10 +413,19 @@ class EasyUIBasedGenerateTaskPipeline(BasedGenerateTaskPipeline):
|
||||
message.message_metadata = self._task_state.metadata.model_dump_json()
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MESSAGE_TRACE, conversation_id=self._conversation_id, message_id=self._message_id
|
||||
)
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.MESSAGE_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
),
|
||||
payload={
|
||||
"conversation_id": self._conversation_id,
|
||||
"message_id": self._message_id,
|
||||
},
|
||||
),
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
message_was_created.send(
|
||||
|
||||
@@ -88,11 +88,10 @@ class MessageCycleManager:
|
||||
if isinstance(self._application_generate_entity, CompletionAppGenerateEntity):
|
||||
return None
|
||||
|
||||
is_first_message = self._application_generate_entity.is_new_conversation
|
||||
is_first_message = self._application_generate_entity.conversation_id is None
|
||||
extras = self._application_generate_entity.extras
|
||||
auto_generate_conversation_name = extras.get("auto_generate_conversation_name", True)
|
||||
|
||||
thread: Thread | None = None
|
||||
if auto_generate_conversation_name and is_first_message:
|
||||
# start generate thread
|
||||
# time.sleep not block other logic
|
||||
@@ -108,10 +107,9 @@ class MessageCycleManager:
|
||||
thread.daemon = True
|
||||
thread.start()
|
||||
|
||||
if is_first_message:
|
||||
self._application_generate_entity.is_new_conversation = False
|
||||
return thread
|
||||
|
||||
return thread
|
||||
return None
|
||||
|
||||
def _generate_conversation_name_worker(self, flask_app: Flask, conversation_id: str, query: str):
|
||||
with flask_app.app_context():
|
||||
|
||||
@@ -15,8 +15,7 @@ from datetime import datetime
|
||||
from typing import Any, Union
|
||||
|
||||
from core.app.entities.app_invoke_entities import AdvancedChatAppGenerateEntity, WorkflowAppGenerateEntity
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.workflow.constants import SYSTEM_VARIABLE_NODE_ID
|
||||
from core.workflow.entities import WorkflowExecution, WorkflowNodeExecution
|
||||
from core.workflow.enums import (
|
||||
@@ -373,6 +372,7 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
|
||||
self._workflow_node_execution_repository.save(domain_execution)
|
||||
self._workflow_node_execution_repository.save_execution_data(domain_execution)
|
||||
self._enqueue_node_trace_task(domain_execution)
|
||||
|
||||
def _fail_running_node_executions(self, *, error_message: str) -> None:
|
||||
now = naive_utc_now()
|
||||
@@ -390,17 +390,138 @@ class WorkflowPersistenceLayer(GraphEngineLayer):
|
||||
|
||||
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
|
||||
external_trace_id = None
|
||||
parent_trace_context = None
|
||||
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
|
||||
external_trace_id = self._application_generate_entity.extras.get("external_trace_id")
|
||||
parent_trace_context = self._application_generate_entity.extras.get("parent_trace_context")
|
||||
|
||||
trace_task = TraceTask(
|
||||
TraceTaskName.WORKFLOW_TRACE,
|
||||
workflow_execution=execution,
|
||||
conversation_id=conversation_id,
|
||||
user_id=self._trace_manager.user_id,
|
||||
external_trace_id=external_trace_id,
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.WORKFLOW_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=self._application_generate_entity.app_config.tenant_id,
|
||||
user_id=self._trace_manager.user_id,
|
||||
app_id=self._application_generate_entity.app_config.app_id,
|
||||
),
|
||||
payload={
|
||||
"workflow_execution": execution,
|
||||
"conversation_id": conversation_id,
|
||||
"user_id": self._trace_manager.user_id,
|
||||
"external_trace_id": external_trace_id,
|
||||
"parent_trace_context": parent_trace_context,
|
||||
},
|
||||
),
|
||||
trace_manager=self._trace_manager,
|
||||
)
|
||||
|
||||
def _enqueue_node_trace_task(self, domain_execution: WorkflowNodeExecution) -> None:
|
||||
if not self._trace_manager:
|
||||
return
|
||||
|
||||
execution = self._get_workflow_execution()
|
||||
meta = domain_execution.metadata or {}
|
||||
|
||||
parent_trace_context = None
|
||||
if isinstance(self._application_generate_entity, (WorkflowAppGenerateEntity, AdvancedChatAppGenerateEntity)):
|
||||
parent_trace_context = self._application_generate_entity.extras.get("parent_trace_context")
|
||||
|
||||
node_data: dict[str, Any] = {
|
||||
"workflow_id": domain_execution.workflow_id,
|
||||
"workflow_execution_id": execution.id_,
|
||||
"tenant_id": self._application_generate_entity.app_config.tenant_id,
|
||||
"app_id": self._application_generate_entity.app_config.app_id,
|
||||
"node_execution_id": domain_execution.id,
|
||||
"node_id": domain_execution.node_id,
|
||||
"node_type": str(domain_execution.node_type.value),
|
||||
"title": domain_execution.title,
|
||||
"status": str(domain_execution.status.value),
|
||||
"error": domain_execution.error,
|
||||
"elapsed_time": domain_execution.elapsed_time,
|
||||
"index": domain_execution.index,
|
||||
"predecessor_node_id": domain_execution.predecessor_node_id,
|
||||
"created_at": domain_execution.created_at,
|
||||
"finished_at": domain_execution.finished_at,
|
||||
"total_tokens": meta.get(WorkflowNodeExecutionMetadataKey.TOTAL_TOKENS, 0),
|
||||
"prompt_tokens": meta.get(WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS),
|
||||
"completion_tokens": meta.get(WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS),
|
||||
"total_price": meta.get(WorkflowNodeExecutionMetadataKey.TOTAL_PRICE, 0.0),
|
||||
"currency": meta.get(WorkflowNodeExecutionMetadataKey.CURRENCY),
|
||||
"tool_name": (meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO) or {}).get("tool_name")
|
||||
if isinstance(meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO), dict)
|
||||
else None,
|
||||
"iteration_id": meta.get(WorkflowNodeExecutionMetadataKey.ITERATION_ID),
|
||||
"iteration_index": meta.get(WorkflowNodeExecutionMetadataKey.ITERATION_INDEX),
|
||||
"loop_id": meta.get(WorkflowNodeExecutionMetadataKey.LOOP_ID),
|
||||
"loop_index": meta.get(WorkflowNodeExecutionMetadataKey.LOOP_INDEX),
|
||||
"parallel_id": meta.get(WorkflowNodeExecutionMetadataKey.PARALLEL_ID),
|
||||
"node_inputs": dict(domain_execution.inputs) if domain_execution.inputs else None,
|
||||
"node_outputs": dict(domain_execution.outputs) if domain_execution.outputs else None,
|
||||
"process_data": dict(domain_execution.process_data) if domain_execution.process_data else None,
|
||||
}
|
||||
node_data["invoke_from"] = self._application_generate_entity.invoke_from.value
|
||||
node_data["user_id"] = self._system_variables().get(SystemVariableKey.USER_ID.value)
|
||||
|
||||
# Extract model info from process_data — LLM nodes store provider/model there,
|
||||
if domain_execution.process_data:
|
||||
if mp := domain_execution.process_data.get("model_provider"):
|
||||
node_data["model_provider"] = mp
|
||||
if mn := domain_execution.process_data.get("model_name"):
|
||||
node_data["model_name"] = mn
|
||||
|
||||
if domain_execution.node_type.value == "knowledge-retrieval" and domain_execution.outputs:
|
||||
results = domain_execution.outputs.get("result") or []
|
||||
dataset_ids: list[str] = []
|
||||
dataset_names: list[str] = []
|
||||
for doc in results:
|
||||
if not isinstance(doc, dict):
|
||||
continue
|
||||
doc_meta = doc.get("metadata") or {}
|
||||
did = doc_meta.get("dataset_id")
|
||||
dname = doc_meta.get("dataset_name")
|
||||
if did and did not in dataset_ids:
|
||||
dataset_ids.append(did)
|
||||
if dname and dname not in dataset_names:
|
||||
dataset_names.append(dname)
|
||||
if dataset_ids:
|
||||
node_data["dataset_ids"] = dataset_ids
|
||||
if dataset_names:
|
||||
node_data["dataset_names"] = dataset_names
|
||||
|
||||
tool_info = meta.get(WorkflowNodeExecutionMetadataKey.TOOL_INFO)
|
||||
if isinstance(tool_info, dict):
|
||||
plugin_id = tool_info.get("plugin_unique_identifier")
|
||||
if plugin_id:
|
||||
node_data["plugin_name"] = plugin_id
|
||||
credential_id = tool_info.get("credential_id")
|
||||
if credential_id:
|
||||
node_data["credential_id"] = credential_id
|
||||
node_data["credential_provider_type"] = tool_info.get("provider_type")
|
||||
|
||||
conversation_id = self._system_variables().get(SystemVariableKey.CONVERSATION_ID.value)
|
||||
if conversation_id:
|
||||
node_data["conversation_id"] = conversation_id
|
||||
|
||||
if parent_trace_context:
|
||||
node_data["parent_trace_context"] = parent_trace_context
|
||||
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.NODE_EXECUTION_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=node_data.get("tenant_id"),
|
||||
user_id=node_data.get("user_id"),
|
||||
app_id=node_data.get("app_id"),
|
||||
),
|
||||
payload={"node_execution_data": node_data},
|
||||
),
|
||||
trace_manager=self._trace_manager,
|
||||
)
|
||||
self._trace_manager.add_trace_task(trace_task)
|
||||
|
||||
def _system_variables(self) -> Mapping[str, Any]:
|
||||
runtime_state = self.graph_runtime_state
|
||||
|
||||
@@ -8,7 +8,6 @@ from core.file.file_manager import file_manager
|
||||
from core.helper.code_executor.code_executor import CodeExecutor
|
||||
from core.helper.code_executor.code_node_provider import CodeNodeProvider
|
||||
from core.helper.ssrf_proxy import ssrf_proxy
|
||||
from core.rag.retrieval.dataset_retrieval import DatasetRetrieval
|
||||
from core.tools.tool_file_manager import ToolFileManager
|
||||
from core.workflow.entities.graph_config import NodeConfigDict
|
||||
from core.workflow.enums import NodeType
|
||||
@@ -17,7 +16,6 @@ from core.workflow.nodes.base.node import Node
|
||||
from core.workflow.nodes.code.code_node import CodeNode
|
||||
from core.workflow.nodes.code.limits import CodeNodeLimits
|
||||
from core.workflow.nodes.http_request.node import HttpRequestNode
|
||||
from core.workflow.nodes.knowledge_retrieval.knowledge_retrieval_node import KnowledgeRetrievalNode
|
||||
from core.workflow.nodes.node_mapping import LATEST_VERSION, NODE_TYPE_CLASSES_MAPPING
|
||||
from core.workflow.nodes.protocols import FileManagerProtocol, HttpClientProtocol
|
||||
from core.workflow.nodes.template_transform.template_renderer import (
|
||||
@@ -49,7 +47,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
code_providers: Sequence[type[CodeNodeProvider]] | None = None,
|
||||
code_limits: CodeNodeLimits | None = None,
|
||||
template_renderer: Jinja2TemplateRenderer | None = None,
|
||||
template_transform_max_output_length: int | None = None,
|
||||
http_request_http_client: HttpClientProtocol | None = None,
|
||||
http_request_tool_file_manager_factory: Callable[[], ToolFileManager] = ToolFileManager,
|
||||
http_request_file_manager: FileManagerProtocol | None = None,
|
||||
@@ -71,13 +68,9 @@ class DifyNodeFactory(NodeFactory):
|
||||
max_object_array_length=dify_config.CODE_MAX_OBJECT_ARRAY_LENGTH,
|
||||
)
|
||||
self._template_renderer = template_renderer or CodeExecutorJinja2TemplateRenderer()
|
||||
self._template_transform_max_output_length = (
|
||||
template_transform_max_output_length or dify_config.TEMPLATE_TRANSFORM_MAX_LENGTH
|
||||
)
|
||||
self._http_request_http_client = http_request_http_client or ssrf_proxy
|
||||
self._http_request_tool_file_manager_factory = http_request_tool_file_manager_factory
|
||||
self._http_request_file_manager = http_request_file_manager or file_manager
|
||||
self._rag_retrieval = DatasetRetrieval()
|
||||
|
||||
@override
|
||||
def create_node(self, node_config: NodeConfigDict) -> Node:
|
||||
@@ -129,7 +122,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
template_renderer=self._template_renderer,
|
||||
max_output_length=self._template_transform_max_output_length,
|
||||
)
|
||||
|
||||
if node_type == NodeType.HTTP_REQUEST:
|
||||
@@ -143,15 +135,6 @@ class DifyNodeFactory(NodeFactory):
|
||||
file_manager=self._http_request_file_manager,
|
||||
)
|
||||
|
||||
if node_type == NodeType.KNOWLEDGE_RETRIEVAL:
|
||||
return KnowledgeRetrievalNode(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
graph_init_params=self.graph_init_params,
|
||||
graph_runtime_state=self.graph_runtime_state,
|
||||
rag_retrieval=self._rag_retrieval,
|
||||
)
|
||||
|
||||
return node_class(
|
||||
id=node_id,
|
||||
config=node_config,
|
||||
|
||||
@@ -4,8 +4,9 @@ from typing import Any, TextIO, Union
|
||||
from pydantic import BaseModel
|
||||
|
||||
from configs import dify_config
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from core.tools.entities.tool_entities import ToolInvokeMessage
|
||||
|
||||
_TEXT_COLOR_MAPPING = {
|
||||
@@ -36,13 +37,15 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
|
||||
color: str | None = ""
|
||||
current_loop: int = 1
|
||||
tenant_id: str | None = None
|
||||
|
||||
def __init__(self, color: str | None = None):
|
||||
def __init__(self, color: str | None = None, tenant_id: str | None = None):
|
||||
super().__init__()
|
||||
"""Initialize callback handler."""
|
||||
# use a specific color is not specified
|
||||
self.color = color or "green"
|
||||
self.current_loop = 1
|
||||
self.tenant_id = tenant_id
|
||||
|
||||
def on_tool_start(
|
||||
self,
|
||||
@@ -71,15 +74,23 @@ class DifyAgentCallbackHandler(BaseModel):
|
||||
print_text("\n")
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.TOOL_TRACE,
|
||||
message_id=message_id,
|
||||
tool_name=tool_name,
|
||||
tool_inputs=tool_inputs,
|
||||
tool_outputs=tool_outputs,
|
||||
timer=timer,
|
||||
)
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.TOOL_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=self.tenant_id,
|
||||
app_id=trace_manager.app_id,
|
||||
user_id=trace_manager.user_id,
|
||||
),
|
||||
payload={
|
||||
"message_id": message_id,
|
||||
"tool_name": tool_name,
|
||||
"tool_inputs": tool_inputs,
|
||||
"tool_outputs": tool_outputs,
|
||||
"timer": timer,
|
||||
},
|
||||
),
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
def on_tool_error(self, error: Union[Exception, KeyboardInterrupt], **kwargs: Any):
|
||||
|
||||
@@ -1,54 +0,0 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Mapping, Sequence
|
||||
from typing import Any, TypeAlias
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from core.workflow.nodes.human_input.entities import FormInput, UserAction
|
||||
from models.execution_extra_content import ExecutionContentType
|
||||
|
||||
|
||||
class HumanInputFormDefinition(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
form_id: str
|
||||
node_id: str
|
||||
node_title: str
|
||||
form_content: str
|
||||
inputs: Sequence[FormInput] = Field(default_factory=list)
|
||||
actions: Sequence[UserAction] = Field(default_factory=list)
|
||||
display_in_ui: bool = False
|
||||
form_token: str | None = None
|
||||
resolved_default_values: Mapping[str, Any] = Field(default_factory=dict)
|
||||
expiration_time: int
|
||||
|
||||
|
||||
class HumanInputFormSubmissionData(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
node_id: str
|
||||
node_title: str
|
||||
rendered_content: str
|
||||
action_id: str
|
||||
action_text: str
|
||||
|
||||
|
||||
class HumanInputContent(BaseModel):
|
||||
model_config = ConfigDict(frozen=True)
|
||||
|
||||
workflow_run_id: str
|
||||
submitted: bool
|
||||
form_definition: HumanInputFormDefinition | None = None
|
||||
form_submission_data: HumanInputFormSubmissionData | None = None
|
||||
type: ExecutionContentType = Field(default=ExecutionContentType.HUMAN_INPUT)
|
||||
|
||||
|
||||
ExecutionExtraContentDomainModel: TypeAlias = HumanInputContent
|
||||
|
||||
__all__ = [
|
||||
"ExecutionExtraContentDomainModel",
|
||||
"HumanInputContent",
|
||||
"HumanInputFormDefinition",
|
||||
"HumanInputFormSubmissionData",
|
||||
]
|
||||
@@ -28,8 +28,8 @@ from core.model_runtime.entities.provider_entities import (
|
||||
)
|
||||
from core.model_runtime.model_providers.__base.ai_model import AIModel
|
||||
from core.model_runtime.model_providers.model_provider_factory import ModelProviderFactory
|
||||
from extensions.ext_database import db
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from models.engine import db
|
||||
from models.provider import (
|
||||
LoadBalancingModelConfig,
|
||||
Provider,
|
||||
|
||||
@@ -6,8 +6,7 @@ from yarl import URL
|
||||
|
||||
from configs import dify_config
|
||||
from core.helper.download import download_with_size_limit
|
||||
from core.plugin.entities.marketplace import MarketplacePluginDeclaration, MarketplacePluginSnapshot
|
||||
from extensions.ext_redis import redis_client
|
||||
from core.plugin.entities.marketplace import MarketplacePluginDeclaration
|
||||
|
||||
marketplace_api_url = URL(str(dify_config.MARKETPLACE_API_URL))
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -44,37 +43,28 @@ def batch_fetch_plugin_by_ids(plugin_ids: list[str]) -> list[dict]:
|
||||
return data.get("data", {}).get("plugins", [])
|
||||
|
||||
|
||||
def batch_fetch_plugin_manifests_ignore_deserialization_error(
|
||||
plugin_ids: list[str],
|
||||
) -> Sequence[MarketplacePluginDeclaration]:
|
||||
if len(plugin_ids) == 0:
|
||||
return []
|
||||
|
||||
url = str(marketplace_api_url / "api/v1/plugins/batch")
|
||||
response = httpx.post(url, json={"plugin_ids": plugin_ids}, headers={"X-Dify-Version": dify_config.project.version})
|
||||
response.raise_for_status()
|
||||
result: list[MarketplacePluginDeclaration] = []
|
||||
for plugin in response.json()["data"]["plugins"]:
|
||||
try:
|
||||
result.append(MarketplacePluginDeclaration.model_validate(plugin))
|
||||
except Exception:
|
||||
logger.exception(
|
||||
"Failed to deserialize marketplace plugin manifest for %s", plugin.get("plugin_id", "unknown")
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def record_install_plugin_event(plugin_unique_identifier: str):
|
||||
url = str(marketplace_api_url / "api/v1/stats/plugins/install_count")
|
||||
response = httpx.post(url, json={"unique_identifier": plugin_unique_identifier})
|
||||
response.raise_for_status()
|
||||
|
||||
|
||||
def fetch_global_plugin_manifest(cache_key_prefix: str, cache_ttl: int) -> None:
|
||||
"""
|
||||
Fetch all plugin manifests from marketplace and cache them in Redis.
|
||||
This should be called once per check cycle to populate the instance-level cache.
|
||||
|
||||
Args:
|
||||
cache_key_prefix: Redis key prefix for caching plugin manifests
|
||||
cache_ttl: Cache TTL in seconds
|
||||
|
||||
Raises:
|
||||
httpx.HTTPError: If the HTTP request fails
|
||||
Exception: If any other error occurs during fetching or caching
|
||||
"""
|
||||
url = str(marketplace_api_url / "api/v1/dist/plugins/manifest.json")
|
||||
response = httpx.get(url, headers={"X-Dify-Version": dify_config.project.version}, timeout=30)
|
||||
response.raise_for_status()
|
||||
|
||||
raw_json = response.json()
|
||||
plugins_data = raw_json.get("plugins", [])
|
||||
|
||||
# Parse and cache all plugin snapshots
|
||||
for plugin_data in plugins_data:
|
||||
plugin_snapshot = MarketplacePluginSnapshot.model_validate(plugin_data)
|
||||
redis_client.setex(
|
||||
name=f"{cache_key_prefix}{plugin_snapshot.plugin_id}",
|
||||
time=cache_ttl,
|
||||
value=plugin_snapshot.model_dump_json(),
|
||||
)
|
||||
|
||||
@@ -9,6 +9,7 @@ class RuleGeneratePayload(BaseModel):
|
||||
instruction: str = Field(..., description="Rule generation instruction")
|
||||
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
|
||||
no_variable: bool = Field(default=False, description="Whether to exclude variables")
|
||||
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")
|
||||
|
||||
|
||||
class RuleCodeGeneratePayload(RuleGeneratePayload):
|
||||
@@ -18,3 +19,4 @@ class RuleCodeGeneratePayload(RuleGeneratePayload):
|
||||
class RuleStructuredOutputPayload(BaseModel):
|
||||
instruction: str = Field(..., description="Structured output generation instruction")
|
||||
model_config_data: ModelConfig = Field(..., alias="model_config", description="Model configuration")
|
||||
app_id: str | None = Field(default=None, description="App ID for prompt generation tracing")
|
||||
|
||||
@@ -7,7 +7,6 @@ from typing import Protocol, cast
|
||||
import json_repair
|
||||
|
||||
from core.app.app_config.entities import ModelConfig
|
||||
from core.llm_generator.entities import RuleCodeGeneratePayload, RuleGeneratePayload, RuleStructuredOutputPayload
|
||||
from core.llm_generator.output_parser.rule_config_generator import RuleConfigGeneratorOutputParser
|
||||
from core.llm_generator.output_parser.suggested_questions_after_answer import SuggestedQuestionsAfterAnswerOutputParser
|
||||
from core.llm_generator.prompts import (
|
||||
@@ -27,10 +26,11 @@ from core.model_runtime.entities.llm_entities import LLMResult
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, SystemPromptMessage, UserPromptMessage
|
||||
from core.model_runtime.entities.model_entities import ModelType
|
||||
from core.model_runtime.errors.invoke import InvokeAuthorizationError, InvokeError
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.entities.trace_entity import OperationType
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.utils.prompt_template_parser import PromptTemplateParser
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from core.workflow.entities.workflow_node_execution import WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
@@ -74,7 +74,7 @@ class LLMGenerator:
|
||||
prompt_messages=list(prompts), model_parameters={"max_tokens": 500, "temperature": 1}, stream=False
|
||||
)
|
||||
answer = response.message.get_text_content()
|
||||
if answer == "":
|
||||
if answer is None:
|
||||
return ""
|
||||
try:
|
||||
result_dict = json.loads(answer)
|
||||
@@ -96,15 +96,17 @@ class LLMGenerator:
|
||||
name = name[:75] + "..."
|
||||
|
||||
# get tracing instance
|
||||
trace_manager = TraceQueueManager(app_id=app_id)
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.GENERATE_NAME_TRACE,
|
||||
conversation_id=conversation_id,
|
||||
generate_conversation_name=name,
|
||||
inputs=prompt,
|
||||
timer=timer,
|
||||
tenant_id=tenant_id,
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.GENERATE_NAME_TRACE,
|
||||
context=TelemetryContext(tenant_id=tenant_id, app_id=app_id),
|
||||
payload={
|
||||
"conversation_id": conversation_id,
|
||||
"generate_conversation_name": name,
|
||||
"inputs": prompt,
|
||||
"timer": timer,
|
||||
"tenant_id": tenant_id,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -153,19 +155,27 @@ class LLMGenerator:
|
||||
return questions
|
||||
|
||||
@classmethod
|
||||
def generate_rule_config(cls, tenant_id: str, args: RuleGeneratePayload):
|
||||
def generate_rule_config(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
instruction: str,
|
||||
model_config: ModelConfig,
|
||||
no_variable: bool,
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
):
|
||||
output_parser = RuleConfigGeneratorOutputParser()
|
||||
|
||||
error = ""
|
||||
error_step = ""
|
||||
rule_config = {"prompt": "", "variables": [], "opening_statement": "", "error": ""}
|
||||
model_parameters = args.model_config_data.completion_params
|
||||
if args.no_variable:
|
||||
model_parameters = model_config.completion_params
|
||||
if no_variable:
|
||||
prompt_template = PromptTemplateParser(WORKFLOW_RULE_CONFIG_PROMPT_GENERATE_TEMPLATE)
|
||||
|
||||
prompt_generate = prompt_template.format(
|
||||
inputs={
|
||||
"TASK_DESCRIPTION": args.instruction,
|
||||
"TASK_DESCRIPTION": instruction,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
@@ -177,26 +187,45 @@ class LLMGenerator:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=args.model_config_data.provider,
|
||||
model=args.model_config_data.name,
|
||||
provider=model_config.provider,
|
||||
model=model_config.name,
|
||||
)
|
||||
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
llm_result = None
|
||||
with measure_time() as timer:
|
||||
try:
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
rule_config["prompt"] = response.message.get_text_content()
|
||||
rule_config["prompt"] = llm_result.message.get_text_content() or ""
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate rule config"
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
||||
rule_config["error"] = str(e)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate rule config"
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", model_config.name)
|
||||
rule_config["error"] = str(e)
|
||||
error = str(e)
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
if user_id:
|
||||
prompt_value = rule_config.get("prompt", "")
|
||||
generated_output = str(prompt_value) if prompt_value else ""
|
||||
cls._emit_prompt_generation_trace(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=OperationType.RULE_GENERATE,
|
||||
instruction=instruction,
|
||||
generated_output=generated_output,
|
||||
llm_result=llm_result,
|
||||
model_config=model_config,
|
||||
timer=timer,
|
||||
error=error or None,
|
||||
)
|
||||
|
||||
return rule_config
|
||||
|
||||
# get rule config prompt, parameter and statement
|
||||
@@ -211,7 +240,7 @@ class LLMGenerator:
|
||||
# format the prompt_generate_prompt
|
||||
prompt_generate_prompt = prompt_template.format(
|
||||
inputs={
|
||||
"TASK_DESCRIPTION": args.instruction,
|
||||
"TASK_DESCRIPTION": instruction,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
@@ -222,84 +251,125 @@ class LLMGenerator:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=args.model_config_data.provider,
|
||||
model=args.model_config_data.name,
|
||||
provider=model_config.provider,
|
||||
model=model_config.name,
|
||||
)
|
||||
|
||||
try:
|
||||
llm_result = None
|
||||
with measure_time() as timer:
|
||||
try:
|
||||
# the first step to generate the task prompt
|
||||
prompt_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
try:
|
||||
# the first step to generate the task prompt
|
||||
prompt_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
llm_result = prompt_content
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate prefix prompt"
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
if user_id:
|
||||
cls._emit_prompt_generation_trace(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=OperationType.RULE_GENERATE,
|
||||
instruction=instruction,
|
||||
generated_output="",
|
||||
llm_result=llm_result,
|
||||
model_config=model_config,
|
||||
timer=timer,
|
||||
error=error,
|
||||
)
|
||||
|
||||
return rule_config
|
||||
|
||||
rule_config["prompt"] = prompt_content.message.get_text_content() or ""
|
||||
|
||||
if not isinstance(prompt_content.message.content, str):
|
||||
raise NotImplementedError("prompt content is not a string")
|
||||
parameter_generate_prompt = parameter_template.format(
|
||||
inputs={
|
||||
"INPUT_TEXT": prompt_content.message.content,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate prefix prompt"
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)]
|
||||
|
||||
return rule_config
|
||||
|
||||
rule_config["prompt"] = prompt_content.message.get_text_content()
|
||||
|
||||
parameter_generate_prompt = parameter_template.format(
|
||||
inputs={
|
||||
"INPUT_TEXT": prompt_content.message.get_text_content(),
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
parameter_messages = [UserPromptMessage(content=parameter_generate_prompt)]
|
||||
|
||||
# the second step to generate the task_parameter and task_statement
|
||||
statement_generate_prompt = statement_template.format(
|
||||
inputs={
|
||||
"TASK_DESCRIPTION": args.instruction,
|
||||
"INPUT_TEXT": prompt_content.message.get_text_content(),
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
|
||||
|
||||
try:
|
||||
parameter_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
|
||||
# the second step to generate the task_parameter and task_statement
|
||||
statement_generate_prompt = statement_template.format(
|
||||
inputs={
|
||||
"TASK_DESCRIPTION": instruction,
|
||||
"INPUT_TEXT": prompt_content.message.content,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
rule_config["variables"] = re.findall(r'"\s*([^"]+)\s*"', parameter_content.message.get_text_content())
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate variables"
|
||||
statement_messages = [UserPromptMessage(content=statement_generate_prompt)]
|
||||
|
||||
try:
|
||||
statement_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["opening_statement"] = statement_content.message.get_text_content()
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate conversation opener"
|
||||
try:
|
||||
parameter_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(parameter_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["variables"] = re.findall(
|
||||
r'"\s*([^"]+)\s*"', prompt_content.message.get_text_content() or ""
|
||||
)
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate variables"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", args.model_config_data.name)
|
||||
rule_config["error"] = str(e)
|
||||
try:
|
||||
statement_content: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(statement_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
rule_config["opening_statement"] = statement_content.message.get_text_content() or ""
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
error_step = "generate conversation opener"
|
||||
|
||||
except Exception as e:
|
||||
logger.exception("Failed to generate rule config, model: %s", model_config.name)
|
||||
rule_config["error"] = str(e)
|
||||
error = str(e)
|
||||
|
||||
rule_config["error"] = f"Failed to {error_step}. Error: {error}" if error else ""
|
||||
|
||||
if user_id:
|
||||
generated_output = rule_config.get("prompt", "")
|
||||
cls._emit_prompt_generation_trace(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=OperationType.RULE_GENERATE,
|
||||
instruction=instruction,
|
||||
generated_output=str(generated_output) if generated_output else "",
|
||||
llm_result=llm_result,
|
||||
model_config=model_config,
|
||||
timer=timer,
|
||||
error=error or None,
|
||||
)
|
||||
|
||||
return rule_config
|
||||
|
||||
@classmethod
|
||||
def generate_code(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
args: RuleCodeGeneratePayload,
|
||||
instruction: str,
|
||||
model_config: ModelConfig,
|
||||
code_language: str = "javascript",
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
):
|
||||
if args.code_language == "python":
|
||||
if code_language == "python":
|
||||
prompt_template = PromptTemplateParser(PYTHON_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||
else:
|
||||
prompt_template = PromptTemplateParser(JAVASCRIPT_CODE_GENERATOR_PROMPT_TEMPLATE)
|
||||
|
||||
prompt = prompt_template.format(
|
||||
inputs={
|
||||
"INSTRUCTION": args.instruction,
|
||||
"CODE_LANGUAGE": args.code_language,
|
||||
"INSTRUCTION": instruction,
|
||||
"CODE_LANGUAGE": code_language,
|
||||
},
|
||||
remove_template_variables=False,
|
||||
)
|
||||
@@ -308,28 +378,49 @@ class LLMGenerator:
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=args.model_config_data.provider,
|
||||
model=args.model_config_data.name,
|
||||
provider=model_config.provider,
|
||||
model=model_config.name,
|
||||
)
|
||||
|
||||
prompt_messages = [UserPromptMessage(content=prompt)]
|
||||
model_parameters = args.model_config_data.completion_params
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
model_parameters = model_config.completion_params
|
||||
|
||||
llm_result = None
|
||||
error = None
|
||||
with measure_time() as timer:
|
||||
try:
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
generated_code = llm_result.message.get_text_content() or ""
|
||||
result = {"code": generated_code, "language": code_language, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
result = {"code": "", "language": code_language, "error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to invoke LLM model, model: %s, language: %s", model_config.name, code_language
|
||||
)
|
||||
error = str(e)
|
||||
result = {"code": "", "language": code_language, "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
if user_id:
|
||||
cls._emit_prompt_generation_trace(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=OperationType.CODE_GENERATE,
|
||||
instruction=instruction,
|
||||
generated_output=result.get("code", ""),
|
||||
llm_result=llm_result,
|
||||
model_config=model_config,
|
||||
timer=timer,
|
||||
error=error,
|
||||
)
|
||||
|
||||
generated_code = response.message.get_text_content()
|
||||
return {"code": generated_code, "language": args.code_language, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"code": "", "language": args.code_language, "error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception(
|
||||
"Failed to invoke LLM model, model: %s, language: %s", args.model_config_data.name, args.code_language
|
||||
)
|
||||
return {"code": "", "language": args.code_language, "error": f"An unexpected error occurred: {str(e)}"}
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def generate_qa_document(cls, tenant_id: str, query, document_language: str):
|
||||
@@ -355,49 +446,81 @@ class LLMGenerator:
|
||||
raise TypeError("Expected LLMResult when stream=False")
|
||||
response = result
|
||||
|
||||
answer = response.message.get_text_content()
|
||||
answer = response.message.get_text_content() or ""
|
||||
return answer.strip()
|
||||
|
||||
@classmethod
|
||||
def generate_structured_output(cls, tenant_id: str, args: RuleStructuredOutputPayload):
|
||||
def generate_structured_output(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
instruction: str,
|
||||
model_config: ModelConfig,
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
):
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=args.model_config_data.provider,
|
||||
model=args.model_config_data.name,
|
||||
provider=model_config.provider,
|
||||
model=model_config.name,
|
||||
)
|
||||
|
||||
prompt_messages = [
|
||||
SystemPromptMessage(content=SYSTEM_STRUCTURED_OUTPUT_GENERATE),
|
||||
UserPromptMessage(content=args.instruction),
|
||||
UserPromptMessage(content=instruction),
|
||||
]
|
||||
model_parameters = args.model_config_data.completion_params
|
||||
model_parameters = model_config.completion_params
|
||||
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
llm_result = None
|
||||
error = None
|
||||
result = {"output": "", "error": ""}
|
||||
|
||||
with measure_time() as timer:
|
||||
try:
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
raw_content = llm_result.message.content
|
||||
|
||||
if not isinstance(raw_content, str):
|
||||
raise ValueError(f"LLM response content must be a string, got: {type(raw_content)}")
|
||||
|
||||
try:
|
||||
parsed_content = json.loads(raw_content)
|
||||
except json.JSONDecodeError:
|
||||
parsed_content = json_repair.loads(raw_content)
|
||||
|
||||
if not isinstance(parsed_content, dict | list):
|
||||
raise ValueError(f"Failed to parse structured output from llm: {raw_content}")
|
||||
|
||||
generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False)
|
||||
result = {"output": generated_json_schema, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
result = {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to invoke LLM model, model: %s", model_config.name)
|
||||
error = str(e)
|
||||
result = {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
if user_id:
|
||||
cls._emit_prompt_generation_trace(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=OperationType.STRUCTURED_OUTPUT,
|
||||
instruction=instruction,
|
||||
generated_output=result.get("output", ""),
|
||||
llm_result=llm_result,
|
||||
model_config=model_config,
|
||||
timer=timer,
|
||||
error=error,
|
||||
)
|
||||
|
||||
raw_content = response.message.get_text_content()
|
||||
|
||||
try:
|
||||
parsed_content = json.loads(raw_content)
|
||||
except json.JSONDecodeError:
|
||||
parsed_content = json_repair.loads(raw_content)
|
||||
|
||||
if not isinstance(parsed_content, dict | list):
|
||||
raise ValueError(f"Failed to parse structured output from llm: {raw_content}")
|
||||
|
||||
generated_json_schema = json.dumps(parsed_content, indent=2, ensure_ascii=False)
|
||||
return {"output": generated_json_schema, "error": ""}
|
||||
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"output": "", "error": f"Failed to generate JSON Schema. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to invoke LLM model, model: %s", args.model_config_data.name)
|
||||
return {"output": "", "error": f"An unexpected error occurred: {str(e)}"}
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def instruction_modify_legacy(
|
||||
@@ -407,12 +530,14 @@ class LLMGenerator:
|
||||
instruction: str,
|
||||
model_config: ModelConfig,
|
||||
ideal_output: str | None,
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
):
|
||||
last_run: Message | None = (
|
||||
db.session.query(Message).where(Message.app_id == flow_id).order_by(Message.created_at.desc()).first()
|
||||
)
|
||||
if not last_run:
|
||||
return LLMGenerator.__instruction_modify_common(
|
||||
result = LLMGenerator.__instruction_modify_common(
|
||||
tenant_id=tenant_id,
|
||||
model_config=model_config,
|
||||
last_run=None,
|
||||
@@ -421,22 +546,28 @@ class LLMGenerator:
|
||||
instruction=instruction,
|
||||
node_type="llm",
|
||||
ideal_output=ideal_output,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
)
|
||||
last_run_dict = {
|
||||
"query": last_run.query,
|
||||
"answer": last_run.answer,
|
||||
"error": last_run.error,
|
||||
}
|
||||
return LLMGenerator.__instruction_modify_common(
|
||||
tenant_id=tenant_id,
|
||||
model_config=model_config,
|
||||
last_run=last_run_dict,
|
||||
current=current,
|
||||
error_message=str(last_run.error),
|
||||
instruction=instruction,
|
||||
node_type="llm",
|
||||
ideal_output=ideal_output,
|
||||
)
|
||||
else:
|
||||
last_run_dict = {
|
||||
"query": last_run.query,
|
||||
"answer": last_run.answer,
|
||||
"error": last_run.error,
|
||||
}
|
||||
result = LLMGenerator.__instruction_modify_common(
|
||||
tenant_id=tenant_id,
|
||||
model_config=model_config,
|
||||
last_run=last_run_dict,
|
||||
current=current,
|
||||
error_message=str(last_run.error),
|
||||
instruction=instruction,
|
||||
node_type="llm",
|
||||
ideal_output=ideal_output,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
)
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def instruction_modify_workflow(
|
||||
@@ -448,6 +579,8 @@ class LLMGenerator:
|
||||
model_config: ModelConfig,
|
||||
ideal_output: str | None,
|
||||
workflow_service: WorkflowServiceInterface,
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
):
|
||||
session = db.session()
|
||||
|
||||
@@ -478,6 +611,8 @@ class LLMGenerator:
|
||||
instruction=instruction,
|
||||
node_type=node_type,
|
||||
ideal_output=ideal_output,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
)
|
||||
|
||||
def agent_log_of(node_execution: WorkflowNodeExecutionModel) -> Sequence:
|
||||
@@ -511,6 +646,8 @@ class LLMGenerator:
|
||||
instruction=instruction,
|
||||
node_type=last_run.node_type,
|
||||
ideal_output=ideal_output,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
@@ -523,6 +660,8 @@ class LLMGenerator:
|
||||
instruction: str,
|
||||
node_type: str,
|
||||
ideal_output: str | None,
|
||||
user_id: str | None = None,
|
||||
app_id: str | None = None,
|
||||
):
|
||||
LAST_RUN = "{{#last_run#}}"
|
||||
CURRENT = "{{#current#}}"
|
||||
@@ -562,24 +701,120 @@ class LLMGenerator:
|
||||
]
|
||||
model_parameters = {"temperature": 0.4}
|
||||
|
||||
try:
|
||||
response: LLMResult = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
llm_result = None
|
||||
error = None
|
||||
result = {}
|
||||
|
||||
with measure_time() as timer:
|
||||
try:
|
||||
llm_result = model_instance.invoke_llm(
|
||||
prompt_messages=list(prompt_messages), model_parameters=model_parameters, stream=False
|
||||
)
|
||||
|
||||
generated_raw = llm_result.message.get_text_content()
|
||||
first_brace = generated_raw.find("{")
|
||||
last_brace = generated_raw.rfind("}")
|
||||
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
|
||||
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
|
||||
json_str = generated_raw[first_brace : last_brace + 1]
|
||||
data = json_repair.loads(json_str)
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
|
||||
result = data
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
result = {"error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True)
|
||||
error = str(e)
|
||||
result = {"error": f"An unexpected error occurred: {str(e)}"}
|
||||
|
||||
if user_id:
|
||||
generated_output = ""
|
||||
if isinstance(result, dict):
|
||||
for key in ["prompt", "code", "output", "modified"]:
|
||||
if result.get(key):
|
||||
generated_output = str(result[key])
|
||||
break
|
||||
|
||||
LLMGenerator._emit_prompt_generation_trace(
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=OperationType.INSTRUCTION_MODIFY,
|
||||
instruction=instruction,
|
||||
generated_output=generated_output,
|
||||
llm_result=llm_result,
|
||||
model_config=model_config,
|
||||
timer=timer,
|
||||
error=error,
|
||||
)
|
||||
|
||||
generated_raw = response.message.get_text_content()
|
||||
first_brace = generated_raw.find("{")
|
||||
last_brace = generated_raw.rfind("}")
|
||||
if first_brace == -1 or last_brace == -1 or last_brace < first_brace:
|
||||
raise ValueError(f"Could not find a valid JSON object in response: {generated_raw}")
|
||||
json_str = generated_raw[first_brace : last_brace + 1]
|
||||
data = json_repair.loads(json_str)
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError(f"Expected a JSON object, but got {type(data).__name__}")
|
||||
return data
|
||||
except InvokeError as e:
|
||||
error = str(e)
|
||||
return {"error": f"Failed to generate code. Error: {error}"}
|
||||
except Exception as e:
|
||||
logger.exception("Failed to invoke LLM model, model: %s", json.dumps(model_config.name), exc_info=True)
|
||||
return {"error": f"An unexpected error occurred: {str(e)}"}
|
||||
return result
|
||||
|
||||
@classmethod
|
||||
def _emit_prompt_generation_trace(
|
||||
cls,
|
||||
tenant_id: str,
|
||||
user_id: str,
|
||||
app_id: str | None,
|
||||
operation_type: OperationType,
|
||||
instruction: str,
|
||||
generated_output: str,
|
||||
llm_result: LLMResult | None,
|
||||
model_config: ModelConfig | None = None,
|
||||
timer=None,
|
||||
error: str | None = None,
|
||||
):
|
||||
if llm_result:
|
||||
prompt_tokens = llm_result.usage.prompt_tokens
|
||||
completion_tokens = llm_result.usage.completion_tokens
|
||||
total_tokens = llm_result.usage.total_tokens
|
||||
model_name = llm_result.model
|
||||
# Extract provider from model_config if available, otherwise fall back to parsing model name
|
||||
if model_config and model_config.provider:
|
||||
model_provider = model_config.provider
|
||||
else:
|
||||
model_provider = model_name.split("/")[0] if "/" in model_name else ""
|
||||
latency = llm_result.usage.latency
|
||||
total_price = float(llm_result.usage.total_price) if llm_result.usage.total_price else None
|
||||
currency = llm_result.usage.currency
|
||||
else:
|
||||
prompt_tokens = 0
|
||||
completion_tokens = 0
|
||||
total_tokens = 0
|
||||
model_provider = model_config.provider if model_config else ""
|
||||
model_name = model_config.name if model_config else ""
|
||||
latency = 0.0
|
||||
if timer:
|
||||
start_time = timer.get("start")
|
||||
end_time = timer.get("end")
|
||||
if start_time and end_time:
|
||||
latency = (end_time - start_time).total_seconds()
|
||||
total_price = None
|
||||
currency = None
|
||||
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.PROMPT_GENERATION_TRACE,
|
||||
context=TelemetryContext(tenant_id=tenant_id, user_id=user_id, app_id=app_id),
|
||||
payload={
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id,
|
||||
"app_id": app_id,
|
||||
"operation_type": operation_type,
|
||||
"instruction": instruction,
|
||||
"generated_output": generated_output,
|
||||
"prompt_tokens": prompt_tokens,
|
||||
"completion_tokens": completion_tokens,
|
||||
"total_tokens": total_tokens,
|
||||
"model_provider": model_provider,
|
||||
"model_name": model_name,
|
||||
"latency": latency,
|
||||
"total_price": total_price,
|
||||
"currency": currency,
|
||||
"timer": timer,
|
||||
"error": error,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
@@ -15,16 +15,23 @@ class TraceContextFilter(logging.Filter):
|
||||
"""
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
# Get trace context from OpenTelemetry
|
||||
trace_id, span_id = self._get_otel_context()
|
||||
# Preserve explicit trace_id set by the caller (e.g. emit_metric_only_event)
|
||||
existing_trace_id = getattr(record, "trace_id", "")
|
||||
if not existing_trace_id:
|
||||
# Get trace context from OpenTelemetry
|
||||
trace_id, span_id = self._get_otel_context()
|
||||
|
||||
# Set trace_id (fallback to ContextVar if no OTEL context)
|
||||
if trace_id:
|
||||
record.trace_id = trace_id
|
||||
# Set trace_id (fallback to ContextVar if no OTEL context)
|
||||
if trace_id:
|
||||
record.trace_id = trace_id
|
||||
else:
|
||||
record.trace_id = get_trace_id()
|
||||
|
||||
record.span_id = span_id or ""
|
||||
else:
|
||||
record.trace_id = get_trace_id()
|
||||
|
||||
record.span_id = span_id or ""
|
||||
# Keep existing trace_id; only fill span_id if missing
|
||||
if not getattr(record, "span_id", ""):
|
||||
record.span_id = ""
|
||||
|
||||
# For backward compatibility, also set req_id
|
||||
record.req_id = get_request_id()
|
||||
@@ -55,9 +62,12 @@ class IdentityContextFilter(logging.Filter):
|
||||
|
||||
def filter(self, record: logging.LogRecord) -> bool:
|
||||
identity = self._extract_identity()
|
||||
record.tenant_id = identity.get("tenant_id", "")
|
||||
record.user_id = identity.get("user_id", "")
|
||||
record.user_type = identity.get("user_type", "")
|
||||
if not getattr(record, "tenant_id", ""):
|
||||
record.tenant_id = identity.get("tenant_id", "")
|
||||
if not getattr(record, "user_id", ""):
|
||||
record.user_id = identity.get("user_id", "")
|
||||
if not getattr(record, "user_type", ""):
|
||||
record.user_type = identity.get("user_type", "")
|
||||
return True
|
||||
|
||||
def _extract_identity(self) -> dict[str, str]:
|
||||
|
||||
@@ -5,9 +5,10 @@ from typing import Any
|
||||
from core.app.app_config.entities import AppConfig
|
||||
from core.moderation.base import ModerationAction, ModerationError
|
||||
from core.moderation.factory import ModerationFactory
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.ops.utils import measure_time
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -49,14 +50,18 @@ class InputModeration:
|
||||
moderation_result = moderation_factory.moderation_for_inputs(inputs, query)
|
||||
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.MODERATION_TRACE,
|
||||
message_id=message_id,
|
||||
moderation_result=moderation_result,
|
||||
inputs=inputs,
|
||||
timer=timer,
|
||||
)
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.MODERATION_TRACE,
|
||||
context=TelemetryContext(tenant_id=tenant_id, app_id=app_id),
|
||||
payload={
|
||||
"message_id": message_id,
|
||||
"moderation_result": moderation_result,
|
||||
"inputs": inputs,
|
||||
"timer": timer,
|
||||
},
|
||||
),
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
if not moderation_result.flagged:
|
||||
|
||||
@@ -9,8 +9,8 @@ from pydantic import BaseModel, ConfigDict, field_serializer, field_validator
|
||||
class BaseTraceInfo(BaseModel):
|
||||
message_id: str | None = None
|
||||
message_data: Any | None = None
|
||||
inputs: Union[str, dict[str, Any], list] | None = None
|
||||
outputs: Union[str, dict[str, Any], list] | None = None
|
||||
inputs: Union[str, dict[str, Any], list[Any]] | None = None
|
||||
outputs: Union[str, dict[str, Any], list[Any]] | None = None
|
||||
start_time: datetime | None = None
|
||||
end_time: datetime | None = None
|
||||
metadata: dict[str, Any]
|
||||
@@ -18,7 +18,7 @@ class BaseTraceInfo(BaseModel):
|
||||
|
||||
@field_validator("inputs", "outputs")
|
||||
@classmethod
|
||||
def ensure_type(cls, v):
|
||||
def ensure_type(cls, v: str | dict[str, Any] | list[Any] | None) -> str | dict[str, Any] | list[Any] | None:
|
||||
if v is None:
|
||||
return None
|
||||
if isinstance(v, str | dict | list):
|
||||
@@ -27,6 +27,48 @@ class BaseTraceInfo(BaseModel):
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@property
|
||||
def resolved_trace_id(self) -> str | None:
|
||||
"""Get trace_id with intelligent fallback.
|
||||
|
||||
Priority:
|
||||
1. External trace_id (from X-Trace-Id header)
|
||||
2. workflow_run_id (if this trace type has it)
|
||||
3. message_id (as final fallback)
|
||||
"""
|
||||
if self.trace_id:
|
||||
return self.trace_id
|
||||
|
||||
# Try workflow_run_id (only exists on workflow-related traces)
|
||||
workflow_run_id = getattr(self, "workflow_run_id", None)
|
||||
if workflow_run_id:
|
||||
return workflow_run_id
|
||||
|
||||
# Final fallback to message_id
|
||||
return str(self.message_id) if self.message_id else None
|
||||
|
||||
@property
|
||||
def resolved_parent_context(self) -> tuple[str | None, str | None]:
|
||||
"""Resolve cross-workflow parent linking from metadata.
|
||||
|
||||
Extracts typed parent IDs from the untyped ``parent_trace_context``
|
||||
metadata dict (set by tool_node when invoking nested workflows).
|
||||
|
||||
Returns:
|
||||
(trace_correlation_override, parent_span_id_source) where
|
||||
trace_correlation_override is the outer workflow_run_id and
|
||||
parent_span_id_source is the outer node_execution_id.
|
||||
"""
|
||||
parent_ctx = self.metadata.get("parent_trace_context")
|
||||
if not isinstance(parent_ctx, dict):
|
||||
return None, None
|
||||
trace_override = parent_ctx.get("parent_workflow_run_id")
|
||||
parent_span = parent_ctx.get("parent_node_execution_id")
|
||||
return (
|
||||
trace_override if isinstance(trace_override, str) else None,
|
||||
parent_span if isinstance(parent_span, str) else None,
|
||||
)
|
||||
|
||||
@field_serializer("start_time", "end_time")
|
||||
def serialize_datetime(self, dt: datetime | None) -> str | None:
|
||||
if dt is None:
|
||||
@@ -48,10 +90,14 @@ class WorkflowTraceInfo(BaseTraceInfo):
|
||||
workflow_run_version: str
|
||||
error: str | None = None
|
||||
total_tokens: int
|
||||
prompt_tokens: int | None = None
|
||||
completion_tokens: int | None = None
|
||||
file_list: list[str]
|
||||
query: str
|
||||
metadata: dict[str, Any]
|
||||
|
||||
invoked_by: str | None = None
|
||||
|
||||
|
||||
class MessageTraceInfo(BaseTraceInfo):
|
||||
conversation_model: str
|
||||
@@ -59,7 +105,7 @@ class MessageTraceInfo(BaseTraceInfo):
|
||||
answer_tokens: int
|
||||
total_tokens: int
|
||||
error: str | None = None
|
||||
file_list: Union[str, dict[str, Any], list] | None = None
|
||||
file_list: Union[str, dict[str, Any], list[Any]] | None = None
|
||||
message_file_data: Any | None = None
|
||||
conversation_mode: str
|
||||
gen_ai_server_time_to_first_token: float | None = None
|
||||
@@ -106,7 +152,7 @@ class ToolTraceInfo(BaseTraceInfo):
|
||||
tool_config: dict[str, Any]
|
||||
time_cost: Union[int, float]
|
||||
tool_parameters: dict[str, Any]
|
||||
file_url: Union[str, None, list] = None
|
||||
file_url: Union[str, None, list[str]] = None
|
||||
|
||||
|
||||
class GenerateNameTraceInfo(BaseTraceInfo):
|
||||
@@ -114,6 +160,79 @@ class GenerateNameTraceInfo(BaseTraceInfo):
|
||||
tenant_id: str
|
||||
|
||||
|
||||
class PromptGenerationTraceInfo(BaseTraceInfo):
|
||||
"""Trace information for prompt generation operations (rule-generate, code-generate, etc.)."""
|
||||
|
||||
tenant_id: str
|
||||
user_id: str
|
||||
app_id: str | None = None
|
||||
|
||||
operation_type: str
|
||||
instruction: str
|
||||
|
||||
prompt_tokens: int
|
||||
completion_tokens: int
|
||||
total_tokens: int
|
||||
|
||||
model_provider: str
|
||||
model_name: str
|
||||
|
||||
latency: float
|
||||
|
||||
total_price: float | None = None
|
||||
currency: str | None = None
|
||||
|
||||
error: str | None = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class WorkflowNodeTraceInfo(BaseTraceInfo):
|
||||
workflow_id: str
|
||||
workflow_run_id: str
|
||||
tenant_id: str
|
||||
node_execution_id: str
|
||||
node_id: str
|
||||
node_type: str
|
||||
title: str
|
||||
|
||||
status: str
|
||||
error: str | None = None
|
||||
elapsed_time: float
|
||||
|
||||
index: int
|
||||
predecessor_node_id: str | None = None
|
||||
|
||||
total_tokens: int = 0
|
||||
total_price: float = 0.0
|
||||
currency: str | None = None
|
||||
|
||||
model_provider: str | None = None
|
||||
model_name: str | None = None
|
||||
prompt_tokens: int | None = None
|
||||
completion_tokens: int | None = None
|
||||
|
||||
tool_name: str | None = None
|
||||
|
||||
iteration_id: str | None = None
|
||||
iteration_index: int | None = None
|
||||
loop_id: str | None = None
|
||||
loop_index: int | None = None
|
||||
parallel_id: str | None = None
|
||||
|
||||
node_inputs: Mapping[str, Any] | None = None
|
||||
node_outputs: Mapping[str, Any] | None = None
|
||||
process_data: Mapping[str, Any] | None = None
|
||||
|
||||
invoked_by: str | None = None
|
||||
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
|
||||
class DraftNodeExecutionTrace(WorkflowNodeTraceInfo):
|
||||
pass
|
||||
|
||||
|
||||
class TaskData(BaseModel):
|
||||
app_id: str
|
||||
trace_info_type: str
|
||||
@@ -128,16 +247,38 @@ trace_info_info_map = {
|
||||
"DatasetRetrievalTraceInfo": DatasetRetrievalTraceInfo,
|
||||
"ToolTraceInfo": ToolTraceInfo,
|
||||
"GenerateNameTraceInfo": GenerateNameTraceInfo,
|
||||
"PromptGenerationTraceInfo": PromptGenerationTraceInfo,
|
||||
"WorkflowNodeTraceInfo": WorkflowNodeTraceInfo,
|
||||
"DraftNodeExecutionTrace": DraftNodeExecutionTrace,
|
||||
}
|
||||
|
||||
|
||||
class OperationType(StrEnum):
|
||||
"""Operation type for token metric labels.
|
||||
|
||||
Used as a metric attribute on ``dify.tokens.input`` / ``dify.tokens.output``
|
||||
counters so consumers can break down token usage by operation.
|
||||
"""
|
||||
|
||||
WORKFLOW = "workflow"
|
||||
NODE_EXECUTION = "node_execution"
|
||||
MESSAGE = "message"
|
||||
RULE_GENERATE = "rule_generate"
|
||||
CODE_GENERATE = "code_generate"
|
||||
STRUCTURED_OUTPUT = "structured_output"
|
||||
INSTRUCTION_MODIFY = "instruction_modify"
|
||||
|
||||
|
||||
class TraceTaskName(StrEnum):
|
||||
CONVERSATION_TRACE = "conversation"
|
||||
WORKFLOW_TRACE = "workflow"
|
||||
DRAFT_NODE_EXECUTION_TRACE = "draft_node_execution"
|
||||
MESSAGE_TRACE = "message"
|
||||
MODERATION_TRACE = "moderation"
|
||||
SUGGESTED_QUESTION_TRACE = "suggested_question"
|
||||
DATASET_RETRIEVAL_TRACE = "dataset_retrieval"
|
||||
TOOL_TRACE = "tool"
|
||||
GENERATE_NAME_TRACE = "generate_conversation_name"
|
||||
PROMPT_GENERATION_TRACE = "prompt_generation"
|
||||
DATASOURCE_TRACE = "datasource"
|
||||
NODE_EXECUTION_TRACE = "node_execution"
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from langfuse import Langfuse
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
@@ -30,7 +31,7 @@ from core.ops.utils import filter_none_values
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.enums import NodeType
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import EndUser, Message, WorkflowNodeExecutionTriggeredFrom
|
||||
from models.enums import MessageStatus
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -71,7 +72,50 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
metadata = trace_info.metadata
|
||||
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
|
||||
|
||||
if trace_info.message_id:
|
||||
# Check for parent_trace_context to detect nested workflow
|
||||
parent_trace_context = trace_info.metadata.get("parent_trace_context")
|
||||
|
||||
if parent_trace_context:
|
||||
# Nested workflow: create span under outer trace
|
||||
outer_trace_id = parent_trace_context.get("trace_id")
|
||||
parent_node_execution_id = parent_trace_context.get("parent_node_execution_id")
|
||||
parent_conversation_id = parent_trace_context.get("parent_conversation_id")
|
||||
parent_workflow_run_id = parent_trace_context.get("parent_workflow_run_id")
|
||||
|
||||
# Resolve outer trace_id: try message_id lookup first, fallback to workflow_run_id
|
||||
if parent_conversation_id:
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
with session_factory() as session:
|
||||
message_data_stmt = select(Message.id).where(
|
||||
Message.conversation_id == parent_conversation_id,
|
||||
Message.workflow_run_id == parent_workflow_run_id,
|
||||
)
|
||||
resolved_message_id = session.scalar(message_data_stmt)
|
||||
if resolved_message_id:
|
||||
outer_trace_id = resolved_message_id
|
||||
else:
|
||||
outer_trace_id = parent_workflow_run_id
|
||||
else:
|
||||
outer_trace_id = parent_workflow_run_id
|
||||
|
||||
# Create inner workflow span under outer trace
|
||||
workflow_span_data = LangfuseSpan(
|
||||
id=trace_info.workflow_run_id,
|
||||
name=TraceTaskName.WORKFLOW_TRACE,
|
||||
input=dict(trace_info.workflow_run_inputs),
|
||||
output=dict(trace_info.workflow_run_outputs),
|
||||
trace_id=outer_trace_id,
|
||||
parent_observation_id=parent_node_execution_id,
|
||||
start_time=trace_info.start_time,
|
||||
end_time=trace_info.end_time,
|
||||
metadata=metadata,
|
||||
level=LevelEnum.DEFAULT if trace_info.error == "" else LevelEnum.ERROR,
|
||||
status_message=trace_info.error or "",
|
||||
)
|
||||
self.add_span(langfuse_span_data=workflow_span_data)
|
||||
# Use outer_trace_id for all node spans/generations
|
||||
trace_id = outer_trace_id
|
||||
elif trace_info.message_id:
|
||||
trace_id = trace_info.trace_id or trace_info.message_id
|
||||
name = TraceTaskName.MESSAGE_TRACE
|
||||
trace_data = LangfuseTrace(
|
||||
@@ -174,6 +218,11 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
}
|
||||
)
|
||||
|
||||
# Determine parent_observation_id for nested workflows
|
||||
node_parent_observation_id = None
|
||||
if parent_trace_context or trace_info.message_id:
|
||||
node_parent_observation_id = trace_info.workflow_run_id
|
||||
|
||||
# add generation span
|
||||
if process_data and process_data.get("model_mode") == "chat":
|
||||
total_token = metadata.get("total_tokens", 0)
|
||||
@@ -206,7 +255,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||
status_message=trace_info.error or "",
|
||||
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
|
||||
parent_observation_id=node_parent_observation_id,
|
||||
usage=generation_usage,
|
||||
)
|
||||
|
||||
@@ -225,7 +274,7 @@ class LangFuseDataTrace(BaseTraceInstance):
|
||||
metadata=metadata,
|
||||
level=(LevelEnum.DEFAULT if status == "succeeded" else LevelEnum.ERROR),
|
||||
status_message=trace_info.error or "",
|
||||
parent_observation_id=trace_info.workflow_run_id if trace_info.message_id else None,
|
||||
parent_observation_id=node_parent_observation_id,
|
||||
)
|
||||
|
||||
self.add_span(langfuse_span_data=span_data)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import cast
|
||||
|
||||
from langsmith import Client
|
||||
from langsmith.schemas import RunBase
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import sessionmaker
|
||||
|
||||
from core.ops.base_trace_instance import BaseTraceInstance
|
||||
@@ -30,7 +31,7 @@ from core.ops.utils import filter_none_values, generate_dotted_order
|
||||
from core.repositories import DifyCoreRepositoryFactory
|
||||
from core.workflow.enums import NodeType, WorkflowNodeExecutionMetadataKey
|
||||
from extensions.ext_database import db
|
||||
from models import EndUser, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
from models import EndUser, Message, MessageFile, WorkflowNodeExecutionTriggeredFrom
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -64,7 +65,35 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
self.generate_name_trace(trace_info)
|
||||
|
||||
def workflow_trace(self, trace_info: WorkflowTraceInfo):
|
||||
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
# Check for parent_trace_context for cross-workflow linking
|
||||
parent_trace_context = trace_info.metadata.get("parent_trace_context")
|
||||
|
||||
if parent_trace_context:
|
||||
# Inner workflow: resolve outer trace_id and link to parent node
|
||||
outer_trace_id = parent_trace_context.get("parent_workflow_run_id")
|
||||
|
||||
# Try to resolve message_id from conversation_id if available
|
||||
if parent_trace_context.get("parent_conversation_id"):
|
||||
try:
|
||||
session_factory = sessionmaker(bind=db.engine)
|
||||
with session_factory() as session:
|
||||
message_data_stmt = select(Message.id).where(
|
||||
Message.conversation_id == parent_trace_context["parent_conversation_id"],
|
||||
Message.workflow_run_id == parent_trace_context["parent_workflow_run_id"],
|
||||
)
|
||||
resolved_message_id = session.scalar(message_data_stmt)
|
||||
if resolved_message_id:
|
||||
outer_trace_id = resolved_message_id
|
||||
except Exception as e:
|
||||
logger.debug("Failed to resolve message_id from conversation_id: %s", str(e))
|
||||
|
||||
trace_id = outer_trace_id
|
||||
parent_run_id = parent_trace_context.get("parent_node_execution_id")
|
||||
else:
|
||||
# Outer workflow: existing behavior
|
||||
trace_id = trace_info.trace_id or trace_info.message_id or trace_info.workflow_run_id
|
||||
parent_run_id = trace_info.message_id or None
|
||||
|
||||
if trace_info.start_time is None:
|
||||
trace_info.start_time = datetime.now()
|
||||
message_dotted_order = (
|
||||
@@ -78,7 +107,8 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
metadata = trace_info.metadata
|
||||
metadata["workflow_app_log_id"] = trace_info.workflow_app_log_id
|
||||
|
||||
if trace_info.message_id:
|
||||
# Only create message_run for outer workflows (no parent_trace_context)
|
||||
if trace_info.message_id and not parent_trace_context:
|
||||
message_run = LangSmithRunModel(
|
||||
id=trace_info.message_id,
|
||||
name=TraceTaskName.MESSAGE_TRACE,
|
||||
@@ -121,9 +151,9 @@ class LangSmithDataTrace(BaseTraceInstance):
|
||||
},
|
||||
error=trace_info.error,
|
||||
tags=["workflow"],
|
||||
parent_run_id=trace_info.message_id or None,
|
||||
parent_run_id=parent_run_id,
|
||||
trace_id=trace_id,
|
||||
dotted_order=workflow_dotted_order,
|
||||
dotted_order=None if parent_trace_context else workflow_dotted_order,
|
||||
serialized=None,
|
||||
events=[],
|
||||
session_id=None,
|
||||
|
||||
@@ -15,22 +15,32 @@ from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session, sessionmaker
|
||||
|
||||
from core.helper.encrypter import batch_decrypt_token, encrypt_token, obfuscated_token
|
||||
from core.ops.entities.config_entity import OPS_FILE_PATH, TracingProviderEnum
|
||||
from core.ops.entities.config_entity import (
|
||||
OPS_FILE_PATH,
|
||||
TracingProviderEnum,
|
||||
)
|
||||
from core.ops.entities.trace_entity import (
|
||||
DatasetRetrievalTraceInfo,
|
||||
DraftNodeExecutionTrace,
|
||||
GenerateNameTraceInfo,
|
||||
MessageTraceInfo,
|
||||
ModerationTraceInfo,
|
||||
PromptGenerationTraceInfo,
|
||||
SuggestedQuestionTraceInfo,
|
||||
TaskData,
|
||||
ToolTraceInfo,
|
||||
TraceTaskName,
|
||||
WorkflowNodeTraceInfo,
|
||||
WorkflowTraceInfo,
|
||||
)
|
||||
from core.ops.utils import get_message_data
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_storage import storage
|
||||
from models.engine import db
|
||||
from models.account import Tenant
|
||||
from models.dataset import Dataset
|
||||
from models.model import App, AppModelConfig, Conversation, Message, MessageFile, TraceAppConfig
|
||||
from models.provider import Provider, ProviderCredential, ProviderModel, ProviderModelCredential, ProviderType
|
||||
from models.tools import ApiToolProvider, BuiltinToolProvider, MCPToolProvider, WorkflowToolProvider
|
||||
from models.workflow import WorkflowAppLog
|
||||
from tasks.ops_trace_task import process_trace_tasks
|
||||
|
||||
@@ -40,6 +50,139 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _lookup_app_and_workspace_names(app_id: str | None, tenant_id: str | None) -> tuple[str, str]:
|
||||
"""Return (app_name, workspace_name) for the given IDs. Falls back to empty strings."""
|
||||
app_name = ""
|
||||
workspace_name = ""
|
||||
if not app_id and not tenant_id:
|
||||
return app_name, workspace_name
|
||||
with Session(db.engine) as session:
|
||||
if app_id:
|
||||
name = session.scalar(select(App.name).where(App.id == app_id))
|
||||
if name:
|
||||
app_name = name
|
||||
if tenant_id:
|
||||
name = session.scalar(select(Tenant.name).where(Tenant.id == tenant_id))
|
||||
if name:
|
||||
workspace_name = name
|
||||
return app_name, workspace_name
|
||||
|
||||
|
||||
_PROVIDER_TYPE_TO_MODEL: dict[str, type] = {
|
||||
"builtin": BuiltinToolProvider,
|
||||
"plugin": BuiltinToolProvider,
|
||||
"api": ApiToolProvider,
|
||||
"workflow": WorkflowToolProvider,
|
||||
"mcp": MCPToolProvider,
|
||||
}
|
||||
|
||||
|
||||
def _lookup_credential_name(credential_id: str | None, provider_type: str | None) -> str:
|
||||
if not credential_id:
|
||||
return ""
|
||||
model_cls = _PROVIDER_TYPE_TO_MODEL.get(provider_type or "")
|
||||
if not model_cls:
|
||||
return ""
|
||||
with Session(db.engine) as session:
|
||||
name = session.scalar(select(model_cls.name).where(model_cls.id == credential_id))
|
||||
return str(name) if name else ""
|
||||
|
||||
|
||||
def _lookup_llm_credential_info(
|
||||
tenant_id: str | None, provider: str | None, model: str | None, model_type: str | None = "llm"
|
||||
) -> tuple[str | None, str]:
|
||||
"""
|
||||
Lookup LLM credential ID and name for the given provider and model.
|
||||
Returns (credential_id, credential_name).
|
||||
|
||||
Handles async timing issues gracefully - if credential is deleted between lookups,
|
||||
returns the ID but empty name rather than failing.
|
||||
"""
|
||||
if not tenant_id or not provider:
|
||||
return None, ""
|
||||
|
||||
try:
|
||||
with Session(db.engine) as session:
|
||||
# Try to find provider-level or model-level configuration
|
||||
provider_record = session.scalar(
|
||||
select(Provider).where(
|
||||
Provider.tenant_id == tenant_id,
|
||||
Provider.provider_name == provider,
|
||||
Provider.provider_type == ProviderType.CUSTOM,
|
||||
)
|
||||
)
|
||||
|
||||
if not provider_record:
|
||||
return None, ""
|
||||
|
||||
# Check if there's a model-specific config
|
||||
credential_id = None
|
||||
credential_name = ""
|
||||
is_model_level = False
|
||||
|
||||
if model and provider_record.credential_id:
|
||||
# Try model-level first
|
||||
model_record = session.scalar(
|
||||
select(ProviderModel).where(
|
||||
ProviderModel.tenant_id == tenant_id,
|
||||
ProviderModel.provider_name == provider,
|
||||
ProviderModel.model_name == model,
|
||||
ProviderModel.model_type == model_type,
|
||||
)
|
||||
)
|
||||
|
||||
if model_record and model_record.credential_id:
|
||||
credential_id = model_record.credential_id
|
||||
is_model_level = True
|
||||
|
||||
if not credential_id and provider_record.credential_id:
|
||||
# Fall back to provider-level credential
|
||||
credential_id = provider_record.credential_id
|
||||
is_model_level = False
|
||||
|
||||
# Lookup credential_name if we have credential_id
|
||||
if credential_id:
|
||||
try:
|
||||
if is_model_level:
|
||||
# Query ProviderModelCredential
|
||||
cred_name = session.scalar(
|
||||
select(ProviderModelCredential.credential_name).where(
|
||||
ProviderModelCredential.id == credential_id
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Query ProviderCredential
|
||||
cred_name = session.scalar(
|
||||
select(ProviderCredential.credential_name).where(ProviderCredential.id == credential_id)
|
||||
)
|
||||
|
||||
if cred_name:
|
||||
credential_name = str(cred_name)
|
||||
except Exception as e:
|
||||
# Credential might have been deleted between lookups (async timing)
|
||||
# Return ID but empty name rather than failing
|
||||
logger.warning(
|
||||
"Failed to lookup credential name for credential_id=%s (provider=%s, model=%s): %s",
|
||||
credential_id,
|
||||
provider,
|
||||
model,
|
||||
str(e),
|
||||
)
|
||||
|
||||
return credential_id, credential_name
|
||||
except Exception as e:
|
||||
# Database query failed or other unexpected error
|
||||
# Return empty rather than propagating error to telemetry emission
|
||||
logger.warning(
|
||||
"Failed to lookup LLM credential info for tenant_id=%s, provider=%s, model=%s: %s",
|
||||
tenant_id,
|
||||
provider,
|
||||
model,
|
||||
str(e),
|
||||
)
|
||||
return None, ""
|
||||
|
||||
|
||||
class OpsTraceProviderConfigMap(collections.UserDict[str, dict[str, Any]]):
|
||||
def __getitem__(self, provider: str) -> dict[str, Any]:
|
||||
match provider:
|
||||
@@ -314,6 +457,10 @@ class OpsTraceManager:
|
||||
if app_id is None:
|
||||
return None
|
||||
|
||||
# Handle storage_id format (tenant-{uuid}) - not a real app_id
|
||||
if isinstance(app_id, str) and app_id.startswith("tenant-"):
|
||||
return None
|
||||
|
||||
app: App | None = db.session.query(App).where(App.id == app_id).first()
|
||||
|
||||
if app is None:
|
||||
@@ -466,8 +613,6 @@ class TraceTask:
|
||||
|
||||
@classmethod
|
||||
def _get_workflow_run_repo(cls):
|
||||
from repositories.factory import DifyAPIRepositoryFactory
|
||||
|
||||
if cls._workflow_run_repo is None:
|
||||
with cls._repo_lock:
|
||||
if cls._workflow_run_repo is None:
|
||||
@@ -478,6 +623,56 @@ class TraceTask:
|
||||
cls._workflow_run_repo = DifyAPIRepositoryFactory.create_api_workflow_run_repository(session_maker)
|
||||
return cls._workflow_run_repo
|
||||
|
||||
@classmethod
|
||||
def _get_user_id_from_metadata(cls, metadata: dict[str, Any]) -> str:
|
||||
"""Extract user ID from metadata, prioritizing end_user over account.
|
||||
|
||||
Returns the actual user ID (end_user or account) who invoked the workflow,
|
||||
regardless of invoke_from context.
|
||||
"""
|
||||
# Priority 1: End user (external users via API/WebApp)
|
||||
if user_id := metadata.get("from_end_user_id"):
|
||||
return f"end_user:{user_id}"
|
||||
|
||||
# Priority 2: Account user (internal users via console/debugger)
|
||||
if user_id := metadata.get("from_account_id"):
|
||||
return f"account:{user_id}"
|
||||
|
||||
# Priority 3: User (internal users via console/debugger)
|
||||
if user_id := metadata.get("user_id"):
|
||||
return f"user:{user_id}"
|
||||
|
||||
return "anonymous"
|
||||
|
||||
@classmethod
|
||||
def _calculate_workflow_token_split(cls, workflow_run_id: str, tenant_id: str) -> tuple[int, int]:
|
||||
from core.workflow.enums import WorkflowNodeExecutionMetadataKey
|
||||
from models.workflow import WorkflowNodeExecutionModel
|
||||
|
||||
with Session(db.engine) as session:
|
||||
node_executions = session.scalars(
|
||||
select(WorkflowNodeExecutionModel).where(
|
||||
WorkflowNodeExecutionModel.tenant_id == tenant_id,
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
)
|
||||
).all()
|
||||
|
||||
total_prompt = 0
|
||||
total_completion = 0
|
||||
|
||||
for node_exec in node_executions:
|
||||
metadata = node_exec.execution_metadata_dict
|
||||
|
||||
prompt = metadata.get(WorkflowNodeExecutionMetadataKey.PROMPT_TOKENS)
|
||||
if prompt is not None:
|
||||
total_prompt += prompt
|
||||
|
||||
completion = metadata.get(WorkflowNodeExecutionMetadataKey.COMPLETION_TOKENS)
|
||||
if completion is not None:
|
||||
total_completion += completion
|
||||
|
||||
return (total_prompt, total_completion)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
trace_type: Any,
|
||||
@@ -498,6 +693,8 @@ class TraceTask:
|
||||
self.app_id = None
|
||||
self.trace_id = None
|
||||
self.kwargs = kwargs
|
||||
if user_id is not None and "user_id" not in self.kwargs:
|
||||
self.kwargs["user_id"] = user_id
|
||||
external_trace_id = kwargs.get("external_trace_id")
|
||||
if external_trace_id:
|
||||
self.trace_id = external_trace_id
|
||||
@@ -511,7 +708,7 @@ class TraceTask:
|
||||
TraceTaskName.WORKFLOW_TRACE: lambda: self.workflow_trace(
|
||||
workflow_run_id=self.workflow_run_id, conversation_id=self.conversation_id, user_id=self.user_id
|
||||
),
|
||||
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id),
|
||||
TraceTaskName.MESSAGE_TRACE: lambda: self.message_trace(message_id=self.message_id, **self.kwargs),
|
||||
TraceTaskName.MODERATION_TRACE: lambda: self.moderation_trace(
|
||||
message_id=self.message_id, timer=self.timer, **self.kwargs
|
||||
),
|
||||
@@ -527,6 +724,9 @@ class TraceTask:
|
||||
TraceTaskName.GENERATE_NAME_TRACE: lambda: self.generate_name_trace(
|
||||
conversation_id=self.conversation_id, timer=self.timer, **self.kwargs
|
||||
),
|
||||
TraceTaskName.PROMPT_GENERATION_TRACE: lambda: self.prompt_generation_trace(**self.kwargs),
|
||||
TraceTaskName.NODE_EXECUTION_TRACE: lambda: self.node_execution_trace(**self.kwargs),
|
||||
TraceTaskName.DRAFT_NODE_EXECUTION_TRACE: lambda: self.draft_node_execution_trace(**self.kwargs),
|
||||
}
|
||||
|
||||
return preprocess_map.get(self.trace_type, lambda: None)()
|
||||
@@ -562,6 +762,10 @@ class TraceTask:
|
||||
|
||||
total_tokens = workflow_run.total_tokens
|
||||
|
||||
prompt_tokens, completion_tokens = self._calculate_workflow_token_split(
|
||||
workflow_run_id=workflow_run_id, tenant_id=tenant_id
|
||||
)
|
||||
|
||||
file_list = workflow_run_inputs.get("sys.file") or []
|
||||
query = workflow_run_inputs.get("query") or workflow_run_inputs.get("sys.query") or ""
|
||||
|
||||
@@ -582,7 +786,14 @@ class TraceTask:
|
||||
)
|
||||
message_id = session.scalar(message_data_stmt)
|
||||
|
||||
metadata = {
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
if is_enterprise_telemetry_enabled():
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(workflow_run.app_id, tenant_id)
|
||||
else:
|
||||
app_name, workspace_name = "", ""
|
||||
|
||||
metadata: dict[str, Any] = {
|
||||
"workflow_id": workflow_id,
|
||||
"conversation_id": conversation_id,
|
||||
"workflow_run_id": workflow_run_id,
|
||||
@@ -595,8 +806,14 @@ class TraceTask:
|
||||
"triggered_from": workflow_run.triggered_from,
|
||||
"user_id": user_id,
|
||||
"app_id": workflow_run.app_id,
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
}
|
||||
|
||||
parent_trace_context = self.kwargs.get("parent_trace_context")
|
||||
if parent_trace_context:
|
||||
metadata["parent_trace_context"] = parent_trace_context
|
||||
|
||||
workflow_trace_info = WorkflowTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
workflow_data=workflow_run.to_dict(),
|
||||
@@ -611,6 +828,8 @@ class TraceTask:
|
||||
workflow_run_version=workflow_run_version,
|
||||
error=error,
|
||||
total_tokens=total_tokens,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
file_list=file_list,
|
||||
query=query,
|
||||
metadata=metadata,
|
||||
@@ -618,10 +837,11 @@ class TraceTask:
|
||||
message_id=message_id,
|
||||
start_time=workflow_run.created_at,
|
||||
end_time=workflow_run.finished_at,
|
||||
invoked_by=self._get_user_id_from_metadata(metadata),
|
||||
)
|
||||
return workflow_trace_info
|
||||
|
||||
def message_trace(self, message_id: str | None):
|
||||
def message_trace(self, message_id: str | None, **kwargs):
|
||||
if not message_id:
|
||||
return {}
|
||||
message_data = get_message_data(message_id)
|
||||
@@ -644,6 +864,19 @@ class TraceTask:
|
||||
|
||||
streaming_metrics = self._extract_streaming_metrics(message_data)
|
||||
|
||||
tenant_id = ""
|
||||
with Session(db.engine) as session:
|
||||
tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id))
|
||||
if tid:
|
||||
tenant_id = str(tid)
|
||||
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
if is_enterprise_telemetry_enabled():
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
|
||||
else:
|
||||
app_name, workspace_name = "", ""
|
||||
|
||||
metadata = {
|
||||
"conversation_id": message_data.conversation_id,
|
||||
"ls_provider": message_data.model_provider,
|
||||
@@ -655,7 +888,14 @@ class TraceTask:
|
||||
"workflow_run_id": message_data.workflow_run_id,
|
||||
"from_source": message_data.from_source,
|
||||
"message_id": message_id,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": message_data.app_id,
|
||||
"user_id": message_data.from_end_user_id or message_data.from_account_id,
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
message_tokens = message_data.message_tokens
|
||||
|
||||
@@ -672,7 +912,9 @@ class TraceTask:
|
||||
outputs=message_data.answer,
|
||||
file_list=file_list,
|
||||
start_time=created_at,
|
||||
end_time=created_at + timedelta(seconds=message_data.provider_response_latency),
|
||||
end_time=message_data.updated_at
|
||||
if message_data.updated_at and message_data.updated_at > created_at
|
||||
else created_at + timedelta(seconds=message_data.provider_response_latency),
|
||||
metadata=metadata,
|
||||
message_file_data=message_file_data,
|
||||
conversation_mode=conversation_mode,
|
||||
@@ -697,6 +939,8 @@ class TraceTask:
|
||||
"preset_response": moderation_result.preset_response,
|
||||
"query": moderation_result.query,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
# get workflow_app_log_id
|
||||
workflow_app_log_id = None
|
||||
@@ -738,6 +982,8 @@ class TraceTask:
|
||||
"workflow_run_id": message_data.workflow_run_id,
|
||||
"from_source": message_data.from_source,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
# get workflow_app_log_id
|
||||
workflow_app_log_id = None
|
||||
@@ -777,6 +1023,52 @@ class TraceTask:
|
||||
if not message_data:
|
||||
return {}
|
||||
|
||||
tenant_id = ""
|
||||
with Session(db.engine) as session:
|
||||
tid = session.scalar(select(App.tenant_id).where(App.id == message_data.app_id))
|
||||
if tid:
|
||||
tenant_id = str(tid)
|
||||
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
if is_enterprise_telemetry_enabled():
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(message_data.app_id, tenant_id)
|
||||
else:
|
||||
app_name, workspace_name = "", ""
|
||||
|
||||
doc_list = [doc.model_dump() for doc in documents] if documents else []
|
||||
dataset_ids: set[str] = set()
|
||||
for doc in doc_list:
|
||||
doc_meta = doc.get("metadata") or {}
|
||||
did = doc_meta.get("dataset_id")
|
||||
if did:
|
||||
dataset_ids.add(did)
|
||||
|
||||
embedding_models: dict[str, dict[str, str]] = {}
|
||||
if dataset_ids:
|
||||
with Session(db.engine) as session:
|
||||
rows = session.execute(
|
||||
select(Dataset.id, Dataset.embedding_model, Dataset.embedding_model_provider).where(
|
||||
Dataset.id.in_(list(dataset_ids))
|
||||
)
|
||||
).all()
|
||||
for row in rows:
|
||||
embedding_models[str(row[0])] = {
|
||||
"embedding_model": row[1] or "",
|
||||
"embedding_model_provider": row[2] or "",
|
||||
}
|
||||
|
||||
# Extract rerank model info from retrieval_model kwargs
|
||||
rerank_model_provider = ""
|
||||
rerank_model_name = ""
|
||||
if "retrieval_model" in kwargs:
|
||||
retrieval_model = kwargs["retrieval_model"]
|
||||
if isinstance(retrieval_model, dict):
|
||||
reranking_model = retrieval_model.get("reranking_model")
|
||||
if isinstance(reranking_model, dict):
|
||||
rerank_model_provider = reranking_model.get("reranking_provider_name", "")
|
||||
rerank_model_name = reranking_model.get("reranking_model_name", "")
|
||||
|
||||
metadata = {
|
||||
"message_id": message_id,
|
||||
"ls_provider": message_data.model_provider,
|
||||
@@ -787,13 +1079,23 @@ class TraceTask:
|
||||
"agent_based": message_data.agent_based,
|
||||
"workflow_run_id": message_data.workflow_run_id,
|
||||
"from_source": message_data.from_source,
|
||||
"tenant_id": tenant_id,
|
||||
"app_id": message_data.app_id,
|
||||
"user_id": message_data.from_end_user_id or message_data.from_account_id,
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
"embedding_models": embedding_models,
|
||||
"rerank_model_provider": rerank_model_provider,
|
||||
"rerank_model_name": rerank_model_name,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
dataset_retrieval_trace_info = DatasetRetrievalTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
message_id=message_id,
|
||||
inputs=message_data.query or message_data.inputs,
|
||||
documents=[doc.model_dump() for doc in documents] if documents else [],
|
||||
documents=doc_list,
|
||||
start_time=timer.get("start"),
|
||||
end_time=timer.get("end"),
|
||||
metadata=metadata,
|
||||
@@ -836,6 +1138,10 @@ class TraceTask:
|
||||
"error": error,
|
||||
"tool_parameters": tool_parameters,
|
||||
}
|
||||
if message_data.workflow_run_id:
|
||||
metadata["workflow_run_id"] = message_data.workflow_run_id
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
file_url = ""
|
||||
message_file_data = db.session.query(MessageFile).filter_by(message_id=message_id).first()
|
||||
@@ -890,6 +1196,8 @@ class TraceTask:
|
||||
"conversation_id": conversation_id,
|
||||
"tenant_id": tenant_id,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
generate_name_trace_info = GenerateNameTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
@@ -904,6 +1212,182 @@ class TraceTask:
|
||||
|
||||
return generate_name_trace_info
|
||||
|
||||
def prompt_generation_trace(self, **kwargs) -> PromptGenerationTraceInfo | dict:
|
||||
tenant_id = kwargs.get("tenant_id", "")
|
||||
user_id = kwargs.get("user_id", "")
|
||||
app_id = kwargs.get("app_id")
|
||||
operation_type = kwargs.get("operation_type", "")
|
||||
instruction = kwargs.get("instruction", "")
|
||||
generated_output = kwargs.get("generated_output", "")
|
||||
|
||||
prompt_tokens = kwargs.get("prompt_tokens", 0)
|
||||
completion_tokens = kwargs.get("completion_tokens", 0)
|
||||
total_tokens = kwargs.get("total_tokens", 0)
|
||||
|
||||
model_provider = kwargs.get("model_provider", "")
|
||||
model_name = kwargs.get("model_name", "")
|
||||
|
||||
latency = kwargs.get("latency", 0.0)
|
||||
|
||||
timer = kwargs.get("timer")
|
||||
start_time = timer.get("start") if timer else None
|
||||
end_time = timer.get("end") if timer else None
|
||||
|
||||
total_price = kwargs.get("total_price")
|
||||
currency = kwargs.get("currency")
|
||||
|
||||
error = kwargs.get("error")
|
||||
|
||||
app_name = None
|
||||
workspace_name = None
|
||||
if app_id:
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(app_id, tenant_id)
|
||||
|
||||
metadata = {
|
||||
"tenant_id": tenant_id,
|
||||
"user_id": user_id,
|
||||
"app_id": app_id or "",
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
"operation_type": operation_type,
|
||||
"model_provider": model_provider,
|
||||
"model_name": model_name,
|
||||
}
|
||||
if node_execution_id := kwargs.get("node_execution_id"):
|
||||
metadata["node_execution_id"] = node_execution_id
|
||||
|
||||
return PromptGenerationTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
inputs=instruction,
|
||||
outputs=generated_output,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
metadata=metadata,
|
||||
tenant_id=tenant_id,
|
||||
user_id=user_id,
|
||||
app_id=app_id,
|
||||
operation_type=operation_type,
|
||||
instruction=instruction,
|
||||
prompt_tokens=prompt_tokens,
|
||||
completion_tokens=completion_tokens,
|
||||
total_tokens=total_tokens,
|
||||
model_provider=model_provider,
|
||||
model_name=model_name,
|
||||
latency=latency,
|
||||
total_price=total_price,
|
||||
currency=currency,
|
||||
error=error,
|
||||
)
|
||||
|
||||
def node_execution_trace(self, **kwargs) -> WorkflowNodeTraceInfo | dict:
|
||||
node_data: dict = kwargs.get("node_execution_data", {})
|
||||
if not node_data:
|
||||
return {}
|
||||
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
if is_enterprise_telemetry_enabled():
|
||||
app_name, workspace_name = _lookup_app_and_workspace_names(
|
||||
node_data.get("app_id"), node_data.get("tenant_id")
|
||||
)
|
||||
else:
|
||||
app_name, workspace_name = "", ""
|
||||
|
||||
# Try tool credential lookup first
|
||||
credential_id = node_data.get("credential_id")
|
||||
if is_enterprise_telemetry_enabled():
|
||||
credential_name = _lookup_credential_name(credential_id, node_data.get("credential_provider_type"))
|
||||
# If no credential_id found (e.g., LLM nodes), try LLM credential lookup
|
||||
if not credential_id:
|
||||
llm_cred_id, llm_cred_name = _lookup_llm_credential_info(
|
||||
tenant_id=node_data.get("tenant_id"),
|
||||
provider=node_data.get("model_provider"),
|
||||
model=node_data.get("model_name"),
|
||||
model_type="llm",
|
||||
)
|
||||
if llm_cred_id:
|
||||
credential_id = llm_cred_id
|
||||
credential_name = llm_cred_name
|
||||
else:
|
||||
credential_name = ""
|
||||
metadata: dict[str, Any] = {
|
||||
"tenant_id": node_data.get("tenant_id"),
|
||||
"app_id": node_data.get("app_id"),
|
||||
"app_name": app_name,
|
||||
"workspace_name": workspace_name,
|
||||
"user_id": node_data.get("user_id"),
|
||||
"invoke_from": node_data.get("invoke_from"),
|
||||
"credential_id": node_data.get("credential_id"),
|
||||
"credential_name": credential_name,
|
||||
"dataset_ids": node_data.get("dataset_ids"),
|
||||
"dataset_names": node_data.get("dataset_names"),
|
||||
"plugin_name": node_data.get("plugin_name"),
|
||||
}
|
||||
|
||||
parent_trace_context = node_data.get("parent_trace_context")
|
||||
if parent_trace_context:
|
||||
metadata["parent_trace_context"] = parent_trace_context
|
||||
|
||||
message_id: str | None = None
|
||||
conversation_id = node_data.get("conversation_id")
|
||||
workflow_execution_id = node_data.get("workflow_execution_id")
|
||||
if conversation_id and workflow_execution_id and not parent_trace_context:
|
||||
with Session(db.engine) as session:
|
||||
msg_id = session.scalar(
|
||||
select(Message.id).where(
|
||||
Message.conversation_id == conversation_id,
|
||||
Message.workflow_run_id == workflow_execution_id,
|
||||
)
|
||||
)
|
||||
if msg_id:
|
||||
message_id = str(msg_id)
|
||||
metadata["message_id"] = message_id
|
||||
if conversation_id:
|
||||
metadata["conversation_id"] = conversation_id
|
||||
|
||||
return WorkflowNodeTraceInfo(
|
||||
trace_id=self.trace_id,
|
||||
message_id=message_id,
|
||||
start_time=node_data.get("created_at"),
|
||||
end_time=node_data.get("finished_at"),
|
||||
metadata=metadata,
|
||||
workflow_id=node_data.get("workflow_id", ""),
|
||||
workflow_run_id=node_data.get("workflow_execution_id", ""),
|
||||
tenant_id=node_data.get("tenant_id", ""),
|
||||
node_execution_id=node_data.get("node_execution_id", ""),
|
||||
node_id=node_data.get("node_id", ""),
|
||||
node_type=node_data.get("node_type", ""),
|
||||
title=node_data.get("title", ""),
|
||||
status=node_data.get("status", ""),
|
||||
error=node_data.get("error"),
|
||||
elapsed_time=node_data.get("elapsed_time", 0.0),
|
||||
index=node_data.get("index", 0),
|
||||
predecessor_node_id=node_data.get("predecessor_node_id"),
|
||||
total_tokens=node_data.get("total_tokens", 0),
|
||||
total_price=node_data.get("total_price", 0.0),
|
||||
currency=node_data.get("currency"),
|
||||
model_provider=node_data.get("model_provider"),
|
||||
model_name=node_data.get("model_name"),
|
||||
prompt_tokens=node_data.get("prompt_tokens"),
|
||||
completion_tokens=node_data.get("completion_tokens"),
|
||||
tool_name=node_data.get("tool_name"),
|
||||
iteration_id=node_data.get("iteration_id"),
|
||||
iteration_index=node_data.get("iteration_index"),
|
||||
loop_id=node_data.get("loop_id"),
|
||||
loop_index=node_data.get("loop_index"),
|
||||
parallel_id=node_data.get("parallel_id"),
|
||||
node_inputs=node_data.get("node_inputs"),
|
||||
node_outputs=node_data.get("node_outputs"),
|
||||
process_data=node_data.get("process_data"),
|
||||
invoked_by=self._get_user_id_from_metadata(metadata),
|
||||
)
|
||||
|
||||
def draft_node_execution_trace(self, **kwargs) -> DraftNodeExecutionTrace | dict:
|
||||
node_trace = self.node_execution_trace(**kwargs)
|
||||
if not node_trace or not isinstance(node_trace, WorkflowNodeTraceInfo):
|
||||
return node_trace
|
||||
return DraftNodeExecutionTrace(**node_trace.model_dump())
|
||||
|
||||
def _extract_streaming_metrics(self, message_data) -> dict:
|
||||
if not message_data.message_metadata:
|
||||
return {}
|
||||
@@ -937,13 +1421,17 @@ class TraceQueueManager:
|
||||
self.user_id = user_id
|
||||
self.trace_instance = OpsTraceManager.get_ops_trace_instance(app_id)
|
||||
self.flask_app = current_app._get_current_object() # type: ignore
|
||||
|
||||
from core.telemetry.gateway import is_enterprise_telemetry_enabled
|
||||
|
||||
self._enterprise_telemetry_enabled = is_enterprise_telemetry_enabled()
|
||||
if trace_manager_timer is None:
|
||||
self.start_timer()
|
||||
|
||||
def add_trace_task(self, trace_task: TraceTask):
|
||||
global trace_manager_timer, trace_manager_queue
|
||||
try:
|
||||
if self.trace_instance:
|
||||
if self._enterprise_telemetry_enabled or self.trace_instance:
|
||||
trace_task.app_id = self.app_id
|
||||
trace_manager_queue.put(trace_task)
|
||||
except Exception:
|
||||
@@ -979,20 +1467,27 @@ class TraceQueueManager:
|
||||
def send_to_celery(self, tasks: list[TraceTask]):
|
||||
with self.flask_app.app_context():
|
||||
for task in tasks:
|
||||
if task.app_id is None:
|
||||
continue
|
||||
storage_id = task.app_id
|
||||
if storage_id is None:
|
||||
tenant_id = task.kwargs.get("tenant_id")
|
||||
if tenant_id:
|
||||
storage_id = f"tenant-{tenant_id}"
|
||||
else:
|
||||
logger.warning("Skipping trace without app_id or tenant_id, trace_type: %s", task.trace_type)
|
||||
continue
|
||||
|
||||
file_id = uuid4().hex
|
||||
trace_info = task.execute()
|
||||
|
||||
task_data = TaskData(
|
||||
app_id=task.app_id,
|
||||
app_id=storage_id,
|
||||
trace_info_type=type(trace_info).__name__,
|
||||
trace_info=trace_info.model_dump() if trace_info else None,
|
||||
)
|
||||
file_path = f"{OPS_FILE_PATH}{task.app_id}/{file_id}.json"
|
||||
file_path = f"{OPS_FILE_PATH}{storage_id}/{file_id}.json"
|
||||
storage.save(file_path, task_data.model_dump_json().encode("utf-8"))
|
||||
file_info = {
|
||||
"file_id": file_id,
|
||||
"app_id": task.app_id,
|
||||
"app_id": storage_id,
|
||||
}
|
||||
process_trace_tasks.delay(file_info) # type: ignore
|
||||
|
||||
@@ -5,7 +5,7 @@ from urllib.parse import urlparse
|
||||
|
||||
from sqlalchemy import select
|
||||
|
||||
from models.engine import db
|
||||
from extensions.ext_database import db
|
||||
from models.model import Message
|
||||
|
||||
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import uuid
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Union
|
||||
|
||||
@@ -12,7 +11,6 @@ from core.app.apps.chat.app_generator import ChatAppGenerator
|
||||
from core.app.apps.completion.app_generator import CompletionAppGenerator
|
||||
from core.app.apps.workflow.app_generator import WorkflowAppGenerator
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom
|
||||
from core.app.layers.pause_state_persist_layer import PauseStateLayerConfig
|
||||
from core.plugin.backwards_invocation.base import BaseBackwardsInvocation
|
||||
from extensions.ext_database import db
|
||||
from models import Account
|
||||
@@ -103,11 +101,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=db.engine,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
return AdvancedChatAppGenerator().generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
@@ -119,9 +112,7 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
"conversation_id": conversation_id,
|
||||
},
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
workflow_run_id=str(uuid.uuid4()),
|
||||
streaming=stream,
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
elif app.mode == AppMode.AGENT_CHAT:
|
||||
return AgentChatAppGenerator().generate(
|
||||
@@ -168,11 +159,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
if not workflow:
|
||||
raise ValueError("unexpected app type")
|
||||
|
||||
pause_config = PauseStateLayerConfig(
|
||||
session_factory=db.engine,
|
||||
state_owner_user_id=workflow.created_by,
|
||||
)
|
||||
|
||||
return WorkflowAppGenerator().generate(
|
||||
app_model=app,
|
||||
workflow=workflow,
|
||||
@@ -181,7 +167,6 @@ class PluginAppBackwardsInvocation(BaseBackwardsInvocation):
|
||||
invoke_from=InvokeFrom.SERVICE_API,
|
||||
streaming=stream,
|
||||
call_depth=1,
|
||||
pause_state_config=pause_config,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
|
||||
@@ -1,4 +1,4 @@
|
||||
from pydantic import BaseModel, Field, computed_field, model_validator
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from core.model_runtime.entities.provider_entities import ProviderEntity
|
||||
from core.plugin.entities.endpoint import EndpointProviderDeclaration
|
||||
@@ -48,15 +48,3 @@ class MarketplacePluginDeclaration(BaseModel):
|
||||
if "tool" in data and not data["tool"]:
|
||||
del data["tool"]
|
||||
return data
|
||||
|
||||
|
||||
class MarketplacePluginSnapshot(BaseModel):
|
||||
org: str
|
||||
name: str
|
||||
latest_version: str
|
||||
latest_package_identifier: str
|
||||
latest_package_url: str
|
||||
|
||||
@computed_field
|
||||
def plugin_id(self) -> str:
|
||||
return f"{self.org}/{self.name}"
|
||||
|
||||
@@ -1,15 +1,13 @@
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
from collections import Counter, defaultdict
|
||||
from collections.abc import Generator, Mapping
|
||||
from typing import Any, Union, cast
|
||||
|
||||
from flask import Flask, current_app
|
||||
from sqlalchemy import and_, func, literal, or_, select
|
||||
from sqlalchemy import and_, literal, or_, select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from core.app.app_config.entities import (
|
||||
@@ -20,7 +18,6 @@ from core.app.app_config.entities import (
|
||||
)
|
||||
from core.app.entities.app_invoke_entities import InvokeFrom, ModelConfigWithCredentialsEntity
|
||||
from core.callback_handler.index_tool_callback_handler import DatasetIndexToolCallbackHandler
|
||||
from core.db.session_factory import session_factory
|
||||
from core.entities.agent_entities import PlanningStrategy
|
||||
from core.entities.model_entities import ModelStatus
|
||||
from core.file import File, FileTransferMethod, FileType
|
||||
@@ -30,8 +27,7 @@ from core.model_runtime.entities.llm_entities import LLMResult, LLMUsage
|
||||
from core.model_runtime.entities.message_entities import PromptMessage, PromptMessageRole, PromptMessageTool
|
||||
from core.model_runtime.entities.model_entities import ModelFeature, ModelType
|
||||
from core.model_runtime.model_providers.__base.large_language_model import LargeLanguageModel
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.ops.ops_trace_manager import TraceQueueManager, TraceTask
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
from core.ops.utils import measure_time
|
||||
from core.prompt.advanced_prompt_transform import AdvancedPromptTransform
|
||||
from core.prompt.entities.advanced_prompt_entities import ChatModelMessage, CompletionModelPromptTemplate
|
||||
@@ -59,32 +55,16 @@ from core.rag.retrieval.template_prompts import (
|
||||
METADATA_FILTER_USER_PROMPT_2,
|
||||
METADATA_FILTER_USER_PROMPT_3,
|
||||
)
|
||||
from core.telemetry import TelemetryContext, TelemetryEvent, TraceTaskName
|
||||
from core.telemetry import emit as telemetry_emit
|
||||
from core.tools.signature import sign_upload_file
|
||||
from core.tools.utils.dataset_retriever.dataset_retriever_base_tool import DatasetRetrieverBaseTool
|
||||
from core.workflow.nodes.knowledge_retrieval import exc
|
||||
from core.workflow.repositories.rag_retrieval_protocol import (
|
||||
KnowledgeRetrievalRequest,
|
||||
Source,
|
||||
SourceChildChunk,
|
||||
SourceMetadata,
|
||||
)
|
||||
from extensions.ext_database import db
|
||||
from extensions.ext_redis import redis_client
|
||||
from libs.json_in_md_parser import parse_and_check_json_markdown
|
||||
from models import UploadFile
|
||||
from models.dataset import (
|
||||
ChildChunk,
|
||||
Dataset,
|
||||
DatasetMetadata,
|
||||
DatasetQuery,
|
||||
DocumentSegment,
|
||||
RateLimitLog,
|
||||
SegmentAttachmentBinding,
|
||||
)
|
||||
from models.dataset import ChildChunk, Dataset, DatasetMetadata, DatasetQuery, DocumentSegment, SegmentAttachmentBinding
|
||||
from models.dataset import Document as DatasetDocument
|
||||
from models.dataset import Document as DocumentModel
|
||||
from services.external_knowledge_service import ExternalDatasetService
|
||||
from services.feature_service import FeatureService
|
||||
|
||||
default_retrieval_model: dict[str, Any] = {
|
||||
"search_method": RetrievalMethod.SEMANTIC_SEARCH,
|
||||
@@ -94,8 +74,6 @@ default_retrieval_model: dict[str, Any] = {
|
||||
"score_threshold_enabled": False,
|
||||
}
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class DatasetRetrieval:
|
||||
def __init__(self, application_generate_entity=None):
|
||||
@@ -114,233 +92,6 @@ class DatasetRetrieval:
|
||||
else:
|
||||
self._llm_usage = self._llm_usage.plus(usage)
|
||||
|
||||
def knowledge_retrieval(self, request: KnowledgeRetrievalRequest) -> list[Source]:
|
||||
self._check_knowledge_rate_limit(request.tenant_id)
|
||||
available_datasets = self._get_available_datasets(request.tenant_id, request.dataset_ids)
|
||||
available_datasets_ids = [i.id for i in available_datasets]
|
||||
if not available_datasets_ids:
|
||||
return []
|
||||
|
||||
if not request.query:
|
||||
return []
|
||||
|
||||
metadata_filter_document_ids, metadata_condition = None, None
|
||||
|
||||
if request.metadata_filtering_mode != "disabled":
|
||||
# Convert workflow layer types to app_config layer types
|
||||
if not request.metadata_model_config:
|
||||
raise ValueError("metadata_model_config is required for this method")
|
||||
|
||||
app_metadata_model_config = ModelConfig.model_validate(request.metadata_model_config.model_dump())
|
||||
|
||||
app_metadata_filtering_conditions = None
|
||||
if request.metadata_filtering_conditions is not None:
|
||||
app_metadata_filtering_conditions = MetadataFilteringCondition.model_validate(
|
||||
request.metadata_filtering_conditions.model_dump()
|
||||
)
|
||||
|
||||
query = request.query if request.query is not None else ""
|
||||
|
||||
metadata_filter_document_ids, metadata_condition = self.get_metadata_filter_condition(
|
||||
dataset_ids=available_datasets_ids,
|
||||
query=query,
|
||||
tenant_id=request.tenant_id,
|
||||
user_id=request.user_id,
|
||||
metadata_filtering_mode=request.metadata_filtering_mode,
|
||||
metadata_model_config=app_metadata_model_config,
|
||||
metadata_filtering_conditions=app_metadata_filtering_conditions,
|
||||
inputs={},
|
||||
)
|
||||
|
||||
if request.retrieval_mode == DatasetRetrieveConfigEntity.RetrieveStrategy.SINGLE:
|
||||
planning_strategy = PlanningStrategy.REACT_ROUTER
|
||||
# Ensure required fields are not None for single retrieval mode
|
||||
if request.model_provider is None or request.model_name is None or request.query is None:
|
||||
raise ValueError("model_provider, model_name, and query are required for single retrieval mode")
|
||||
|
||||
model_manager = ModelManager()
|
||||
model_instance = model_manager.get_model_instance(
|
||||
tenant_id=request.tenant_id,
|
||||
model_type=ModelType.LLM,
|
||||
provider=request.model_provider,
|
||||
model=request.model_name,
|
||||
)
|
||||
|
||||
provider_model_bundle = model_instance.provider_model_bundle
|
||||
model_type_instance = model_instance.model_type_instance
|
||||
model_type_instance = cast(LargeLanguageModel, model_type_instance)
|
||||
|
||||
model_credentials = model_instance.credentials
|
||||
|
||||
# check model
|
||||
provider_model = provider_model_bundle.configuration.get_provider_model(
|
||||
model=request.model_name, model_type=ModelType.LLM
|
||||
)
|
||||
|
||||
if provider_model is None:
|
||||
raise exc.ModelNotExistError(f"Model {request.model_name} not exist.")
|
||||
|
||||
if provider_model.status == ModelStatus.NO_CONFIGURE:
|
||||
raise exc.ModelCredentialsNotInitializedError(
|
||||
f"Model {request.model_name} credentials is not initialized."
|
||||
)
|
||||
elif provider_model.status == ModelStatus.NO_PERMISSION:
|
||||
raise exc.ModelNotSupportedError(f"Dify Hosted OpenAI {request.model_name} currently not support.")
|
||||
elif provider_model.status == ModelStatus.QUOTA_EXCEEDED:
|
||||
raise exc.ModelQuotaExceededError(f"Model provider {request.model_provider} quota exceeded.")
|
||||
|
||||
stop = []
|
||||
completion_params = (request.completion_params or {}).copy()
|
||||
if "stop" in completion_params:
|
||||
stop = completion_params["stop"]
|
||||
del completion_params["stop"]
|
||||
|
||||
model_schema = model_type_instance.get_model_schema(request.model_name, model_credentials)
|
||||
|
||||
if not model_schema:
|
||||
raise exc.ModelNotExistError(f"Model {request.model_name} not exist.")
|
||||
|
||||
model_config = ModelConfigWithCredentialsEntity(
|
||||
provider=request.model_provider,
|
||||
model=request.model_name,
|
||||
model_schema=model_schema,
|
||||
mode=request.model_mode or "chat",
|
||||
provider_model_bundle=provider_model_bundle,
|
||||
credentials=model_credentials,
|
||||
parameters=completion_params,
|
||||
stop=stop,
|
||||
)
|
||||
all_documents = self.single_retrieve(
|
||||
request.app_id,
|
||||
request.tenant_id,
|
||||
request.user_id,
|
||||
request.user_from,
|
||||
request.query,
|
||||
available_datasets,
|
||||
model_instance,
|
||||
model_config,
|
||||
planning_strategy,
|
||||
None, # message_id
|
||||
metadata_filter_document_ids,
|
||||
metadata_condition,
|
||||
)
|
||||
else:
|
||||
all_documents = self.multiple_retrieve(
|
||||
app_id=request.app_id,
|
||||
tenant_id=request.tenant_id,
|
||||
user_id=request.user_id,
|
||||
user_from=request.user_from,
|
||||
available_datasets=available_datasets,
|
||||
query=request.query,
|
||||
top_k=request.top_k,
|
||||
score_threshold=request.score_threshold,
|
||||
reranking_mode=request.reranking_mode,
|
||||
reranking_model=request.reranking_model,
|
||||
weights=request.weights,
|
||||
reranking_enable=request.reranking_enable,
|
||||
metadata_filter_document_ids=metadata_filter_document_ids,
|
||||
metadata_condition=metadata_condition,
|
||||
attachment_ids=request.attachment_ids,
|
||||
)
|
||||
|
||||
dify_documents = [item for item in all_documents if item.provider == "dify"]
|
||||
external_documents = [item for item in all_documents if item.provider == "external"]
|
||||
retrieval_resource_list = []
|
||||
# deal with external documents
|
||||
for item in external_documents:
|
||||
source = Source(
|
||||
metadata=SourceMetadata(
|
||||
source="knowledge",
|
||||
dataset_id=item.metadata.get("dataset_id"),
|
||||
dataset_name=item.metadata.get("dataset_name"),
|
||||
document_id=item.metadata.get("document_id"),
|
||||
document_name=item.metadata.get("title"),
|
||||
data_source_type="external",
|
||||
retriever_from="workflow",
|
||||
score=item.metadata.get("score"),
|
||||
doc_metadata=item.metadata,
|
||||
),
|
||||
title=item.metadata.get("title"),
|
||||
content=item.page_content,
|
||||
)
|
||||
retrieval_resource_list.append(source)
|
||||
# deal with dify documents
|
||||
if dify_documents:
|
||||
records = RetrievalService.format_retrieval_documents(dify_documents)
|
||||
dataset_ids = [i.segment.dataset_id for i in records]
|
||||
document_ids = [i.segment.document_id for i in records]
|
||||
|
||||
with session_factory.create_session() as session:
|
||||
datasets = session.query(Dataset).where(Dataset.id.in_(dataset_ids)).all()
|
||||
documents = session.query(DatasetDocument).where(DatasetDocument.id.in_(document_ids)).all()
|
||||
|
||||
dataset_map = {i.id: i for i in datasets}
|
||||
document_map = {i.id: i for i in documents}
|
||||
|
||||
if records:
|
||||
for record in records:
|
||||
segment = record.segment
|
||||
dataset = dataset_map.get(segment.dataset_id)
|
||||
document = document_map.get(segment.document_id)
|
||||
|
||||
if dataset and document:
|
||||
source = Source(
|
||||
metadata=SourceMetadata(
|
||||
source="knowledge",
|
||||
dataset_id=dataset.id,
|
||||
dataset_name=dataset.name,
|
||||
document_id=document.id,
|
||||
document_name=document.name,
|
||||
data_source_type=document.data_source_type,
|
||||
segment_id=segment.id,
|
||||
retriever_from="workflow",
|
||||
score=record.score or 0.0,
|
||||
segment_hit_count=segment.hit_count,
|
||||
segment_word_count=segment.word_count,
|
||||
segment_position=segment.position,
|
||||
segment_index_node_hash=segment.index_node_hash,
|
||||
doc_metadata=document.doc_metadata,
|
||||
child_chunks=[
|
||||
SourceChildChunk(
|
||||
id=str(getattr(chunk, "id", "")),
|
||||
content=str(getattr(chunk, "content", "")),
|
||||
position=int(getattr(chunk, "position", 0)),
|
||||
score=float(getattr(chunk, "score", 0.0)),
|
||||
)
|
||||
for chunk in (record.child_chunks or [])
|
||||
],
|
||||
position=None,
|
||||
),
|
||||
title=document.name,
|
||||
files=list(record.files) if record.files else None,
|
||||
content=segment.get_sign_content(),
|
||||
)
|
||||
if segment.answer:
|
||||
source.content = f"question:{segment.get_sign_content()} \nanswer:{segment.answer}"
|
||||
|
||||
if record.summary:
|
||||
source.summary = record.summary
|
||||
|
||||
retrieval_resource_list.append(source)
|
||||
|
||||
if retrieval_resource_list:
|
||||
|
||||
def _score(item: Source) -> float:
|
||||
meta = item.metadata
|
||||
score = meta.score
|
||||
if isinstance(score, (int, float)):
|
||||
return float(score)
|
||||
return 0.0
|
||||
|
||||
retrieval_resource_list = sorted(
|
||||
retrieval_resource_list,
|
||||
key=_score, # type: ignore[arg-type, return-value]
|
||||
reverse=True,
|
||||
)
|
||||
for position, item in enumerate(retrieval_resource_list, start=1):
|
||||
item.metadata.position = position # type: ignore[index]
|
||||
return retrieval_resource_list
|
||||
|
||||
def retrieve(
|
||||
self,
|
||||
app_id: str,
|
||||
@@ -400,7 +151,14 @@ class DatasetRetrieval:
|
||||
if features:
|
||||
if ModelFeature.TOOL_CALL in features or ModelFeature.MULTI_TOOL_CALL in features:
|
||||
planning_strategy = PlanningStrategy.ROUTER
|
||||
available_datasets = self._get_available_datasets(tenant_id, dataset_ids)
|
||||
available_datasets = []
|
||||
|
||||
dataset_stmt = select(Dataset).where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
|
||||
datasets: list[Dataset] = db.session.execute(dataset_stmt).scalars().all() # type: ignore
|
||||
for dataset in datasets:
|
||||
if dataset.available_document_count == 0 and dataset.provider != "external":
|
||||
continue
|
||||
available_datasets.append(dataset)
|
||||
|
||||
if inputs:
|
||||
inputs = {key: str(value) for key, value in inputs.items()}
|
||||
@@ -971,10 +729,21 @@ class DatasetRetrieval:
|
||||
self.application_generate_entity.trace_manager if self.application_generate_entity else None
|
||||
)
|
||||
if trace_manager:
|
||||
trace_manager.add_trace_task(
|
||||
TraceTask(
|
||||
TraceTaskName.DATASET_RETRIEVAL_TRACE, message_id=message_id, documents=documents, timer=timer
|
||||
)
|
||||
app_config = self.application_generate_entity.app_config if self.application_generate_entity else None
|
||||
telemetry_emit(
|
||||
TelemetryEvent(
|
||||
name=TraceTaskName.DATASET_RETRIEVAL_TRACE,
|
||||
context=TelemetryContext(
|
||||
tenant_id=app_config.tenant_id if app_config else None,
|
||||
app_id=app_config.app_id if app_config else None,
|
||||
),
|
||||
payload={
|
||||
"message_id": message_id,
|
||||
"documents": documents,
|
||||
"timer": timer,
|
||||
},
|
||||
),
|
||||
trace_manager=trace_manager,
|
||||
)
|
||||
|
||||
def _on_query(
|
||||
@@ -1404,6 +1173,7 @@ class DatasetRetrieval:
|
||||
query=query or "",
|
||||
)
|
||||
|
||||
result_text = ""
|
||||
try:
|
||||
# handle invoke result
|
||||
invoke_result = cast(
|
||||
@@ -1434,8 +1204,7 @@ class DatasetRetrieval:
|
||||
"condition": item.get("comparison_operator"),
|
||||
}
|
||||
)
|
||||
except Exception as e:
|
||||
logger.warning(e, exc_info=True)
|
||||
except Exception:
|
||||
return None
|
||||
return automatic_metadata_filters
|
||||
|
||||
@@ -1649,12 +1418,7 @@ class DatasetRetrieval:
|
||||
usage = None
|
||||
for result in invoke_result:
|
||||
text = result.delta.message.content
|
||||
if isinstance(text, str):
|
||||
full_text += text
|
||||
elif isinstance(text, list):
|
||||
for i in text:
|
||||
if i.data:
|
||||
full_text += i.data
|
||||
full_text += text
|
||||
|
||||
if not model:
|
||||
model = result.model
|
||||
@@ -1772,53 +1536,3 @@ class DatasetRetrieval:
|
||||
cancel_event.set()
|
||||
if thread_exceptions is not None:
|
||||
thread_exceptions.append(e)
|
||||
|
||||
def _get_available_datasets(self, tenant_id: str, dataset_ids: list[str]) -> list[Dataset]:
|
||||
with session_factory.create_session() as session:
|
||||
subquery = (
|
||||
session.query(DocumentModel.dataset_id, func.count(DocumentModel.id).label("available_document_count"))
|
||||
.where(
|
||||
DocumentModel.indexing_status == "completed",
|
||||
DocumentModel.enabled == True,
|
||||
DocumentModel.archived == False,
|
||||
DocumentModel.dataset_id.in_(dataset_ids),
|
||||
)
|
||||
.group_by(DocumentModel.dataset_id)
|
||||
.having(func.count(DocumentModel.id) > 0)
|
||||
.subquery()
|
||||
)
|
||||
|
||||
results = (
|
||||
session.query(Dataset)
|
||||
.outerjoin(subquery, Dataset.id == subquery.c.dataset_id)
|
||||
.where(Dataset.tenant_id == tenant_id, Dataset.id.in_(dataset_ids))
|
||||
.where((subquery.c.available_document_count > 0) | (Dataset.provider == "external"))
|
||||
.all()
|
||||
)
|
||||
|
||||
available_datasets = []
|
||||
for dataset in results:
|
||||
if not dataset:
|
||||
continue
|
||||
available_datasets.append(dataset)
|
||||
return available_datasets
|
||||
|
||||
def _check_knowledge_rate_limit(self, tenant_id: str):
|
||||
knowledge_rate_limit = FeatureService.get_knowledge_rate_limit(tenant_id)
|
||||
if knowledge_rate_limit.enabled:
|
||||
current_time = int(time.time() * 1000)
|
||||
key = f"rate_limit_{tenant_id}"
|
||||
redis_client.zadd(key, {current_time: current_time})
|
||||
redis_client.zremrangebyscore(key, 0, current_time - 60000)
|
||||
request_count = redis_client.zcard(key)
|
||||
if request_count > knowledge_rate_limit.limit:
|
||||
with session_factory.create_session() as session:
|
||||
rate_limit_log = RateLimitLog(
|
||||
tenant_id=tenant_id,
|
||||
subscription_plan=knowledge_rate_limit.subscription_plan,
|
||||
operation="knowledge",
|
||||
)
|
||||
session.add(rate_limit_log)
|
||||
raise exc.RateLimitExceededError(
|
||||
"you have reached the knowledge base request rate limit of your subscription."
|
||||
)
|
||||
|
||||
@@ -1,18 +1,19 @@
|
||||
"""Repository implementations for data access."""
|
||||
"""
|
||||
Repository implementations for data access.
|
||||
|
||||
from __future__ import annotations
|
||||
This package contains concrete implementations of the repository interfaces
|
||||
defined in the core.workflow.repository package.
|
||||
"""
|
||||
|
||||
from .celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
||||
from .celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
|
||||
from .factory import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from .sqlalchemy_workflow_execution_repository import SQLAlchemyWorkflowExecutionRepository
|
||||
from .sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
from core.repositories.celery_workflow_execution_repository import CeleryWorkflowExecutionRepository
|
||||
from core.repositories.celery_workflow_node_execution_repository import CeleryWorkflowNodeExecutionRepository
|
||||
from core.repositories.factory import DifyCoreRepositoryFactory, RepositoryImportError
|
||||
from core.repositories.sqlalchemy_workflow_node_execution_repository import SQLAlchemyWorkflowNodeExecutionRepository
|
||||
|
||||
__all__ = [
|
||||
"CeleryWorkflowExecutionRepository",
|
||||
"CeleryWorkflowNodeExecutionRepository",
|
||||
"DifyCoreRepositoryFactory",
|
||||
"RepositoryImportError",
|
||||
"SQLAlchemyWorkflowExecutionRepository",
|
||||
"SQLAlchemyWorkflowNodeExecutionRepository",
|
||||
]
|
||||
|
||||
@@ -1,553 +0,0 @@
|
||||
import dataclasses
|
||||
import json
|
||||
from collections.abc import Mapping, Sequence
|
||||
from datetime import datetime
|
||||
from typing import Any
|
||||
|
||||
from sqlalchemy import Engine, select
|
||||
from sqlalchemy.orm import Session, selectinload, sessionmaker
|
||||
|
||||
from core.workflow.nodes.human_input.entities import (
|
||||
DeliveryChannelConfig,
|
||||
EmailDeliveryMethod,
|
||||
EmailRecipients,
|
||||
ExternalRecipient,
|
||||
FormDefinition,
|
||||
HumanInputNodeData,
|
||||
MemberRecipient,
|
||||
WebAppDeliveryMethod,
|
||||
)
|
||||
from core.workflow.nodes.human_input.enums import (
|
||||
DeliveryMethodType,
|
||||
HumanInputFormKind,
|
||||
HumanInputFormStatus,
|
||||
)
|
||||
from core.workflow.repositories.human_input_form_repository import (
|
||||
FormCreateParams,
|
||||
FormNotFoundError,
|
||||
HumanInputFormEntity,
|
||||
HumanInputFormRecipientEntity,
|
||||
)
|
||||
from libs.datetime_utils import naive_utc_now
|
||||
from libs.uuid_utils import uuidv7
|
||||
from models.account import Account, TenantAccountJoin
|
||||
from models.human_input import (
|
||||
BackstageRecipientPayload,
|
||||
ConsoleDeliveryPayload,
|
||||
ConsoleRecipientPayload,
|
||||
EmailExternalRecipientPayload,
|
||||
EmailMemberRecipientPayload,
|
||||
HumanInputDelivery,
|
||||
HumanInputForm,
|
||||
HumanInputFormRecipient,
|
||||
RecipientType,
|
||||
StandaloneWebAppRecipientPayload,
|
||||
)
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _DeliveryAndRecipients:
|
||||
delivery: HumanInputDelivery
|
||||
recipients: Sequence[HumanInputFormRecipient]
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class _WorkspaceMemberInfo:
|
||||
user_id: str
|
||||
email: str
|
||||
|
||||
|
||||
class _HumanInputFormRecipientEntityImpl(HumanInputFormRecipientEntity):
|
||||
def __init__(self, recipient_model: HumanInputFormRecipient):
|
||||
self._recipient_model = recipient_model
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._recipient_model.id
|
||||
|
||||
@property
|
||||
def token(self) -> str:
|
||||
if self._recipient_model.access_token is None:
|
||||
raise AssertionError(f"access_token should not be None for recipient {self._recipient_model.id}")
|
||||
return self._recipient_model.access_token
|
||||
|
||||
|
||||
class _HumanInputFormEntityImpl(HumanInputFormEntity):
|
||||
def __init__(self, form_model: HumanInputForm, recipient_models: Sequence[HumanInputFormRecipient]):
|
||||
self._form_model = form_model
|
||||
self._recipients = [_HumanInputFormRecipientEntityImpl(recipient) for recipient in recipient_models]
|
||||
self._web_app_recipient = next(
|
||||
(
|
||||
recipient
|
||||
for recipient in recipient_models
|
||||
if recipient.recipient_type == RecipientType.STANDALONE_WEB_APP
|
||||
),
|
||||
None,
|
||||
)
|
||||
self._console_recipient = next(
|
||||
(recipient for recipient in recipient_models if recipient.recipient_type == RecipientType.CONSOLE),
|
||||
None,
|
||||
)
|
||||
self._submitted_data: Mapping[str, Any] | None = (
|
||||
json.loads(form_model.submitted_data) if form_model.submitted_data is not None else None
|
||||
)
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._form_model.id
|
||||
|
||||
@property
|
||||
def web_app_token(self):
|
||||
if self._console_recipient is not None:
|
||||
return self._console_recipient.access_token
|
||||
if self._web_app_recipient is None:
|
||||
return None
|
||||
return self._web_app_recipient.access_token
|
||||
|
||||
@property
|
||||
def recipients(self) -> list[HumanInputFormRecipientEntity]:
|
||||
return list(self._recipients)
|
||||
|
||||
@property
|
||||
def rendered_content(self) -> str:
|
||||
return self._form_model.rendered_content
|
||||
|
||||
@property
|
||||
def selected_action_id(self) -> str | None:
|
||||
return self._form_model.selected_action_id
|
||||
|
||||
@property
|
||||
def submitted_data(self) -> Mapping[str, Any] | None:
|
||||
return self._submitted_data
|
||||
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
return self._form_model.submitted_at is not None
|
||||
|
||||
@property
|
||||
def status(self) -> HumanInputFormStatus:
|
||||
return self._form_model.status
|
||||
|
||||
@property
|
||||
def expiration_time(self) -> datetime:
|
||||
return self._form_model.expiration_time
|
||||
|
||||
|
||||
@dataclasses.dataclass(frozen=True)
|
||||
class HumanInputFormRecord:
|
||||
form_id: str
|
||||
workflow_run_id: str | None
|
||||
node_id: str
|
||||
tenant_id: str
|
||||
app_id: str
|
||||
form_kind: HumanInputFormKind
|
||||
definition: FormDefinition
|
||||
rendered_content: str
|
||||
created_at: datetime
|
||||
expiration_time: datetime
|
||||
status: HumanInputFormStatus
|
||||
selected_action_id: str | None
|
||||
submitted_data: Mapping[str, Any] | None
|
||||
submitted_at: datetime | None
|
||||
submission_user_id: str | None
|
||||
submission_end_user_id: str | None
|
||||
completed_by_recipient_id: str | None
|
||||
recipient_id: str | None
|
||||
recipient_type: RecipientType | None
|
||||
access_token: str | None
|
||||
|
||||
@property
|
||||
def submitted(self) -> bool:
|
||||
return self.submitted_at is not None
|
||||
|
||||
@classmethod
|
||||
def from_models(
|
||||
cls, form_model: HumanInputForm, recipient_model: HumanInputFormRecipient | None
|
||||
) -> "HumanInputFormRecord":
|
||||
definition_payload = json.loads(form_model.form_definition)
|
||||
if "expiration_time" not in definition_payload:
|
||||
definition_payload["expiration_time"] = form_model.expiration_time
|
||||
return cls(
|
||||
form_id=form_model.id,
|
||||
workflow_run_id=form_model.workflow_run_id,
|
||||
node_id=form_model.node_id,
|
||||
tenant_id=form_model.tenant_id,
|
||||
app_id=form_model.app_id,
|
||||
form_kind=form_model.form_kind,
|
||||
definition=FormDefinition.model_validate(definition_payload),
|
||||
rendered_content=form_model.rendered_content,
|
||||
created_at=form_model.created_at,
|
||||
expiration_time=form_model.expiration_time,
|
||||
status=form_model.status,
|
||||
selected_action_id=form_model.selected_action_id,
|
||||
submitted_data=json.loads(form_model.submitted_data) if form_model.submitted_data else None,
|
||||
submitted_at=form_model.submitted_at,
|
||||
submission_user_id=form_model.submission_user_id,
|
||||
submission_end_user_id=form_model.submission_end_user_id,
|
||||
completed_by_recipient_id=form_model.completed_by_recipient_id,
|
||||
recipient_id=recipient_model.id if recipient_model else None,
|
||||
recipient_type=recipient_model.recipient_type if recipient_model else None,
|
||||
access_token=recipient_model.access_token if recipient_model else None,
|
||||
)
|
||||
|
||||
|
||||
class _InvalidTimeoutStatusError(ValueError):
|
||||
pass
|
||||
|
||||
|
||||
class HumanInputFormRepositoryImpl:
|
||||
def __init__(
|
||||
self,
|
||||
session_factory: sessionmaker | Engine,
|
||||
tenant_id: str,
|
||||
):
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
self._tenant_id = tenant_id
|
||||
|
||||
def _delivery_method_to_model(
|
||||
self,
|
||||
session: Session,
|
||||
form_id: str,
|
||||
delivery_method: DeliveryChannelConfig,
|
||||
) -> _DeliveryAndRecipients:
|
||||
delivery_id = str(uuidv7())
|
||||
delivery_model = HumanInputDelivery(
|
||||
id=delivery_id,
|
||||
form_id=form_id,
|
||||
delivery_method_type=delivery_method.type,
|
||||
delivery_config_id=delivery_method.id,
|
||||
channel_payload=delivery_method.model_dump_json(),
|
||||
)
|
||||
recipients: list[HumanInputFormRecipient] = []
|
||||
if isinstance(delivery_method, WebAppDeliveryMethod):
|
||||
recipient_model = HumanInputFormRecipient(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipient_type=RecipientType.STANDALONE_WEB_APP,
|
||||
recipient_payload=StandaloneWebAppRecipientPayload().model_dump_json(),
|
||||
)
|
||||
recipients.append(recipient_model)
|
||||
elif isinstance(delivery_method, EmailDeliveryMethod):
|
||||
email_recipients_config = delivery_method.config.recipients
|
||||
recipients.extend(
|
||||
self._build_email_recipients(
|
||||
session=session,
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
recipients_config=email_recipients_config,
|
||||
)
|
||||
)
|
||||
|
||||
return _DeliveryAndRecipients(delivery=delivery_model, recipients=recipients)
|
||||
|
||||
def _build_email_recipients(
|
||||
self,
|
||||
session: Session,
|
||||
form_id: str,
|
||||
delivery_id: str,
|
||||
recipients_config: EmailRecipients,
|
||||
) -> list[HumanInputFormRecipient]:
|
||||
member_user_ids = [
|
||||
recipient.user_id for recipient in recipients_config.items if isinstance(recipient, MemberRecipient)
|
||||
]
|
||||
external_emails = [
|
||||
recipient.email for recipient in recipients_config.items if isinstance(recipient, ExternalRecipient)
|
||||
]
|
||||
if recipients_config.whole_workspace:
|
||||
members = self._query_all_workspace_members(session=session)
|
||||
else:
|
||||
members = self._query_workspace_members_by_ids(session=session, restrict_to_user_ids=member_user_ids)
|
||||
|
||||
return self._create_email_recipients_from_resolved(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
members=members,
|
||||
external_emails=external_emails,
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _create_email_recipients_from_resolved(
|
||||
*,
|
||||
form_id: str,
|
||||
delivery_id: str,
|
||||
members: Sequence[_WorkspaceMemberInfo],
|
||||
external_emails: Sequence[str],
|
||||
) -> list[HumanInputFormRecipient]:
|
||||
recipient_models: list[HumanInputFormRecipient] = []
|
||||
seen_emails: set[str] = set()
|
||||
|
||||
for member in members:
|
||||
if not member.email:
|
||||
continue
|
||||
if member.email in seen_emails:
|
||||
continue
|
||||
seen_emails.add(member.email)
|
||||
payload = EmailMemberRecipientPayload(user_id=member.user_id, email=member.email)
|
||||
recipient_models.append(
|
||||
HumanInputFormRecipient.new(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
payload=payload,
|
||||
)
|
||||
)
|
||||
|
||||
for email in external_emails:
|
||||
if not email:
|
||||
continue
|
||||
if email in seen_emails:
|
||||
continue
|
||||
seen_emails.add(email)
|
||||
recipient_models.append(
|
||||
HumanInputFormRecipient.new(
|
||||
form_id=form_id,
|
||||
delivery_id=delivery_id,
|
||||
payload=EmailExternalRecipientPayload(email=email),
|
||||
)
|
||||
)
|
||||
|
||||
return recipient_models
|
||||
|
||||
def _query_all_workspace_members(
|
||||
self,
|
||||
session: Session,
|
||||
) -> list[_WorkspaceMemberInfo]:
|
||||
stmt = (
|
||||
select(Account.id, Account.email)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
|
||||
.where(TenantAccountJoin.tenant_id == self._tenant_id)
|
||||
)
|
||||
rows = session.execute(stmt).all()
|
||||
return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows]
|
||||
|
||||
def _query_workspace_members_by_ids(
|
||||
self,
|
||||
session: Session,
|
||||
restrict_to_user_ids: Sequence[str],
|
||||
) -> list[_WorkspaceMemberInfo]:
|
||||
unique_ids = {user_id for user_id in restrict_to_user_ids if user_id}
|
||||
if not unique_ids:
|
||||
return []
|
||||
|
||||
stmt = (
|
||||
select(Account.id, Account.email)
|
||||
.join(TenantAccountJoin, TenantAccountJoin.account_id == Account.id)
|
||||
.where(TenantAccountJoin.tenant_id == self._tenant_id)
|
||||
)
|
||||
stmt = stmt.where(Account.id.in_(unique_ids))
|
||||
|
||||
rows = session.execute(stmt).all()
|
||||
return [_WorkspaceMemberInfo(user_id=account_id, email=email) for account_id, email in rows]
|
||||
|
||||
def create_form(self, params: FormCreateParams) -> HumanInputFormEntity:
|
||||
form_config: HumanInputNodeData = params.form_config
|
||||
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
# Generate unique form ID
|
||||
form_id = str(uuidv7())
|
||||
start_time = naive_utc_now()
|
||||
node_expiration = form_config.expiration_time(start_time)
|
||||
form_definition = FormDefinition(
|
||||
form_content=form_config.form_content,
|
||||
inputs=form_config.inputs,
|
||||
user_actions=form_config.user_actions,
|
||||
rendered_content=params.rendered_content,
|
||||
expiration_time=node_expiration,
|
||||
default_values=dict(params.resolved_default_values),
|
||||
display_in_ui=params.display_in_ui,
|
||||
node_title=form_config.title,
|
||||
)
|
||||
form_model = HumanInputForm(
|
||||
id=form_id,
|
||||
tenant_id=self._tenant_id,
|
||||
app_id=params.app_id,
|
||||
workflow_run_id=params.workflow_execution_id,
|
||||
form_kind=params.form_kind,
|
||||
node_id=params.node_id,
|
||||
form_definition=form_definition.model_dump_json(),
|
||||
rendered_content=params.rendered_content,
|
||||
expiration_time=node_expiration,
|
||||
created_at=start_time,
|
||||
)
|
||||
session.add(form_model)
|
||||
recipient_models: list[HumanInputFormRecipient] = []
|
||||
for delivery in params.delivery_methods:
|
||||
delivery_and_recipients = self._delivery_method_to_model(
|
||||
session=session,
|
||||
form_id=form_id,
|
||||
delivery_method=delivery,
|
||||
)
|
||||
session.add(delivery_and_recipients.delivery)
|
||||
session.add_all(delivery_and_recipients.recipients)
|
||||
recipient_models.extend(delivery_and_recipients.recipients)
|
||||
if params.console_recipient_required and not any(
|
||||
recipient.recipient_type == RecipientType.CONSOLE for recipient in recipient_models
|
||||
):
|
||||
console_delivery_id = str(uuidv7())
|
||||
console_delivery = HumanInputDelivery(
|
||||
id=console_delivery_id,
|
||||
form_id=form_id,
|
||||
delivery_method_type=DeliveryMethodType.WEBAPP,
|
||||
delivery_config_id=None,
|
||||
channel_payload=ConsoleDeliveryPayload().model_dump_json(),
|
||||
)
|
||||
console_recipient = HumanInputFormRecipient(
|
||||
form_id=form_id,
|
||||
delivery_id=console_delivery_id,
|
||||
recipient_type=RecipientType.CONSOLE,
|
||||
recipient_payload=ConsoleRecipientPayload(
|
||||
account_id=params.console_creator_account_id,
|
||||
).model_dump_json(),
|
||||
)
|
||||
session.add(console_delivery)
|
||||
session.add(console_recipient)
|
||||
recipient_models.append(console_recipient)
|
||||
if params.backstage_recipient_required and not any(
|
||||
recipient.recipient_type == RecipientType.BACKSTAGE for recipient in recipient_models
|
||||
):
|
||||
backstage_delivery_id = str(uuidv7())
|
||||
backstage_delivery = HumanInputDelivery(
|
||||
id=backstage_delivery_id,
|
||||
form_id=form_id,
|
||||
delivery_method_type=DeliveryMethodType.WEBAPP,
|
||||
delivery_config_id=None,
|
||||
channel_payload=ConsoleDeliveryPayload().model_dump_json(),
|
||||
)
|
||||
backstage_recipient = HumanInputFormRecipient(
|
||||
form_id=form_id,
|
||||
delivery_id=backstage_delivery_id,
|
||||
recipient_type=RecipientType.BACKSTAGE,
|
||||
recipient_payload=BackstageRecipientPayload(
|
||||
account_id=params.console_creator_account_id,
|
||||
).model_dump_json(),
|
||||
)
|
||||
session.add(backstage_delivery)
|
||||
session.add(backstage_recipient)
|
||||
recipient_models.append(backstage_recipient)
|
||||
session.flush()
|
||||
|
||||
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
|
||||
|
||||
def get_form(self, workflow_execution_id: str, node_id: str) -> HumanInputFormEntity | None:
|
||||
form_query = select(HumanInputForm).where(
|
||||
HumanInputForm.workflow_run_id == workflow_execution_id,
|
||||
HumanInputForm.node_id == node_id,
|
||||
HumanInputForm.tenant_id == self._tenant_id,
|
||||
)
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
form_model: HumanInputForm | None = session.scalars(form_query).first()
|
||||
if form_model is None:
|
||||
return None
|
||||
|
||||
recipient_query = select(HumanInputFormRecipient).where(HumanInputFormRecipient.form_id == form_model.id)
|
||||
recipient_models = session.scalars(recipient_query).all()
|
||||
return _HumanInputFormEntityImpl(form_model=form_model, recipient_models=recipient_models)
|
||||
|
||||
|
||||
class HumanInputFormSubmissionRepository:
|
||||
"""Repository for fetching and submitting human input forms."""
|
||||
|
||||
def __init__(self, session_factory: sessionmaker | Engine):
|
||||
if isinstance(session_factory, Engine):
|
||||
session_factory = sessionmaker(bind=session_factory)
|
||||
self._session_factory = session_factory
|
||||
|
||||
def get_by_token(self, form_token: str) -> HumanInputFormRecord | None:
|
||||
query = (
|
||||
select(HumanInputFormRecipient)
|
||||
.options(selectinload(HumanInputFormRecipient.form))
|
||||
.where(HumanInputFormRecipient.access_token == form_token)
|
||||
)
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
recipient_model = session.scalars(query).first()
|
||||
if recipient_model is None or recipient_model.form is None:
|
||||
return None
|
||||
return HumanInputFormRecord.from_models(recipient_model.form, recipient_model)
|
||||
|
||||
def get_by_form_id_and_recipient_type(
|
||||
self,
|
||||
form_id: str,
|
||||
recipient_type: RecipientType,
|
||||
) -> HumanInputFormRecord | None:
|
||||
query = (
|
||||
select(HumanInputFormRecipient)
|
||||
.options(selectinload(HumanInputFormRecipient.form))
|
||||
.where(
|
||||
HumanInputFormRecipient.form_id == form_id,
|
||||
HumanInputFormRecipient.recipient_type == recipient_type,
|
||||
)
|
||||
)
|
||||
with self._session_factory(expire_on_commit=False) as session:
|
||||
recipient_model = session.scalars(query).first()
|
||||
if recipient_model is None or recipient_model.form is None:
|
||||
return None
|
||||
return HumanInputFormRecord.from_models(recipient_model.form, recipient_model)
|
||||
|
||||
def mark_submitted(
|
||||
self,
|
||||
*,
|
||||
form_id: str,
|
||||
recipient_id: str | None,
|
||||
selected_action_id: str,
|
||||
form_data: Mapping[str, Any],
|
||||
submission_user_id: str | None,
|
||||
submission_end_user_id: str | None,
|
||||
) -> HumanInputFormRecord:
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
form_model = session.get(HumanInputForm, form_id)
|
||||
if form_model is None:
|
||||
raise FormNotFoundError(f"form not found, id={form_id}")
|
||||
|
||||
recipient_model = session.get(HumanInputFormRecipient, recipient_id) if recipient_id else None
|
||||
|
||||
form_model.selected_action_id = selected_action_id
|
||||
form_model.submitted_data = json.dumps(form_data)
|
||||
form_model.submitted_at = naive_utc_now()
|
||||
form_model.status = HumanInputFormStatus.SUBMITTED
|
||||
form_model.submission_user_id = submission_user_id
|
||||
form_model.submission_end_user_id = submission_end_user_id
|
||||
form_model.completed_by_recipient_id = recipient_id
|
||||
|
||||
session.add(form_model)
|
||||
session.flush()
|
||||
session.refresh(form_model)
|
||||
if recipient_model is not None:
|
||||
session.refresh(recipient_model)
|
||||
|
||||
return HumanInputFormRecord.from_models(form_model, recipient_model)
|
||||
|
||||
def mark_timeout(
|
||||
self,
|
||||
*,
|
||||
form_id: str,
|
||||
timeout_status: HumanInputFormStatus,
|
||||
reason: str | None = None,
|
||||
) -> HumanInputFormRecord:
|
||||
with self._session_factory(expire_on_commit=False) as session, session.begin():
|
||||
form_model = session.get(HumanInputForm, form_id)
|
||||
if form_model is None:
|
||||
raise FormNotFoundError(f"form not found, id={form_id}")
|
||||
|
||||
if timeout_status not in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}:
|
||||
raise _InvalidTimeoutStatusError(f"invalid timeout status: {timeout_status}")
|
||||
|
||||
# already handled or submitted
|
||||
if form_model.status in {HumanInputFormStatus.TIMEOUT, HumanInputFormStatus.EXPIRED}:
|
||||
return HumanInputFormRecord.from_models(form_model, None)
|
||||
|
||||
if form_model.submitted_at is not None or form_model.status == HumanInputFormStatus.SUBMITTED:
|
||||
raise FormNotFoundError(f"form already submitted, id={form_id}")
|
||||
|
||||
form_model.status = timeout_status
|
||||
form_model.selected_action_id = None
|
||||
form_model.submitted_data = None
|
||||
form_model.submission_user_id = None
|
||||
form_model.submission_end_user_id = None
|
||||
form_model.completed_by_recipient_id = None
|
||||
# Reason is recorded in status/error downstream; not stored on form.
|
||||
session.add(form_model)
|
||||
session.flush()
|
||||
session.refresh(form_model)
|
||||
|
||||
return HumanInputFormRecord.from_models(form_model, None)
|
||||
@@ -488,7 +488,6 @@ class SQLAlchemyWorkflowNodeExecutionRepository(WorkflowNodeExecutionRepository)
|
||||
WorkflowNodeExecutionModel.workflow_run_id == workflow_run_id,
|
||||
WorkflowNodeExecutionModel.tenant_id == self._tenant_id,
|
||||
WorkflowNodeExecutionModel.triggered_from == triggered_from,
|
||||
WorkflowNodeExecutionModel.status != WorkflowNodeExecutionStatus.PAUSED,
|
||||
)
|
||||
|
||||
if self._app_id:
|
||||
|
||||
43
api/core/telemetry/__init__.py
Normal file
43
api/core/telemetry/__init__.py
Normal file
@@ -0,0 +1,43 @@
|
||||
"""Telemetry facade.
|
||||
|
||||
Thin public API for emitting telemetry events. All routing logic
|
||||
lives in ``core.telemetry.gateway`` which is shared by both CE and EE.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
from core.telemetry.events import TelemetryContext, TelemetryEvent
|
||||
from core.telemetry.gateway import emit as gateway_emit
|
||||
from core.telemetry.gateway import get_trace_task_to_case
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.ops_trace_manager import TraceQueueManager
|
||||
|
||||
|
||||
def emit(event: TelemetryEvent, trace_manager: TraceQueueManager | None = None) -> None:
|
||||
"""Emit a telemetry event.
|
||||
|
||||
Translates the ``TelemetryEvent`` (keyed by ``TraceTaskName``) into a
|
||||
``TelemetryCase`` and delegates to ``core.telemetry.gateway.emit()``.
|
||||
"""
|
||||
case = get_trace_task_to_case().get(event.name)
|
||||
if case is None:
|
||||
return
|
||||
|
||||
context: dict[str, object] = {
|
||||
"tenant_id": event.context.tenant_id,
|
||||
"user_id": event.context.user_id,
|
||||
"app_id": event.context.app_id,
|
||||
}
|
||||
gateway_emit(case, context, event.payload, trace_manager)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"TelemetryContext",
|
||||
"TelemetryEvent",
|
||||
"TraceTaskName",
|
||||
"emit",
|
||||
]
|
||||
21
api/core/telemetry/events.py
Normal file
21
api/core/telemetry/events.py
Normal file
@@ -0,0 +1,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from core.ops.entities.trace_entity import TraceTaskName
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TelemetryContext:
|
||||
tenant_id: str | None = None
|
||||
user_id: str | None = None
|
||||
app_id: str | None = None
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
class TelemetryEvent:
|
||||
name: TraceTaskName
|
||||
context: TelemetryContext
|
||||
payload: dict[str, Any]
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user